Skip to content

Instantly share code, notes, and snippets.

@ZipFile
Created February 5, 2025 11:12
Show Gist options
  • Save ZipFile/7b9911c3f0c57c7d112725618a596464 to your computer and use it in GitHub Desktop.
Save ZipFile/7b9911c3f0c57c7d112725618a596464 to your computer and use it in GitHub Desktop.
python-dependency-injection + pytest + sqlalchemy
# pip install dependency-injector pytest pytest-asyncio pytest-postgresql SQLAlchemy pydantic-settings asyncpg psycopg[binary]
import os
from contextlib import ExitStack
from typing import Any, AsyncIterator, Iterator, NoReturn
import pytest_asyncio
from dependency_injector.providers import Provider
from pytest import fixture
from pytest_postgresql.janitor import DatabaseJanitor
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.future import Engine, create_engine
from myproject.containers import MainContainer
from myproject.settings import Settings
from myproject.sqlalchemy.models import Base
@fixture(scope="session")
def in_docker() -> bool:
return os.path.exists("/.dockerenv")
@fixture(scope="session")
def database(worker_id: str, in_docker: bool) -> Iterator[str]:
"""Creates and destroys temporary database in postgres"""
db_name = f"test_{worker_id}"
user = os.environ.get("POSTGRES_USER", "postgres")
password = os.environ.get("POSTGRES_PASSWORD", "postgres")
host = os.environ.get("POSTGRES_HOST", "db" if in_docker else "localhost")
port = int(os.environ.get("POSTGRES_PORT", "5432"))
janitor = DatabaseJanitor(
user=user,
password=password,
host=host,
port=port,
dbname=db_name,
version="16.6",
)
with janitor:
yield f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{db_name}"
@fixture(scope="session")
def engine(database: str) -> AsyncEngine:
return create_async_engine(database)
@fixture(scope="session")
def sync_engine(database: str) -> Engine:
return create_engine(
database.replace("postgresql+asyncpg", "postgresql+psycopg"),
)
@fixture(scope="session")
def tables(sync_engine: Engine) -> None:
Base.metadata.create_all(sync_engine)
@pytest_asyncio.fixture
async def db_session(
engine: AsyncEngine,
tables: None,
container: MainContainer,
) -> AsyncIterator[AsyncSession]:
# https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#joining-a-session-into-an-external-transaction-such-as-for-test-suites
async with engine.connect() as conn:
await conn.begin()
async_session = AsyncSession(
bind=conn,
expire_on_commit=False,
join_transaction_mode="create_savepoint",
)
with container.sa_session_factory.override(lambda: async_session):
yield async_session
await conn.rollback()
await engine.dispose()
@fixture(scope="session")
def _container(
database: str,
engine: AsyncEngine,
) -> Iterator[MainContainer]:
container = MainContainer()
settings = Settings()
container.config.from_dict(settings.model_dump())
exit_stack = ExitStack()
def with_(provider: Provider[Any], value: Any) -> None:
exit_stack.enter_context(provider.override(value))
def session_factory() -> NoReturn:
raise AssertionError("use db_session fixture!")
with exit_stack:
with_(container.pg_dsn, database)
with_(container.sa_engine, engine)
with_(container.sa_session_factory, session_factory)
yield container
@fixture
def container(_container: MainContainer) -> Iterator[MainContainer]:
with _container.reset_singletons():
_container.init_resources()
yield _container
_container.shutdown_resources()
from dependency_injector.containers import DeclarativeContainer, WiringConfiguration
from dependency_injector.providers import Configuration, Factory, Singleton
class MainContainer(DeclarativeContainer):
config: Configuration = Configuration()
pg_dsn: Factory[str] = Factory(
pg_dsn_factory,
user=config.postgres.user,
password=config.postgres.password,
host=config.postgres.host,
port=config.postgres.port,
db=config.postgres.db,
)
sa_engine: Singleton[AsyncEngine] = Singleton(create_async_engine, pg_dsn,)
sa_session_factory: Singleton[AsyncSessionFactory] = Singleton(
sessionmaker,
bind=sa_engine,
class_=AsyncSession,
autocommit=False,
autoflush=False,
expire_on_commit=False,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment