Created
June 8, 2026 20:32
-
-
Save TomAugspurger/cebf449f7621f69ff70a5452f1ca5706 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| # SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| """Small benchmark for shuffle-based joins in cudf-polars.""" | |
| from __future__ import annotations | |
| import argparse | |
| import concurrent.futures | |
| import contextlib | |
| import dataclasses | |
| import json | |
| import tempfile | |
| import textwrap | |
| import time | |
| from pathlib import Path | |
| from typing import Any, TYPE_CHECKING | |
| import numpy as np | |
| import polars as pl | |
| from cudf_polars import Translator | |
| from cudf_polars.dsl.traversal import traversal | |
| from cudf_polars.engine.options import StreamingOptions | |
| from cudf_polars.streaming.parallel import lower_ir_graph | |
| from cudf_polars.streaming.shuffle import Shuffle | |
| from cudf_polars.streaming.statistics import collect_statistics | |
| from cudf_polars.utils.config import ConfigOptions | |
| if TYPE_CHECKING: | |
| from collections.abc import Iterator | |
| try: | |
| import nvtx | |
| except ImportError: # pragma: no cover | |
| nvtx = None | |
| def build_parser() -> argparse.ArgumentParser: | |
| """Build command-line parser.""" | |
| parser = argparse.ArgumentParser( | |
| description="Benchmark a cudf-polars join configured to prefer shuffle joins." | |
| ) | |
| parser.add_argument("--n-left", type=int, default=8_000_000) | |
| parser.add_argument("--n-right", type=int, default=2_000_000) | |
| parser.add_argument( | |
| "--distinct-keys", | |
| type=int, | |
| default=2_000_000, | |
| help="Number of key values sampled by the left side.", | |
| ) | |
| parser.add_argument("--seed", type=int, default=0) | |
| parser.add_argument( | |
| "--how", | |
| type=str, | |
| default="inner", | |
| choices=["inner", "left", "right", "full", "semi", "anti"], | |
| ) | |
| parser.add_argument("--iterations", type=int, default=3) | |
| parser.add_argument( | |
| "--data-mode", | |
| type=str, | |
| default="in-memory", | |
| choices=["in-memory", "parquet"], | |
| help="Use DataFrameScan inputs or write+scan temporary parquet files.", | |
| ) | |
| parser.add_argument( | |
| "--frontend", | |
| required=True, | |
| type=str, | |
| choices=["dask", "duckdb", "in-memory", "polars-cpu", "ray", "spmd"], | |
| help=textwrap.dedent("""\ | |
| Execution frontend: | |
| - dask : Dask distributed multi-GPU execution | |
| - duckdb : DuckDB CPU execution | |
| - in-memory : Single-process GPU, in-memory evaluation | |
| - polars-cpu : Polars CPU streaming engine (no GPU) | |
| - ray : Ray actor-based multi-GPU execution | |
| - spmd : SPMD execution via rrun launcher"""), | |
| ) | |
| parser.add_argument( | |
| "--connect", | |
| dest="connect", | |
| default=None, | |
| type=str, | |
| help=textwrap.dedent("""\ | |
| Connect to an existing cluster instead of creating a local one. | |
| Only supported with --frontend dask or ray: | |
| - dask : a TCP address (e.g. tcp://host:8786) or a scheduler file path | |
| - ray : a Ray address (e.g. ray://host:10001 or "auto")"""), | |
| ) | |
| parser.add_argument( | |
| "--num-gpus", | |
| dest="num_gpus", | |
| default=None, | |
| type=int, | |
| help="Number of GPUs for local cluster creation (--frontend ray/dask only). " | |
| "Cannot be used with --connect. Defaults to all visible GPUs.", | |
| ) | |
| parser.add_argument( | |
| "--verify-plan", | |
| action=argparse.BooleanOptionalAction, | |
| default=True, | |
| help="Count Shuffle nodes in lowered IR (static planning only).", | |
| ) | |
| parser.add_argument( | |
| "--check-broadcast-sensitivity", | |
| action=argparse.BooleanOptionalAction, | |
| default=False, | |
| help="Also inspect shuffle count with a high broadcast limit.", | |
| ) | |
| parser.add_argument( | |
| "--sensitivity-broadcast-limit", | |
| type=int, | |
| default=1_000_000_000, | |
| help="Broadcast limit to use for sensitivity check.", | |
| ) | |
| StreamingOptions._add_cli_args(parser) | |
| parser.set_defaults( | |
| dynamic_planning=False, | |
| max_rows_per_partition=250_000, | |
| target_partition_size=1, | |
| broadcast_limit=1, | |
| ) | |
| return parser | |
| def make_tables(args: argparse.Namespace) -> tuple[pl.DataFrame, pl.DataFrame]: | |
| """Create synthetic left/right tables for join benchmarking.""" | |
| if args.n_left < 1 or args.n_right < 1: | |
| raise ValueError("Both --n-left and --n-right must be >= 1.") | |
| if args.distinct_keys < 1: | |
| raise ValueError("--distinct-keys must be >= 1.") | |
| rng = np.random.default_rng(args.seed) | |
| key_domain = max(args.distinct_keys, args.n_right) | |
| left = pl.DataFrame( | |
| { | |
| "key": rng.integers( | |
| 0, key_domain, size=args.n_left, dtype=np.int64 | |
| ), | |
| "left_payload": rng.integers( | |
| 0, 1_000_000, size=args.n_left, dtype=np.int64 | |
| ), | |
| "left_row_id": np.arange(args.n_left, dtype=np.int64), | |
| } | |
| ) | |
| right = pl.DataFrame( | |
| { | |
| "key": np.arange(args.n_right, dtype=np.int64), | |
| "right_payload": rng.integers( | |
| 0, 1_000_000, size=args.n_right, dtype=np.int64 | |
| ), | |
| "right_row_id": np.arange(args.n_right, dtype=np.int64), | |
| } | |
| ) | |
| return left, right | |
| @contextlib.contextmanager | |
| def build_query(args: argparse.Namespace) -> Iterator[pl.LazyFrame]: | |
| """Yield a join query over synthetic inputs.""" | |
| left, right = make_tables(args) | |
| if args.data_mode == "in-memory": | |
| left_lf = left.lazy() | |
| right_lf = right.lazy() | |
| yield left_lf.join(right_lf, on="key", how=args.how) | |
| return | |
| with tempfile.TemporaryDirectory( | |
| prefix="cudf_polars_shuffle_join_" | |
| ) as tmp: | |
| tmp_path = Path(tmp) | |
| left_path = tmp_path / "left.parquet" | |
| right_path = tmp_path / "right.parquet" | |
| left.write_parquet(left_path) | |
| right.write_parquet(right_path) | |
| left_lf = pl.scan_parquet(left_path) | |
| right_lf = pl.scan_parquet(right_path) | |
| yield left_lf.join(right_lf, on="key", how=args.how) | |
| def create_engine(args: argparse.Namespace) -> pl.GPUEngine: | |
| """Create a streaming GPUEngine using StreamingOptions CLI args.""" | |
| streaming_options = StreamingOptions._from_argparse(args) | |
| executor_options = streaming_options.to_executor_options() | |
| engine_options = streaming_options.to_engine_options() | |
| engine_options.setdefault("raise_on_fail", True) | |
| return pl.GPUEngine( | |
| executor="streaming", | |
| executor_options=executor_options, | |
| **engine_options, | |
| ) | |
| def _executor_summary(engine: pl.GPUEngine) -> dict[str, object]: | |
| config_options = ConfigOptions.from_polars_engine(engine) | |
| executor = config_options.executor | |
| dynamic_planning = executor.dynamic_planning | |
| if dynamic_planning is not None: | |
| dynamic_planning = dataclasses.asdict(dynamic_planning) | |
| return { | |
| "cluster": executor.cluster, | |
| "max_rows_per_partition": executor.max_rows_per_partition, | |
| "target_partition_size": executor.target_partition_size, | |
| "broadcast_limit": executor.broadcast_limit, | |
| "dynamic_planning": dynamic_planning, | |
| "min_device_size": executor.min_device_size, | |
| } | |
| def count_shuffle_nodes( | |
| query: pl.LazyFrame, engine: pl.GPUEngine | |
| ) -> int | None: | |
| """Return the number of lowered Shuffle nodes for static planning.""" | |
| config_options = ConfigOptions.from_polars_engine(engine) | |
| if config_options.executor.dynamic_planning is not None: | |
| return None | |
| ir = Translator(query._ldf.visit(), engine).translate_ir() | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=2) as io_pool: | |
| lowered_ir, _ = lower_ir_graph( | |
| ir, | |
| config_options, | |
| collect_statistics(ir, config_options, io_pool), | |
| ) | |
| return sum( | |
| 1 for node in traversal([lowered_ir]) if isinstance(node, Shuffle) | |
| ) | |
| def maybe_verify_plan( | |
| args: argparse.Namespace, query: pl.LazyFrame, engine: pl.GPUEngine | |
| ) -> None: | |
| if not args.verify_plan: | |
| return | |
| shuffle_count = count_shuffle_nodes(query, engine) | |
| if shuffle_count is None: | |
| print("Plan verification skipped: dynamic planning is enabled.") | |
| else: | |
| print(f"Lowered Shuffle node count: {shuffle_count}") | |
| if shuffle_count > 0: | |
| print("Plan check: shuffle/hash join path detected.") | |
| else: | |
| print( | |
| "Plan check: no shuffle nodes detected (likely broadcast join)." | |
| ) | |
| if not args.check_broadcast_sensitivity: | |
| return | |
| streaming_options = StreamingOptions._from_argparse(args) | |
| executor_options = streaming_options.to_executor_options() | |
| executor_options["dynamic_planning"] = None | |
| executor_options["broadcast_limit"] = args.sensitivity_broadcast_limit | |
| engine_options = streaming_options.to_engine_options() | |
| engine_options.setdefault("raise_on_fail", True) | |
| sensitivity_engine = pl.GPUEngine( | |
| executor="streaming", | |
| executor_options=executor_options, | |
| **engine_options, | |
| ) | |
| sensitivity_shuffle_count = count_shuffle_nodes(query, sensitivity_engine) | |
| print( | |
| "Sensitivity check shuffle count " | |
| f"(broadcast_limit={args.sensitivity_broadcast_limit}): " | |
| f"{sensitivity_shuffle_count}" | |
| ) | |
| def run_benchmark(args: argparse.Namespace) -> None: | |
| planning_engine = create_engine(args) | |
| print("Resolved streaming executor options:") | |
| print( | |
| json.dumps( | |
| _executor_summary(planning_engine), | |
| default=str, | |
| indent=2, | |
| sort_keys=True, | |
| ) | |
| ) | |
| with build_query(args) as query: | |
| maybe_verify_plan(args, query, planning_engine) | |
| durations: list[float] = [] | |
| result_rows: int | None = None | |
| stream_options = StreamingOptions._from_argparse(args) | |
| executor_options = stream_options.to_executor_options() | |
| engine_options = stream_options.to_engine_options() | |
| engine_options.setdefault("raise_on_fail", True) | |
| if args.frontend in {"dask", "duckdb", "polars-cpu"}: | |
| raise NotImplementedError( | |
| f"--frontend {args.frontend!r} is not implemented in this microbenchmark. " | |
| "Use in-memory, ray, or spmd." | |
| ) | |
| @contextlib.contextmanager | |
| def engine_cm() -> Iterator[Any]: | |
| if args.frontend == "in-memory": | |
| yield planning_engine | |
| return | |
| if args.frontend == "ray": | |
| from cudf_polars.engine.ray import RayEngine | |
| ray_executor_options = executor_options.copy() | |
| ray_executor_options.pop("cluster", None) | |
| ray_init_options: dict[str, object] = {} | |
| if args.connect is not None: | |
| ray_init_options["address"] = args.connect | |
| if args.num_gpus is not None: | |
| ray_init_options["num_gpus"] = args.num_gpus | |
| ray_init_options["runtime_env"] = { | |
| "nsight": { | |
| "python-backtrace": "cuda", | |
| "python-sampling": "true", | |
| "trace": "cuda,osrt,nvtx,python-gil,ucx", | |
| } | |
| } | |
| with RayEngine( | |
| rapidsmpf_options=stream_options.to_rapidsmpf_options(), | |
| executor_options=ray_executor_options, | |
| engine_options=engine_options, | |
| ray_init_options=ray_init_options, | |
| ) as engine: | |
| yield engine | |
| return | |
| if args.frontend == "spmd": | |
| from cudf_polars.engine.spmd import SPMDEngine | |
| spmd_executor_options = executor_options.copy() | |
| spmd_executor_options.pop("cluster", None) | |
| with SPMDEngine( | |
| rapidsmpf_options=stream_options.to_rapidsmpf_options(), | |
| executor_options=spmd_executor_options, | |
| engine_options=engine_options, | |
| ) as engine: | |
| yield engine | |
| return | |
| raise AssertionError(f"Unexpected frontend {args.frontend!r}") | |
| with engine_cm() as engine: | |
| if args.frontend == "spmd" and getattr(engine, "rank", 0) != 0: | |
| # Non-root ranks still execute query but don't print benchmark output. | |
| for _ in range(args.iterations): | |
| query.collect(engine=engine) | |
| return | |
| for i in range(args.iterations): | |
| if nvtx is None: | |
| annotation = contextlib.nullcontext() | |
| else: | |
| annotation = nvtx.annotate( | |
| message=f"shuffle-join iteration {i}", | |
| domain="cudf_polars", | |
| color="green", | |
| ) | |
| with annotation: | |
| start = time.monotonic() | |
| result = query.collect(engine=engine) | |
| durations.append(time.monotonic() - start) | |
| result_rows = result.height | |
| print( | |
| f"Iteration {i}: {durations[-1]:.4f}s" | |
| + ( | |
| f", rows={result_rows}" | |
| if result_rows is not None | |
| else "" | |
| ) | |
| ) | |
| mean_duration = sum(durations) / len(durations) | |
| print( | |
| "Timing summary (seconds): " | |
| f"min={min(durations):.4f}, max={max(durations):.4f}, mean={mean_duration:.4f}" | |
| ) | |
| def main() -> None: | |
| args = build_parser().parse_args() | |
| run_benchmark(args) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment