Skip to content

Instantly share code, notes, and snippets.

@TomAugspurger
Created April 7, 2026 21:36
Show Gist options
  • Select an option

  • Save TomAugspurger/24b668b2cdc28909f77f86c610b999c6 to your computer and use it in GitHub Desktop.

Select an option

Save TomAugspurger/24b668b2cdc28909f77f86c610b999c6 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
"""
Analyze an Nsight Systems report (``.nsys-rep``) exported to SQLite.
Subcommands:
- ``summary`` — per IR node type: summed host NVTX time and attributed GPU
kernel time in the chosen NVTX domain (default ``cudf_polars``).
- ``io`` — libkvikio ``RemoteHandle::read`` ranges: payload bytes and rough
throughput (GB/s) from NVTX in domain ``libkvikio``.
"""
from __future__ import annotations
import argparse
import re
import sqlite3
import struct
import subprocess
import sys
from collections import defaultdict
from pathlib import Path
from rich.console import Console
from rich.table import Table
# IR node labels from nvtx_annotate_cudf_polars on do_evaluate in dsl/ir.py
DEFAULT_IR_LABELS: frozenset[str] = frozenset(
{
"Scan",
"Sink",
"Cache",
"DataFrameScan",
"Select",
"Reduce",
"Rolling",
"GroupBy",
"ConditionalJoin",
"Join",
"HStack",
"Distinct",
"Sort",
"Slice",
"Filter",
"Projection",
"MergeSorted",
"MapFunction",
"Union",
"HConcat",
"Empty",
}
)
DEFAULT_EXCLUDE_LABELS: frozenset[str] = frozenset({"ConvertIR", "ExecuteIR"})
# nvtxPayloadType_t (nvToolsExt.h): kvikio uses UINT64 for read size in KVIKIO_NVTX_FUNC_RANGE(size).
NVTX_PAYLOAD_TYPE_UNSIGNED_INT64 = 1
NVTX_PAYLOAD_TYPE_INT64 = 2
DEFAULT_KVIKIO_NVTX_DOMAIN = "libkvikio"
# RemoteHandle::read uses __PRETTY_FUNCTION__ (see kvikio remote_handle.cpp).
KVIKIO_REMOTE_READ_SUBSTR = "RemoteHandle::read"
QUERY_ITER_RE = re.compile(
r"^Query\s+(\d+)\s+-\s+Iteration\s+(\d+)\s*$", re.IGNORECASE
)
def _default_sqlite_path(nsys_rep: Path) -> Path:
return nsys_rep.with_suffix(".sqlite")
def export_sqlite(nsys_rep: Path, sqlite_out: Path, *, force: bool) -> None:
if sqlite_out.exists() and not force:
return
sqlite_out.parent.mkdir(parents=True, exist_ok=True)
cmd = [
"nsys",
"export",
"--type",
"sqlite",
"--force-overwrite=true",
"-q",
"true",
"-o",
str(sqlite_out),
str(nsys_rep),
]
subprocess.run(cmd, check=True)
def _nvtx_range_event_types(conn: sqlite3.Connection) -> list[int]:
cur = conn.execute(
"""
SELECT id FROM ENUM_NSYS_EVENT_TYPE
WHERE name IN (
'NvtxPushPopRange',
'NvtxStartEndRange',
'NvtxtPushPopRange',
'NvtxtStartEndRange'
)
"""
)
return [row[0] for row in cur.fetchall()]
def _domain_ids_for_name(conn: sqlite3.Connection, domain: str) -> list[int]:
cur = conn.execute(
"""
SELECT DISTINCT domainId FROM NVTX_EVENTS
WHERE eventType = (SELECT id FROM ENUM_NSYS_EVENT_TYPE WHERE name = 'NvtxDomainCreate' LIMIT 1)
AND text = ?
""",
(domain,),
)
return [row[0] for row in cur.fetchall()]
def _launch_name_ids(conn: sqlite3.Connection) -> list[int]:
cur = conn.execute(
"""
SELECT id FROM StringIds
WHERE value LIKE 'cudaLaunchKernel%' OR value LIKE 'cudaLaunch%'
"""
)
return [row[0] for row in cur.fetchall()]
def load_query_iteration_events_in_domain(
conn: sqlite3.Connection,
event_types: list[int],
domain_ids: list[int],
) -> list[tuple[int, int, int, int]]:
"""
``Query N - Iteration M`` NVTX ranges in the given domain.
Returns ``(start, end, query_num, iteration_num)`` per range (nanoseconds).
"""
cur = conn.execute(
"""
SELECT n.start, n.end, COALESCE(t.value, n.text) AS lbl
FROM NVTX_EVENTS n
LEFT JOIN StringIds t ON n.textId = t.id
WHERE n.eventType IN ({})
AND n.domainId IN ({})
AND n.end IS NOT NULL
""".format(
",".join("?" * len(event_types)),
",".join("?" * len(domain_ids)),
),
[*event_types, *domain_ids],
)
out: list[tuple[int, int, int, int]] = []
for start, end, lbl in cur:
if not lbl:
continue
m = QUERY_ITER_RE.match(lbl.strip())
if not m:
continue
out.append((int(start), int(end), int(m.group(1)), int(m.group(2))))
return out
def query_iteration_wall_span_ns(slices: list[tuple[int, int]]) -> int | None:
"""Elapsed wall time: ``max(end) - min(start)`` over ranges, or ``None`` if empty."""
if not slices:
return None
return max(e for _, e in slices) - min(s for s, _ in slices)
def filter_windows_for_cli(
qi_events: list[tuple[int, int, int, int]],
*,
query_id: int | None,
iteration: int | None,
) -> tuple[list[tuple[int, int]] | None, list[tuple[int, int]]]:
"""
IR filter windows (or ``None``) and wall-clock slices for reporting.
When ``query_id`` / ``iteration`` are set, both use the same filtered
``(start, end)`` list. Otherwise IR filter is disabled and wall slices are
all query-iteration ranges in the domain.
"""
if query_id is not None or iteration is not None:
windows = [
(s, e)
for s, e, nq, ni in qi_events
if (query_id is None or nq == query_id)
and (iteration is None or ni == iteration)
]
if not windows:
raise SystemExit(
"No NVTX ranges matched the query/iteration filter in the requested domain "
f"(query={query_id!r}, iteration={iteration!r})."
)
return windows, windows
all_slices = [(s, e) for s, e, _, _ in qi_events]
return None, all_slices
def intervals_overlap(a0: int, a1: int, b0: int, b1: int) -> bool:
return a0 < b1 and b0 < a1
def filter_by_windows(
rows: list[tuple[int, int, int, str]],
windows: list[tuple[int, int]] | None,
) -> list[tuple[int, int, int, str]]:
if not windows:
return rows
out: list[tuple[int, int, int, str]] = []
for start, end, tid, label in rows:
if any(intervals_overlap(start, end, w0, w1) for w0, w1 in windows):
out.append((start, end, tid, label))
return out
def _table_columns(conn: sqlite3.Connection, table: str) -> frozenset[str]:
cur = conn.execute(f'PRAGMA table_info("{table}")')
return frozenset(str(row[1]) for row in cur.fetchall())
def decode_nvtx_payload_uint64(
raw: object,
payload_type: int | None,
) -> int | None:
"""
Interpret Nsight-exported NVTX payload as an unsigned byte count (kvikio style).
``payloadType`` follows ``nvtxPayloadType_t``; kvikio passes ``size`` as UINT64.
"""
if raw is None:
return None
if payload_type is not None and payload_type not in (
0, # UNKNOWN: still try if raw is present (some exports omit type)
NVTX_PAYLOAD_TYPE_UNSIGNED_INT64,
NVTX_PAYLOAD_TYPE_INT64,
):
return None
if isinstance(raw, int):
if raw < 0 and payload_type == NVTX_PAYLOAD_TYPE_INT64:
return None
return int(raw) & 0xFFFFFFFFFFFFFFFF
if isinstance(raw, (bytes, bytearray, memoryview)):
buf = bytes(raw)
if len(buf) >= 8:
return struct.unpack_from("<Q", buf, 0)[0]
if len(buf) >= 4 and payload_type in (None, 0, 4, 5, 6):
return int(struct.unpack_from("<I", buf, 0)[0])
return None
def load_kvikio_remote_read_events(
conn: sqlite3.Connection,
*,
event_types: list[int],
kvikio_domain: str,
) -> list[tuple[int, int, int, str, int | None]]:
"""
libkvikio ``RemoteHandle::read`` NVTX ranges with decoded payload bytes.
Returns ``(start_ns, end_ns, globalTid, label, nbytes_or_none)`` per event.
"""
domain_ids = _domain_ids_for_name(conn, kvikio_domain)
if not domain_ids:
return []
cur = conn.execute(
f"""
SELECT n.start, n.end, n.globalTid, COALESCE(t.value, n.text) AS lbl, n.uint64Value as payload
FROM NVTX_EVENTS n
LEFT JOIN StringIds t ON n.textId = t.id
WHERE n.eventType IN ({",".join("?" * len(event_types))})
AND n.domainId IN ({",".join("?" * len(domain_ids))})
AND n.end IS NOT NULL
""",
[*event_types, *domain_ids],
)
out: list[tuple[int, int, int, int | None]] = []
for row in cur:
start, end, tid, lbl, payload = row
if not lbl:
continue
text = str(lbl).strip()
if KVIKIO_REMOTE_READ_SUBSTR not in text:
continue
nbytes = payload
out.append((int(start), int(end), int(tid), text, nbytes))
return out
def filter_kvikio_by_windows(
rows: list[tuple[int, int, int, str, int | None]],
windows: list[tuple[int, int]] | None,
) -> list[tuple[int, int, int, str, int | None]]:
if not windows:
return rows
out: list[tuple[int, int, int, str, int | None]] = []
for start, end, tid, lbl, nb in rows:
if any(intervals_overlap(start, end, w0, w1) for w0, w1 in windows):
out.append((start, end, tid, lbl, nb))
return out
def summarize_kvikio_reads(
rows: list[tuple[int, int, int, str, int | None]],
) -> tuple[int, int, int, float]:
"""
Returns ``(count_with_payload, missing_payload, total_bytes, total_duration_ns)``.
Duration is sum of per-range (end - start); overlapping ranges double-count time.
"""
total_b = 0
total_ns = 0
missing = 0
for start, end, _tid, _lbl, nb in rows:
dur = end - start
if dur <= 0:
continue
if nb is None:
missing += 1
continue
total_b += nb
total_ns += dur
have = len(rows) - missing
return have, missing, total_b, total_ns
def print_kvikio_read_throughput(
console: Console,
*,
rows: list[tuple[int, int, int, str, int | None]],
kvikio_domain: str,
filtered: bool,
) -> None:
have, missing, total_b, total_ns = summarize_kvikio_reads(rows)
console.print()
console.print(
"[bold]libkvikio RemoteHandle::read[/bold] "
f"(domain [cyan]{kvikio_domain!r}[/cyan]"
+ (", filtered to query/iteration windows" if filtered else "")
+ ")"
)
if not rows:
console.print(" No matching NVTX ranges (trace may lack libkvikio or read() was not used).")
return
console.print(f" Matching ranges: {len(rows)} (with payload bytes: {have}, missing payload: {missing})")
if have == 0 or total_ns <= 0:
console.print(
" Cannot compute throughput (no UINT64 payload on these events, or NVTX_EVENTS "
"export has no payload column — inspect with `sqlite3 … '.schema NVTX_EVENTS'`)."
)
return
gb = total_b / 1e9
s = total_ns / 1e9
gbs = gb / s if s > 0 else 0.0
console.print(f" Sum of read sizes: {total_b} bytes ({gb:.6f} GB)")
console.print(
f" Sum of range durations: {format_ns(total_ns)} "
"(sum over events; overlaps on multiple threads inflate this vs wall time)"
)
console.print(f" [bold]Aggregate bytes / sum(duration): {gbs:.4f} GB/s[/bold]")
if len(rows) <= 20:
console.print(" Per-event (bytes, duration, bytes/s):")
for start, end, tid, lbl, nb in sorted(rows, key=lambda r: r[0]):
dur = end - start
if nb is None or dur <= 0:
console.print(f" tid={tid} {format_ns(dur)} payload=n/a")
else:
# GB/s = (bytes/1e9) / (ns/1e9) = bytes/ns with bytes and ns as given.
per_gbs = nb / dur
console.print(
f" tid={tid} bytes={nb} dur={format_ns(dur)} ~{per_gbs:.4f} GB/s"
)
def load_ir_nvtx_rows(
conn: sqlite3.Connection,
*,
domain_ids: list[int],
event_types: list[int],
ir_only: bool,
ir_labels: frozenset[str],
exclude_labels: frozenset[str],
) -> list[tuple[int, int, int, str]]:
if not domain_ids:
raise SystemExit("No NvtxDomainCreate events found for the requested domain name.")
cur = conn.execute(
"""
SELECT n.start, n.end, n.globalTid, COALESCE(t.value, n.text) AS lbl
FROM NVTX_EVENTS n
LEFT JOIN StringIds t ON n.textId = t.id
WHERE n.eventType IN ({})
AND n.domainId IN ({})
AND n.end IS NOT NULL
""".format(
",".join("?" * len(event_types)),
",".join("?" * len(domain_ids)),
),
[*event_types, *domain_ids],
)
rows: list[tuple[int, int, int, str]] = []
for start, end, tid, lbl in cur:
if lbl is None:
lbl = ""
lbl = lbl.strip()
if not lbl:
continue
if lbl in exclude_labels:
continue
if ir_only and lbl not in ir_labels:
continue
rows.append((start, end, tid, lbl))
return rows
def load_kernels_with_launch(
conn: sqlite3.Connection, launch_name_ids: list[int]
) -> list[tuple[int, int, int, int]]:
"""(kernel_start, kernel_end, launch_start, globalTid) per kernel."""
if not launch_name_ids:
return []
placeholders = ",".join("?" * len(launch_name_ids))
cur = conn.execute(
f"""
SELECT k.start, k.end, r.start, r.globalTid
FROM CUPTI_ACTIVITY_KIND_KERNEL k
INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME r
ON k.correlationId = r.correlationId
WHERE r.nameId IN ({placeholders})
AND r.globalTid IS NOT NULL
""",
launch_name_ids,
)
return [(int(a), int(b), int(c), int(d)) for a, b, c, d in cur.fetchall()]
def critical_path_innermost_ns(
rows: list[tuple[int, int, int, str]],
) -> defaultdict[str, int]:
"""
Per ``globalTid``, wall time where each IR label is the innermost active
range (NVTX stack top). Sums across threads.
Intervals are sweep-processed with ends-before-starts at the same timestamp,
inner ends / outer starts ordered first so LIFO nesting is respected.
Overlapping non-nested intervals on one thread are healed by popping until
the expected uid is on top (may distort rare malformed traces).
"""
by_tid: dict[int, list[tuple[int, int, str, int]]] = defaultdict(list)
for i, (s, e, tid, lab) in enumerate(rows):
if e <= s:
continue
by_tid[tid].append((s, e, lab, i))
out: defaultdict[str, int] = defaultdict(int)
for _tid, ivs in by_tid.items():
evs: list[tuple[int, int, int, int, str]] = []
for s, e, lab, uid in ivs:
evs.append((s, 1, -e, uid, lab))
evs.append((e, 0, -s, uid, lab))
evs.sort(key=lambda x: (x[0], x[1], x[2], x[3]))
stack: list[tuple[int, str]] = []
prev_t: int | None = None
for t, phase, _prio, uid, lab in evs:
if prev_t is not None and t > prev_t and stack:
out[stack[-1][1]] += t - prev_t
if phase == 1:
stack.append((uid, lab))
else:
while stack and stack[-1][0] != uid:
stack.pop()
if stack:
stack.pop()
prev_t = t
return out
def attribute_device_by_innermost_launch(
nvtx_rows: list[tuple[int, int, int, str]],
kernels: list[tuple[int, int, int, int]],
) -> defaultdict[str, int]:
"""
For each kernel, pick the active NVTX range on the same thread with the
largest ``start`` still <= launch_start (innermost at launch). Add full
kernel duration to that label only.
"""
by_tid: dict[int, list[tuple[int, int, str]]] = defaultdict(list)
for start, end, tid, label in nvtx_rows:
by_tid[tid].append((start, end, label))
for tid in by_tid:
by_tid[tid].sort(key=lambda x: x[0])
device_ns: defaultdict[str, int] = defaultdict(int)
for ks, ke, launch_start, gtid in kernels:
candidates = by_tid.get(gtid)
if not candidates:
continue
best: tuple[int, int, str] | None = None
for s, e, lab in candidates:
if s <= launch_start < e:
if best is None or s > best[0]:
best = (s, e, lab)
if best is not None:
device_ns[best[2]] += ke - ks
return device_ns
def format_ns(ns: float) -> str:
if ns >= 1e9:
return f"{ns / 1e9:.3f} s"
if ns >= 1e6:
return f"{ns / 1e6:.3f} ms"
if ns >= 1e3:
return f"{ns / 1e3:.3f} µs"
return f"{ns:.0f} ns"
def _export_and_open_sqlite(
nsys_rep: Path,
sqlite_path: Path,
*,
no_export: bool,
force_export: bool,
) -> int:
"""Run ``nsys export`` if needed; return 0 or a process exit code."""
if not no_export:
try:
export_sqlite(nsys_rep, sqlite_path, force=force_export)
except subprocess.CalledProcessError as e:
print(
"nsys export failed. If the report is newer than your nsys build, "
"upgrade Nsight Systems or use a matching ``nsys``.\n"
f"Command exit code: {e.returncode}",
file=sys.stderr,
)
return e.returncode or 1
elif not sqlite_path.is_file():
print(f"SQLite not found (--no-export): {sqlite_path}", file=sys.stderr)
return 1
return 0
def _add_export_args(p: argparse.ArgumentParser) -> None:
p.add_argument(
"report",
type=Path,
help="Path to .nsys-rep (SQLite is created beside it by default)",
)
p.add_argument(
"--sqlite",
type=Path,
default=None,
help="SQLite output path (default: <report>.sqlite)",
)
p.add_argument(
"--no-export",
action="store_true",
help="Do not run nsys export; require existing SQLite",
)
p.add_argument(
"--force-export",
action="store_true",
help="Re-run nsys export even if SQLite exists",
)
def run_summary(
conn: sqlite3.Connection,
console: Console,
*,
nsys_rep: Path,
sqlite_path: Path,
domain: str,
query_id: int | None,
iteration: int | None,
all_types: bool,
exclude: frozenset[str],
) -> int:
"""IR node host/device table. Returns 0 on success."""
event_types = _nvtx_range_event_types(conn)
if not event_types:
print("Could not resolve NVTX range event types from ENUM_NSYS_EVENT_TYPE.", file=sys.stderr)
return 1
domain_ids = _domain_ids_for_name(conn, domain)
if not domain_ids:
print(
f"No NvtxDomainCreate rows with text={domain!r}. "
"Try --domain rapidsmpf or check the trace.",
file=sys.stderr,
)
return 1
qi_events = load_query_iteration_events_in_domain(conn, event_types, domain_ids)
windows, wall_query_slices = filter_windows_for_cli(
qi_events, query_id=query_id, iteration=iteration
)
query_wall_ns = query_iteration_wall_span_ns(wall_query_slices)
rows = load_ir_nvtx_rows(
conn,
domain_ids=domain_ids,
event_types=event_types,
ir_only=not all_types,
ir_labels=DEFAULT_IR_LABELS,
exclude_labels=exclude,
)
rows = filter_by_windows(rows, windows)
host_by_label: defaultdict[str, int] = defaultdict(int)
for start, end, _tid, label in rows:
host_by_label[label] += end - start
launch_ids = _launch_name_ids(conn)
kernels = load_kernels_with_launch(conn, launch_ids)
device_by_label = attribute_device_by_innermost_launch(rows, kernels)
cp_by_label = critical_path_innermost_ns(rows)
labels = sorted(set(host_by_label) | set(device_by_label) | set(cp_by_label))
total_host = sum(host_by_label.values())
total_device = sum(device_by_label.values())
wall_start = min((r[0] for r in rows), default=None)
wall_end = max((r[1] for r in rows), default=None)
wall_ns = (wall_end - wall_start) if wall_start is not None else 0
print(f"Report: {nsys_rep}")
print(f"SQLite: {sqlite_path}")
print(f"Domain: {domain!r} (domainId(s) {domain_ids})")
if windows is not None:
print(
f"Query filter: query={query_id!r} iteration={iteration!r} "
f"({len(windows)} window(s))"
)
print(f"IR ranges: {len(rows)}")
if query_wall_ns is not None:
n_qi = len(wall_query_slices)
print(
f"Query/iteration wall: {format_ns(query_wall_ns)} "
f"({n_qi} top-level NVTX range(s) in domain {domain!r}; "
"max(end) - min(start))"
)
else:
print(
f"Query/iteration wall: n/a (no 'Query N - Iteration M' NVTX in domain {domain!r})"
)
print(f"IR wall span: {format_ns(wall_ns)} (min start → max end of included IR NVTX)")
print()
print(
"Host time = sum of (end - start) per NVTX range instance (inclusive; "
"nested parent/child ranges double-count wall time on CPU)."
)
print(
"Device time = sum of kernel durations attributed to the innermost IR range "
"active on the launching thread at cuda launch (avoids double-count across labels)."
)
print(
"Concurrent CUDA streams can run kernels in parallel; summed device time can "
"exceed wall time. Host/Device % columns use totals over this table."
)
print()
table = Table(show_header=True, header_style="bold", padding=(0, 1))
table.add_column("Node Type", no_wrap=True)
table.add_column("Host Time", justify="right")
table.add_column("Host %", justify="right")
table.add_column("Device Time", justify="right")
table.add_column("Device %", justify="right")
for lab in labels:
h = host_by_label.get(lab, 0)
d = device_by_label.get(lab, 0)
hp = 100.0 * h / total_host if total_host else 0.0
dp = 100.0 * d / total_device if total_device else 0.0
table.add_row(
lab,
format_ns(h),
f"{hp:.1f}%",
format_ns(d),
f"{dp:.1f}%",
)
table.add_row(
"TOTAL",
format_ns(total_host),
"100.0%",
format_ns(total_device),
"100.0%",
style="bold",
)
console.print(table)
return 0
def run_io(
conn: sqlite3.Connection,
console: Console,
*,
nsys_rep: Path,
sqlite_path: Path,
kvikio_domain: str,
query_domain: str,
query_id: int | None,
iteration: int | None,
) -> int:
"""libkvikio read throughput summary."""
print(f"Report: {nsys_rep}")
print(f"SQLite: {sqlite_path}")
event_types = _nvtx_range_event_types(conn)
if not event_types:
print("Could not resolve NVTX range event types from ENUM_NSYS_EVENT_TYPE.", file=sys.stderr)
return 1
return run_kvikio_read_stats(
conn,
console,
event_types=event_types,
kvikio_domain=kvikio_domain,
query_domain=query_domain,
query_id=query_id,
iteration=iteration,
)
def run_kvikio_read_stats(
conn: sqlite3.Connection,
console: Console,
*,
event_types: list[int],
kvikio_domain: str,
query_domain: str,
query_id: int | None,
iteration: int | None,
) -> int:
"""
Print libkvikio read throughput summary. Returns 0, or 1 if query windows
were requested but ``query_domain`` has no matching NVTX domain.
"""
windows: list[tuple[int, int]] | None = None
filtered = False
if query_id is not None or iteration is not None:
domain_ids = _domain_ids_for_name(conn, query_domain)
if not domain_ids:
print(
f"No NvtxDomainCreate rows with text={query_domain!r}; "
"cannot apply --query/--iteration windows.",
file=sys.stderr,
)
return 1
qi_events = load_query_iteration_events_in_domain(
conn, event_types, domain_ids
)
windows, _ = filter_windows_for_cli(
qi_events, query_id=query_id, iteration=iteration
)
filtered = True
rows = load_kvikio_remote_read_events(
conn, event_types=event_types, kvikio_domain=kvikio_domain
)
rows = filter_kvikio_by_windows(rows, windows)
print_kvikio_read_throughput(
console,
rows=rows,
kvikio_domain=kvikio_domain,
filtered=filtered,
)
return 0
def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
sub = parser.add_subparsers(dest="command", required=True)
p_sum = sub.add_parser(
"summary",
help="Per IR node: host NVTX time and attributed GPU kernel time.",
)
_add_export_args(p_sum)
p_sum.add_argument(
"--domain",
default="cudf_polars",
help="NVTX domain for IR ranges and for 'Query N - Iteration M' windows",
)
p_sum.add_argument(
"--query",
type=int,
default=None,
help="Only include IR ranges overlapping an NVTX range 'Query N - Iteration M'",
)
p_sum.add_argument(
"--iteration",
type=int,
default=None,
help="With --query, filter by iteration index M (0-based as in NVTX text)",
)
p_sum.add_argument(
"--all-types",
action="store_true",
help="Include every range in the domain, not only known IR node labels",
)
p_sum.add_argument(
"--exclude",
default=",".join(sorted(DEFAULT_EXCLUDE_LABELS)),
help="Comma-separated NVTX labels to skip (default: ConvertIR,ExecuteIR)",
)
p_io = sub.add_parser(
"io",
help="libkvikio RemoteHandle::read: bytes from NVTX payload and GB/s estimate.",
)
_add_export_args(p_io)
p_io.add_argument(
"--domain",
default="cudf_polars",
help="NVTX domain containing 'Query N - Iteration M' ranges (for --query / --iteration)",
)
p_io.add_argument(
"--kvikio-domain",
default=DEFAULT_KVIKIO_NVTX_DOMAIN,
help=f"NVTX domain for kvikio reads (default: {DEFAULT_KVIKIO_NVTX_DOMAIN!r})",
)
p_io.add_argument(
"--query",
type=int,
default=None,
help="Only include reads overlapping a 'Query N - Iteration M' range in --domain",
)
p_io.add_argument(
"--iteration",
type=int,
default=None,
help="With --query, filter by iteration index M",
)
args = parser.parse_args(argv)
nsys_rep = args.report.expanduser().resolve()
if not nsys_rep.is_file():
print(f"Report not found: {nsys_rep}", file=sys.stderr)
return 1
sqlite_path = (args.sqlite or _default_sqlite_path(nsys_rep)).expanduser().resolve()
rc = _export_and_open_sqlite(
nsys_rep,
sqlite_path,
no_export=args.no_export,
force_export=args.force_export,
)
if rc != 0:
return rc
console = Console()
conn = sqlite3.connect(f"file:{sqlite_path}?mode=ro", uri=True)
try:
if args.command == "summary":
exclude = frozenset(x.strip() for x in args.exclude.split(",") if x.strip())
return run_summary(
conn,
console,
nsys_rep=nsys_rep,
sqlite_path=sqlite_path,
domain=args.domain,
query_id=args.query,
iteration=args.iteration,
all_types=args.all_types,
exclude=exclude,
)
if args.command == "io":
return run_io(
conn,
console,
nsys_rep=nsys_rep,
sqlite_path=sqlite_path,
kvikio_domain=args.kvikio_domain,
query_domain=args.domain,
query_id=args.query,
iteration=args.iteration,
)
raise AssertionError(f"unknown command {args.command!r}")
finally:
conn.close()
if __name__ == "__main__":
raise SystemExit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment