Skip to content

Instantly share code, notes, and snippets.

@TomAugspurger
Created August 21, 2025 14:03
Show Gist options
  • Save TomAugspurger/c2237f0764ea65d7e0fa99011260d605 to your computer and use it in GitHub Desktop.
Save TomAugspurger/c2237f0764ea65d7e0fa99011260d605 to your computer and use it in GitHub Desktop.
"""
Outputs the following (parital) IR for query 3. Note the multiple `Select` nodes.
Sort(
schema={'l_orderkey': <DataType(polars=Int64, plc=<type_id.INT64: 4>)>, 'revenue': <DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'o_orderdate': <DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_shippriority': <DataType(polars=Int32, plc=<type_id.INT32: 3>)>},
by=(NamedExpr(revenue, Col(<DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'revenue')), NamedExpr(o_orderdate, Col(<DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_orderdate'))),
order=(<order.DESCENDING: 1>, <order.ASCENDING: 0>),
null_order=(<null_order.AFTER: 0>, <null_order.BEFORE: 1>),
stable=False,
zlice=(0, 10),
Select(
schema={'l_orderkey': <DataType(polars=Int64, plc=<type_id.INT64: 4>)>, 'revenue': <DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'o_orderdate': <DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_shippriority': <DataType(polars=Int32, plc=<type_id.INT32: 3>)>},
exprs=(
NamedExpr(l_orderkey, Col(<DataType(polars=Int64, plc=<type_id.INT64: 4>)>, 'o_orderkey')),
NamedExpr(revenue, Col(<DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'revenue')),
NamedExpr(o_orderdate, Col(<DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_orderdate')),
NamedExpr(o_shippriority, Col(<DataType(polars=Int32, plc=<type_id.INT32: 3>)>, 'o_shippriority'))
),
should_broadcast=True,
Select(
schema={'o_orderkey': <DataType(polars=Int64, plc=<type_id.INT64: 4>)>, 'o_orderdate': <DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_shippriority': <DataType(polars=Int32, plc=<type_id.INT32: 3>)>, 'revenue': <DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>},
exprs=(
NamedExpr(o_orderkey, Col(<DataType(polars=Int64, plc=<type_id.INT64: 4>)>, 'o_orderkey')),
NamedExpr(o_orderdate, Col(<DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_orderdate')),
NamedExpr(o_shippriority, Col(<DataType(polars=Int32, plc=<type_id.INT32: 3>)>, 'o_shippriority')),
NamedExpr(revenue, UnaryFunction(<DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'fill_null', (), Col(<DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'revenue'), Literal(<DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 0)))
),
should_broadcast=True,
GroupBy(
schema={'o_orderkey': <DataType(polars=Int64, plc=<type_id.INT64: 4>)>, 'o_orderdate': <DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_shippriority': <DataType(polars=Int32, plc=<type_id.INT32: 3>)>, 'revenue': <DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>},
keys=(
NamedExpr(o_orderkey, Col(<DataType(polars=Int64, plc=<type_id.INT64: 4>)>, 'o_orderkey')),
NamedExpr(o_orderdate, Col(<DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_orderdate')),
NamedExpr(o_shippriority, Col(<DataType(polars=Int32, plc=<type_id.INT32: 3>)>, 'o_shippriority'))
),
agg_requests=(NamedExpr(revenue, Agg(<DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'sum', None, Col(<DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'revenue'))),),
maintain_order=False,
zlice=None,
...
"""
import shlex
import polars as pl
import rich
from cudf_polars.dsl.ir import Scan
from cudf_polars.dsl.translate import Translator
from cudf_polars.dsl.traversal import traversal
from cudf_polars.experimental.benchmarks.pdsh import PDSHQueries
from cudf_polars.experimental.benchmarks.utils import (
RunConfig,
get_executor_options,
parse_args,
)
from cudf_polars.utils.config import ConfigOptions
def main():
options = shlex.split("--path /datasets/toaugspurger/tpch/scale-100/ --no-print-results --iterations 2 --executor streaming --scheduler distributed --n-workers=8 --protocol=ucx --shuffle=rapidsmpf --rapidsmpf-dask-statistics --no-rapidsmpf-print-statistics --rmm-async --explain 3")
benchmark = PDSHQueries
args = parse_args(options, num_queries=22)
vars(args).update({"query_set": "pdsh"})
run_config = RunConfig.from_args(args)
executor_options = get_executor_options(run_config, benchmark=benchmark)
engine = pl.GPUEngine(
raise_on_fail=True,
executor=run_config.executor,
executor_options=executor_options,
)
q = benchmark.q3(run_config)
ir = Translator(q._ldf.visit(), engine).translate_ir()
rich.print(ir)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment