Created
February 5, 2025 11:12
-
-
Save ZipFile/7b9911c3f0c57c7d112725618a596464 to your computer and use it in GitHub Desktop.
python-dependency-injection + pytest + sqlalchemy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pip install dependency-injector pytest pytest-asyncio pytest-postgresql SQLAlchemy pydantic-settings asyncpg psycopg[binary] | |
import os | |
from contextlib import ExitStack | |
from typing import Any, AsyncIterator, Iterator, NoReturn | |
import pytest_asyncio | |
from dependency_injector.providers import Provider | |
from pytest import fixture | |
from pytest_postgresql.janitor import DatabaseJanitor | |
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine | |
from sqlalchemy.future import Engine, create_engine | |
from myproject.containers import MainContainer | |
from myproject.settings import Settings | |
from myproject.sqlalchemy.models import Base | |
@fixture(scope="session") | |
def in_docker() -> bool: | |
return os.path.exists("/.dockerenv") | |
@fixture(scope="session") | |
def database(worker_id: str, in_docker: bool) -> Iterator[str]: | |
"""Creates and destroys temporary database in postgres""" | |
db_name = f"test_{worker_id}" | |
user = os.environ.get("POSTGRES_USER", "postgres") | |
password = os.environ.get("POSTGRES_PASSWORD", "postgres") | |
host = os.environ.get("POSTGRES_HOST", "db" if in_docker else "localhost") | |
port = int(os.environ.get("POSTGRES_PORT", "5432")) | |
janitor = DatabaseJanitor( | |
user=user, | |
password=password, | |
host=host, | |
port=port, | |
dbname=db_name, | |
version="16.6", | |
) | |
with janitor: | |
yield f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{db_name}" | |
@fixture(scope="session") | |
def engine(database: str) -> AsyncEngine: | |
return create_async_engine(database) | |
@fixture(scope="session") | |
def sync_engine(database: str) -> Engine: | |
return create_engine( | |
database.replace("postgresql+asyncpg", "postgresql+psycopg"), | |
) | |
@fixture(scope="session") | |
def tables(sync_engine: Engine) -> None: | |
Base.metadata.create_all(sync_engine) | |
@pytest_asyncio.fixture | |
async def db_session( | |
engine: AsyncEngine, | |
tables: None, | |
container: MainContainer, | |
) -> AsyncIterator[AsyncSession]: | |
# https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#joining-a-session-into-an-external-transaction-such-as-for-test-suites | |
async with engine.connect() as conn: | |
await conn.begin() | |
async_session = AsyncSession( | |
bind=conn, | |
expire_on_commit=False, | |
join_transaction_mode="create_savepoint", | |
) | |
with container.sa_session_factory.override(lambda: async_session): | |
yield async_session | |
await conn.rollback() | |
await engine.dispose() | |
@fixture(scope="session") | |
def _container( | |
database: str, | |
engine: AsyncEngine, | |
) -> Iterator[MainContainer]: | |
container = MainContainer() | |
settings = Settings() | |
container.config.from_dict(settings.model_dump()) | |
exit_stack = ExitStack() | |
def with_(provider: Provider[Any], value: Any) -> None: | |
exit_stack.enter_context(provider.override(value)) | |
def session_factory() -> NoReturn: | |
raise AssertionError("use db_session fixture!") | |
with exit_stack: | |
with_(container.pg_dsn, database) | |
with_(container.sa_engine, engine) | |
with_(container.sa_session_factory, session_factory) | |
yield container | |
@fixture | |
def container(_container: MainContainer) -> Iterator[MainContainer]: | |
with _container.reset_singletons(): | |
_container.init_resources() | |
yield _container | |
_container.shutdown_resources() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dependency_injector.containers import DeclarativeContainer, WiringConfiguration | |
from dependency_injector.providers import Configuration, Factory, Singleton | |
class MainContainer(DeclarativeContainer): | |
config: Configuration = Configuration() | |
pg_dsn: Factory[str] = Factory( | |
pg_dsn_factory, | |
user=config.postgres.user, | |
password=config.postgres.password, | |
host=config.postgres.host, | |
port=config.postgres.port, | |
db=config.postgres.db, | |
) | |
sa_engine: Singleton[AsyncEngine] = Singleton(create_async_engine, pg_dsn,) | |
sa_session_factory: Singleton[AsyncSessionFactory] = Singleton( | |
sessionmaker, | |
bind=sa_engine, | |
class_=AsyncSession, | |
autocommit=False, | |
autoflush=False, | |
expire_on_commit=False, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment