Last active
February 20, 2026 10:59
-
-
Save Tishka17/453d40ef8d81b44e82225c43bf24999e to your computer and use it in GitHub Desktop.
Dishka async-sync container. `aget` and `get` methods
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| import warnings | |
| from asyncio import Lock | |
| from collections.abc import Callable, MutableMapping | |
| from contextlib import AbstractAsyncContextManager | |
| from types import TracebackType | |
| from typing import Any, TypeVar, cast, overload | |
| from dishka.entities.component import DEFAULT_COMPONENT, Component | |
| from dishka.entities.factory_type import FactoryType | |
| from dishka.entities.key import DependencyKey | |
| from dishka.entities.marker import Has, HasContext | |
| from dishka.entities.scope import BaseScope, Scope | |
| from dishka.provider import Provider, activate | |
| from .container_objects import Exit | |
| from .context_proxy import ContextProxy | |
| from .dependency_source import Factory | |
| from .entities.validation_settings import ( | |
| DEFAULT_VALIDATION, | |
| ValidationSettings, | |
| ) | |
| from .exception_base import DishkaError | |
| from .exceptions import ( | |
| ChildScopeNotFoundError, | |
| ExitError, | |
| NoActiveFactoryError, | |
| NoChildScopesError, | |
| NoFactoryError, | |
| NoNonSkippedScopesError, | |
| ) | |
| from .graph_builder.builder import GraphBuilder | |
| from .provider import BaseProvider, make_root_context_provider | |
| from .registry import Registry | |
| T = TypeVar("T") | |
| class HybridContainer: | |
| __slots__ = ( | |
| "_cache", | |
| "_context", | |
| "_exits", | |
| "child_registries", | |
| "close_parent", | |
| "lock", | |
| "parent_container", | |
| "registry", | |
| ) | |
| def __init__( | |
| self, | |
| registry: Registry, | |
| *child_registries: Registry, | |
| parent_container: HybridContainer | None = None, | |
| context: dict[Any, Any] | None = None, | |
| lock_factory: Callable[ | |
| [], AbstractAsyncContextManager[Any], | |
| ] | None = None, | |
| close_parent: bool = False, | |
| ): | |
| self.registry = registry | |
| self.child_registries = child_registries | |
| self._context = {CONTAINER_KEY: self} | |
| if context: | |
| self._context.update(context) | |
| self._cache = {CONTAINER_KEY: self} | |
| self.parent_container = parent_container | |
| self.lock: AbstractAsyncContextManager[Any] | None | |
| if lock_factory: | |
| self.lock = lock_factory() | |
| else: | |
| self.lock = None | |
| self._exits: list[Exit] = [] | |
| self.close_parent = close_parent | |
| @property | |
| def scope(self) -> BaseScope: | |
| return self.registry.scope | |
| @property | |
| def context(self) -> MutableMapping[DependencyKey, Any]: | |
| warnings.warn( | |
| "`container.context` is deprecated", | |
| DeprecationWarning, | |
| stacklevel=2, | |
| ) | |
| return ContextProxy(cache=self._cache, context=self._context) | |
| def __call__( | |
| self, | |
| context: dict[Any, Any] | None = None, | |
| lock_factory: Callable[ | |
| [], AbstractAsyncContextManager[Any], | |
| ] | None = None, | |
| scope: BaseScope | None = None, | |
| ) -> HybridContextWrapper: | |
| """ | |
| Prepare container for entering the inner scope. | |
| :param context: Data which will available in inner scope | |
| :param lock_factory: Callable to create lock instance or None | |
| :param scope: target scope or None to enter next non-skipped scope | |
| :return: async context manager for inner scope | |
| """ | |
| if not self.child_registries: | |
| raise NoChildScopesError | |
| child = HybridContainer( | |
| *self.child_registries, | |
| parent_container=self, | |
| context=context, | |
| lock_factory=lock_factory, | |
| ) | |
| if scope is None: | |
| while child.registry.scope.skip: | |
| if not child.child_registries: | |
| raise NoNonSkippedScopesError | |
| child = HybridContainer( | |
| *child.child_registries, | |
| parent_container=child, | |
| context=context, | |
| lock_factory=lock_factory, | |
| close_parent=True, | |
| ) | |
| else: | |
| while child.registry.scope is not scope: | |
| if not child.child_registries: | |
| raise ChildScopeNotFoundError(scope, self.registry.scope) | |
| child = HybridContainer( | |
| *child.child_registries, | |
| parent_container=child, | |
| context=context, | |
| lock_factory=lock_factory, | |
| close_parent=True, | |
| ) | |
| return HybridContextWrapper(child) | |
| @overload | |
| async def aget( | |
| self, | |
| dependency_type: type[T], | |
| component: Component | None = DEFAULT_COMPONENT, | |
| ) -> T: | |
| ... | |
| @overload | |
| async def aget( | |
| self, | |
| dependency_type: Any, | |
| component: Component | None = DEFAULT_COMPONENT, | |
| ) -> Any: | |
| ... | |
| async def aget( | |
| self, | |
| dependency_type: Any, | |
| component: Component | None = DEFAULT_COMPONENT, | |
| ) -> Any: | |
| lock = self.lock | |
| key = DependencyKey(dependency_type, component) | |
| try: | |
| if not lock: | |
| return await self._aget_unlocked(key) | |
| async with lock: | |
| return await self._aget_unlocked(key) | |
| except (NoFactoryError, NoActiveFactoryError) as e: | |
| e.scope = self.scope | |
| raise | |
| async def _aget(self, key: DependencyKey) -> Any: | |
| lock = self.lock | |
| if not lock: | |
| return await self._aget_unlocked(key) | |
| async with lock: | |
| return await self._aget_unlocked(key) | |
| async def _aget_unlocked(self, key: DependencyKey) -> Any: | |
| if key in self._cache: | |
| return self._cache[key] | |
| compiled = self.registry.get_compiled_async(key) | |
| if not compiled: | |
| if not self.parent_container: | |
| abstract_dependencies = ( | |
| self.registry.get_more_abstract_factories(key) | |
| ) | |
| concrete_dependencies = ( | |
| self.registry.get_more_concrete_factories(key) | |
| ) | |
| raise NoFactoryError( | |
| key, | |
| suggest_abstract_factories=abstract_dependencies, | |
| suggest_concrete_factories=concrete_dependencies, | |
| ) | |
| try: | |
| return await self.parent_container._aget(key) # noqa: SLF001 | |
| except NoFactoryError as ex: | |
| abstract_dependencies = ( | |
| self.registry.get_more_abstract_factories(key) | |
| ) | |
| concrete_dependencies = ( | |
| self.registry.get_more_concrete_factories(key) | |
| ) | |
| ex.suggest_abstract_factories.extend(abstract_dependencies) | |
| ex.suggest_concrete_factories.extend(concrete_dependencies) | |
| raise | |
| try: | |
| return await compiled( | |
| self._aget_unlocked, | |
| self._exits, | |
| self._cache, | |
| self._context, | |
| ) | |
| except NoFactoryError as e: | |
| # cast is needed because registry.get_factory will always | |
| # return Factory. This happens because registry.get_compiled | |
| # uses the same method and returns None if the factory is not found | |
| # If None is returned, then go to the parent container | |
| e.add_path(cast(Factory, self.registry.get_factory(key))) | |
| raise | |
| except NoActiveFactoryError as e: | |
| e.add_path(cast(Factory, self.registry.get_factory(key))) | |
| raise | |
| @overload | |
| def get( | |
| self, | |
| dependency_type: type[T], | |
| component: Component | None = DEFAULT_COMPONENT, | |
| ) -> T: | |
| ... | |
| @overload | |
| def get( | |
| self, | |
| dependency_type: Any, | |
| component: Component | None = DEFAULT_COMPONENT, | |
| ) -> Any: | |
| ... | |
| def get( | |
| self, | |
| dependency_type: Any, | |
| component: Component | None = DEFAULT_COMPONENT, | |
| ) -> Any: | |
| key = DependencyKey(dependency_type, component) | |
| try: | |
| return self._get_unlocked(key) | |
| except (NoFactoryError, NoActiveFactoryError) as e: | |
| e.scope = self.scope | |
| raise | |
| def _get(self, key: DependencyKey) -> Any: | |
| lock = self.lock | |
| return self._get_unlocked(key) | |
| def _get_unlocked(self, key: DependencyKey) -> Any: | |
| if key in self._cache: | |
| return self._cache[key] | |
| compiled = self.registry.get_compiled(key) | |
| if not compiled: | |
| if not self.parent_container: | |
| abstract_dependencies = ( | |
| self.registry.get_more_abstract_factories(key) | |
| ) | |
| concrete_dependencies = ( | |
| self.registry.get_more_concrete_factories(key) | |
| ) | |
| raise NoFactoryError( | |
| key, | |
| suggest_abstract_factories=abstract_dependencies, | |
| suggest_concrete_factories=concrete_dependencies, | |
| ) | |
| try: | |
| return self.parent_container._get(key) # noqa: SLF001 | |
| except NoFactoryError as ex: | |
| abstract_dependencies = ( | |
| self.registry.get_more_abstract_factories(key) | |
| ) | |
| concrete_dependencies = ( | |
| self.registry.get_more_concrete_factories(key) | |
| ) | |
| ex.suggest_abstract_factories.extend(abstract_dependencies) | |
| ex.suggest_concrete_factories.extend(concrete_dependencies) | |
| raise | |
| try: | |
| return compiled( | |
| self._get_unlocked, | |
| self._exits, | |
| self._cache, | |
| self._context, | |
| ) | |
| except NoFactoryError as e: | |
| # cast is needed because registry.get_factory will always | |
| # return Factory. This happens because registry.get_compiled | |
| # uses the same method and returns None if the factory is not found | |
| # If None is returned, then go to the parent container | |
| e.add_path(cast(Factory, self.registry.get_factory(key))) | |
| raise | |
| except NoActiveFactoryError as e: | |
| e.add_path(cast(Factory, self.registry.get_factory(key))) | |
| raise | |
| def close(self, exception: BaseException | None = None) -> None: | |
| errors = [] | |
| for exit_generator in self._exits[::-1]: | |
| try: | |
| if exit_generator.type is FactoryType.ASYNC_GENERATOR: | |
| raise DishkaError("Cannot finalize async generator in sync close method") | |
| elif exit_generator.type is FactoryType.GENERATOR: | |
| exit_generator.callable.send(exception) # type: ignore[attr-defined] | |
| except StopIteration: # noqa: PERF203 | |
| pass | |
| except Exception as err: # noqa: BLE001 | |
| errors.append(err) | |
| self._cache = {CONTAINER_KEY: self} | |
| if self.close_parent and self.parent_container: | |
| try: | |
| self.parent_container.close(exception) | |
| except Exception as err: # noqa: BLE001 | |
| errors.append(err) | |
| if errors: | |
| raise ExitError("Cleanup context errors", errors) # noqa: TRY003 | |
| async def aclose(self, exception: BaseException | None = None) -> None: | |
| errors = [] | |
| for exit_generator in self._exits[::-1]: | |
| try: | |
| if exit_generator.type is FactoryType.ASYNC_GENERATOR: | |
| await exit_generator.callable.asend(exception) # type: ignore[attr-defined] | |
| elif exit_generator.type is FactoryType.GENERATOR: | |
| exit_generator.callable.send(exception) # type: ignore[attr-defined] | |
| except StopIteration: # noqa: PERF203 | |
| pass | |
| except StopAsyncIteration: | |
| pass | |
| except Exception as err: # noqa: BLE001 | |
| errors.append(err) | |
| self._cache = {CONTAINER_KEY: self} | |
| if self.close_parent and self.parent_container: | |
| try: | |
| await self.parent_container.aclose(exception) | |
| except Exception as err: # noqa: BLE001 | |
| errors.append(err) | |
| if errors: | |
| raise ExitError("Cleanup context errors", errors) # noqa: TRY003 | |
| async def _has(self, marker: Any) -> bool: | |
| compiled = self.registry.get_compiled_activation_async(marker) | |
| if not compiled: | |
| if not self.parent_container: | |
| return False | |
| return await self.parent_container._has(marker) # noqa: SLF001 | |
| return bool(await compiled( | |
| self._get_unlocked, | |
| self._exits, | |
| self._cache, | |
| self._context, | |
| )) | |
| def _has_context(self, marker: Any) -> bool: | |
| return marker in self._context | |
| class HybridContextWrapper: | |
| def __init__(self, container: HybridContainer): | |
| self.container = container | |
| async def __aenter__(self) -> HybridContainer: | |
| return self.container | |
| def __enter__(self) -> HybridContainer: | |
| return self.container | |
| def __exit__( | |
| self, | |
| exc_type: type[BaseException] | None = None, | |
| exc_val: BaseException | None = None, | |
| exc_tb: TracebackType | None = None, | |
| ) -> None: | |
| self.container.close(exception=exc_val) | |
| async def __aexit__( | |
| self, | |
| exc_type: type[BaseException] | None = None, | |
| exc_val: BaseException | None = None, | |
| exc_tb: TracebackType | None = None, | |
| ) -> None: | |
| await self.container.aclose(exception=exc_val) | |
| class HasProvider(Provider): | |
| @activate(Has) | |
| async def has( | |
| self, | |
| marker: DependencyKey, | |
| container: HybridContainer, | |
| ) -> bool: | |
| key = DependencyKey(marker.type_hint.value, marker.component) | |
| return await container._has(key) # noqa: SLF001 | |
| @activate(HasContext) | |
| def has_context( | |
| self, | |
| marker: HasContext, | |
| container: HybridContainer, | |
| ) -> bool: | |
| return container._has_context(marker.value) # noqa: SLF001 | |
| def make_hybrid_container( | |
| *providers: BaseProvider, | |
| scopes: type[BaseScope] = Scope, | |
| context: dict[Any, Any] | None = None, | |
| lock_factory: Callable[ | |
| [], AbstractAsyncContextManager[Any], | |
| ] | None = Lock, | |
| skip_validation: bool = False, | |
| start_scope: BaseScope | None = None, | |
| validation_settings: ValidationSettings = DEFAULT_VALIDATION, | |
| ) -> HybridContainer: | |
| context_provider = make_root_context_provider(providers, context, scopes) | |
| has_provider = HasProvider() | |
| builder = GraphBuilder( | |
| scopes=scopes, | |
| container_key=CONTAINER_KEY, | |
| skip_validation=skip_validation, | |
| validation_settings=validation_settings, | |
| ) | |
| builder.add_multicomponent_providers(has_provider) | |
| builder.add_providers(*providers) | |
| builder.add_providers(context_provider) | |
| registries = builder.build() | |
| container = HybridContainer( | |
| *registries, | |
| context=context, | |
| lock_factory=lock_factory, | |
| ) | |
| if start_scope is None: | |
| while container.registry.scope.skip: | |
| container = HybridContainer( | |
| *container.child_registries, | |
| parent_container=container, | |
| context=context, | |
| lock_factory=lock_factory, | |
| close_parent=True, | |
| ) | |
| else: | |
| while container.registry.scope is not start_scope: | |
| container = HybridContainer( | |
| *container.child_registries, | |
| parent_container=container, | |
| context=context, | |
| lock_factory=lock_factory, | |
| close_parent=True, | |
| ) | |
| return container | |
| CONTAINER_KEY = DependencyKey(HybridContainer, DEFAULT_COMPONENT) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Not working: