Last active
November 8, 2024 18:51
-
-
Save panasenco/27d01bd0dc3a11325f36f00001abdb7b to your computer and use it in GitHub Desktop.
Command line tool that takes one or more GZIP-compressed CSV files containing 3 columns: value (JSON string), process_started_at_utc (ISO 8601 timestamp), and process_ended_at_utc (ISO 8601 timestamp), and loads them all into Snowflake while optionally replacing existing data that matches the given filter, all in a single transaction!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import logging | |
from pathlib import Path | |
from typing import Any, Iterable, Sequence | |
import backoff | |
import snowflake.connector.errors | |
from .snowflake_connection import get_snowflake_connection | |
@backoff.on_exception( | |
backoff.expo, | |
snowflake.connector.errors.ProgrammingError, | |
max_tries=10, | |
jitter=backoff.full_jitter, | |
giveup=lambda err: "number of waiters for this lock exceeds" not in err.raw_msg, | |
on_backoff=lambda details: get_logger("elt").warning( | |
f"Reached max table lock waiters in Snowflake, retrying {details['tries']}/9." | |
), | |
backoff_log_level=logging.DEBUG, | |
factor=20, # wait between 0s and 20s, 40s, 80s, 160s, ... (increases until 640s and varies randomly due to jitter) | |
max_value=640, | |
) | |
def load_to_snowflake( | |
schema: str, | |
table: str, | |
files: Iterable[Path], | |
temporary: bool = False, | |
transient: bool = False, | |
delete_filter: str | None = None, | |
delete_filter_params: Sequence[Any] | None = None, | |
) -> None: | |
with get_snowflake_connection() as warehouse_connection: | |
with warehouse_connection.cursor() as c: | |
c.execute(f"use schema {schema}") | |
c.execute( | |
f""" | |
create{" temporary" if temporary else " transient" if transient else ""} table if not exists {table} | |
( | |
value variant, | |
process_started_at_utc timestamp_ntz not null, | |
process_ended_at_utc timestamp_ntz not null | |
) | |
""", | |
) | |
for file in files: | |
c.execute(f"put file://{file} @%{table}") | |
c.execute("begin transaction") | |
commit = False | |
try: | |
if delete_filter: | |
logging.info( | |
f"Deleting from table {schema}.{table} where {delete_filter} with params [{delete_filter_params}]" | |
) | |
c.execute(f"delete from {table} where {delete_filter}", params=delete_filter_params) | |
if len(files) > 0: | |
c.execute( | |
f""" | |
copy into {table} | |
from @%{table}/ | |
files=('{"', '".join([file.name for file in files])}') | |
file_format = (type = 'csv' field_delimiter = ',' field_optionally_enclosed_by = '\"' skip_header = 1) | |
purge=true | |
""" | |
) | |
commit = True | |
finally: | |
c.execute("commit" if commit else "rollback") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
prog="load_to_snowflake", | |
description="CLI command to load a compressed CSV file into a Snowflake table.", | |
) | |
parser.add_argument( | |
"-s", | |
"--schema", | |
help="Schema of the table to load the data into.", | |
required=True, | |
) | |
parser.add_argument( | |
"-t", | |
"--table", | |
help="Name of the table to load the data into.", | |
required=True, | |
) | |
file_group = parser.add_mutually_exclusive_group() | |
file_group.add_argument( | |
"-f", | |
"--file", | |
help="Path(s) to the file(s) to load into the table.", | |
type=Path, | |
action="append", | |
) | |
file_group.add_argument( | |
"-F", | |
"--files-file", | |
help="File containing a list of files to load into the table. Use '-' for stdin.", | |
type=argparse.FileType('r'), | |
) | |
parser.add_argument( | |
"-r", | |
"--transient", | |
help="Set this flag to create the table as transient if it doesn't already exist.", | |
action="store_true", | |
) | |
parser.add_argument( | |
"-d", | |
"--delete-filter", | |
help=( | |
"Delete rows from the table matching this filter before loading the data. " | |
"DO NOT insert user-provided data directly into this filter as that could enable SQL injection attacks! " | |
"Instead provide the data as parameters and use numeric parameter binding: " | |
"https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-example#qmark-or-numeric-binding" | |
), | |
required=False, | |
) | |
parser.add_argument( | |
"delete_filter_params", | |
nargs="*", | |
help="Parameters to pass to the delete filter.", | |
) | |
parser.add_argument( | |
"-v", | |
"--verbose", | |
help="Be verbose. Include once for INFO output, twice for DEBUG output.", | |
action="count", | |
default=0, | |
) | |
args = parser.parse_args() | |
LOGGING_LEVELS = [logging.WARNING, logging.INFO, logging.DEBUG] | |
logging.basicConfig(level=LOGGING_LEVELS[min(args.verbose, len(LOGGING_LEVELS) - 1)]) # cap to last level index | |
if args.files_file: | |
files = [Path(line.strip()) for line in args.files_file] | |
elif args.file: | |
files = args.file | |
else: | |
files = [] | |
load_to_snowflake( | |
schema=args.schema, | |
table=args.table, | |
files=files, | |
transient=args.transient, | |
delete_filter=args.delete_filter, | |
delete_filter_params=args.delete_filter_params, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
backoff | |
snowflake-connector-python |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import os | |
from pathlib import Path | |
from typing import Any | |
import backoff | |
import snowflake.connector | |
# Use server-side parameter binding. | |
# See: https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-example#using-qmark-or-numeric-binding | |
snowflake.connector.paramstyle = "numeric" | |
@backoff.on_exception( | |
backoff.expo, | |
snowflake.connector.errors.DatabaseError, | |
max_tries=8, | |
giveup=lambda err: "JWT" not in getattr(err, "raw_msg", ""), | |
on_backoff=lambda details: logging.info(f"Encountered JWT token issue, trying {details['tries']}/8."), | |
backoff_log_level=logging.DEBUG, | |
) | |
def get_snowflake_connection(**kwargs: Any) -> snowflake.connector.connection.SnowflakeConnection: | |
""" | |
Connects to Snowflake using input parameters or environment variables | |
See: https://docs.snowflake.com/en/developer-guide/snowflake-cli-v2/connecting/specify-credentials#use-environment-variables-for-snowflake-credentials | |
""" | |
return snowflake.connector.connect(**kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment