Created
August 21, 2025 14:03
-
-
Save TomAugspurger/c2237f0764ea65d7e0fa99011260d605 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Outputs the following (parital) IR for query 3. Note the multiple `Select` nodes. | |
Sort( | |
schema={'l_orderkey': <DataType(polars=Int64, plc=<type_id.INT64: 4>)>, 'revenue': <DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'o_orderdate': <DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_shippriority': <DataType(polars=Int32, plc=<type_id.INT32: 3>)>}, | |
by=(NamedExpr(revenue, Col(<DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'revenue')), NamedExpr(o_orderdate, Col(<DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_orderdate'))), | |
order=(<order.DESCENDING: 1>, <order.ASCENDING: 0>), | |
null_order=(<null_order.AFTER: 0>, <null_order.BEFORE: 1>), | |
stable=False, | |
zlice=(0, 10), | |
Select( | |
schema={'l_orderkey': <DataType(polars=Int64, plc=<type_id.INT64: 4>)>, 'revenue': <DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'o_orderdate': <DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_shippriority': <DataType(polars=Int32, plc=<type_id.INT32: 3>)>}, | |
exprs=( | |
NamedExpr(l_orderkey, Col(<DataType(polars=Int64, plc=<type_id.INT64: 4>)>, 'o_orderkey')), | |
NamedExpr(revenue, Col(<DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'revenue')), | |
NamedExpr(o_orderdate, Col(<DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_orderdate')), | |
NamedExpr(o_shippriority, Col(<DataType(polars=Int32, plc=<type_id.INT32: 3>)>, 'o_shippriority')) | |
), | |
should_broadcast=True, | |
Select( | |
schema={'o_orderkey': <DataType(polars=Int64, plc=<type_id.INT64: 4>)>, 'o_orderdate': <DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_shippriority': <DataType(polars=Int32, plc=<type_id.INT32: 3>)>, 'revenue': <DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>}, | |
exprs=( | |
NamedExpr(o_orderkey, Col(<DataType(polars=Int64, plc=<type_id.INT64: 4>)>, 'o_orderkey')), | |
NamedExpr(o_orderdate, Col(<DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_orderdate')), | |
NamedExpr(o_shippriority, Col(<DataType(polars=Int32, plc=<type_id.INT32: 3>)>, 'o_shippriority')), | |
NamedExpr(revenue, UnaryFunction(<DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'fill_null', (), Col(<DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'revenue'), Literal(<DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 0))) | |
), | |
should_broadcast=True, | |
GroupBy( | |
schema={'o_orderkey': <DataType(polars=Int64, plc=<type_id.INT64: 4>)>, 'o_orderdate': <DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_shippriority': <DataType(polars=Int32, plc=<type_id.INT32: 3>)>, 'revenue': <DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>}, | |
keys=( | |
NamedExpr(o_orderkey, Col(<DataType(polars=Int64, plc=<type_id.INT64: 4>)>, 'o_orderkey')), | |
NamedExpr(o_orderdate, Col(<DataType(polars=Datetime(time_unit='ms', time_zone=None), plc=<type_id.TIMESTAMP_MILLISECONDS: 14>)>, 'o_orderdate')), | |
NamedExpr(o_shippriority, Col(<DataType(polars=Int32, plc=<type_id.INT32: 3>)>, 'o_shippriority')) | |
), | |
agg_requests=(NamedExpr(revenue, Agg(<DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'sum', None, Col(<DataType(polars=Float64, plc=<type_id.FLOAT64: 10>)>, 'revenue'))),), | |
maintain_order=False, | |
zlice=None, | |
... | |
""" | |
import shlex | |
import polars as pl | |
import rich | |
from cudf_polars.dsl.ir import Scan | |
from cudf_polars.dsl.translate import Translator | |
from cudf_polars.dsl.traversal import traversal | |
from cudf_polars.experimental.benchmarks.pdsh import PDSHQueries | |
from cudf_polars.experimental.benchmarks.utils import ( | |
RunConfig, | |
get_executor_options, | |
parse_args, | |
) | |
from cudf_polars.utils.config import ConfigOptions | |
def main(): | |
options = shlex.split("--path /datasets/toaugspurger/tpch/scale-100/ --no-print-results --iterations 2 --executor streaming --scheduler distributed --n-workers=8 --protocol=ucx --shuffle=rapidsmpf --rapidsmpf-dask-statistics --no-rapidsmpf-print-statistics --rmm-async --explain 3") | |
benchmark = PDSHQueries | |
args = parse_args(options, num_queries=22) | |
vars(args).update({"query_set": "pdsh"}) | |
run_config = RunConfig.from_args(args) | |
executor_options = get_executor_options(run_config, benchmark=benchmark) | |
engine = pl.GPUEngine( | |
raise_on_fail=True, | |
executor=run_config.executor, | |
executor_options=executor_options, | |
) | |
q = benchmark.q3(run_config) | |
ir = Translator(q._ldf.visit(), engine).translate_ir() | |
rich.print(ir) | |
if __name__ == "__main__": | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment