Introduction
In this article, we'll explore the integration of FastAPI with the new asynchronous SQLAlchemy 2.0. Additionally, we'll delve into configuring pytest to execute asynchronous tests, allowing compatibility with pytest-xdist. We'll also cover the application of Alembic for db migrations with an asynchronous database driver.
The inspiration for this sparked as I delved into the insightful book "Architecture Patterns with Python" authored by Harry Percival & Bob Gregory.
I also encountered the captivating concept of the "Stairway test," which is eloquently detailed in the repository by https://github.com/alvassin/alembic-quickstart/tree/master. This concept profoundly resonated with me and led me to formulate the ideas presented in this post.
Reqirements
I run this project using Python 3.9, probably you can easily adapt it to work on earlier versions.
I use poetry to manage project requirements.
Source code can be found here
Install dependencies
$ poetry add fastapi uvicorn uvloop asyncpg alembic pydantic-settings
$ poetry add sqlalchemy --extras asyncio
Install dev dependencies
$ poetry add --group=dev httpx sqlalchemy-utils pytest yarl mypy black isort
Setting up the database
We're going to use FastAPI to create a straightforward API designed for user creation and retrieval from a database. Our primary objective is to illustrate the synergy between SQLAlchemy 2.0 and FastAPI, thus the API intricacies won't be our focal point in this context.
Let's start by creating a configuration file that will hold our database connection string. In my preference, I opt for leveraging pydantic-settings
in scenarios of this nature. However, feel free to utilize any other method such as os.getenv
if that aligns better with your workflow.
For the sake of clarity, I have encapsulated the entire database URI within a single parameter. It's important to note that in a real-world scenario, such configuration settings would likely be segregated into discrete entities like db_host, db_port, db_user, db_password, and more.
# app/settings.py
from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
app_name: str = "Example API"
app_host: str = "0.0.0.0"
app_port: int = 3000
database_url: str = "postgresql+asyncpg://blog_example_user:password@localhost:5432/blog_example_base"
project_root: Path = Path(__file__).parent.parent.resolve()
model_config = SettingsConfigDict(env_file=".env")
settings = Settings()
I'm not going to cover how to start a local PostgreSQL database in this post, but you can, for example use the official docker image to start a local database.
To be able to run tests your database user should have CREATEDB
privilege.
Example SQL commands to create a new user with a new database.
CREATE USER "blog_example_user" WITH PASSWORD 'password';
CREATE DATABASE "blog_example_base" OWNER "blog_example_user";
ALTER USER "blog_example_user" CREATEDB;
Let's create our ORM model
I prefer to put all orm-related stuff into an orm
module and use its __init__.py
So in code i use it like import orm
query = select(orm.User)...
This way it's much easier to distinguish between SQLAlchemy models and my business models.
So, first let's add the base class
# orm/base_model.py
from sqlalchemy import MetaData
from sqlalchemy.orm import DeclarativeBase
# Default naming convention for all indexes and constraints
# See why this is important and how it would save your time:
# https://alembic.sqlalchemy.org/en/latest/naming.html
convention = {
"all_column_names": lambda constraint, table: "_".join(
[column.name for column in constraint.columns.values()]
),
"ix": "ix__%(table_name)s__%(all_column_names)s",
"uq": "uq__%(table_name)s__%(all_column_names)s",
"ck": "ck__%(table_name)s__%(constraint_name)s",
"fk": "fk__%(table_name)s__%(all_column_names)s__%(referred_table_name)s",
"pk": "pk__%(table_name)s",
}
class OrmBase(DeclarativeBase):
metadata = MetaData(naming_convention=convention) # type: ignore
Then, let's create a session manager for our database. This class will be used as a singleton and will be responsible for abstracting the database connection and session handling:
# orm/session_manager.py
import contextlib
from typing import AsyncIterator, Optional
from sqlalchemy.ext.asyncio import (
AsyncConnection,
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
class DatabaseSessionManager:
def __init__(self) -> None:
self._engine: Optional[AsyncEngine] = None
self._sessionmaker: Optional[async_sessionmaker[AsyncSession]] = None
def init(self, db_url: str) -> None:
# Just additional example of customization.
# you can add parameters to init and so on
if "postgresql" in db_url:
# These settings are needed to work with pgbouncer in transaction mode
# because you can't use prepared statements in such case
connect_args = {
"statement_cache_size": 0,
"prepared_statement_cache_size": 0,
}
else:
connect_args = {}
self._engine = create_async_engine(
url=db_url,
pool_pre_ping=True,
connect_args=connect_args,
)
self._sessionmaker = async_sessionmaker(
bind=self._engine,
expire_on_commit=False,
)
async def close(self) -> None:
if self._engine is None:
return
await self._engine.dispose()
self._engine = None
self._sessionmaker = None
@contextlib.asynccontextmanager
async def session(self) -> AsyncIterator[AsyncSession]:
if self._sessionmaker is None:
raise IOError("DatabaseSessionManager is not initialized")
async with self._sessionmaker() as session:
try:
yield session
except Exception:
await session.rollback()
raise
@contextlib.asynccontextmanager
async def connect(self) -> AsyncIterator[AsyncConnection]:
if self._engine is None:
raise IOError("DatabaseSessionManager is not initialized")
async with self._engine.begin() as connection:
try:
yield connection
except Exception:
await connection.rollback()
raise
db_manager = DatabaseSessionManager()
Notice that we're we're using the async version of the create_engine
method, which returns an AsyncEngine
object. We will also use the async version of the sessionmaker
method, which returns an AsyncSession
object for committing and rolling back transactions.
We are going to use init
and close
methods in FastAPI's lifespan event, to run it during startup and shutdown of our application.
The benefits of this approach are:
- You can connect to as many databases as needed, which was a problem for me with the middleware approach. (just create different DatabaseSessionManager for each database)
- Your DB connections are released at application shutdown instead of garbage collection, which means you won't run into issues if you use
uvicorn --reload
- Your DB sessions will automatically be closed when the route using
session
dependency finishes, so any uncommitted operations will be rolled back.
Then, we need to create a FastAPI dependency that will be used to get the database session. This dependency will be used in the API views:
# orm/session_manager.py
async def get_session() -> AsyncSession:
async with db_manager.session() as session:
yield session
And we're done with the database configuration. Now we can create the database models (I just used the model from SQLAlchemy tutorial):
# orm/user_model.py
from typing import Optional
from sqlalchemy import String
from sqlalchemy.orm import Mapped, mapped_column
from .base_model import OrmBase
class User(OrmBase):
__tablename__ = "user_account"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(30))
fullname: Mapped[Optional[str]]
def __repr__(self) -> str:
return f"User(id={self.id!r}, name={self.name!r}, fullname={self.fullname!r})"
It's the most modern form of Declarative, which is driven from PEP 484 type annotations using a special type Mapped
, which indicates attributes to be mapped as particular types.
# orm/__init__.py
"""
Data structures, used in project.
Add your new models here so Alembic could pick them up.
You may do changes in tables, then execute
`alembic revision --message="Your text" --autogenerate`
and alembic would generate new migration for you
in alembic/versions folder.
"""
from .base_model import OrmBase
from .session_manager import db_manager, get_session
from .user_model import User
__all__ = ["OrmBase", "get_session", "db_manager", "User"]
I use dunder init file in orm module to be able to use import orm
and then call objects like orm.db_manager
, orm.User
etc... This approach substantially simplifies the distinction between your SQLAlchemy models and your business-oriented models.
Creating the API views
Now that we have the database configuration and models set up, we can create the API views.
For simplicity I'm going to put all API related models and functions in one file so you can check it easily. In real life you probably should consider segmentation for organizational clarity.
Let's start by creating a models for validating incoming API request and providing a response:
# api/user.py
from pydantic import BaseModel, ConfigDict, Field
class UserCreateRequest(BaseModel):
name: str = Field(max_length=30)
fullname: str
class UserResponse(BaseModel):
id: int
name: str
fullname: str
model_config = ConfigDict(from_attributes=True)
Then you need to return a list of users, many people opt for a simplistic approach such as responding with something like list[User]
. After some time they need to add some additional information to such endpoint, but it can't be done easily.
So it's better to use more flexible response structure from the beginning, like:
# api/user.py
class APIUserResponse(BaseModel):
status: Literal['ok'] = 'ok'
data: UserResponse
class APIUserListResponse(BaseModel):
status: Literal['ok'] = 'ok'
data: list[UserResponse]
And, finally, our API views:
Few extra notes:
- Our UserResponse model has every field which is already json-compatible (strings and ints). That's why we can use
.model_dump()
. If your model have some types likeUUID
,datetime
, other classes, - you can use.model_dump(mode='json')
and pydantic will automatically convert output values to be json-supported types. - I prefer to return Response directly than use FastAPI's
response_model
conversion. For me it's more convenient plus it's actually faster. (You could check https://github.com/falkben/fastapi_experiments/ -> orjson_response.py) - For the sake of simplicity I do ORM queries right in api views. In bigger project it's better to create additional service layer and put all your orm/sql queries in one module. If such queries spread through your code base you are going to regret it later.
# api/user.py
import uuid
from typing import Literal
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
import orm
class UserCreateRequest(BaseModel):
name: str = Field(max_length=30)
fullname: str
class UserResponse(BaseModel):
id: int
name: str
fullname: str
model_config = ConfigDict(from_attributes=True)
class APIUserResponse(BaseModel):
status: Literal["ok"] = "ok"
data: UserResponse
class APIUserListResponse(BaseModel):
status: Literal["ok"] = "ok"
data: list[UserResponse]
router = APIRouter()
@router.get("/{user_id}/", response_model=APIUserResponse)
async def get_user(
user_id: int, session: AsyncSession = Depends(orm.get_session)
) -> JSONResponse:
user = await session.get(orm.User, user_id)
if not user:
return JSONResponse(
content={"status": "error", "message": "User not found"},
status_code=status.HTTP_404_NOT_FOUND,
)
response_model = UserResponse.model_validate(user)
return JSONResponse(
content={
"status": "ok",
"data": response_model.model_dump(),
}
)
@router.get("/", response_model=APIUserListResponse)
async def get_users(session: AsyncSession = Depends(orm.get_session)) -> JSONResponse:
users_results = await session.scalars(select(orm.User))
response_data = [
UserResponse.model_validate(u).model_dump() for u in users_results.all()
]
return JSONResponse(
content={
"status": "ok",
"data": response_data,
}
)
@router.post("/", response_model=APIUserResponse, status_code=status.HTTP_201_CREATED)
async def create_user(
user_data: UserCreateRequest, session: AsyncSession = Depends(orm.get_session)
) -> JSONResponse:
user_candidate = orm.User(**user_data.model_dump())
session.add(user_candidate)
# I skip error handling
await session.commit()
await session.refresh(user_candidate)
response_model = UserResponse.model_validate(user_candidate)
return JSONResponse(
content={
"status": "ok",
"data": response_model.model_dump(),
},
status_code=status.HTTP_201_CREATED,
)
Here we have a simple FastAPI router with three API views: get_user
, get_users
and create_user
. Notice that we're using the Depends
keyword to inject the database async session into the API views. This is how we can use the database session in the API views.
Setting up FastAPI
Now that we have the API views set up, we can create the FastAPI application.
# main.py
import contextlib
from typing import AsyncIterator
import uvicorn
from fastapi import FastAPI
import orm
from api import user
from app.settings import settings
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
orm.db_manager.init(settings.database_url)
yield
await orm.db_manager.close()
app = FastAPI(title="Very simple example", lifespan=lifespan)
app.include_router(user.router, prefix="/api/users", tags=["users"])
if __name__ == "__main__":
# There are a lot of parameters for uvicorn, you should check the docs
uvicorn.run(
app,
host=settings.app_host,
port=settings.app_port,
)
In order for us to run our application, first we'll need to create our database tables. Let's see how we can do that using Alembic.
Migrations with Alembic
To start with alembic, we can use the alembic init
command to create the alembic configuration. We'll use the async
template for this:
alembic init -t async alembic
This will create the alembic
directory with the alembic configuration. We'll need to make a few changes to the configuration.
alembic.ini
Uncomment the file_template
line so names of the migrations will be more user- friendly and with dates so we can sort them.
Remove sqlalchemy.url
string because we are going to set this parameter via alembic/env.py
alembic/env.py
First, we'll need to import our database models so that they're added to the Base.metadata
object. This happens automatically when the model inherits from OrmBase
, but we need to import the models to ensure that they're imported before the alembic configuration is loaded. Because we put all models into orm/__init__.py
we can do import orm
and models will be loaded.
Then, we need to set the sqlalchemy.url
configuration to use our database connection string.
Important note - We are going to generate Alembic configuration for tests, so we need to be careful and not to rewrite sqlalchemy.url
if it's already set.
And finally, we'll point the target metadata
to our Base.metadata
object.
Below I'll show the changes we need to make to the alembic/env.py
file:
# alembic/env.py
import orm
from app.settings import settings
current_url = config.get_main_option('sqlalchemy.url', None)
if not current_url:
config.set_main_option("sqlalchemy.url", settings.database_url)
target_metadata = orm.OrmBase.metadata
Then, we're able to run the alembic revision
command to create a new revision:
alembic revision --autogenerate -m "Add user model"
This will create a new revision file in the alembic/versions
directory. We can then run the alembic upgrade head
command to apply the migration to the database:
alembic upgrade head
To revert the last migration you could use
alembic downgrade -1
Starting the server
To start the server, runpython main.py
. This will start the server on port 8000 by default. The docs will be available at http://localhost:8000/docs
. You should be able to see and run any of the API views that we've created.
This should be enough to start using FastAPI with SQLAlchemy 2.0. However, one important component of software development is testing, so let's see how we can test our API views.
Testing the API views
For this section, my focus is primarily on demonstrating the mechanics of integration testing with FastAPI and SQLAlchemy 2.0. This means that our tests will call the API views and check the responses. While we won't be testing the database models, it's worth noting that a similar setup can be applied for such scenarios as well.
We'll start with the helper functions we are going to need:
Thesqlalchemy_utils
package have the two very useful functions - create_database
and drop_database.
Regrettably, these functions are synchronous and incompatible with the asyncpg
driver. This typically leads tutorials to recommend the installation of psycopg2
and the adoption of a separate synchronous engine for database creation. However, in the spirit of experimentation, we can just slightly modify such functions so they can use create_async_engine
You can see them in Github repo
Next utils are:
# tests/db_utils.py
import contextlib
import uuid
from argparse import Namespace
from pathlib import Path
from typing import AsyncIterator, Optional, Union
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy_utils.functions.database import (
_set_url_database,
_sqlite_file_exists,
make_url,
)
from sqlalchemy_utils.functions.orm import quote
from yarl import URL
from alembic.config import Config as AlembicConfig
from app.settings import settings
def make_alembic_config(
cmd_opts: Namespace, base_path: Union[str, Path] = settings.project_root
) -> AlembicConfig:
# Replace path to alembic.ini file to absolute
base_path = Path(base_path)
if not Path(cmd_opts.config).is_absolute():
cmd_opts.config = str(base_path.joinpath(cmd_opts.config).absolute())
config = AlembicConfig(
file_=cmd_opts.config,
ini_section=cmd_opts.name,
cmd_opts=cmd_opts,
)
# Replace path to alembic folder to absolute
alembic_location = config.get_main_option("script_location")
if not Path(alembic_location).is_absolute():
config.set_main_option(
"script_location", str(base_path.joinpath(alembic_location).absolute())
)
if cmd_opts.pg_url:
config.set_main_option("sqlalchemy.url", cmd_opts.pg_url)
return config
def alembic_config_from_url(pg_url: Optional[str] = None) -> AlembicConfig:
"""Provides python object, representing alembic.ini file."""
cmd_options = Namespace(
config="alembic.ini", # Config file name
name="alembic", # Name of section in .ini file to use for Alembic config
pg_url=pg_url, # DB URI
raiseerr=True, # Raise a full stack trace on error
x=None, # Additional arguments consumed by custom env.py scripts
)
return make_alembic_config(cmd_opts=cmd_options)
@contextlib.asynccontextmanager
async def tmp_database(db_url: URL, suffix: str = "", **kwargs) -> AsyncIterator[str]:
"""Context manager for creating new database and deleting it on exit."""
tmp_db_name = ".".join([uuid.uuid4().hex, "tests-base", suffix])
tmp_db_url = str(db_url.with_path(tmp_db_name))
await create_database_async(tmp_db_url, **kwargs)
try:
yield tmp_db_url
finally:
await drop_database_async(tmp_db_url)
# Next functions are copied from `sqlalchemy_utils` and slightly
# modified to support async. Maybe
async def create_database_async(
url: str, encoding: str = "utf8", template: Optional[str] = None
) -> None:
url = make_url(url)
database = url.database
dialect_name = url.get_dialect().name
dialect_driver = url.get_dialect().driver
if dialect_name == "postgresql":
url = _set_url_database(url, database="postgres")
elif dialect_name == "mssql":
url = _set_url_database(url, database="master")
elif dialect_name == "cockroachdb":
url = _set_url_database(url, database="defaultdb")
elif not dialect_name == "sqlite":
url = _set_url_database(url, database=None)
if (dialect_name == "mssql" and dialect_driver in {"pymssql", "pyodbc"}) or (
dialect_name == "postgresql"
and dialect_driver in {"asyncpg", "pg8000", "psycopg2", "psycopg2cffi"}
):
engine = create_async_engine(url, isolation_level="AUTOCOMMIT")
else:
engine = create_async_engine(url)
if dialect_name == "postgresql":
if not template:
template = "template1"
async with engine.begin() as conn:
text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format(
quote(conn, database), encoding, quote(conn, template)
)
await conn.execute(sa.text(text))
elif dialect_name == "mysql":
async with engine.begin() as conn:
text = "CREATE DATABASE {} CHARACTER SET = '{}'".format(
quote(conn, database), encoding
)
await conn.execute(sa.text(text))
elif dialect_name == "sqlite" and database != ":memory:":
if database:
async with engine.begin() as conn:
await conn.execute(sa.text("CREATE TABLE DB(id int)"))
await conn.execute(sa.text("DROP TABLE DB"))
else:
async with engine.begin() as conn:
text = f"CREATE DATABASE {quote(conn, database)}"
await conn.execute(sa.text(text))
await engine.dispose()
async def drop_database_async(url: str) -> None:
url = make_url(url)
database = url.database
dialect_name = url.get_dialect().name
dialect_driver = url.get_dialect().driver
if dialect_name == "postgresql":
url = _set_url_database(url, database="postgres")
elif dialect_name == "mssql":
url = _set_url_database(url, database="master")
elif dialect_name == "cockroachdb":
url = _set_url_database(url, database="defaultdb")
elif not dialect_name == "sqlite":
url = _set_url_database(url, database=None)
if dialect_name == "mssql" and dialect_driver in {"pymssql", "pyodbc"}:
engine = create_async_engine(url, connect_args={"autocommit": True})
elif dialect_name == "postgresql" and dialect_driver in {
"asyncpg",
"pg8000",
"psycopg2",
"psycopg2cffi",
}:
engine = create_async_engine(url, isolation_level="AUTOCOMMIT")
else:
engine = create_async_engine(url)
if dialect_name == "sqlite" and database != ":memory:":
if database:
os.remove(database)
elif dialect_name == "postgresql":
async with engine.begin() as conn:
# Disconnect all users from the database we are dropping.
version = conn.dialect.server_version_info
pid_column = "pid" if (version >= (9, 2)) else "procpid"
text = """
SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{database}'
AND {pid_column} <> pg_backend_pid();
""".format(
pid_column=pid_column, database=database
)
await conn.execute(sa.text(text))
# Drop the database.
text = f"DROP DATABASE {quote(conn, database)}"
await conn.execute(sa.text(text))
else:
async with engine.begin() as conn:
text = f"DROP DATABASE {quote(conn, database)}"
await conn.execute(sa.text(text))
await engine.dispose()
Let's start by creating a conftest.py
file in the root of our tests/integration
directory. This file will be responsible for setting up the test database. Since this is an intricate setup, let's break it down into smaller pieces. We'll start with the imports:
# tests/conftest.py
from typing import Optional
import pytest
from httpx import AsyncClient
from yarl import URL
import orm
from alembic.command import upgrade
from app.settings import settings
from orm.session_manager import db_manager
from tests.db_utils import alembic_config_from_url, tmp_database
There isn't much action going here. We're importing the necessary packages. Let's move on and create our app
and client
fixtures, used to create the FastAPI test application and test client:
# tests/conftest.py
@pytest.fixture()
def app():
from main import app
yield app
@pytest.fixture()
async def client(session, app):
async with AsyncClient(app=app, base_url="http://test") as client:
yield client
Because we use FastAPI, and it uses anyio
we can use it in our tests. Many people are using pytest-asyncio
. To skip testing on trio
eventloop we need to create a new fixture. With it we also can just write async def test_...
test functions without marking them additionally.
# tests/conftest.py
@pytest.fixture(scope="session", autouse=True)
def anyio_backend():
return "asyncio", {"use_uvloop": True}
And now we're ready to create the database connection and session. Our test connection will be scoped to the session, so that we can use the same connection for all the tests, as it's best practice to avoid creating a new connection for each test, or even request.
# tests/conftest.py
@pytest.fixture(scope="session")
def pg_url():
"""Provides base PostgreSQL URL for creating temporary databases."""
return URL(settings.database_url)
@pytest.fixture(scope="session")
async def migrated_postgres_template(pg_url):
"""
Creates temporary database and applies migrations.
Has "session" scope, so is called only once per tests run.
"""
async with tmp_database(pg_url, "pytest") as tmp_url:
alembic_config = alembic_config_from_url(tmp_url)
# sometimes we have so called data-migrations.
# they can call different db-related functions etc..
# so we modify our settings
settings.database_url = tmp_url
# It is important to always close the connections at the end of such migrations,
# or we will get errors like `source database is being accessed by other users`
upgrade(alembic_config, "head")
yield tmp_url
@pytest.fixture(scope="session")
async def sessionmanager_for_tests(migrated_postgres_template):
db_manager.init(db_url=migrated_postgres_template)
# can add another init (redis, etc...)
yield db_manager
await db_manager.close()
@pytest.fixture()
async def session(sessionmanager_for_tests):
async with db_manager.session() as session:
yield session
# Clean tables after each test. I tried:
# 1. Create new database using an empty `migrated_postgres_template` as template
# (postgres could copy whole db structure)
# 2. Do TRUNCATE after each test.
# 3. Do DELETE after each test.
# DELETE FROM is the fastest
# https://www.lob.com/blog/truncate-vs-delete-efficiently-clearing-data-from-a-postgres-table
# BUT DELETE FROM query does not reset any AUTO_INCREMENT counter
async with db_manager.connect() as conn:
for table in reversed(orm.OrmBase.metadata.sorted_tables):
# Clean tables in such order that tables which depend on another go first
await conn.execute(table.delete())
await conn.commit()
DELETE FROM
does not reset any AUTO_INCREMENT counter so our user.id attribute is going to increase during single tests run. You should consider if it's bad for you or not. For me it's no problem, I don't want to switch to TRUNCATE.
Now we can write our first and simple test
# tests/test_orm_works.py
from sqlalchemy import text
import orm
async def test_orm_session(session):
user = orm.User(
name="Michael",
fullname="Michael Test Jr.",
)
session.add(user)
await session.commit()
rows = await session.execute(text('SELECT id, name, fullname FROM "user_account"'))
result = list(rows)[0]
assert isinstance(result[0], int)
assert result[1] == "Michael"
assert result[2] == "Michael Test Jr."
You could run pytest
and it works.
But we're not done yet. We need to add very useful Stairway test
and doing so we will face a new challenges with Alembic.
Stairway test
Simple and efficient method to check that migration does not have typos and rolls back all schema changes. Does not require maintenance - you can add this test to your project once and forget about it.
In particular, test detects the data types, that were previously created by upgrade()
method and were not removed by downgrade()
: when creating a table/column, Alembic automatically creates custom data types specified in columns (e.g. enum), but does not delete them when deleting table or column - developer has to do it manually.
How it works
Test retrieves all migrations list, and for each migration executes upgrade
, downgrade
, upgrade
Alembic commands.
Let's add new test package migrations
and create fixtures
# tests/migrations/conftest.py
import pytest
from sqlalchemy.ext.asyncio import create_async_engine
from tests.db_utils import alembic_config_from_url, tmp_database
@pytest.fixture()
async def postgres(pg_url):
"""
Creates empty temporary database.
"""
async with tmp_database(pg_url, "pytest") as tmp_url:
yield tmp_url
@pytest.fixture()
async def postgres_engine(postgres):
"""
SQLAlchemy engine, bound to temporary database.
"""
engine = create_async_engine(
url=postgres,
pool_pre_ping=True,
)
try:
yield engine
finally:
await engine.dispose()
@pytest.fixture()
def alembic_config(postgres):
"""
Alembic configuration object, bound to temporary database.
"""
return alembic_config_from_url(postgres)
And the test itself:
# tests/migrations/test_stairway.py
"""
Test can find forgotten downgrade methods, undeleted data types in downgrade
methods, typos and many other errors.
Does not require any maintenance - you just add it once to check 80% of typos
and mistakes in migrations forever.
"""
import pytest
from alembic.command import downgrade, upgrade
from alembic.config import Config
from alembic.script import Script, ScriptDirectory
from tests.db_utils import alembic_config_from_url
def get_revisions():
# Create Alembic configuration object
# (we don't need database for getting revisions list)
config = alembic_config_from_url()
# Get directory object with Alembic migrations
revisions_dir = ScriptDirectory.from_config(config)
# Get & sort migrations, from first to last
revisions = list(revisions_dir.walk_revisions("base", "heads"))
revisions.reverse()
return revisions
@pytest.mark.parametrize("revision", get_revisions())
def test_migrations_stairway(alembic_config: Config, revision: Script):
upgrade(alembic_config, revision.revision)
# We need -1 for downgrading first migration (its down_revision is None)
downgrade(alembic_config, revision.down_revision or "-1")
upgrade(alembic_config, revision.revision)
Running pytest again we will get an error:
E RuntimeError: asyncio.run() cannot be called from a running event loop
That's because inside upgrade
command alembic use asyncio.run
to run migrations via asyncpg
driver. That works just fine then we run migration commands from command line, but during test run an active asyncio event loop is already in place and we can't use asyncio.run
We are definitely don't want to rewrite alembic internals. But we need some way to run an async function from sync run_migrations_online
while eventloop is already running.
I decided do the following:
- check for running eventloop, it there is none, we can run standard alembic's way
- if there is an eventloop we can use
asyncio.create_task
to wrap our migration command. - the problem is that we need somehow to
await
this task inside our pytest fixture, while creating it during alembicupgrade
command. - to solve this problem I decided to add a new variable to
conftest.py
and set it from alembic. Yeap, it's kind of a global variable, but I fail to find more elegant solution.
# tests/conftest.py
#... add to the end
MIGRATION_TASK: Optional[Task] = None
@pytest.fixture(scope="session")
async def migrated_postgres_template(pg_url):
"""
Creates temporary database and applies migrations.
Has "session" scope, so is called only once per tests run.
"""
async with tmp_database(pg_url, "pytest") as tmp_url:
alembic_config = alembic_config_from_url(tmp_url)
upgrade(alembic_config, "head")
await MIGRATION_TASK # added line
yield tmp_url
# alembic/env.py
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
try:
current_loop = asyncio.get_running_loop()
except RuntimeError:
# there is no loop, can use asyncio.run
asyncio.run(run_async_migrations())
return
from tests import conftest
conftest.MIGRATION_TASK = asyncio.create_task(run_async_migrations())
Everythig should be fine, right?! Right?!
Well, not quite. We've got ourselves a new error:
E AttributeError: 'NoneType' object has no attribute 'configure'
../alembic/env.py:61: AttributeError
What is going on? After reading sources of alembic it's get clear:
- When we run
upgrade
command, alembic loads some data intocontext
object using special context-manager:
with EnvironmentContext(...):
script.run_env()
run_env
method loads ouralembic/env.py
and invokesrun_migrations_online()
- We create asyncio Task and return from
run_migrations_online
. It ends the context manager and clearscontext
object. So when we are trying to actually run some code inside this task, we already don't have some parameters.context
is None and that's why we got an error shown before.
So we need somehow to pass data into our async task. To do that I decided to use contextvars
. If we create contexvars before creating asyncio.Task then this task will get a copy of contextvars and will be able to use them.
Let's start with import and process of setting the context variable.
# alembic/env.py
from contextvars import ContextVar
ctx_var: ContextVar[dict[str, Any]] = ContextVar("ctx_var")
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
try:
current_loop = asyncio.get_running_loop()
except RuntimeError:
# there is no loop, can use asyncio.run
asyncio.run(run_async_migrations())
return
from tests import conftest
ctx_var.set({
"config": context.config,
"script": context.script,
"opts": context._proxy.context_opts, # type: ignore
})
conftest.MIGRATION_TASK = asyncio.create_task(run_async_migrations())
Next step - using this contextvar
# alembic/env.py
def do_run_migrations(connection: Connection) -> None:
try:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
except AttributeError:
context_data = ctx_var.get()
with EnvironmentContext(
config=context_data["config"],
script=context_data["script"],
**context_data["opts"],
):
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
That's it. Now you can run pytest and see that everything is ok.
This setup allow you to use pytest-xdist
. This package splits your test suite into chunks and runs each chunk in different process. It could speed up you tests if you have a lot of them. And because each process will create its own unique test database it works without any problems.
Finally, just some clumsy integration test to show than our API is working
# tests/test_api.py
from fastapi import status
async def test_my_api(client, app):
# Test to show that api is working
response = await client.get("/api/users/")
assert response.status_code == status.HTTP_200_OK
assert response.json() == {"status": "ok", "data": []}
response = await client.post(
"/api/users/",
json={"email": "test@example.com", "full_name": "Full Name Test"},
)
assert response.status_code == status.HTTP_201_CREATED
new_user_id = response.json()["data"]["id"]
response = await client.get(f"/api/users/{new_user_id}/")
assert response.status_code == status.HTTP_200_OK
assert response.json() == {
"status": "ok",
"data": {
"id": new_user_id,
"email": "test@example.com",
"full_name": "Full Name Test",
},
}
response = await client.get("/api/users/")
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 1
$ pytest . --vv -x
===================================================================================== test session starts =====================================================================================
platform darwin -- Python 3.9.15, pytest-7.4.0, pluggy-1.2.0 -- /venv-8ZwWMPCX-py3.9/bin/python
cachedir: .pytest_cache
rootdir: /Users/something/blog_article_v2
plugins: anyio-3.7.1
collected 3 items
tests/test_api.py::test_my_api PASSED [ 33%]
tests/test_orm_works.py::test_orm_session PASSED [ 66%]
tests/migrations/test_stairway.py::test_migrations_stairway[revision0] PASSED [100%]
====================================================================================== 3 passed in 1.48s ======================================================================================
Hope it is going to be useful to someone and Google will index this article someday. ^.^
Have a nice day!