Skip to content

Instantly share code, notes, and snippets.

@panasenco
Last active November 8, 2024 18:51
Show Gist options
  • Save panasenco/27d01bd0dc3a11325f36f00001abdb7b to your computer and use it in GitHub Desktop.
Save panasenco/27d01bd0dc3a11325f36f00001abdb7b to your computer and use it in GitHub Desktop.
Command line tool that takes one or more GZIP-compressed CSV files containing 3 columns: value (JSON string), process_started_at_utc (ISO 8601 timestamp), and process_ended_at_utc (ISO 8601 timestamp), and loads them all into Snowflake while optionally replacing existing data that matches the given filter, all in a single transaction!
import argparse
import logging
from pathlib import Path
from typing import Any, Iterable, Sequence
import backoff
import snowflake.connector.errors
from .snowflake_connection import get_snowflake_connection
@backoff.on_exception(
backoff.expo,
snowflake.connector.errors.ProgrammingError,
max_tries=10,
jitter=backoff.full_jitter,
giveup=lambda err: "number of waiters for this lock exceeds" not in err.raw_msg,
on_backoff=lambda details: get_logger("elt").warning(
f"Reached max table lock waiters in Snowflake, retrying {details['tries']}/9."
),
backoff_log_level=logging.DEBUG,
factor=20, # wait between 0s and 20s, 40s, 80s, 160s, ... (increases until 640s and varies randomly due to jitter)
max_value=640,
)
def load_to_snowflake(
schema: str,
table: str,
files: Iterable[Path],
temporary: bool = False,
transient: bool = False,
delete_filter: str | None = None,
delete_filter_params: Sequence[Any] | None = None,
) -> None:
with get_snowflake_connection() as warehouse_connection:
with warehouse_connection.cursor() as c:
c.execute(f"use schema {schema}")
c.execute(
f"""
create{" temporary" if temporary else " transient" if transient else ""} table if not exists {table}
(
value variant,
process_started_at_utc timestamp_ntz not null,
process_ended_at_utc timestamp_ntz not null
)
""",
)
for file in files:
c.execute(f"put file://{file} @%{table}")
c.execute("begin transaction")
commit = False
try:
if delete_filter:
logging.info(
f"Deleting from table {schema}.{table} where {delete_filter} with params [{delete_filter_params}]"
)
c.execute(f"delete from {table} where {delete_filter}", params=delete_filter_params)
if len(files) > 0:
c.execute(
f"""
copy into {table}
from @%{table}/
files=('{"', '".join([file.name for file in files])}')
file_format = (type = 'csv' field_delimiter = ',' field_optionally_enclosed_by = '\"' skip_header = 1)
purge=true
"""
)
commit = True
finally:
c.execute("commit" if commit else "rollback")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="load_to_snowflake",
description="CLI command to load a compressed CSV file into a Snowflake table.",
)
parser.add_argument(
"-s",
"--schema",
help="Schema of the table to load the data into.",
required=True,
)
parser.add_argument(
"-t",
"--table",
help="Name of the table to load the data into.",
required=True,
)
file_group = parser.add_mutually_exclusive_group()
file_group.add_argument(
"-f",
"--file",
help="Path(s) to the file(s) to load into the table.",
type=Path,
action="append",
)
file_group.add_argument(
"-F",
"--files-file",
help="File containing a list of files to load into the table. Use '-' for stdin.",
type=argparse.FileType('r'),
)
parser.add_argument(
"-r",
"--transient",
help="Set this flag to create the table as transient if it doesn't already exist.",
action="store_true",
)
parser.add_argument(
"-d",
"--delete-filter",
help=(
"Delete rows from the table matching this filter before loading the data. "
"DO NOT insert user-provided data directly into this filter as that could enable SQL injection attacks! "
"Instead provide the data as parameters and use numeric parameter binding: "
"https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-example#qmark-or-numeric-binding"
),
required=False,
)
parser.add_argument(
"delete_filter_params",
nargs="*",
help="Parameters to pass to the delete filter.",
)
parser.add_argument(
"-v",
"--verbose",
help="Be verbose. Include once for INFO output, twice for DEBUG output.",
action="count",
default=0,
)
args = parser.parse_args()
LOGGING_LEVELS = [logging.WARNING, logging.INFO, logging.DEBUG]
logging.basicConfig(level=LOGGING_LEVELS[min(args.verbose, len(LOGGING_LEVELS) - 1)]) # cap to last level index
if args.files_file:
files = [Path(line.strip()) for line in args.files_file]
elif args.file:
files = args.file
else:
files = []
load_to_snowflake(
schema=args.schema,
table=args.table,
files=files,
transient=args.transient,
delete_filter=args.delete_filter,
delete_filter_params=args.delete_filter_params,
)
backoff
snowflake-connector-python
import logging
import os
from pathlib import Path
from typing import Any
import backoff
import snowflake.connector
# Use server-side parameter binding.
# See: https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-example#using-qmark-or-numeric-binding
snowflake.connector.paramstyle = "numeric"
@backoff.on_exception(
backoff.expo,
snowflake.connector.errors.DatabaseError,
max_tries=8,
giveup=lambda err: "JWT" not in getattr(err, "raw_msg", ""),
on_backoff=lambda details: logging.info(f"Encountered JWT token issue, trying {details['tries']}/8."),
backoff_log_level=logging.DEBUG,
)
def get_snowflake_connection(**kwargs: Any) -> snowflake.connector.connection.SnowflakeConnection:
"""
Connects to Snowflake using input parameters or environment variables
See: https://docs.snowflake.com/en/developer-guide/snowflake-cli-v2/connecting/specify-credentials#use-environment-variables-for-snowflake-credentials
"""
return snowflake.connector.connect(**kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment