Skip to content

Instantly share code, notes, and snippets.

@bmritz
Created December 30, 2024 21:31
Show Gist options
  • Save bmritz/5300a788f1cff587e13b028f603e4fa6 to your computer and use it in GitHub Desktop.
Save bmritz/5300a788f1cff587e13b028f603e4fa6 to your computer and use it in GitHub Desktop.
This shows how to use a postgres database template to run each test on an isolated instance of a test database with good performance
import datetime
import os
import random
import string
import time
import uuid
import asyncpg
import pytest
import pytest_asyncio
from your_python_module.database import core as db_core
from your_python_module.database import queries, run_migrations
from your_python_module.logging_ import get_logger
from testcontainers.postgres import PostgresContainer
# I used this SO answer as guidance: https://stackoverflow.com/a/52710018/23145816
logger = get_logger()
# Add this to disable Ryuk, which is causing the container cleanup issues
os.environ["TESTCONTAINERS_RYUK_DISABLED"] = "true"
@pytest.fixture(scope="module")
def pg_container():
"""Start up a new Postgres container with `testcontainers`."""
with PostgresContainer(
"postgres:16",
username="postgres",
password="123",
dbname="postgres",
port=5432,
) as postgres:
# unfortunately its flaky without this sleep
time.sleep(2)
yield postgres
@pytest.fixture(scope="module")
def monkeysession():
with pytest.MonkeyPatch.context() as mp:
yield mp
@pytest_asyncio.fixture(scope="module", loop_scope="module")
async def seeded_database(
pg_container: PostgresContainer,
monkeysession: pytest.MonkeyPatch,
):
"""Seed the Postgres database with migrations & sample data."""
pool = await asyncpg.create_pool(
host=pg_container.get_container_host_ip(),
port=pg_container.get_exposed_port(5432),
user=pg_container.username,
password=pg_container.password,
database=pg_container.dbname,
min_size=1,
max_size=3,
)
# we need to set the connection pool in the core module
# in order to run the migrations properly
monkeysession.setattr(db_core, "database_connection_pool", pool)
# First create required roles and schema
async with pool.acquire() as connection:
await connection.execute(
"""
CREATE ROLE pgadmin;
CREATE SCHEMA migrations;
"""
)
# Then run migrations which will use the pgadmin role
await run_migrations()
# Finally insert sample data
async with pool.acquire() as connection:
await connection.execute(
"""
-- Insert sample data into web_content_requests
INSERT INTO web_content_requests (id, url, content_type, content_location)
VALUES
('00000000-0000-0000-0000-000000000000', 'https://example.com/page1', 'HTML', 'gs://bucket/page1.html'),
('00000000-0000-0000-0000-000000000001', 'https://example.com/page2', 'PDF', 'gs://bucket/page2.pdf'),
('00000000-0000-0000-0000-000000000002', 'https://example.com/page3', 'HTML', 'gs://bucket/page3.html');
-- Insert sample data into web_content
INSERT INTO web_content (id, web_content_request_id, title, byline, datetime_published, markdown)
VALUES
('00000000-0000-0000-0000-000000000004', (SELECT id FROM web_content_requests WHERE url = 'https://example.com/page1'), 'Example Page 1', 'John Doe', '2023-10-01 10:00:00', '# Example Page 1\n\nThis is a markdown content for page 1.'),
('00000000-0000-0000-0000-000000000005', (SELECT id FROM web_content_requests WHERE url = 'https://example.com/page2'), 'Example Page 2', 'Jane Smith', '2023-10-02 11:00:00', '# Example Page 2\n\nThis is a markdown content for page 2.'),
('00000000-0000-0000-0000-000000000006', (SELECT id FROM web_content_requests WHERE url = 'https://example.com/page3'), 'Example Page 3', 'Alice Johnson', '2023-10-03 12:00:00', '# Example Page 3\n\nThis is a markdown content for page 3.');
;
"""
)
monkeysession.setattr(db_core, "database_connection_pool", None)
await pool.close()
yield "DB Seeded"
@pytest_asyncio.fixture(scope="function", loop_scope="module")
async def clone_db_creator_connection(
pg_container: PostgresContainer,
):
"""Yield a connection that is used to create new databases."""
connection = await asyncpg.connect(
host=pg_container.get_container_host_ip(),
port=pg_container.get_exposed_port(5432),
user=pg_container.username,
password=pg_container.password,
# connect to template1 to avoid creating a new database with the same name
# as the db you are connected to (which would fail)
database="template1",
)
yield connection
await connection.close()
@pytest_asyncio.fixture(scope="function", loop_scope="module")
async def cloned_db(
pg_container: PostgresContainer,
monkeysession: pytest.MonkeyPatch,
clone_db_creator_connection: asyncpg.Connection,
seeded_database: str,
):
"""Monkeypatch the asyncpg db connection pool objects.
This mocks the db_core.database_connection_pool object in our code
to use the postgres database we spun up for testing.
The pool is configured to connect to a new database created from the template
of the seeded database.
"""
db_name = f"test_db_{"".join(random.sample(string.ascii_lowercase, 12))}"
create_stmt = f"CREATE DATABASE {db_name} WITH TEMPLATE {pg_container.dbname};"
_ = await clone_db_creator_connection.execute(create_stmt)
pool = await asyncpg.create_pool(
host=pg_container.get_container_host_ip(),
port=pg_container.get_exposed_port(5432),
user=pg_container.username,
password=pg_container.password,
database=db_name,
min_size=1,
max_size=1,
)
monkeysession.setattr(db_core, "database_connection_pool", pool)
yield pool
await pool.close()
await clone_db_creator_connection.execute(f"DROP DATABASE {db_name};")
@pytest.mark.asyncio(loop_scope="module")
async def test_insert_web_content_request(
cloned_db: asyncpg.pool.Pool,
):
request_id = await queries.web_content.insert_web_content_request(
"https://example.com",
datetime.datetime.strptime("2022-01-01 00:00:00", "%Y-%m-%d %H:%M:%S"),
"diffbot/json",
"jlaksjdf",
)
# Verify the record was inserted
async with cloned_db.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT url, content_type, content_location
FROM web_content_requests
WHERE id = $1
""",
request_id,
)
assert row is not None
assert row["url"] == "https://example.com"
assert row["content_type"] == "diffbot/json"
assert row["content_location"] == "jlaksjdf"
@pytest.mark.asyncio(loop_scope="module")
async def test_insert_web_content(
cloned_db: asyncpg.pool.Pool,
):
ids = await queries.web_content.insert_web_content(
[
(
uuid.UUID("00000000-0000-0000-0000-000000000002"),
"Example Page",
"John Doe",
datetime.datetime.strptime("2022-01-01 00:00:00", "%Y-%m-%d %H:%M:%S"),
"# Example Page\n\nThis is a markdown content.",
)
]
)
assert len(ids)
# Verify the record was inserted
async with cloned_db.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT title, byline, markdown
FROM web_content
WHERE id = $1
""",
ids[0]["id"],
)
assert row is not None
assert row["title"] == "Example Page"
assert row["byline"] == "John Doe"
assert row["markdown"] == "# Example Page\n\nThis is a markdown content."
@pytest.mark.asyncio(loop_scope="module")
async def test_insert_web_content_single(
cloned_db: asyncpg.pool.Pool,
):
"""Test inserting a single web content record."""
web_content_id = await queries.web_content.insert_web_content_single(
uuid.UUID("00000000-0000-0000-0000-000000000002"),
"Single Test Page",
"Test Author",
datetime.datetime.strptime("2022-01-01 00:00:00", "%Y-%m-%d %H:%M:%S"),
"# Test Page\n\nThis is test content.",
)
assert isinstance(web_content_id, uuid.UUID)
# Verify the record was inserted
async with cloned_db.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT title, byline, markdown
FROM web_content
WHERE id = $1
""",
web_content_id,
)
assert row is not None
assert row["title"] == "Single Test Page"
assert row["byline"] == "Test Author"
assert row["markdown"] == "# Test Page\n\nThis is test content."
@pytest.mark.asyncio(loop_scope="module")
async def test_insert_web_content_css_selector(
cloned_db: asyncpg.pool.Pool,
):
"""Test inserting a CSS selector rule."""
selector_id = await queries.web_content.insert_web_content_css_selector(
r"^https://example\.com/blog/.*$", "article.blog-content"
)
assert isinstance(selector_id, uuid.UUID)
# Verify the record was inserted
async with cloned_db.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT url_regex, css_selector
FROM web_content_css_selectors
WHERE id = $1
""",
selector_id,
)
assert row is not None
assert row["url_regex"] == r"^https://example\.com/blog/.*$"
assert row["css_selector"] == "article.blog-content"
@pytest.mark.asyncio(loop_scope="module")
async def test_select_css_selectors_for_url(
cloned_db: asyncpg.pool.Pool,
):
"""Test selecting CSS selectors for a URL."""
# First insert a test selector
await queries.web_content.insert_web_content_css_selector(
r"^https://example\.com/blog/.*$", "article.blog-content"
)
# Test matching URL
selectors = await queries.web_content.select_css_selectors_for_url(
"https://example.com/blog/test-post"
)
assert isinstance(selectors, list)
assert "article.blog-content" in selectors
# Test non-matching URL
selectors = await queries.web_content.select_css_selectors_for_url(
"https://different-site.com/page"
)
assert selectors is None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment