Last active
September 26, 2020 10:16
-
-
Save exhuma/c135551cbb0ad4bb993efea69d1084c9 to your computer and use it in GitHub Desktop.
Test-Harness for SQLAlchemy unit-tests
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Helper functions for unit-testing with SQLAlchemy | |
This provides a context-manager "rb_session" which | |
creates a new session that ignores all ".commit()" | |
calls. This might not work with all databases. It | |
has been tested with PostgreSQL. Verify that the | |
commits are really ignored if you use any other DB. | |
""" | |
from contextlib import contextmanager | |
from os.path import dirname, join, relpath | |
from typing import Any, Dict, Iterator, List, Optional, Tuple | |
from sqlalchemy import create_engine | |
from sqlalchemy.orm import Session | |
TSeedFiles = List[Tuple[str, Dict[str, Any]]] | |
class SeedException(Exception): | |
""" | |
An exception that is thrown when a seed-file fails to load | |
""" | |
def load_seed_files(session: Session, seed_files: TSeedFiles) -> None: | |
""" | |
Loads a list of seeds into the database. | |
Each seed can be supplied with a list of variables for templated seeds. The | |
variables are directly passed to :py:meth:`sqlalchemy.orm.Session.execute` | |
as ``params`` argument. | |
So, assuming the SQL file contains the following: | |
.. code-block:: sql | |
INSERT INTY mytable (foo, bar) VALUES ('hello', :inserted); | |
The following snippet can be used to load it, including any variables. | |
>>> session = get_session() | |
>>> load_seed_files(session, [ | |
... ('myseed.sql', {'inserted': datetime(2019, 1, 1)}), | |
... ]) | |
""" | |
for seed_file, variables in seed_files: | |
seed_path = join(dirname(__file__), 'data', 'seeds', seed_file) | |
with open(relpath(seed_path), encoding='utf8') as fptr: | |
data = fptr.read() | |
try: | |
session.execute(data, params=variables) | |
except Exception as exc: | |
# Prevent the whole seed to be printed as error (first two | |
# lines are sufficient) | |
lines = str(exc).splitlines() | |
simplified_error = ' '.join(lines[:2]) | |
conn = session.bind.engine | |
raise SeedException( | |
'Unable to import seed file %r into %r: %s' | |
% (seed_path, conn, simplified_error) | |
) from None | |
session.commit() # type: ignore | |
@contextmanager | |
def rb_session( | |
dsn: str, seed_files: Optional[TSeedFiles] = None | |
) -> Iterator[Session]: | |
""" | |
A simple context-manager that wraps a database session that will never | |
commit. | |
Any "commit" calls on the session returned by this context manager will be | |
ignored. | |
""" | |
seed_files = seed_files or [] | |
engine = create_engine(dsn) | |
connection = engine.connect() | |
trans = connection.begin() | |
session = Session(bind=connection) | |
load_seed_files(session, seed_files) | |
try: | |
yield session | |
finally: | |
trans.rollback() # type: ignore | |
session.close() # type: ignore | |
connection.close() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
This module contains pytest-definitions (fixtures & co) which are shared across | |
all tests in the project. | |
''' | |
# pylint: disable=redefined-outer-name | |
# | |
# Needed for pylint fixtures | |
import logging | |
from pytest import fixture | |
from sqlalchemy import create_engine | |
from sqlalchemy.orm import Session | |
def nuke_all_tables(session: Session) -> None: | |
""" | |
Truncates all tables i.e nuking all records | |
""" | |
# Truncate all tables | |
result = session.execute( | |
"""\ | |
SELECT schemaname, tablename FROM pg_tables | |
WHERE schemaname in ('public', 'history') AND | |
tablename != 'alembic_version'; | |
""") | |
table_names = ["%s.%s" % (_[0], _[1]) for _ in result.fetchall()] | |
session.execute("TRUNCATE %s CASCADE;" % (", ".join(table_names))) | |
session.commit() | |
@fixture | |
def rb_session(): | |
""" | |
Returns a session which will always be rolled back and deletes | |
all data from accidental commits | |
""" | |
configs = get_configs() | |
engine = create_engine(DSN) | |
session = Session() | |
session.bind = engine | |
try: | |
yield session | |
finally: | |
session.rollback() # type: ignore | |
nuke_all_tables(session) | |
session.close() # type: ignore |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment