Last active
March 17, 2024 07:04
-
-
Save HacKanCuBa/bfee44fb8f3e81289c36c7bf5a579dfa to your computer and use it in GitHub Desktop.
SQLAlchemy handy helper functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
from contextlib import asynccontextmanager, contextmanager | |
from time import monotonic | |
from typing import Annotated, Any, AsyncGenerator, Generator, Hashable, Iterable, Literal, Optional, Sized, Union, overload | |
from sqlalchemy import event | |
from sqlalchemy.dialects.mysql.asyncmy import AsyncAdapt_asyncmy_cursor | |
from sqlalchemy.engine import URL, Connection, Engine, Row, create_engine | |
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine | |
from sqlalchemy.orm import Session, sessionmaker | |
from sqlalchemy.pool import AsyncAdaptedQueuePool, QueuePool | |
AnyCacheable = Annotated[Hashable, "Any type that works well with functools.cache, meaning hashable (i.e., not dicts!)"] | |
@event.listens_for(Engine, "before_cursor_execute") | |
def _before_cursor_execute(conn: Connection, *_: Any) -> None: | |
conn.info.setdefault("query_start_time", []).append(monotonic()) | |
# noinspection PyUnusedLocal | |
@event.listens_for(Engine, "after_cursor_execute") | |
def _after_cursor_execute( | |
conn: Connection, | |
cursor: AsyncAdapt_asyncmy_cursor, | |
statement: str, | |
parameters: tuple[dict[str, Any], ...] | dict[str, Any] | None, | |
*_: Any, | |
) -> None: | |
total = monotonic() - conn.info["query_start_time"].pop(-1) | |
# logger.debug('DB query\n\t%s\n\tparams: %s\n\tfinished in %f seconds', statement.replace("\n", ""), parameters, total) | |
@overload | |
def _get_db_engine(db_url: Union[str, URL], *, sync: Literal[True], **kwargs: AnyCacheable) -> Engine: | |
... | |
@overload | |
def _get_db_engine(db_url: Union[str, URL], *, sync: Literal[False], **kwargs: AnyCacheable) -> AsyncEngine: | |
... | |
@functools.cache | |
def _get_db_engine(db_url: Union[str, URL], *, sync: bool, **kwargs: AnyCacheable) -> Union[Engine, AsyncEngine]: | |
if "connect_args" in kwargs: | |
connect_args_raw = kwargs.pop("connect_args") | |
assert isinstance(connect_args_raw, Iterable) and all( | |
isinstance(arg, Sized) and len(arg) == 2 for arg in connect_args_raw | |
) | |
connect_args = dict(connect_args_raw) | |
else: | |
connect_args = {"connect_timeout": 5} # Some dialects use "timeout" | |
poolclass = QueuePool if sync else AsyncAdaptedQueuePool | |
# You may want to move some of this to some sort of global constant | |
params = { | |
"isolation_level": "READ COMMITTED", # See https://docs.sqlalchemy.org/en/20/core/connections.html#dbapi-autocommit | |
"echo": kwargs.pop("echo", False), # Don't be so verbose unless this is true | |
"future": True, | |
"connect_args": connect_args, | |
"poolclass": poolclass, | |
} | |
params.update(kwargs) | |
if sync: | |
return create_engine(db_url, **params) | |
return create_async_engine(db_url, **params) | |
@asynccontextmanager | |
async def async_db_engine(db_url: Union[str, URL], **kwargs: AnyCacheable) -> AsyncGenerator[AsyncEngine, None]: | |
"""Get a new async pooled engine ready to be used, as a context manager.""" | |
engine = _get_db_engine(db_url, sync=False, **kwargs) | |
try: | |
yield engine | |
finally: | |
await engine.dispose() | |
@asynccontextmanager | |
async def async_db_session(engine: AsyncEngine, **kwargs: Any) -> AsyncGenerator[AsyncSession, None]: | |
"""Get a new async ORM session ready to be used, as a context manager.""" | |
# You may want to move some of this to some sort of global constant | |
params = { | |
"expire_on_commit": False, | |
} | |
params.update(kwargs) | |
async_session = async_sessionmaker(engine, **params) # type: ignore[call-overload] | |
async with async_session() as session: | |
yield session | |
@contextmanager | |
def db_engine(db_url: Union[str, URL], **kwargs: Any) -> Generator[Engine, None, None]: | |
"""Get a new pooled engine ready to be used, as a context manager.""" | |
engine = _get_db_engine(db_url, sync=True, **kwargs) | |
try: | |
yield engine | |
finally: | |
engine.dispose() | |
@contextmanager | |
def db_session(engine: Engine, **kwargs: Any) -> Generator[Session, None, None]: | |
"""Get a new ORM session ready to be used, as a context manager.""" | |
params = { | |
"expire_on_commit": False, | |
} | |
params.update(kwargs) | |
session = sessionmaker(engine, **params) # type: ignore[call-overload] | |
with session() as session: | |
yield session | |
def asdict(row: Row) -> dict[str, Any]: | |
"""Convert a row to a dict.""" | |
# Yeah, I have no idea why it's a protected method, but it is properly documented, and we are supposed to use this. | |
# See: https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.Row._asdict | |
# noinspection PyProtectedMember | |
dct = row._asdict() # this may have keys as `sqlalchemy.sql.elements.quoted_name` instead of str | |
return {f"{key}": value for key, value in dct.items()} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment