Fastapi, async SQLAlchemy, pytest, and Alembic (all using asyncpg)

Introduction

In this article, we'll explore the integration of FastAPI with the new asynchronous SQLAlchemy 2.0. Additionally, we'll delve into configuring pytest to execute asynchronous tests, allowing compatibility with pytest-xdist. We'll also cover the application of Alembic for db migrations with an asynchronous database driver.

The inspiration for this sparked as I delved into the insightful book "Architecture Patterns with Python" authored by Harry Percival & Bob Gregory.

I also encountered the captivating concept of the "Stairway test," which is eloquently detailed in the repository by https://github.com/alvassin/alembic-quickstart/tree/master. This concept profoundly resonated with me and led me to formulate the ideas presented in this post.

Reqirements

I run this project using Python 3.9, probably you can easily adapt it to work on earlier versions.

I use poetry to manage project requirements.

Source code can be found here

Install dependencies

$ poetry add fastapi uvicorn uvloop asyncpg alembic pydantic-settings
$ poetry add sqlalchemy --extras asyncio

Install dev dependencies

$ poetry add --group=dev httpx sqlalchemy-utils pytest yarl mypy black isort

Setting up the database

We're going to use FastAPI to create a straightforward API designed for user creation and retrieval from a database. Our primary objective is to illustrate the synergy between SQLAlchemy 2.0 and FastAPI, thus the API intricacies won't be our focal point in this context.

Let's start by creating a configuration file that will hold our database connection string. In my preference, I opt for leveraging pydantic-settings in scenarios of this nature. However, feel free to utilize any other method such as os.getenv if that aligns better with your workflow.

For the sake of clarity, I have encapsulated the entire database URI within a single parameter. It's important to note that in a real-world scenario, such configuration settings would likely be segregated into discrete entities like db_host, db_port, db_user, db_password, and more.

# app/settings.py
from pathlib import Path

from pydantic_settings import BaseSettings, SettingsConfigDict


class Settings(BaseSettings):
    app_name: str = "Example API"
    app_host: str = "0.0.0.0"
    app_port: int = 3000

    database_url: str = "postgresql+asyncpg://blog_example_user:password@localhost:5432/blog_example_base"

    project_root: Path = Path(__file__).parent.parent.resolve()

    model_config = SettingsConfigDict(env_file=".env")


settings = Settings()

I'm not going to cover how to start a local PostgreSQL database in this post, but you can, for example use the official docker image to start a local database.

To be able to run tests your database user should have CREATEDB privilege.

Example SQL commands to create a new user with a new database.

CREATE USER "blog_example_user" WITH PASSWORD 'password';
CREATE DATABASE "blog_example_base" OWNER "blog_example_user";
ALTER USER "blog_example_user" CREATEDB;

Let's create our ORM model

I prefer to put all orm-related stuff into an orm module and use its __init__.py

So in code i use it like import orm query = select(orm.User)...

This way it's much easier to distinguish between SQLAlchemy models and my business models.

So, first let's add the base class

# orm/base_model.py
from sqlalchemy import MetaData
from sqlalchemy.orm import DeclarativeBase

# Default naming convention for all indexes and constraints
# See why this is important and how it would save your time:
# https://alembic.sqlalchemy.org/en/latest/naming.html
convention = {
    "all_column_names": lambda constraint, table: "_".join(
        [column.name for column in constraint.columns.values()]
    ),
    "ix": "ix__%(table_name)s__%(all_column_names)s",
    "uq": "uq__%(table_name)s__%(all_column_names)s",
    "ck": "ck__%(table_name)s__%(constraint_name)s",
    "fk": "fk__%(table_name)s__%(all_column_names)s__%(referred_table_name)s",
    "pk": "pk__%(table_name)s",
}


class OrmBase(DeclarativeBase):
    metadata = MetaData(naming_convention=convention)  # type: ignore


Then, let's create a session manager for our database. This class will be used as a singleton and will be responsible for abstracting the database connection and session handling:

# orm/session_manager.py
import contextlib
from typing import AsyncIterator, Optional

from sqlalchemy.ext.asyncio import (
    AsyncConnection,
    AsyncEngine,
    AsyncSession,
    async_sessionmaker,
    create_async_engine,
)


class DatabaseSessionManager:
    def __init__(self) -> None:
        self._engine: Optional[AsyncEngine] = None
        self._sessionmaker: Optional[async_sessionmaker[AsyncSession]] = None

    def init(self, db_url: str) -> None:
        # Just additional example of customization.
        # you can add parameters to init and so on
        if "postgresql" in db_url:
            # These settings are needed to work with pgbouncer in transaction mode
            # because you can't use prepared statements in such case
            connect_args = {
                "statement_cache_size": 0,
                "prepared_statement_cache_size": 0,
            }
        else:
            connect_args = {}
        self._engine = create_async_engine(
            url=db_url,
            pool_pre_ping=True,
            connect_args=connect_args,
        )
        self._sessionmaker = async_sessionmaker(
            bind=self._engine,
            expire_on_commit=False,
        )

    async def close(self) -> None:
        if self._engine is None:
            return
        await self._engine.dispose()
        self._engine = None
        self._sessionmaker = None

    @contextlib.asynccontextmanager
    async def session(self) -> AsyncIterator[AsyncSession]:
        if self._sessionmaker is None:
            raise IOError("DatabaseSessionManager is not initialized")
        async with self._sessionmaker() as session:
            try:
                yield session
            except Exception:
                await session.rollback()
                raise

    @contextlib.asynccontextmanager
    async def connect(self) -> AsyncIterator[AsyncConnection]:
        if self._engine is None:
            raise IOError("DatabaseSessionManager is not initialized")
        async with self._engine.begin() as connection:
            try:
                yield connection
            except Exception:
                await connection.rollback()
                raise


db_manager = DatabaseSessionManager()

Notice that we're we're using the async version of the create_engine method, which returns an AsyncEngine object. We will also use the async version of the sessionmaker method, which returns an AsyncSession object for committing and rolling back transactions.

We are going to use init and close methods in FastAPI's lifespan event, to run it during startup and shutdown of our application.

The benefits of this approach are:

  • You can connect to as many databases as needed, which was a problem for me with the middleware approach. (just create different DatabaseSessionManager for each database)
  • Your DB connections are released at application shutdown instead of garbage collection, which means you won't run into issues if you use uvicorn --reload
  • Your DB sessions will automatically be closed when the route using session dependency finishes, so any uncommitted operations will be rolled back.

Then, we need to create a FastAPI dependency that will be used to get the database session. This dependency will be used in the API views:

# orm/session_manager.py
async def get_session() -> AsyncSession:
    async with db_manager.session() as session:
        yield session

And we're done with the database configuration. Now we can create the database models (I just used the model from SQLAlchemy tutorial):

# orm/user_model.py

from typing import Optional

from sqlalchemy import String
from sqlalchemy.orm import Mapped, mapped_column

from .base_model import OrmBase


class User(OrmBase):
    __tablename__ = "user_account"

    id: Mapped[int] = mapped_column(primary_key=True)
    name: Mapped[str] = mapped_column(String(30))
    fullname: Mapped[Optional[str]]

    def __repr__(self) -> str:
        return f"User(id={self.id!r}, name={self.name!r}, fullname={self.fullname!r})"

It's the most modern form of Declarative, which is driven from PEP 484 type annotations using a special type Mapped, which indicates attributes to be mapped as particular types.

Working with Database Metadata — SQLAlchemy 2.0 Documentation

Read more about Declaring Mapped Classes

# orm/__init__.py
"""
Data structures, used in project.

Add your new models here so Alembic could pick them up.

You may do changes in tables, then execute
`alembic revision --message="Your text" --autogenerate`
and alembic would generate new migration for you
in alembic/versions folder.
"""
from .base_model import OrmBase
from .session_manager import db_manager, get_session
from .user_model import User

__all__ = ["OrmBase", "get_session", "db_manager", "User"]

I use dunder init file in orm module to be able to use import orm and then call objects like orm.db_manager, orm.User etc... This approach substantially simplifies the distinction between your SQLAlchemy models and your business-oriented models.

Creating the API views

Now that we have the database configuration and models set up, we can create the API views.

For simplicity I'm going to put all API related models and functions in one file so you can check it easily. In real life you probably should consider segmentation for organizational clarity.

Let's start by creating a models for validating incoming API request and providing a response:

# api/user.py
from pydantic import BaseModel, ConfigDict, Field

class UserCreateRequest(BaseModel):
    name: str = Field(max_length=30)
    fullname: str


class UserResponse(BaseModel):
    id: int
    name: str
    fullname: str

    model_config = ConfigDict(from_attributes=True)

Then you need to return a list of users, many people opt for a simplistic approach such as responding with something like list[User] . After some time they need to add some additional information to such endpoint, but it can't be done easily.

So it's better to use more flexible response structure from the beginning, like:

# api/user.py
class APIUserResponse(BaseModel):
    status: Literal['ok'] = 'ok'
    data: UserResponse


class APIUserListResponse(BaseModel):
    status: Literal['ok'] = 'ok'
    data: list[UserResponse]

And, finally, our API views:

Few extra notes:

  • Our UserResponse model has every field which is already json-compatible (strings and ints). That's why we can use .model_dump() . If your model have some types like UUID , datetime, other classes, - you can use .model_dump(mode='json') and pydantic will automatically convert output values to be json-supported types.
  • I prefer to return Response directly than use FastAPI's response_model conversion. For me it's more convenient plus it's actually faster. (You could check https://github.com/falkben/fastapi_experiments/ -> orjson_response.py)
  • For the sake of simplicity I do ORM queries right in api views. In bigger project it's better to create additional service layer and put all your orm/sql queries in one module. If such queries spread through your code base you are going to regret it later.
# api/user.py
import uuid
from typing import Literal

from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession

import orm


class UserCreateRequest(BaseModel):
    name: str = Field(max_length=30)
    fullname: str


class UserResponse(BaseModel):
    id: int
    name: str
    fullname: str

    model_config = ConfigDict(from_attributes=True)


class APIUserResponse(BaseModel):
    status: Literal["ok"] = "ok"
    data: UserResponse


class APIUserListResponse(BaseModel):
    status: Literal["ok"] = "ok"
    data: list[UserResponse]


router = APIRouter()


@router.get("/{user_id}/", response_model=APIUserResponse)
async def get_user(
    user_id: int, session: AsyncSession = Depends(orm.get_session)
) -> JSONResponse:
    user = await session.get(orm.User, user_id)
    if not user:
        return JSONResponse(
            content={"status": "error", "message": "User not found"},
            status_code=status.HTTP_404_NOT_FOUND,
        )
    response_model = UserResponse.model_validate(user)
    return JSONResponse(
        content={
            "status": "ok",
            "data": response_model.model_dump(),
        }
    )


@router.get("/", response_model=APIUserListResponse)
async def get_users(session: AsyncSession = Depends(orm.get_session)) -> JSONResponse:
    users_results = await session.scalars(select(orm.User))
    response_data = [
        UserResponse.model_validate(u).model_dump() for u in users_results.all()
    ]
    return JSONResponse(
        content={
            "status": "ok",
            "data": response_data,
        }
    )


@router.post("/", response_model=APIUserResponse, status_code=status.HTTP_201_CREATED)
async def create_user(
    user_data: UserCreateRequest, session: AsyncSession = Depends(orm.get_session)
) -> JSONResponse:
    user_candidate = orm.User(**user_data.model_dump())
    session.add(user_candidate)
    # I skip error handling
    await session.commit()
    await session.refresh(user_candidate)
    response_model = UserResponse.model_validate(user_candidate)
    return JSONResponse(
        content={
            "status": "ok",
            "data": response_model.model_dump(),
        },
        status_code=status.HTTP_201_CREATED,
    )

Here we have a simple FastAPI router with three API views: get_user, get_users and create_user. Notice that we're using the Depends keyword to inject the database async session into the API views. This is how we can use the database session in the API views.

Setting up FastAPI

Now that we have the API views set up, we can create the FastAPI application.

# main.py
import contextlib
from typing import AsyncIterator

import uvicorn
from fastapi import FastAPI

import orm
from api import user
from app.settings import settings


@contextlib.asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
    orm.db_manager.init(settings.database_url)
    yield
    await orm.db_manager.close()


app = FastAPI(title="Very simple example", lifespan=lifespan)
app.include_router(user.router, prefix="/api/users", tags=["users"])

if __name__ == "__main__":
    # There are a lot of parameters for uvicorn, you should check the docs
    uvicorn.run(
        app,
        host=settings.app_host,
        port=settings.app_port,
    )

In order for us to run our application, first we'll need to create our database tables. Let's see how we can do that using Alembic.

Migrations with Alembic

To start with alembic, we can use the alembic init command to create the alembic configuration. We'll use the async template for this:

alembic init -t async alembic

This will create the alembic directory with the alembic configuration. We'll need to make a few changes to the configuration.

alembic.ini

Uncomment the file_template line so names of the migrations will be more user- friendly and with dates so we can sort them.

Remove sqlalchemy.url string because we are going to set this parameter via alembic/env.py

alembic/env.py

First, we'll need to import our database models so that they're added to the Base.metadata object. This happens automatically when the model inherits from OrmBase, but we need to import the models to ensure that they're imported before the alembic configuration is loaded. Because we put all models into orm/__init__.py we can do import orm and models will be loaded.

Then, we need to set the sqlalchemy.url configuration to use our database connection string.

Important note - We are going to generate Alembic configuration for tests, so we need to be careful and not to rewrite sqlalchemy.url if it's already set.

And finally, we'll point the target metadata to our Base.metadata object.

Below I'll show the changes we need to make to the alembic/env.py file:

# alembic/env.py

import orm
from app.settings import settings
current_url = config.get_main_option('sqlalchemy.url', None)
if not current_url:
    config.set_main_option("sqlalchemy.url", settings.database_url)

target_metadata = orm.OrmBase.metadata 

Then, we're able to run the alembic revision command to create a new revision:

alembic revision --autogenerate -m "Add user model"

This will create a new revision file in the alembic/versions directory. We can then run the alembic upgrade head command to apply the migration to the database:

alembic upgrade head

To revert the last migration you could use

alembic downgrade -1
Tutorial — Alembic 1.11.2 documentation

More about alembic commands

Starting the server

To start the server, runpython main.py. This will start the server on port 8000 by default. The docs will be available at http://localhost:8000/docs. You should be able to see and run any of the API views that we've created.

This should be enough to start using FastAPI with SQLAlchemy 2.0. However, one important component of software development is testing, so let's see how we can test our API views.

Testing the API views

For this section, my focus is primarily on demonstrating the mechanics of integration testing with FastAPI and SQLAlchemy 2.0. This means that our tests will call the API views and check the responses. While we won't be testing the database models, it's worth noting that a similar setup can be applied for such scenarios as well.

We'll start with the helper functions we are going to need:

Thesqlalchemy_utils package have the two very useful functions - create_database and drop_database.

Regrettably, these functions are synchronous and incompatible with the asyncpg driver. This typically leads tutorials to recommend the installation of psycopg2 and the adoption of a separate synchronous engine for database creation. However, in the spirit of experimentation, we can just slightly modify such functions so they can use create_async_engine

You can see them in Github repo

Next utils are:

# tests/db_utils.py
import contextlib
import uuid
from argparse import Namespace
from pathlib import Path
from typing import AsyncIterator, Optional, Union

import sqlalchemy as sa
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy_utils.functions.database import (
    _set_url_database,
    _sqlite_file_exists,
    make_url,
)
from sqlalchemy_utils.functions.orm import quote
from yarl import URL

from alembic.config import Config as AlembicConfig
from app.settings import settings


def make_alembic_config(
    cmd_opts: Namespace, base_path: Union[str, Path] = settings.project_root
) -> AlembicConfig:
    # Replace path to alembic.ini file to absolute
    base_path = Path(base_path)
    if not Path(cmd_opts.config).is_absolute():
        cmd_opts.config = str(base_path.joinpath(cmd_opts.config).absolute())
    config = AlembicConfig(
        file_=cmd_opts.config,
        ini_section=cmd_opts.name,
        cmd_opts=cmd_opts,
    )
    # Replace path to alembic folder to absolute
    alembic_location = config.get_main_option("script_location")
    if not Path(alembic_location).is_absolute():
        config.set_main_option(
            "script_location", str(base_path.joinpath(alembic_location).absolute())
        )
    if cmd_opts.pg_url:
        config.set_main_option("sqlalchemy.url", cmd_opts.pg_url)
    return config


def alembic_config_from_url(pg_url: Optional[str] = None) -> AlembicConfig:
    """Provides python object, representing alembic.ini file."""
    cmd_options = Namespace(
        config="alembic.ini",  # Config file name
        name="alembic",  # Name of section in .ini file to use for Alembic config
        pg_url=pg_url,  # DB URI
        raiseerr=True,  # Raise a full stack trace on error
        x=None,  # Additional arguments consumed by custom env.py scripts
    )
    return make_alembic_config(cmd_opts=cmd_options)


@contextlib.asynccontextmanager
async def tmp_database(db_url: URL, suffix: str = "", **kwargs) -> AsyncIterator[str]:
    """Context manager for creating new database and deleting it on exit."""
    tmp_db_name = ".".join([uuid.uuid4().hex, "tests-base", suffix])
    tmp_db_url = str(db_url.with_path(tmp_db_name))
    await create_database_async(tmp_db_url, **kwargs)
    try:
        yield tmp_db_url
    finally:
        await drop_database_async(tmp_db_url)

# Next functions are copied from `sqlalchemy_utils` and slightly 
# modified to support async. Maybe 
async def create_database_async(
    url: str, encoding: str = "utf8", template: Optional[str] = None
) -> None:
    url = make_url(url)
    database = url.database
    dialect_name = url.get_dialect().name
    dialect_driver = url.get_dialect().driver

    if dialect_name == "postgresql":
        url = _set_url_database(url, database="postgres")
    elif dialect_name == "mssql":
        url = _set_url_database(url, database="master")
    elif dialect_name == "cockroachdb":
        url = _set_url_database(url, database="defaultdb")
    elif not dialect_name == "sqlite":
        url = _set_url_database(url, database=None)

    if (dialect_name == "mssql" and dialect_driver in {"pymssql", "pyodbc"}) or (
        dialect_name == "postgresql"
        and dialect_driver in {"asyncpg", "pg8000", "psycopg2", "psycopg2cffi"}
    ):
        engine = create_async_engine(url, isolation_level="AUTOCOMMIT")
    else:
        engine = create_async_engine(url)

    if dialect_name == "postgresql":
        if not template:
            template = "template1"

        async with engine.begin() as conn:
            text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format(
                quote(conn, database), encoding, quote(conn, template)
            )
            await conn.execute(sa.text(text))

    elif dialect_name == "mysql":
        async with engine.begin() as conn:
            text = "CREATE DATABASE {} CHARACTER SET = '{}'".format(
                quote(conn, database), encoding
            )
            await conn.execute(sa.text(text))

    elif dialect_name == "sqlite" and database != ":memory:":
        if database:
            async with engine.begin() as conn:
                await conn.execute(sa.text("CREATE TABLE DB(id int)"))
                await conn.execute(sa.text("DROP TABLE DB"))

    else:
        async with engine.begin() as conn:
            text = f"CREATE DATABASE {quote(conn, database)}"
            await conn.execute(sa.text(text))

    await engine.dispose()


async def drop_database_async(url: str) -> None:
    url = make_url(url)
    database = url.database
    dialect_name = url.get_dialect().name
    dialect_driver = url.get_dialect().driver

    if dialect_name == "postgresql":
        url = _set_url_database(url, database="postgres")
    elif dialect_name == "mssql":
        url = _set_url_database(url, database="master")
    elif dialect_name == "cockroachdb":
        url = _set_url_database(url, database="defaultdb")
    elif not dialect_name == "sqlite":
        url = _set_url_database(url, database=None)

    if dialect_name == "mssql" and dialect_driver in {"pymssql", "pyodbc"}:
        engine = create_async_engine(url, connect_args={"autocommit": True})
    elif dialect_name == "postgresql" and dialect_driver in {
        "asyncpg",
        "pg8000",
        "psycopg2",
        "psycopg2cffi",
    }:
        engine = create_async_engine(url, isolation_level="AUTOCOMMIT")
    else:
        engine = create_async_engine(url)

    if dialect_name == "sqlite" and database != ":memory:":
        if database:
            os.remove(database)
    elif dialect_name == "postgresql":
        async with engine.begin() as conn:
            # Disconnect all users from the database we are dropping.
            version = conn.dialect.server_version_info
            pid_column = "pid" if (version >= (9, 2)) else "procpid"
            text = """
            SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
            FROM pg_stat_activity
            WHERE pg_stat_activity.datname = '{database}'
            AND {pid_column} <> pg_backend_pid();
            """.format(
                pid_column=pid_column, database=database
            )
            await conn.execute(sa.text(text))

            # Drop the database.
            text = f"DROP DATABASE {quote(conn, database)}"
            await conn.execute(sa.text(text))
    else:
        async with engine.begin() as conn:
            text = f"DROP DATABASE {quote(conn, database)}"
            await conn.execute(sa.text(text))

    await engine.dispose()

Let's start by creating a conftest.py file in the root of our tests/integration directory. This file will be responsible for setting up the test database. Since this is an intricate setup, let's break it down into smaller pieces. We'll start with the imports:

# tests/conftest.py
from typing import Optional

import pytest
from httpx import AsyncClient
from yarl import URL

import orm
from alembic.command import upgrade
from app.settings import settings
from orm.session_manager import db_manager
from tests.db_utils import alembic_config_from_url, tmp_database

There isn't much action going here. We're importing the necessary packages. Let's move on and create our app and client fixtures, used to create the FastAPI test application and test client:

# tests/conftest.py
@pytest.fixture()
def app():
    from main import app

    yield app


@pytest.fixture()
async def client(session, app):
    async with AsyncClient(app=app, base_url="http://test") as client:
        yield client

Because we use FastAPI, and it uses anyio we can use it in our tests. Many people are using pytest-asyncio. To skip testing on trio eventloop we need to create a new fixture. With it we also can just write async def test_... test functions without marking them additionally.

# tests/conftest.py
@pytest.fixture(scope="session", autouse=True)
def anyio_backend():
    return "asyncio", {"use_uvloop": True}

And now we're ready to create the database connection and session. Our test connection will be scoped to the session, so that we can use the same connection for all the tests, as it's best practice to avoid creating a new connection for each test, or even request.

# tests/conftest.py

@pytest.fixture(scope="session")
def pg_url():
    """Provides base PostgreSQL URL for creating temporary databases."""
    return URL(settings.database_url)


@pytest.fixture(scope="session")
async def migrated_postgres_template(pg_url):
    """
    Creates temporary database and applies migrations.

    Has "session" scope, so is called only once per tests run.
    """
    async with tmp_database(pg_url, "pytest") as tmp_url:
        alembic_config = alembic_config_from_url(tmp_url)
        # sometimes we have so called data-migrations.
        # they can call different db-related functions etc..
        # so we modify our settings
        settings.database_url = tmp_url
        
        # It is important to always close the connections at the end of such migrations,
        # or we will get errors like `source database is being accessed by other users`        

        upgrade(alembic_config, "head")

        yield tmp_url


@pytest.fixture(scope="session")
async def sessionmanager_for_tests(migrated_postgres_template):
    db_manager.init(db_url=migrated_postgres_template)
    # can add another init (redis, etc...)
    yield db_manager
    await db_manager.close()


@pytest.fixture()
async def session(sessionmanager_for_tests):
    async with db_manager.session() as session:
        yield session

    # Clean tables after each test. I tried:
    # 1. Create new database using an empty `migrated_postgres_template` as template
    # (postgres could copy whole db structure)
    # 2. Do TRUNCATE after each test.
    # 3. Do DELETE after each test.
    # DELETE FROM is the fastest
    # https://www.lob.com/blog/truncate-vs-delete-efficiently-clearing-data-from-a-postgres-table
    # BUT DELETE FROM query does not reset any AUTO_INCREMENT counter
    async with db_manager.connect() as conn:
        for table in reversed(orm.OrmBase.metadata.sorted_tables):
            # Clean tables in such order that tables which depend on another go first
            await conn.execute(table.delete())
        await conn.commit()

DELETE FROM does not reset any AUTO_INCREMENT counter so our user.id attribute is going to increase during single tests run. You should consider if it's bad for you or not. For me it's no problem, I don't want to switch to TRUNCATE.

Now we can write our first and simple test

# tests/test_orm_works.py

from sqlalchemy import text

import orm


async def test_orm_session(session):
    user = orm.User(
        name="Michael",
        fullname="Michael Test Jr.",
    )
    session.add(user)
    await session.commit()

    rows = await session.execute(text('SELECT id, name, fullname FROM "user_account"'))
    result = list(rows)[0]
    assert isinstance(result[0], int)
    assert result[1] == "Michael"
    assert result[2] == "Michael Test Jr."

You could run pytest and it works.

But we're not done yet. We need to add very useful Stairway test and doing so we will face a new challenges with Alembic.

Stairway test

Simple and efficient method to check that migration does not have typos and rolls back all schema changes. Does not require maintenance - you can add this test to your project once and forget about it.

In particular, test detects the data types, that were previously created by upgrade() method and were not removed by downgrade(): when creating a table/column, Alembic automatically creates custom data types specified in columns (e.g. enum), but does not delete them when deleting table or column - developer has to do it manually.

How it works

Test retrieves all migrations list, and for each migration executes upgrade, downgrade, upgrade Alembic commands.

Let's add new test package migrations and create fixtures

# tests/migrations/conftest.py

import pytest
from sqlalchemy.ext.asyncio import create_async_engine

from tests.db_utils import alembic_config_from_url, tmp_database


@pytest.fixture()
async def postgres(pg_url):
    """
    Creates empty temporary database.
    """
    async with tmp_database(pg_url, "pytest") as tmp_url:
        yield tmp_url


@pytest.fixture()
async def postgres_engine(postgres):
    """
    SQLAlchemy engine, bound to temporary database.
    """
    engine = create_async_engine(
        url=postgres,
        pool_pre_ping=True,
    )
    try:
        yield engine
    finally:
        await engine.dispose()


@pytest.fixture()
def alembic_config(postgres):
    """
    Alembic configuration object, bound to temporary database.
    """
    return alembic_config_from_url(postgres)

And the test itself:

# tests/migrations/test_stairway.py

"""
Test can find forgotten downgrade methods, undeleted data types in downgrade
methods, typos and many other errors.

Does not require any maintenance - you just add it once to check 80% of typos
and mistakes in migrations forever.
"""
import pytest

from alembic.command import downgrade, upgrade
from alembic.config import Config
from alembic.script import Script, ScriptDirectory
from tests.db_utils import alembic_config_from_url


def get_revisions():
    # Create Alembic configuration object
    # (we don't need database for getting revisions list)
    config = alembic_config_from_url()

    # Get directory object with Alembic migrations
    revisions_dir = ScriptDirectory.from_config(config)

    # Get & sort migrations, from first to last
    revisions = list(revisions_dir.walk_revisions("base", "heads"))
    revisions.reverse()
    return revisions


@pytest.mark.parametrize("revision", get_revisions())
def test_migrations_stairway(alembic_config: Config, revision: Script):
    upgrade(alembic_config, revision.revision)

    # We need -1 for downgrading first migration (its down_revision is None)
    downgrade(alembic_config, revision.down_revision or "-1")
    upgrade(alembic_config, revision.revision)

Running pytest again we will get an error:

E RuntimeError: asyncio.run() cannot be called from a running event loop

That's because inside upgrade command alembic use asyncio.run to run migrations via asyncpg driver. That works just fine then we run migration commands from command line, but during test run an active asyncio event loop is already in place and we can't use asyncio.run

We are definitely don't want to rewrite alembic internals. But we need some way to run an async function from sync run_migrations_online while eventloop is already running.

I decided do the following:

  • check for running eventloop, it there is none, we can run standard alembic's way
  • if there is an eventloop we can use asyncio.create_task to wrap our migration command.
  • the problem is that we need somehow to await this task inside our pytest fixture, while creating it during alembic upgrade command.
  • to solve this problem I decided to add a new variable to conftest.py and set it from alembic. Yeap, it's kind of a global variable, but I fail to find more elegant solution.
# tests/conftest.py
#... add to the end
MIGRATION_TASK: Optional[Task] = None

@pytest.fixture(scope="session")
async def migrated_postgres_template(pg_url):
    """
    Creates temporary database and applies migrations.

    Has "session" scope, so is called only once per tests run.
    """
    async with tmp_database(pg_url, "pytest") as tmp_url:
        alembic_config = alembic_config_from_url(tmp_url)
        upgrade(alembic_config, "head")

        await MIGRATION_TASK # added line

        yield tmp_url
# alembic/env.py

def run_migrations_online() -> None:
    """Run migrations in 'online' mode."""
    try:
        current_loop = asyncio.get_running_loop()
    except RuntimeError:
        # there is no loop, can use asyncio.run
        asyncio.run(run_async_migrations())
        return

    from tests import conftest
    conftest.MIGRATION_TASK = asyncio.create_task(run_async_migrations())

Everythig should be fine, right?! Right?!

Well, not quite. We've got ourselves a new error:

E       AttributeError: 'NoneType' object has no attribute 'configure'

../alembic/env.py:61: AttributeError

What is going on? After reading sources of alembic it's get clear:

  1. When we run upgrade command, alembic loads some data into context object using special context-manager:
with EnvironmentContext(...):
    script.run_env()
  1. run_env method loads our alembic/env.py and invokes run_migrations_online()
  2. We create asyncio Task and return from run_migrations_online. It ends the context manager and clears context object. So when we are trying to actually run some code inside this task, we already don't have some parameters. context is None and that's why we got an error shown before.

So we need somehow to pass data into our async task. To do that I decided to use contextvars. If we create contexvars before creating asyncio.Task then this task will get a copy of contextvars and will be able to use them.

Let's start with import and process of setting the context variable.

# alembic/env.py
from contextvars import ContextVar

ctx_var: ContextVar[dict[str, Any]] = ContextVar("ctx_var")


def run_migrations_online() -> None:
    """Run migrations in 'online' mode."""

    try:
        current_loop = asyncio.get_running_loop()
    except RuntimeError:
        # there is no loop, can use asyncio.run
        asyncio.run(run_async_migrations())
        return
    from tests import conftest
    ctx_var.set({
        "config": context.config,
        "script": context.script,
        "opts": context._proxy.context_opts,  # type: ignore
    })
    conftest.MIGRATION_TASK = asyncio.create_task(run_async_migrations())

Next step - using this contextvar

# alembic/env.py

def do_run_migrations(connection: Connection) -> None:
    try:
        context.configure(connection=connection, target_metadata=target_metadata)

        with context.begin_transaction():
            context.run_migrations()
    except AttributeError:
        context_data = ctx_var.get()
        with EnvironmentContext(
                config=context_data["config"],
                script=context_data["script"],
                **context_data["opts"],
        ):
            context.configure(connection=connection, target_metadata=target_metadata)
            with context.begin_transaction():
                context.run_migrations()

That's it. Now you can run pytest and see that everything is ok.

This setup allow you to use pytest-xdist. This package splits your test suite into chunks and runs each chunk in different process. It could speed up you tests if you have a lot of them. And because each process will create its own unique test database it works without any problems.

Finally, just some clumsy integration test to show than our API is working

# tests/test_api.py

from fastapi import status


async def test_my_api(client, app):
    # Test to show that api is working
    response = await client.get("/api/users/")
    assert response.status_code == status.HTTP_200_OK
    assert response.json() == {"status": "ok", "data": []}

    response = await client.post(
        "/api/users/",
        json={"email": "test@example.com", "full_name": "Full Name Test"},
    )
    assert response.status_code == status.HTTP_201_CREATED
    new_user_id = response.json()["data"]["id"]

    response = await client.get(f"/api/users/{new_user_id}/")
    assert response.status_code == status.HTTP_200_OK
    assert response.json() == {
        "status": "ok",
        "data": {
            "id": new_user_id,
            "email": "test@example.com",
            "full_name": "Full Name Test",
        },
    }

    response = await client.get("/api/users/")
    assert response.status_code == status.HTTP_200_OK
    assert len(response.json()["data"]) == 1
$ pytest . --vv -x
===================================================================================== test session starts =====================================================================================
platform darwin -- Python 3.9.15, pytest-7.4.0, pluggy-1.2.0 -- /venv-8ZwWMPCX-py3.9/bin/python
cachedir: .pytest_cache
rootdir: /Users/something/blog_article_v2
plugins: anyio-3.7.1
collected 3 items                                                                                                                                                                             

tests/test_api.py::test_my_api PASSED                                                                                                                                                   [ 33%]
tests/test_orm_works.py::test_orm_session PASSED                                                                                                                                        [ 66%]
tests/migrations/test_stairway.py::test_migrations_stairway[revision0] PASSED                                                                                                           [100%]

====================================================================================== 3 passed in 1.48s ======================================================================================

Hope it is going to be useful to someone and Google will index this article someday. ^.^ Have a nice day!