Last active
January 23, 2024 21:38
-
-
Save draincoder/2290f0bd826a5356f66ab49c7d93f163 to your computer and use it in GitHub Desktop.
Aiohttp base client for JSON API
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ssl | |
from types import TracebackType | |
from typing import Any, Type | |
import certifi | |
from aiohttp import ClientSession, TCPConnector, ClientTimeout, ClientResponse | |
from aiohttp.typedefs import StrOrURL | |
from aiohttp import ClientResponseError | |
Headers = dict[str, str] | None | |
Body = dict[str, str] | None | |
class BadStatusError(ClientResponseError): | |
def __init__(self, response: ClientResponse, payload: Body) -> None: | |
super().__init__( | |
request_info=response.request_info, | |
history=response.history, | |
status=response.status, | |
headers=response.headers, | |
) | |
self.payload = payload | |
class BaseClient: | |
def __init__( | |
self, | |
timeout: int | None = None, | |
connection_limit: int = 0, | |
) -> None: | |
""" | |
:param timeout: Total number of seconds for the whole request, use None for disable timeout | |
:param connection_limit: Total number simultaneous connections, use 0 for disable limit | |
:return: None | |
""" | |
self._ssl_context = ssl.create_default_context(cafile=certifi.where()) | |
self._connector = TCPConnector(ssl=self._ssl_context, limit=connection_limit) | |
self._timeout = ClientTimeout(total=timeout) | |
self._session = self._create_session() | |
self.ok_status = 200 | |
def _get_session(self) -> ClientSession: | |
if not self._session.closed: | |
return self._session | |
self._session = self._create_session() | |
return self._session | |
def _create_session(self) -> ClientSession: | |
return ClientSession(connector=self._connector, timeout=self._timeout) | |
async def get(self, url: StrOrURL, headers: Headers = None) -> Any: | |
async with self._get_session().get(url, headers=headers) as response: | |
payload = await response.json() | |
if response.status != self.ok_status: | |
raise BadStatusError(response, payload) | |
return payload | |
async def post(self, url: StrOrURL, body: Body = None, headers: Headers = None) -> Any: | |
async with self._get_session().post(url, data=body, headers=headers) as response: | |
payload = await response.json() | |
if response.status != self.ok_status: | |
raise BadStatusError(response, payload) | |
return payload | |
async def close(self) -> None: | |
if not self._session.closed: | |
await self._session.close() | |
await self._connector.close() | |
async def __aenter__(self) -> 'BaseClient': | |
return self | |
async def __aexit__( | |
self, | |
exc_type: Type[BaseException] | None, | |
exc_val: BaseException | None, | |
exc_tb: TracebackType | None, | |
) -> None: | |
await self.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment