Skip to content

Instantly share code, notes, and snippets.

@pashri
Last active March 19, 2025 13:01
Show Gist options
  • Save pashri/e19a85354e973b1f1175b51032a9f3b1 to your computer and use it in GitHub Desktop.
Save pashri/e19a85354e973b1f1175b51032a9f3b1 to your computer and use it in GitHub Desktop.
S3 Mapping Agent
"""S3 Mapping Agent: Allows a dict-like mapping from a CSV in AWS S3"""
from collections.abc import Iterator, Mapping
import json
from typing import Any, Final, Literal
import boto3
RS: Final[str] = '\u241e' # Record separator character
BRS: Final[bytes] = RS.encode() # Binary version of record separator character
LRS: Final[int] = len(BRS) # Length of binary version of record separator
def split_s3_uri(uri: str) -> tuple[str, str]:
"""Split S3 URI"""
bucket, key = uri.split('/', 3)[2:]
return bucket, key
class MappingAgentS3CSV(Mapping): # pragma: no cover
"""S3 Mapping Agent: Allows a dict-like mapping from a CSV in AWS S3"""
def __init__( # pylint: disable=too-many-arguments
self,
uri: str,
key_col: str,
value_col: str,
*,
sep: str = ',',
compression: Literal['NONE', 'GZIP', 'BZIP2'] = 'NONE',
session: boto3.Session | None = None,
):
self.bucket, self.key = split_s3_uri(uri)
self.key_col = key_col
self.value_col = value_col
self.sep = sep
self.compression = compression
session = session or boto3.Session()
self.client = session.client('s3')
def __getitem__(self, key: str) -> str:
query = f'''
SELECT s.{self.key_col}, s.{self.value_col}
FROM S3Object s
WHERE s.{self.key_col} = '{key}';
'''
results: tuple[dict[str, str], ...] = tuple(self.query(query=query))
if len(results) == 0:
raise KeyError(f'Key `{key}` not found.')
if len(results) > 1:
raise LookupError(f'Multiple values for key `{key}` found.')
return results[0][self.value_col]
def __iter__(self) -> Iterator[tuple[str, str]]:
query = f'SELECT {self.key_col} FROM S3Object s;'
results = self.query(query=query)
for result in results:
yield result[self.key_col]
def __len__(self) -> int:
query = 'SELECT count(*) FROM S3Object s;'
results: tuple[dict[str, str], ...] = tuple(self.query(query=query))
count = int(results[0]['_1'])
return count
def query(self, query: str) -> Iterator[dict[str, Any]]:
"""Query S3 CSV with headers
See query style at
https://docs.aws.amazon.com/AmazonS3/latest/API/API_SelectObjectContent.html
Note:
-----
`moto` cannot currently handle `select_object_content`
"""
response = self.client.select_object_content(
Bucket=self.bucket,
Key=self.key,
Expression=query,
ExpressionType='SQL',
InputSerialization={
'CSV': {
'FileHeaderInfo': 'USE',
'FieldDelimiter': self.sep,
'AllowQuotedRecordDelimiter': True,
},
'CompressionType': self.compression,
},
OutputSerialization={
'JSON': {'RecordDelimiter': RS},
},
)
blob: bytes = b''
for segment in response['Payload']:
blob += segment.get('Records', {}).get('Payload', b'')
while (end := blob.find(BRS)) > 0:
record = blob[:end].decode('utf-8')
yield json.loads(record)
blob = blob[end+LRS:]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment