Skip to content

Instantly share code, notes, and snippets.

@flaviut
Last active September 16, 2025 21:56
Show Gist options
  • Save flaviut/d12fda1d790ab6cfdf4019a6603c2817 to your computer and use it in GitHub Desktop.
Save flaviut/d12fda1d790ab6cfdf4019a6603c2817 to your computer and use it in GitHub Desktop.

zstd compression results for random sample

I've taken a random sample of my HDD with generate_sample_file.py, and run some benchmarks on the zstd compression level and duration with various parameters.

chart_speed_vs_level chart_speed_ratio_tradeoff chart_ratio_vs_level
#!/usr/bin/env python3
"""
A Python script to benchmark the zstd command-line tool.
This script iterates through various zstd compression levels (1-19) and long-mode settings
to measure compression/decompression performance, effectiveness, and peak memory usage.
Input: A file path provided as a command-line argument.
Output: An SQLite database file containing the benchmark results.
Dependencies:
- The `zstd` command-line tool must be installed and available in the system's PATH.
- The `psutil` Python library: `pip install psutil`
Usage:
python benchmark_zstd_mem.py /path/to/your/file.dat
python benchmark_zstd_mem.py /path/to/your/file.dat -o custom_results.db
"""
import sys
import subprocess
import os
import time
import sqlite3
import argparse
import shutil
import random
import itertools
from typing import Tuple, List, Optional
from threading import Thread
from tqdm import tqdm
# --- Configuration ---
ZSTD_LEVELS = range(1, 20)
BYTES_TO_MB = 1 / (1024 * 1024)
# --- Dependency Checks ---
try:
import psutil
except ImportError:
print("Error: `psutil` library not found.", file=sys.stderr)
print("Please install it to run this script: `pip install psutil`", file=sys.stderr)
sys.exit(1)
def check_zstd_availability():
"""Check if the 'zstd' command is available in the system's PATH."""
if not shutil.which("zstd"):
print("Error: 'zstd' command not found in your PATH.", file=sys.stderr)
print("Please install zstd to run this benchmark.", file=sys.stderr)
sys.exit(1)
# --- Core Functions ---
def setup_database(db_path: str) -> Tuple[sqlite3.Connection, sqlite3.Cursor]:
"""Creates and sets up the SQLite database and table."""
db_exists = os.path.exists(db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
if db_exists:
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='benchmarks'"
)
if cursor.fetchone():
while True:
choice = (
input(
f"Database '{db_path}' already contains a 'benchmarks' table. Overwrite it? [y/N]: "
)
.lower()
.strip()
)
if choice == "y":
cursor.execute("DROP TABLE benchmarks")
break
elif choice in ("n", ""):
print("Exiting without modifying the database.")
conn.close()
sys.exit(0)
cursor.execute("""
CREATE TABLE benchmarks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
level INTEGER NOT NULL,
long_mode BOOLEAN NOT NULL,
original_size_bytes INTEGER NOT NULL,
compressed_size_bytes INTEGER NOT NULL,
compression_ratio REAL NOT NULL,
compression_time_sec REAL NOT NULL,
decompression_time_sec REAL NOT NULL,
compression_speed_mbps REAL NOT NULL,
decompression_speed_mbps REAL NOT NULL,
compression_peak_mem_mb REAL,
decompression_peak_mem_mb REAL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
conn.commit()
print(f"Database '{db_path}' is ready.")
return conn, cursor
def run_and_monitor_memory(
cmd: List[str],
) -> Tuple[float, Optional[int], Optional[bytes]]:
"""
Runs a command, measures its execution time, and monitors its peak memory usage.
Returns:
A tuple containing (execution_time, peak_memory_bytes, stderr).
"""
peak_mem_bytes = [0] # Use a list to be mutable inside the thread
def monitor(process: psutil.Process):
"""Polls process memory usage until it terminates."""
try:
while process.is_running():
try:
mem_info = process.memory_info()
# RSS (Resident Set Size) is a good proxy for memory usage
if mem_info.rss > peak_mem_bytes[0]:
peak_mem_bytes[0] = mem_info.rss
except (psutil.NoSuchProcess, psutil.AccessDenied):
break # Process ended before we could read memory
time.sleep(0.01) # Poll interval
except Exception:
# Broad exception to ensure thread doesn't die silently
pass
try:
start_time = time.perf_counter()
# Start the process without blocking
proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
# Start the memory monitoring thread
ps_proc = psutil.Process(proc.pid)
monitor_thread = Thread(target=monitor, args=(ps_proc,))
monitor_thread.start()
# Wait for the process and the monitor to finish
stderr_output = proc.communicate()[1]
monitor_thread.join()
end_time = time.perf_counter()
if proc.returncode != 0:
print(f"\nError executing command: {' '.join(cmd)}", file=sys.stderr)
print(f"Stderr: {stderr_output.decode()}", file=sys.stderr)
return (end_time - start_time, None, stderr_output)
return (end_time - start_time, peak_mem_bytes[0], None)
except FileNotFoundError:
print(f"\nError: Command not found: {cmd[0]}", file=sys.stderr)
return (0, None, b"Command not found")
except (psutil.NoSuchProcess, psutil.AccessDenied):
# Process might finish so fast that psutil can't attach.
# In this case, we can't measure memory, but the timing is still valid.
end_time = time.perf_counter()
return (end_time - start_time, 0, None)
def run_benchmark(input_file: str, level: int, use_long: bool) -> Optional[dict]:
"""Runs a single compression/decompression cycle and returns the results."""
base_name = (
f"{os.path.basename(input_file)}.{level}.{'long' if use_long else 'nolong'}"
)
compressed_file = f"{base_name}.zst"
decompressed_file = f"{base_name}.decomp"
results = {}
original_size = os.path.getsize(input_file)
results["original_size_bytes"] = original_size
# --- Compression ---
comp_cmd = ["zstd", f"-{level}", "-f", "-o", compressed_file, input_file]
if use_long:
comp_cmd.insert(1, "--long")
comp_time, comp_mem, comp_err = run_and_monitor_memory(comp_cmd)
if comp_err is not None:
return None
results["compression_time_sec"] = comp_time
results["compression_peak_mem_mb"] = (
comp_mem * BYTES_TO_MB if comp_mem is not None else None
)
try:
results["compressed_size_bytes"] = os.path.getsize(compressed_file)
except FileNotFoundError:
print(
f"\nError: Could not find '{compressed_file}'. Compression failed.",
file=sys.stderr,
)
return None
# --- Decompression ---
decomp_cmd = ["zstd", "-d", "-f", "-o", decompressed_file, compressed_file]
decomp_time, decomp_mem, decomp_err = run_and_monitor_memory(decomp_cmd)
if decomp_err is not None:
return None
results["decompression_time_sec"] = decomp_time
results["decompression_peak_mem_mb"] = (
decomp_mem * BYTES_TO_MB if decomp_mem is not None else None
)
# --- Verification & Cleanup ---
try:
decompressed_size = os.path.getsize(decompressed_file)
if original_size != decompressed_size:
print(
f"\nCRITICAL: Size mismatch! Original={original_size}, Decompressed={decompressed_size}",
file=sys.stderr,
)
# Calculate derived metrics
if results["compression_time_sec"] > 0:
results["compression_speed_mbps"] = (original_size * BYTES_TO_MB) / results[
"compression_time_sec"
]
else:
results["compression_speed_mbps"] = float("inf")
if results["decompression_time_sec"] > 0:
results["decompression_speed_mbps"] = (
original_size * BYTES_TO_MB
) / results["decompression_time_sec"]
else:
results["decompression_speed_mbps"] = float("inf")
if results["compressed_size_bytes"] > 0:
results["compression_ratio"] = (
original_size / results["compressed_size_bytes"]
)
else:
results["compression_ratio"] = float("inf")
finally:
if os.path.exists(compressed_file):
os.remove(compressed_file)
if os.path.exists(decompressed_file):
os.remove(decompressed_file)
return results
def main():
"""Main function to parse arguments and run the benchmark suite."""
parser = argparse.ArgumentParser(
description="A Python script to benchmark zstd performance and memory usage.",
formatter_class=argparse.RawTextHelpFormatter,
epilog="Requires `psutil` and `tqdm` libraries: `pip install psutil tqdm`",
)
parser.add_argument("input_file", help="The input file to use for benchmarking.")
parser.add_argument(
"-o",
"--output_db",
default="zstd_benchmark_mem.db",
help="Path to the output SQLite database file (default: zstd_benchmark_mem.db).",
)
args = parser.parse_args()
# check_zstd_availability() # Assumed to be defined
input_file = args.input_file
if not os.path.isfile(input_file):
print(f"Error: Input file not found at '{input_file}'", file=sys.stderr)
sys.exit(1)
conn, cursor = setup_database(args.output_db)
# 1. Generate all benchmark configurations (level, use_long)
long_modes = [False, True]
benchmark_configs = list(itertools.product(ZSTD_LEVELS, long_modes))
# 2. Randomize the order of execution to get more accurate time estimates
random.shuffle(benchmark_configs)
print(f"Starting {len(benchmark_configs)} benchmark runs in random order...")
try:
# 3. Use tqdm to create a progress bar over the shuffled configurations
for level, use_long in tqdm(benchmark_configs, desc="Running Benchmarks"):
results = run_benchmark(input_file, level, use_long)
if results:
cursor.execute(
"""
INSERT INTO benchmarks (
level, long_mode, original_size_bytes, compressed_size_bytes,
compression_ratio, compression_time_sec, decompression_time_sec,
compression_speed_mbps, decompression_speed_mbps,
compression_peak_mem_mb, decompression_peak_mem_mb
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
level,
use_long,
results["original_size_bytes"],
results["compressed_size_bytes"],
results["compression_ratio"],
results["compression_time_sec"],
results["decompression_time_sec"],
results["compression_speed_mbps"],
results["decompression_speed_mbps"],
results["compression_peak_mem_mb"],
results["decompression_peak_mem_mb"],
),
)
conn.commit()
except KeyboardInterrupt:
print("\nBenchmark interrupted by user. Partial results are saved.")
finally:
conn.close()
print("\nBenchmark finished.")
print(f"Results have been saved to '{args.output_db}'")
if __name__ == "__main__":
main()
import os
import random
import sys
from tqdm import tqdm
# --- Configuration ---
# The final size of the output file in Gibibytes (GiB)
TARGET_GIB = 1
TARGET_SIZE_BYTES = int(TARGET_GIB * 1024**3)
# The maximum amount of data to read from any single source file in Mebibytes (MiB)
MAX_CHUNK_MIB = 32
MAX_CHUNK_BYTES = int(MAX_CHUNK_MIB * 1024**2)
# The name of the generated file
OUTPUT_FILENAME = "filesystem_sample.bin"
# The starting directory for the file search.
# This script is specifically designed to scan from the root filesystem.
START_PATH = "/"
def create_sample_file(root_dir, target_bytes, max_chunk_bytes, output_file_name):
"""
Generates a large sample file by combining random chunks of other files
from a single filesystem.
"""
print(f"Starting file scan on the '{root_dir}' filesystem.")
print("This will not cross into other mounted filesystems (e.g., /home, /boot).")
print("The initial scan may take a while...")
# Get the device ID of the starting filesystem.
# This is the key to staying on one filesystem. This is a Unix-specific feature.
try:
root_dev = os.stat(root_dir).st_dev
except OSError as e:
print(
f"FATAL: Could not stat root directory '{root_dir}': {e}", file=sys.stderr
)
print("Please ensure you are running with 'sudo'.", file=sys.stderr)
sys.exit(1)
all_files = []
walk_iterator = os.walk(root_dir, topdown=True, onerror=lambda e: None)
for dirpath, dirnames, filenames in tqdm(
walk_iterator, desc="Scanning filesystem", unit=" dirs"
):
# Check if the current directory is on the same device as the root.
# If not, we have crossed a mount point.
try:
if os.stat(dirpath).st_dev != root_dev:
# Prune the list of directories to prevent os.walk from descending
# into this other filesystem.
dirnames.clear()
# Skip processing files in this non-root directory
continue
except OSError:
# If we cannot stat a directory, don't descend into it.
dirnames.clear()
continue
for filename in filenames:
file_path = os.path.join(dirpath, filename)
# Ensure it's a file and not a broken symlink or other non-file type
if os.path.isfile(file_path):
all_files.append(file_path)
if not all_files:
print("\nError: No files found to sample from. Exiting.", file=sys.stderr)
return
print(f"Found {len(all_files):,} files. Randomizing list...")
random.shuffle(all_files)
print(f"Beginning generation of '{output_file_name}' ({TARGET_GIB} GiB)...")
try:
# Initialize tqdm progress bar
with tqdm(
total=target_bytes,
unit="B",
unit_scale=True,
unit_divisor=1024,
desc="Generating",
) as pbar:
with open(output_file_name, "wb") as output_file:
# Iterate through the shuffled list of files
for file_path in all_files:
if pbar.n >= target_bytes:
break # Target size reached
# Update tqdm description to show current file
pbar.set_description(f"Reading {file_path_short(file_path)}")
try:
with open(file_path, "rb") as input_file:
bytes_needed = target_bytes - pbar.n
bytes_to_read = min(max_chunk_bytes, bytes_needed)
data_chunk = input_file.read(bytes_to_read)
if data_chunk:
output_file.write(data_chunk)
# Update the progress bar by the number of bytes written
pbar.update(len(data_chunk))
except (IOError, PermissionError):
# Silently skip files we can't read
continue
except IOError as e:
print(f"\nFATAL ERROR: Could not write to output file '{output_file_name}'.")
print(f"Reason: {e}", file=sys.stderr)
return
remaining_bytes = target_bytes - os.path.getsize(output_file_name)
if remaining_bytes > 0:
print(
f"NOTE: Could not reach target size. Final size is smaller by {remaining_bytes / 1024**3:.3f} GiB."
)
print("This can happen if there aren't enough readable files.")
final_size_gib = os.path.getsize(output_file_name) / 1024**3
print(f"\nFinished! Wrote {final_size_gib:.3f} GiB to '{output_file_name}'.")
def file_path_short(path, max_len=40):
"""Truncates a file path for cleaner display in tqdm."""
if len(path) > max_len:
return "..." + path[-(max_len - 3) :]
return path.ljust(max_len)
if __name__ == "__main__":
if os.name != "posix":
print(
"Error: This script's method for staying on one filesystem is specific to POSIX-compliant",
file=sys.stderr,
)
print(
"systems (like Linux and macOS) and will not work correctly on Windows.",
file=sys.stderr,
)
sys.exit(1)
if os.path.exists(OUTPUT_FILENAME):
print(
f"Error: Output file '{OUTPUT_FILENAME}' already exists.", file=sys.stderr
)
print("Please remove or rename it before running the script.", file=sys.stderr)
sys.exit(1)
create_sample_file(
root_dir=START_PATH,
target_bytes=TARGET_SIZE_BYTES,
max_chunk_bytes=MAX_CHUNK_BYTES,
output_file_name=OUTPUT_FILENAME,
)
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
DB_FILE = "zstd_benchmark_mem.db"
def load_data(db_path):
"""Loads benchmark data from the SQLite database into a pandas DataFrame."""
print(f"Loading data from '{db_path}'...")
con = sqlite3.connect(db_path)
# Load data, sorting by level to ensure lines are drawn correctly
df = pd.read_sql_query("SELECT * FROM benchmarks ORDER BY level", con)
con.close()
# Split data by long_mode for easier plotting
df_normal = df[df["long_mode"] == 0].copy()
df_long = df[df["long_mode"] == 1].copy()
return df_normal, df_long
def plot_ratio_vs_level(df_normal, df_long):
"""Plots Compression Ratio vs. Compression Level."""
fig, ax = plt.subplots(figsize=(12, 7))
ax.plot(
df_normal["level"],
df_normal["compression_ratio"],
marker="o",
linestyle="-",
label="Standard Mode",
)
if not df_long.empty:
ax.plot(
df_long["level"],
df_long["compression_ratio"],
marker="x",
linestyle="--",
label="Long Mode",
)
ax.set_title("Zstandard: Compression Ratio vs. Level", fontsize=16)
ax.set_xlabel("Compression Level")
ax.set_ylabel("Compression Ratio (Original / Compressed)")
ax.legend()
ax.grid(True, which="both", linestyle="--", linewidth=0.5)
ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
fig.tight_layout()
plt.savefig("chart_ratio_vs_level.png")
def plot_speed_vs_level(df_normal, df_long):
"""Plots Compression and Decompression Speed vs. Compression Level."""
fig, ax = plt.subplots(figsize=(12, 7))
# Compression Speed
ax.plot(
df_normal["level"],
df_normal["compression_speed_mbps"],
marker="o",
linestyle="-",
color="C0",
label="Compression Speed (Standard)",
)
if not df_long.empty:
ax.plot(
df_long["level"],
df_long["compression_speed_mbps"],
marker="x",
linestyle="--",
color="C1",
label="Compression Speed (Long)",
)
ax.set_yscale("log")
ax.yaxis.set_major_formatter(mticker.ScalarFormatter())
ax.set_title("Zstandard: Speed vs. Level", fontsize=16)
ax.set_xlabel("Compression Level")
ax.set_ylabel("Compression Speed (MB/s) [Log Scale]")
ax.grid(True, which="both", linestyle="--", linewidth=0.5)
# Decompression Speed on a secondary y-axis
ax2 = ax.twinx()
ax2.plot(
df_normal["level"],
df_normal["decompression_speed_mbps"],
marker="s",
linestyle=":",
color="C2",
label="Decompression Speed (Standard)",
)
if not df_long.empty:
ax2.plot(
df_long["level"],
df_long["decompression_speed_mbps"],
marker="d",
linestyle="-.",
color="C3",
label="Decompression Speed (from Long)",
)
ax2.set_ylabel("Decompression Speed (MB/s)")
# Combine legends
lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines + lines2, labels + labels2, loc="upper right")
ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
fig.tight_layout()
plt.savefig("chart_speed_vs_level.png")
def plot_tradeoff(df_normal, df_long):
"""Plots Compression Speed vs. Compression Ratio to show the trade-off."""
fig, ax = plt.subplots(figsize=(12, 8))
ax.plot(
df_normal["compression_ratio"],
df_normal["compression_speed_mbps"],
marker="o",
linestyle="-",
label="Standard Mode",
)
if not df_long.empty:
ax.plot(
df_long["compression_ratio"],
df_long["compression_speed_mbps"],
marker="x",
linestyle="--",
label="Long Mode",
)
# Annotate points with their compression level
for df, mode in [(df_normal, "std"), (df_long, "long")]:
if df.empty:
continue
for i, row in df.iterrows():
ax.text(
row["compression_ratio"],
row["compression_speed_mbps"] * 1.1,
f"{row['level']}",
fontsize=8,
ha="center",
)
ax.set_title("Zstandard: Speed vs. Ratio Trade-off", fontsize=16)
ax.set_xlabel("Compression Ratio")
ax.set_ylabel("Compression Speed (MB/s) [Log Scale]")
ax.set_yscale("log")
ax.yaxis.set_major_formatter(mticker.ScalarFormatter())
ax.legend()
ax.grid(True, which="both", linestyle="--", linewidth=0.5)
fig.tight_layout()
plt.savefig("chart_speed_ratio_tradeoff.png")
if __name__ == "__main__":
# Set a nice style for the plots
try:
plt.style.use("seaborn-v0_8-whitegrid")
except:
plt.style.use("ggplot")
# Load data from the database
df_normal, df_long = load_data(DB_FILE)
if df_normal.empty and df_long.empty:
print("No data found in the database. Exiting.")
else:
# Generate and save the charts
print("Generating plots...")
plot_ratio_vs_level(df_normal, df_long)
plot_speed_vs_level(df_normal, df_long)
plot_tradeoff(df_normal, df_long)
print(f"Charts saved to 'chart_*.png' in the current directory.")
# Display the plots
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment