Last active
September 16, 2024 09:09
-
-
Save nrbnlulu/29737080b82b1abb4a1bec0df4cc10be to your computer and use it in GitHub Desktop.
Strawberry-GraphQL Node type with support to foreign node fields and custom id types.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import base64 | |
from dataclasses import dataclass | |
from functools import cached_property | |
from typing import TYPE_CHECKING, Annotated, Any, Self | |
import strawberry | |
from strawberry.annotation import StrawberryAnnotation | |
from strawberry.types.field import StrawberryField | |
from strawberry.types.private import StrawberryPrivate | |
from gql_.context import Info | |
@dataclass(slots=True) | |
class LazyRef[T]: | |
store: dict[str, T] | |
id_: str | |
def resolve(self) -> T: | |
return self.store[self.id_] | |
# interfaces can't have generics see https://github.com/strawberry-graphql/strawberry/issues/3602 | |
# also I'm not sure it is possible to have a private generic class. | |
# this is the cause of the various hacks in this file | |
@strawberry.interface | |
class NodeV2: | |
id_p: strawberry.Private[Any] | |
def __init_subclass__(cls) -> None: | |
node_registry[cls.__name__] = cls | |
for k, v in cls.__annotations__.copy().items(): | |
if k.endswith("_p"): | |
annotation, node_ref = _resolve_field_type(v) | |
lazy_field = LazyIdField( | |
k, name=k.removesuffix("_p"), annotation=annotation, node_ref=node_ref | |
) | |
setattr(cls, lazy_field.name, lazy_field) | |
@classmethod | |
def _create_id(cls, id_) -> GlobalID: | |
raise NotImplementedError | |
@strawberry.field | |
def id(self) -> strawberry.ID: | |
return strawberry.ID(self._create_id(self.id_p).to_base64()) | |
def _resolve_field_type( | |
annotation: str | Any, | |
) -> tuple[StrawberryAnnotation, LazyRef[type[NodeV2]]]: | |
if isinstance(annotation, str): | |
is_optional = annotation.startswith(LazyIDOpt.__name__) | |
assert ( | |
annotation.startswith(LazyID.__name__) or is_optional | |
), f"Annotation {annotation} is not a LazyID" | |
_, node_type_name = annotation.split("[")[1].split(",") | |
node_type_name = node_type_name.replace(" ", "").removesuffix("|").removesuffix("]") | |
return ( | |
StrawberryAnnotation(strawberry.ID | None if is_optional else strawberry.ID), | |
LazyRef(node_registry, node_type_name), | |
) | |
# this is an annotation | |
raise NotImplementedError | |
node_registry: dict[str, type[NodeV2]] = {} | |
@strawberry.interface | |
class IntNodeV2(NodeV2): | |
if TYPE_CHECKING: | |
id_p: int | |
@classmethod | |
def _create_id(cls, id_: int) -> IntGlobalID: | |
return IntGlobalID(cls.__name__, id_) | |
class GlobalID[T]: | |
__slots__ = ("type_name", "id_") | |
type_name: str | |
id_: T | |
def __init__(self, type_name: str, id_: T) -> None: | |
assert type_name in node_registry, f"{type_name} is not in the node registry" | |
self.type_name = type_name | |
self.id_ = id_ | |
def create_id(self) -> str: | |
return f"{self.type_name}:{self.id_}" | |
@classmethod | |
def id_from_str(cls, id_str: str) -> T: | |
raise NotImplementedError | |
def to_base64(self) -> str: | |
return base64.b64encode(f"{self.type_name}:{self.id_}".encode()).decode() | |
@classmethod | |
def from_base64(cls, base64_id: str) -> Self: | |
type_name, id_ = base64.b64decode(base64_id).decode().split(":") | |
return cls( | |
type_name=type_name, | |
id_=cls.id_from_str(id_), | |
) | |
def get_type(self) -> type[NodeV2]: | |
return node_registry[self.type_name] | |
class IntGlobalID(GlobalID[int]): | |
@classmethod | |
def id_from_str(cls, id_str: str) -> int: | |
return int(id_str) | |
if TYPE_CHECKING: | |
type LazyID[T, R: NodeV2] = Annotated[T, StrawberryPrivate(), R] | |
type LazyIDOpt[T, R: NodeV2] = Annotated[T | None, StrawberryPrivate(), R | None] | |
else: | |
class LazyID[T, R](StrawberryPrivate): | |
def __class_getitem__(cls, item): | |
return Annotated[Any, StrawberryPrivate(), item] | |
LazyIDOpt = LazyID | |
class LazyIdField(StrawberryField): | |
def __init__( | |
self, | |
private_field_name: str, | |
name: str, | |
annotation: StrawberryAnnotation, | |
node_ref: LazyRef[type[NodeV2]], | |
) -> None: | |
super().__init__() | |
self.private_field_name = private_field_name | |
self.name = name | |
self.type_annotation = annotation | |
self.node_ref = node_ref | |
def foo() -> strawberry.ID: ... | |
# set the resolver in order not to create a dataclass field | |
self(foo) | |
@cached_property | |
def node_type(self) -> type[NodeV2]: | |
return self.node_ref.resolve() | |
def get_result( | |
self, source: strawberry.auto, info: Info | None, args, kwargs | |
) -> strawberry.ID | None: | |
if origin := getattr(source, self.private_field_name): | |
return strawberry.ID(self.node_type._create_id(origin).to_base64()) | |
return None | |
__all__ = ["NodeV2", "IntNodeV2", "LazyID", "LazyIDOpt", "IntGlobalID"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import strawberry | |
from gql_.node import IntNodeV2, LazyID | |
@strawberry.type | |
class Apple(IntNodeV2): | |
color: str | |
@strawberry.type | |
class Worm(IntNodeV2): | |
length: int | |
apple_id: LazyID[int, Apple] | |
@strawberry.type | |
class Query: | |
def worm(self) -> Worm: | |
return Worm(id_p=1, length=1, apple_id=2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment