Skip to content

Instantly share code, notes, and snippets.

@SamWarden
Last active July 15, 2025 07:59
Show Gist options
  • Save SamWarden/d1bbd79672203e9eddc23b3c42622cdc to your computer and use it in GitHub Desktop.
Save SamWarden/d1bbd79672203e9eddc23b3c42622cdc to your computer and use it in GitHub Desktop.
This is an example of fixtures for tests that depend on a Postgres database
import uuid
from collections.abc import AsyncGenerator, Generator
from typing import cast
import alembic.command
import pytest
from alembic.config import Config as AlembicConfig
from sqlalchemy import URL, Connection, text
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from testcontainers.postgres import PostgresContainer
class PostgresDbManager:
def __init__(self, engine: AsyncEngine) -> None:
self._engine = engine
async def create_database(self, database: str, template: str = "template1") -> None:
async with self._engine.connect() as connection:
# "template1" is the default template name for the `CREATE DATABASE` statement
await connection.execute(text(f'CREATE DATABASE "{database}" ENCODING \'utf8\' TEMPLATE "{template}"'))
async def drop_database(self, database: str) -> None:
async with self._engine.connect() as connection:
await connection.execute(text(f'DROP DATABASE "{database}"'))
@pytest.fixture(scope="session")
def postgres() -> Generator[PostgresContainer, None, None]:
with PostgresContainer("postgres:17.5", dbname="template-db", driver="asyncpg") as postgres:
yield postgres
def get_alembic_config() -> AlembicConfig:
return AlembicConfig(file_="alembic.ini")
def run_migrations(connection: Connection) -> None:
alembic_config = get_alembic_config()
alembic_config.attributes["connection"] = connection
alembic.command.upgrade(config=alembic_config, revision="head")
@pytest.fixture(scope="session")
async def template_db_engine(postgres: PostgresContainer) -> AsyncGenerator[AsyncEngine, None]:
postgres_url = postgres.get_connection_url()
engine = create_async_engine(postgres_url)
async with engine.connect() as connection:
await connection.run_sync(run_migrations)
# Connections have to be disposed to allow to use the database as a template
await engine.dispose()
yield engine
await engine.dispose()
@pytest.fixture()
async def test_postgres_url(
request: pytest.FixtureRequest, template_db_engine: AsyncEngine
) -> AsyncGenerator[URL, None]:
postgres_engine = create_async_engine(template_db_engine.url.set(database="postgres"), isolation_level="AUTOCOMMIT")
postgres_db_manager = PostgresDbManager(postgres_engine)
test_postgres_id = uuid.uuid4().hex
database_name = f"test-db-{test_postgres_id}"
try:
empty_database = request.param.get("empty_database", False)
except AttributeError:
empty_database = False
if empty_database:
await postgres_db_manager.create_database(database_name)
else:
template_name = cast(str, template_db_engine.url.database)
await postgres_db_manager.create_database(database_name, template=template_name)
postgres_url = template_db_engine.url.set(database=database_name)
yield postgres_url
await postgres_db_manager.drop_database(database_name)
await postgres_engine.dispose()
@pytest.fixture()
async def engine(test_postgres_url: str) -> AsyncGenerator[AsyncEngine, None]:
engine = create_async_engine(test_postgres_url)
yield engine
await engine.dispose()
@pytest.fixture()
async def session_factory(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]:
return async_sessionmaker(
engine,
autoflush=False,
autocommit=False,
expire_on_commit=False,
class_=AsyncSession,
)
@pytest.fixture()
async def session(session_factory: async_sessionmaker[AsyncSession]) -> AsyncGenerator[AsyncSession, None]:
async with session_factory() as session:
yield session
from typing import cast
import uuid
from collections.abc import Generator
import alembic.command
import pytest
from alembic.config import Config as AlembicConfig
from sqlalchemy import URL, Connection, create_engine, text
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, sessionmaker
from testcontainers.postgres import PostgresContainer
class PostgresDbManager:
def __init__(self, engine: Engine) -> None:
self._engine = engine
def create_database(self, database: str, template: str = "template1") -> None:
with self._engine.connect() as connection:
# "template1" is the default template name for the `CREATE DATABASE` statement
connection.execute(text(f'CREATE DATABASE "{database}" ENCODING \'utf8\' TEMPLATE "{template}"'))
def drop_database(self, database: str) -> None:
with self._engine.connect() as connection:
connection.execute(text(f'DROP DATABASE "{database}"'))
@pytest.fixture(scope="session")
def postgres() -> Generator[PostgresContainer, None, None]:
with PostgresContainer("postgres:17.5", dbname="template-db", driver="psycopg") as postgres:
yield postgres
def get_alembic_config() -> AlembicConfig:
return AlembicConfig(file_="alembic.ini")
def run_migrations(connection: Connection) -> None:
alembic_config = get_alembic_config()
alembic_config.attributes["connection"] = connection
alembic.command.upgrade(config=alembic_config, revision="head")
@pytest.fixture(scope="session")
def template_db_engine(postgres: PostgresContainer) -> Generator[Engine]:
postgres_url = postgres.get_connection_url()
engine = create_engine(url=postgres_url)
with engine.connect() as connection:
run_migrations(connection)
# Connections have to be disposed to allow to use the database as a template
engine.dispose()
yield engine
engine.dispose()
@pytest.fixture()
def test_postgres_url(request: pytest.FixtureRequest, template_db_engine: Engine) -> Generator[URL, None, None]:
postgres_engine = create_engine(url=template_db_engine.url.set(database="postgres"), isolation_level="AUTOCOMMIT")
postgres_db_manager = PostgresDbManager(postgres_engine)
test_postgres_id = uuid.uuid4().hex
database_name = f"test-db-{test_postgres_id}"
try:
empty_database = request.param.get("empty_database", False)
except AttributeError:
empty_database = False
if empty_database:
postgres_db_manager.create_database(database_name)
else:
template_name = cast(str, template_db_engine.url.database)
postgres_db_manager.create_database(database_name, template=template_name)
postgres_url = template_db_engine.url.set(database=database_name)
yield postgres_url
postgres_db_manager.drop_database(database_name)
postgres_engine.dispose()
@pytest.fixture()
def engine(test_postgres_url: URL) -> Generator[Engine, None, None]:
engine = create_engine(url=test_postgres_url)
yield engine
engine.dispose()
@pytest.fixture()
def session_factory(engine: Engine) -> sessionmaker[Session]:
return sessionmaker(
engine,
autoflush=False,
autocommit=False,
expire_on_commit=False,
)
@pytest.fixture()
def session(session_factory: sessionmaker[Session]) -> Generator[Session, None, None]:
with session_factory() as session:
yield session
import asyncio
from logging.config import fileConfig
from alembic import context
from sqlalchemy import Engine, engine_from_config, pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncEngine
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
FULL_URL = config.get_main_option("sqlalchemy.url")
TARGET_METADATA = None
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
context.configure(
url=FULL_URL,
target_metadata=TARGET_METADATA,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=TARGET_METADATA)
with context.begin_transaction():
context.run_migrations()
def run_migrations(engine: Engine) -> None:
with engine.connect() as connection:
do_run_migrations(connection)
engine.dispose()
async def run_async_migrations(engine: AsyncEngine) -> None:
async with engine.connect() as connection:
await connection.run_sync(do_run_migrations)
await engine.dispose()
def setup_engine() -> Engine:
return engine_from_config(
config.get_section(config.config_ini_section) or {},
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
url=FULL_URL,
)
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine or receive a connection
and associate the connection with the context.
"""
connection: Connection | None = config.attributes.get("connection", None)
match connection:
case None:
engine = setup_engine()
if engine.driver == "asyncpg":
async_engine = AsyncEngine(engine)
asyncio.run(run_async_migrations(async_engine))
else:
run_migrations(engine)
case Connection():
do_run_migrations(connection)
case _:
raise TypeError(f"Unexpected connection type: {type(connection)}. Expected Connection")
def main() -> None:
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
main()
import pytest
from sqlalchemy import select, sql
from sqlalchemy.ext.asyncio import AsyncSession
async def test_db(session: AsyncSession) -> None:
result = await session.execute(select(sql.true()))
assert result.scalar() is True
@pytest.mark.parametrize("test_postgres_url", [{"empty_database": True}], indirect=True)
async def test_empty_db(session: AsyncSession) -> None:
result = await session.execute(select(sql.true()))
assert result.scalar() is True
import pytest
from sqlalchemy import select, sql
from sqlalchemy.orm import Session
def test_db(session: Session) -> None:
result = session.execute(select(sql.true()))
assert result.scalar() is True
@pytest.mark.parametrize("test_postgres_url", [{"empty_database": True}], indirect=True)
def test_empty_db(session: Session) -> None:
result = session.execute(select(sql.true()))
assert result.scalar() is True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment