Created
May 2, 2023 02:43
-
-
Save Zomatree/f22af7af93caab9176b673eae8acc189 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, Self, TypeVar, Annotated, TypeVarTuple, get_args, overload, get_origin, reveal_type, cast | |
import asyncpg | |
T = TypeVar("T") | |
T_T = TypeVar("T_T", bound="Table", covariant=True) | |
T_OT = TypeVar("T_OT", bound="Table", covariant=True) | |
T_Ts = TypeVarTuple("T_Ts") | |
if TYPE_CHECKING: | |
Connection = asyncpg.Connection[asyncpg.Record] | |
def eval_annotation(annot: Any, locals: dict[str, Any] | None = None, globals: dict[str, Any] | None = None) -> Any: | |
if not isinstance(annot, str): | |
return annot | |
return eval(annot, locals, globals) | |
class _Missing: | |
def __eq__(self, _: Any) -> Literal[False]: | |
return False | |
def __repr__(self) -> str: | |
return "<Missing>" | |
Missing = _Missing() | |
class TableMetadata: | |
def __init__(self, name: str, columns: list[Column[Any]]) -> None: | |
self.name = name | |
self.columns = columns | |
self.values: dict[str, Any] = {} | |
class Column(Generic[T]): | |
def __init__(self, table: type[Table], name: str, datatype: type, default: Any): | |
self.table = table | |
self.name = name | |
self.datatype = datatype | |
self.default = default | |
@overload | |
def __get__(self, instance: None, _: type[Table]) -> Self: | |
... | |
@overload | |
def __get__(self, instance: Table, _: type[Table]) -> T: | |
... | |
def __get__(self, instance: Table | None, _: type[Table]) -> T | Self: | |
if instance is None: | |
return self | |
return instance._metadata.values[self.name] | |
def __eq__(self, value: T | Self) -> WhereQuery: # type: ignore | |
return WhereQuery(self, value, "=") | |
def __lt__(self, value: T | Self) -> WhereQuery: | |
return WhereQuery(self, value, "<") | |
def __le__(self, value: T | Self) -> WhereQuery: | |
return WhereQuery(self, value, "<=") | |
def __ne__(self, value: T | Self) -> WhereQuery: # type: ignore | |
return WhereQuery(self, value, "!=") | |
class ColumnBuilder: | |
def __init__(self) -> None: | |
self._name: str | None = None | |
self._type: type | None = None | |
self._default: Any = Missing | |
self._primary: bool = False | |
self._foreign: Column[Any] | None = None | |
self._table: type[Table] | None = None | |
def name(self, name: str) -> Self: | |
self._name = name | |
return self | |
def type(self, type: type) -> Self: | |
self._type = type | |
return self | |
def default(self, default: Any) -> Self: | |
self._default = default | |
return self | |
def primary(self) -> Self: | |
self._primary = True | |
return self | |
def foreign(self, column: Column[Any]) -> Self: | |
self._foreign = column | |
return self | |
def table(self, table: type[Table]) -> Self: | |
self._table = table | |
return self | |
def build(self) -> Column[Any]: | |
if not self._name: | |
raise Exception("No name") | |
if not self._type: | |
raise Exception("No type") | |
if not self._table: | |
raise Exception("No table") | |
return Column(self._table, self._name, self._type, self._default) | |
class QueryBuilder(Generic[T_T]): | |
def __init__(self, table: type[T_T]) -> None: | |
self.table = table | |
def build(self) -> tuple[str, list[Any]]: | |
raise NotImplementedError | |
async def execute(self, conn: Connection) -> int: | |
query, parameters = self.build() | |
res = await conn.execute(query, *parameters) | |
return int(res.split(" ")[1]) | |
async def fetch(self, conn: Connection) -> list[T_T]: | |
query, parameters = self.build() | |
records = await conn.fetch(query, *parameters) | |
return [self.table(**record) for record in records] | |
async def fetchone(self, conn: Connection) -> T_T | None: | |
query, parameters = self.build() | |
record = await conn.fetchrow(query, *parameters) | |
if record: | |
return self.table(**record) | |
class SelectQueryBuilder(QueryBuilder[T_T]): | |
def __init__(self, table: type[T_T]) -> None: | |
super().__init__(table) | |
self._wheres: list[WhereQuery] = [] | |
def where(self, query: WhereQuery): | |
self._wheres.append(query) | |
return self | |
def build(self) -> tuple[str, list[Any]]: | |
columns = ", ".join([column.name for column in self.table._metadata.columns]) | |
query_parts = [f"select {columns} from {self.table._metadata.name}"] | |
if self._wheres: | |
where_clause = ' and '.join([f"{where.column.name} {where.op} ${i}" for i, where in enumerate(self._wheres)]) | |
query_parts.append(f"where {where_clause}") | |
return " ".join(query_parts), [where.value for where in self._wheres] | |
def join(self, query: SelectQueryBuilder[T_OT]) -> JoinSelectQueryBuilder[T_T, T_OT]: | |
return JoinSelectQueryBuilder(self, query) | |
class JoinSelectQueryBuilder(SelectQueryBuilder[T_T], Generic[T_T, *T_Ts]): | |
def __init__(self, select_query: SelectQueryBuilder[T_T], join: SelectQueryBuilder[Any]): | |
self._wheres = select_query._wheres | |
self.table = select_query.table | |
self.joins: list[SelectQueryBuilder[Table]] = [join] | |
def join(self, query: SelectQueryBuilder[T_OT]) -> JoinSelectQueryBuilder[T_T, *T_Ts, T_OT]: | |
self.joins.append(query) | |
return cast(JoinSelectQueryBuilder[T_T, *T_Ts, T_OT], self) | |
def build(self) -> tuple[str, list[str]]: | |
columns: list[str] = [] | |
values: list[Any] = [] | |
for table in [self.table] + [join.table for join in self.joins]: | |
for column in table._metadata.columns: | |
columns.append(f"{table._metadata.name}.{column.name} as {table._metadata.name}_{column.name}") | |
joins: list[str] = [] | |
for join in self.joins: | |
wheres: list[str] = [] | |
for where in join._wheres: | |
if isinstance(where.value, Column): | |
value = f"{where.value.table._metadata.name}.{where.value.name}" | |
else: | |
value = f"${len(values) + 1}" | |
values.append(where.value) | |
wheres.append(f"{where.column.table._metadata.name}.{where.column.name} {where.op} {value}") | |
joins.append(f"inner join {join.table._metadata.name} on {' and '.join(wheres)}") | |
wheres = [] | |
for where in self._wheres: | |
if isinstance(where.value, Column): | |
value = f"{where.value.table._metadata.name}.{where.value.name}" | |
else: | |
value = f"${len(values) + 1}" | |
values.append(where.value) | |
wheres.append(f"{where.column.table._metadata.name}.{where.column.name} {where.op} {value}") | |
where_clause = f"where {' and '.join(wheres)}" if wheres else "" | |
query = f"select {','.join(columns)} from {self.table._metadata.name} {' '.join(joins)} {where_clause}" | |
return query, values | |
async def fetchone(self, conn: Connection) -> tuple[T_T, *T_Ts] | None: | |
query, parameters = self.build() | |
row = await conn.fetchrow(query, *parameters) | |
if row: | |
collections: dict[str, dict[str, Any]] = {} | |
for column, value in row.items(): | |
table_name, *rest = column.split("_") | |
collections.setdefault(table_name, {})["_".join(rest)] = value | |
return cast(tuple[T_T, *T_Ts], [join.table(**collections[join.table._metadata.name]) for join in self.joins]) | |
async def fetch(self, conn: Connection) -> list[tuple[T_T, *T_Ts]]: | |
query, parameters = self.build() | |
rows = await conn.fetch(query, *parameters) | |
output: list[tuple[Table, ...]] = [] | |
for row in rows: | |
collections: dict[str, dict[str, Any]] = {} | |
for column, value in row.items(): | |
table_name, *rest = column.split("_") | |
collections.setdefault(table_name, {})["_".join(rest)] = value | |
output.append(tuple(join.table(**collections[join.table._metadata.name]) for join in self.joins)) | |
return cast(list[tuple[T_T, *T_Ts]], output) | |
class InsertQueryBuilder(QueryBuilder[T_T]): | |
def build(self) -> tuple[str, list[Any]]: | |
columns = ", ".join([column.name for column in self.table._metadata.columns]) | |
values = ", ".join(f"${i}" for i in range(len(self.table._metadata.columns))) | |
return f"insert into {self.table._metadata.name} ({columns}) values ({values})", [getattr(self, column.name) for column in self.table._metadata.columns] | |
class WhereQuery: | |
def __init__(self, column: Column[Any], value: Any, op: str): | |
self.column = column | |
self.value = value | |
self.op = op | |
class Table: | |
_metadata: ClassVar[TableMetadata] | |
def __init_subclass__(cls, *, table_name: str | None = None) -> None: | |
columns: list[Column[Any]] = [] | |
for key, ann in cls.__annotations__.items(): | |
ann = eval_annotation(ann) | |
origin = get_origin(ann) | |
if origin is Annotated: | |
ty, column_builder = get_args(ann) | |
column_builder._ty = eval_annotation(ty) | |
else: | |
column_builder_ty: type[Column[Any]] = ann | |
ty, = get_args(column_builder_ty) | |
column_builder = ColumnBuilder() | |
column = column_builder.name(key).type(ty).table(cls).build() | |
columns.append(column) | |
setattr(cls, key, column) | |
cls._metadata = TableMetadata(table_name or cls.__name__, columns) | |
def __init__(self, **kwargs: Any): | |
self._metadata.values = kwargs | |
@classmethod | |
def select(cls) -> SelectQueryBuilder[Self]: | |
return SelectQueryBuilder(cls) | |
@classmethod | |
def where(cls, where: WhereQuery) -> SelectQueryBuilder[Self]: | |
return cls.select().where(where) | |
Text = Column[str] | |
Int = Column[int] | |
class Customer(Table, table_name="customers"): | |
id: Annotated[Int, ColumnBuilder().primary()] | |
name: Text | |
class Item(Table, table_name="items"): | |
id: Annotated[Int, ColumnBuilder().primary()] | |
name: Text | |
class Order(Table, table_name="orders"): | |
id: Annotated[Int, ColumnBuilder().primary()] | |
customer: Annotated[Int, ColumnBuilder().foreign(Customer.id)] | |
item: Annotated[Int, ColumnBuilder().foreign(Item.id)] | |
async def main(db: asyncpg.Connection[asyncpg.Record]): | |
query = (Order.select() | |
.join(Customer | |
.where(Order.customer == Customer.id)) | |
.join(Item() | |
.where(Order.item == Item.id) | |
.where(Item.name == "Chair")) | |
) | |
reveal_type(await query.fetchone(db)) | |
reveal_type(await query.fetch(db)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment