Last active
July 15, 2025 07:59
-
-
Save SamWarden/d1bbd79672203e9eddc23b3c42622cdc to your computer and use it in GitHub Desktop.
This is an example of fixtures for tests that depend on a Postgres database
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import uuid | |
from collections.abc import AsyncGenerator, Generator | |
from typing import cast | |
import alembic.command | |
import pytest | |
from alembic.config import Config as AlembicConfig | |
from sqlalchemy import URL, Connection, text | |
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine | |
from testcontainers.postgres import PostgresContainer | |
class PostgresDbManager: | |
def __init__(self, engine: AsyncEngine) -> None: | |
self._engine = engine | |
async def create_database(self, database: str, template: str = "template1") -> None: | |
async with self._engine.connect() as connection: | |
# "template1" is the default template name for the `CREATE DATABASE` statement | |
await connection.execute(text(f'CREATE DATABASE "{database}" ENCODING \'utf8\' TEMPLATE "{template}"')) | |
async def drop_database(self, database: str) -> None: | |
async with self._engine.connect() as connection: | |
await connection.execute(text(f'DROP DATABASE "{database}"')) | |
@pytest.fixture(scope="session") | |
def postgres() -> Generator[PostgresContainer, None, None]: | |
with PostgresContainer("postgres:17.5", dbname="template-db", driver="asyncpg") as postgres: | |
yield postgres | |
def get_alembic_config() -> AlembicConfig: | |
return AlembicConfig(file_="alembic.ini") | |
def run_migrations(connection: Connection) -> None: | |
alembic_config = get_alembic_config() | |
alembic_config.attributes["connection"] = connection | |
alembic.command.upgrade(config=alembic_config, revision="head") | |
@pytest.fixture(scope="session") | |
async def template_db_engine(postgres: PostgresContainer) -> AsyncGenerator[AsyncEngine, None]: | |
postgres_url = postgres.get_connection_url() | |
engine = create_async_engine(postgres_url) | |
async with engine.connect() as connection: | |
await connection.run_sync(run_migrations) | |
# Connections have to be disposed to allow to use the database as a template | |
await engine.dispose() | |
yield engine | |
await engine.dispose() | |
@pytest.fixture() | |
async def test_postgres_url( | |
request: pytest.FixtureRequest, template_db_engine: AsyncEngine | |
) -> AsyncGenerator[URL, None]: | |
postgres_engine = create_async_engine(template_db_engine.url.set(database="postgres"), isolation_level="AUTOCOMMIT") | |
postgres_db_manager = PostgresDbManager(postgres_engine) | |
test_postgres_id = uuid.uuid4().hex | |
database_name = f"test-db-{test_postgres_id}" | |
try: | |
empty_database = request.param.get("empty_database", False) | |
except AttributeError: | |
empty_database = False | |
if empty_database: | |
await postgres_db_manager.create_database(database_name) | |
else: | |
template_name = cast(str, template_db_engine.url.database) | |
await postgres_db_manager.create_database(database_name, template=template_name) | |
postgres_url = template_db_engine.url.set(database=database_name) | |
yield postgres_url | |
await postgres_db_manager.drop_database(database_name) | |
await postgres_engine.dispose() | |
@pytest.fixture() | |
async def engine(test_postgres_url: str) -> AsyncGenerator[AsyncEngine, None]: | |
engine = create_async_engine(test_postgres_url) | |
yield engine | |
await engine.dispose() | |
@pytest.fixture() | |
async def session_factory(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]: | |
return async_sessionmaker( | |
engine, | |
autoflush=False, | |
autocommit=False, | |
expire_on_commit=False, | |
class_=AsyncSession, | |
) | |
@pytest.fixture() | |
async def session(session_factory: async_sessionmaker[AsyncSession]) -> AsyncGenerator[AsyncSession, None]: | |
async with session_factory() as session: | |
yield session |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import cast | |
import uuid | |
from collections.abc import Generator | |
import alembic.command | |
import pytest | |
from alembic.config import Config as AlembicConfig | |
from sqlalchemy import URL, Connection, create_engine, text | |
from sqlalchemy.engine import Engine | |
from sqlalchemy.orm import Session, sessionmaker | |
from testcontainers.postgres import PostgresContainer | |
class PostgresDbManager: | |
def __init__(self, engine: Engine) -> None: | |
self._engine = engine | |
def create_database(self, database: str, template: str = "template1") -> None: | |
with self._engine.connect() as connection: | |
# "template1" is the default template name for the `CREATE DATABASE` statement | |
connection.execute(text(f'CREATE DATABASE "{database}" ENCODING \'utf8\' TEMPLATE "{template}"')) | |
def drop_database(self, database: str) -> None: | |
with self._engine.connect() as connection: | |
connection.execute(text(f'DROP DATABASE "{database}"')) | |
@pytest.fixture(scope="session") | |
def postgres() -> Generator[PostgresContainer, None, None]: | |
with PostgresContainer("postgres:17.5", dbname="template-db", driver="psycopg") as postgres: | |
yield postgres | |
def get_alembic_config() -> AlembicConfig: | |
return AlembicConfig(file_="alembic.ini") | |
def run_migrations(connection: Connection) -> None: | |
alembic_config = get_alembic_config() | |
alembic_config.attributes["connection"] = connection | |
alembic.command.upgrade(config=alembic_config, revision="head") | |
@pytest.fixture(scope="session") | |
def template_db_engine(postgres: PostgresContainer) -> Generator[Engine]: | |
postgres_url = postgres.get_connection_url() | |
engine = create_engine(url=postgres_url) | |
with engine.connect() as connection: | |
run_migrations(connection) | |
# Connections have to be disposed to allow to use the database as a template | |
engine.dispose() | |
yield engine | |
engine.dispose() | |
@pytest.fixture() | |
def test_postgres_url(request: pytest.FixtureRequest, template_db_engine: Engine) -> Generator[URL, None, None]: | |
postgres_engine = create_engine(url=template_db_engine.url.set(database="postgres"), isolation_level="AUTOCOMMIT") | |
postgres_db_manager = PostgresDbManager(postgres_engine) | |
test_postgres_id = uuid.uuid4().hex | |
database_name = f"test-db-{test_postgres_id}" | |
try: | |
empty_database = request.param.get("empty_database", False) | |
except AttributeError: | |
empty_database = False | |
if empty_database: | |
postgres_db_manager.create_database(database_name) | |
else: | |
template_name = cast(str, template_db_engine.url.database) | |
postgres_db_manager.create_database(database_name, template=template_name) | |
postgres_url = template_db_engine.url.set(database=database_name) | |
yield postgres_url | |
postgres_db_manager.drop_database(database_name) | |
postgres_engine.dispose() | |
@pytest.fixture() | |
def engine(test_postgres_url: URL) -> Generator[Engine, None, None]: | |
engine = create_engine(url=test_postgres_url) | |
yield engine | |
engine.dispose() | |
@pytest.fixture() | |
def session_factory(engine: Engine) -> sessionmaker[Session]: | |
return sessionmaker( | |
engine, | |
autoflush=False, | |
autocommit=False, | |
expire_on_commit=False, | |
) | |
@pytest.fixture() | |
def session(session_factory: sessionmaker[Session]) -> Generator[Session, None, None]: | |
with session_factory() as session: | |
yield session |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from logging.config import fileConfig | |
from alembic import context | |
from sqlalchemy import Engine, engine_from_config, pool | |
from sqlalchemy.engine import Connection | |
from sqlalchemy.ext.asyncio import AsyncEngine | |
# this is the Alembic Config object, which provides | |
# access to the values within the .ini file in use. | |
config = context.config | |
# Interpret the config file for Python logging. | |
# This line sets up loggers basically. | |
if config.config_file_name is not None: | |
fileConfig(config.config_file_name) | |
FULL_URL = config.get_main_option("sqlalchemy.url") | |
TARGET_METADATA = None | |
def run_migrations_offline() -> None: | |
"""Run migrations in 'offline' mode. | |
This configures the context with just a URL | |
and not an Engine, though an Engine is acceptable | |
here as well. By skipping the Engine creation | |
we don't even need a DBAPI to be available. | |
Calls to context.execute() here emit the given string to the | |
script output. | |
""" | |
context.configure( | |
url=FULL_URL, | |
target_metadata=TARGET_METADATA, | |
literal_binds=True, | |
dialect_opts={"paramstyle": "named"}, | |
) | |
with context.begin_transaction(): | |
context.run_migrations() | |
def do_run_migrations(connection: Connection) -> None: | |
context.configure(connection=connection, target_metadata=TARGET_METADATA) | |
with context.begin_transaction(): | |
context.run_migrations() | |
def run_migrations(engine: Engine) -> None: | |
with engine.connect() as connection: | |
do_run_migrations(connection) | |
engine.dispose() | |
async def run_async_migrations(engine: AsyncEngine) -> None: | |
async with engine.connect() as connection: | |
await connection.run_sync(do_run_migrations) | |
await engine.dispose() | |
def setup_engine() -> Engine: | |
return engine_from_config( | |
config.get_section(config.config_ini_section) or {}, | |
prefix="sqlalchemy.", | |
poolclass=pool.NullPool, | |
future=True, | |
url=FULL_URL, | |
) | |
def run_migrations_online() -> None: | |
"""Run migrations in 'online' mode. | |
In this scenario we need to create an Engine or receive a connection | |
and associate the connection with the context. | |
""" | |
connection: Connection | None = config.attributes.get("connection", None) | |
match connection: | |
case None: | |
engine = setup_engine() | |
if engine.driver == "asyncpg": | |
async_engine = AsyncEngine(engine) | |
asyncio.run(run_async_migrations(async_engine)) | |
else: | |
run_migrations(engine) | |
case Connection(): | |
do_run_migrations(connection) | |
case _: | |
raise TypeError(f"Unexpected connection type: {type(connection)}. Expected Connection") | |
def main() -> None: | |
if context.is_offline_mode(): | |
run_migrations_offline() | |
else: | |
run_migrations_online() | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
from sqlalchemy import select, sql | |
from sqlalchemy.ext.asyncio import AsyncSession | |
async def test_db(session: AsyncSession) -> None: | |
result = await session.execute(select(sql.true())) | |
assert result.scalar() is True | |
@pytest.mark.parametrize("test_postgres_url", [{"empty_database": True}], indirect=True) | |
async def test_empty_db(session: AsyncSession) -> None: | |
result = await session.execute(select(sql.true())) | |
assert result.scalar() is True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
from sqlalchemy import select, sql | |
from sqlalchemy.orm import Session | |
def test_db(session: Session) -> None: | |
result = session.execute(select(sql.true())) | |
assert result.scalar() is True | |
@pytest.mark.parametrize("test_postgres_url", [{"empty_database": True}], indirect=True) | |
def test_empty_db(session: Session) -> None: | |
result = session.execute(select(sql.true())) | |
assert result.scalar() is True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment