Last active
March 19, 2025 13:01
-
-
Save pashri/e19a85354e973b1f1175b51032a9f3b1 to your computer and use it in GitHub Desktop.
S3 Mapping Agent
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""S3 Mapping Agent: Allows a dict-like mapping from a CSV in AWS S3""" | |
from collections.abc import Iterator, Mapping | |
import json | |
from typing import Any, Final, Literal | |
import boto3 | |
RS: Final[str] = '\u241e' # Record separator character | |
BRS: Final[bytes] = RS.encode() # Binary version of record separator character | |
LRS: Final[int] = len(BRS) # Length of binary version of record separator | |
def split_s3_uri(uri: str) -> tuple[str, str]: | |
"""Split S3 URI""" | |
bucket, key = uri.split('/', 3)[2:] | |
return bucket, key | |
class MappingAgentS3CSV(Mapping): # pragma: no cover | |
"""S3 Mapping Agent: Allows a dict-like mapping from a CSV in AWS S3""" | |
def __init__( # pylint: disable=too-many-arguments | |
self, | |
uri: str, | |
key_col: str, | |
value_col: str, | |
*, | |
sep: str = ',', | |
compression: Literal['NONE', 'GZIP', 'BZIP2'] = 'NONE', | |
session: boto3.Session | None = None, | |
): | |
self.bucket, self.key = split_s3_uri(uri) | |
self.key_col = key_col | |
self.value_col = value_col | |
self.sep = sep | |
self.compression = compression | |
session = session or boto3.Session() | |
self.client = session.client('s3') | |
def __getitem__(self, key: str) -> str: | |
query = f''' | |
SELECT s.{self.key_col}, s.{self.value_col} | |
FROM S3Object s | |
WHERE s.{self.key_col} = '{key}'; | |
''' | |
results: tuple[dict[str, str], ...] = tuple(self.query(query=query)) | |
if len(results) == 0: | |
raise KeyError(f'Key `{key}` not found.') | |
if len(results) > 1: | |
raise LookupError(f'Multiple values for key `{key}` found.') | |
return results[0][self.value_col] | |
def __iter__(self) -> Iterator[tuple[str, str]]: | |
query = f'SELECT {self.key_col} FROM S3Object s;' | |
results = self.query(query=query) | |
for result in results: | |
yield result[self.key_col] | |
def __len__(self) -> int: | |
query = 'SELECT count(*) FROM S3Object s;' | |
results: tuple[dict[str, str], ...] = tuple(self.query(query=query)) | |
count = int(results[0]['_1']) | |
return count | |
def query(self, query: str) -> Iterator[dict[str, Any]]: | |
"""Query S3 CSV with headers | |
See query style at | |
https://docs.aws.amazon.com/AmazonS3/latest/API/API_SelectObjectContent.html | |
Note: | |
----- | |
`moto` cannot currently handle `select_object_content` | |
""" | |
response = self.client.select_object_content( | |
Bucket=self.bucket, | |
Key=self.key, | |
Expression=query, | |
ExpressionType='SQL', | |
InputSerialization={ | |
'CSV': { | |
'FileHeaderInfo': 'USE', | |
'FieldDelimiter': self.sep, | |
'AllowQuotedRecordDelimiter': True, | |
}, | |
'CompressionType': self.compression, | |
}, | |
OutputSerialization={ | |
'JSON': {'RecordDelimiter': RS}, | |
}, | |
) | |
blob: bytes = b'' | |
for segment in response['Payload']: | |
blob += segment.get('Records', {}).get('Payload', b'') | |
while (end := blob.find(BRS)) > 0: | |
record = blob[:end].decode('utf-8') | |
yield json.loads(record) | |
blob = blob[end+LRS:] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment