Created
April 7, 2026 21:36
-
-
Save TomAugspurger/24b668b2cdc28909f77f86c610b999c6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| """ | |
| Analyze an Nsight Systems report (``.nsys-rep``) exported to SQLite. | |
| Subcommands: | |
| - ``summary`` — per IR node type: summed host NVTX time and attributed GPU | |
| kernel time in the chosen NVTX domain (default ``cudf_polars``). | |
| - ``io`` — libkvikio ``RemoteHandle::read`` ranges: payload bytes and rough | |
| throughput (GB/s) from NVTX in domain ``libkvikio``. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import re | |
| import sqlite3 | |
| import struct | |
| import subprocess | |
| import sys | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from rich.console import Console | |
| from rich.table import Table | |
| # IR node labels from nvtx_annotate_cudf_polars on do_evaluate in dsl/ir.py | |
| DEFAULT_IR_LABELS: frozenset[str] = frozenset( | |
| { | |
| "Scan", | |
| "Sink", | |
| "Cache", | |
| "DataFrameScan", | |
| "Select", | |
| "Reduce", | |
| "Rolling", | |
| "GroupBy", | |
| "ConditionalJoin", | |
| "Join", | |
| "HStack", | |
| "Distinct", | |
| "Sort", | |
| "Slice", | |
| "Filter", | |
| "Projection", | |
| "MergeSorted", | |
| "MapFunction", | |
| "Union", | |
| "HConcat", | |
| "Empty", | |
| } | |
| ) | |
| DEFAULT_EXCLUDE_LABELS: frozenset[str] = frozenset({"ConvertIR", "ExecuteIR"}) | |
| # nvtxPayloadType_t (nvToolsExt.h): kvikio uses UINT64 for read size in KVIKIO_NVTX_FUNC_RANGE(size). | |
| NVTX_PAYLOAD_TYPE_UNSIGNED_INT64 = 1 | |
| NVTX_PAYLOAD_TYPE_INT64 = 2 | |
| DEFAULT_KVIKIO_NVTX_DOMAIN = "libkvikio" | |
| # RemoteHandle::read uses __PRETTY_FUNCTION__ (see kvikio remote_handle.cpp). | |
| KVIKIO_REMOTE_READ_SUBSTR = "RemoteHandle::read" | |
| QUERY_ITER_RE = re.compile( | |
| r"^Query\s+(\d+)\s+-\s+Iteration\s+(\d+)\s*$", re.IGNORECASE | |
| ) | |
| def _default_sqlite_path(nsys_rep: Path) -> Path: | |
| return nsys_rep.with_suffix(".sqlite") | |
| def export_sqlite(nsys_rep: Path, sqlite_out: Path, *, force: bool) -> None: | |
| if sqlite_out.exists() and not force: | |
| return | |
| sqlite_out.parent.mkdir(parents=True, exist_ok=True) | |
| cmd = [ | |
| "nsys", | |
| "export", | |
| "--type", | |
| "sqlite", | |
| "--force-overwrite=true", | |
| "-q", | |
| "true", | |
| "-o", | |
| str(sqlite_out), | |
| str(nsys_rep), | |
| ] | |
| subprocess.run(cmd, check=True) | |
| def _nvtx_range_event_types(conn: sqlite3.Connection) -> list[int]: | |
| cur = conn.execute( | |
| """ | |
| SELECT id FROM ENUM_NSYS_EVENT_TYPE | |
| WHERE name IN ( | |
| 'NvtxPushPopRange', | |
| 'NvtxStartEndRange', | |
| 'NvtxtPushPopRange', | |
| 'NvtxtStartEndRange' | |
| ) | |
| """ | |
| ) | |
| return [row[0] for row in cur.fetchall()] | |
| def _domain_ids_for_name(conn: sqlite3.Connection, domain: str) -> list[int]: | |
| cur = conn.execute( | |
| """ | |
| SELECT DISTINCT domainId FROM NVTX_EVENTS | |
| WHERE eventType = (SELECT id FROM ENUM_NSYS_EVENT_TYPE WHERE name = 'NvtxDomainCreate' LIMIT 1) | |
| AND text = ? | |
| """, | |
| (domain,), | |
| ) | |
| return [row[0] for row in cur.fetchall()] | |
| def _launch_name_ids(conn: sqlite3.Connection) -> list[int]: | |
| cur = conn.execute( | |
| """ | |
| SELECT id FROM StringIds | |
| WHERE value LIKE 'cudaLaunchKernel%' OR value LIKE 'cudaLaunch%' | |
| """ | |
| ) | |
| return [row[0] for row in cur.fetchall()] | |
| def load_query_iteration_events_in_domain( | |
| conn: sqlite3.Connection, | |
| event_types: list[int], | |
| domain_ids: list[int], | |
| ) -> list[tuple[int, int, int, int]]: | |
| """ | |
| ``Query N - Iteration M`` NVTX ranges in the given domain. | |
| Returns ``(start, end, query_num, iteration_num)`` per range (nanoseconds). | |
| """ | |
| cur = conn.execute( | |
| """ | |
| SELECT n.start, n.end, COALESCE(t.value, n.text) AS lbl | |
| FROM NVTX_EVENTS n | |
| LEFT JOIN StringIds t ON n.textId = t.id | |
| WHERE n.eventType IN ({}) | |
| AND n.domainId IN ({}) | |
| AND n.end IS NOT NULL | |
| """.format( | |
| ",".join("?" * len(event_types)), | |
| ",".join("?" * len(domain_ids)), | |
| ), | |
| [*event_types, *domain_ids], | |
| ) | |
| out: list[tuple[int, int, int, int]] = [] | |
| for start, end, lbl in cur: | |
| if not lbl: | |
| continue | |
| m = QUERY_ITER_RE.match(lbl.strip()) | |
| if not m: | |
| continue | |
| out.append((int(start), int(end), int(m.group(1)), int(m.group(2)))) | |
| return out | |
| def query_iteration_wall_span_ns(slices: list[tuple[int, int]]) -> int | None: | |
| """Elapsed wall time: ``max(end) - min(start)`` over ranges, or ``None`` if empty.""" | |
| if not slices: | |
| return None | |
| return max(e for _, e in slices) - min(s for s, _ in slices) | |
| def filter_windows_for_cli( | |
| qi_events: list[tuple[int, int, int, int]], | |
| *, | |
| query_id: int | None, | |
| iteration: int | None, | |
| ) -> tuple[list[tuple[int, int]] | None, list[tuple[int, int]]]: | |
| """ | |
| IR filter windows (or ``None``) and wall-clock slices for reporting. | |
| When ``query_id`` / ``iteration`` are set, both use the same filtered | |
| ``(start, end)`` list. Otherwise IR filter is disabled and wall slices are | |
| all query-iteration ranges in the domain. | |
| """ | |
| if query_id is not None or iteration is not None: | |
| windows = [ | |
| (s, e) | |
| for s, e, nq, ni in qi_events | |
| if (query_id is None or nq == query_id) | |
| and (iteration is None or ni == iteration) | |
| ] | |
| if not windows: | |
| raise SystemExit( | |
| "No NVTX ranges matched the query/iteration filter in the requested domain " | |
| f"(query={query_id!r}, iteration={iteration!r})." | |
| ) | |
| return windows, windows | |
| all_slices = [(s, e) for s, e, _, _ in qi_events] | |
| return None, all_slices | |
| def intervals_overlap(a0: int, a1: int, b0: int, b1: int) -> bool: | |
| return a0 < b1 and b0 < a1 | |
| def filter_by_windows( | |
| rows: list[tuple[int, int, int, str]], | |
| windows: list[tuple[int, int]] | None, | |
| ) -> list[tuple[int, int, int, str]]: | |
| if not windows: | |
| return rows | |
| out: list[tuple[int, int, int, str]] = [] | |
| for start, end, tid, label in rows: | |
| if any(intervals_overlap(start, end, w0, w1) for w0, w1 in windows): | |
| out.append((start, end, tid, label)) | |
| return out | |
| def _table_columns(conn: sqlite3.Connection, table: str) -> frozenset[str]: | |
| cur = conn.execute(f'PRAGMA table_info("{table}")') | |
| return frozenset(str(row[1]) for row in cur.fetchall()) | |
| def decode_nvtx_payload_uint64( | |
| raw: object, | |
| payload_type: int | None, | |
| ) -> int | None: | |
| """ | |
| Interpret Nsight-exported NVTX payload as an unsigned byte count (kvikio style). | |
| ``payloadType`` follows ``nvtxPayloadType_t``; kvikio passes ``size`` as UINT64. | |
| """ | |
| if raw is None: | |
| return None | |
| if payload_type is not None and payload_type not in ( | |
| 0, # UNKNOWN: still try if raw is present (some exports omit type) | |
| NVTX_PAYLOAD_TYPE_UNSIGNED_INT64, | |
| NVTX_PAYLOAD_TYPE_INT64, | |
| ): | |
| return None | |
| if isinstance(raw, int): | |
| if raw < 0 and payload_type == NVTX_PAYLOAD_TYPE_INT64: | |
| return None | |
| return int(raw) & 0xFFFFFFFFFFFFFFFF | |
| if isinstance(raw, (bytes, bytearray, memoryview)): | |
| buf = bytes(raw) | |
| if len(buf) >= 8: | |
| return struct.unpack_from("<Q", buf, 0)[0] | |
| if len(buf) >= 4 and payload_type in (None, 0, 4, 5, 6): | |
| return int(struct.unpack_from("<I", buf, 0)[0]) | |
| return None | |
| def load_kvikio_remote_read_events( | |
| conn: sqlite3.Connection, | |
| *, | |
| event_types: list[int], | |
| kvikio_domain: str, | |
| ) -> list[tuple[int, int, int, str, int | None]]: | |
| """ | |
| libkvikio ``RemoteHandle::read`` NVTX ranges with decoded payload bytes. | |
| Returns ``(start_ns, end_ns, globalTid, label, nbytes_or_none)`` per event. | |
| """ | |
| domain_ids = _domain_ids_for_name(conn, kvikio_domain) | |
| if not domain_ids: | |
| return [] | |
| cur = conn.execute( | |
| f""" | |
| SELECT n.start, n.end, n.globalTid, COALESCE(t.value, n.text) AS lbl, n.uint64Value as payload | |
| FROM NVTX_EVENTS n | |
| LEFT JOIN StringIds t ON n.textId = t.id | |
| WHERE n.eventType IN ({",".join("?" * len(event_types))}) | |
| AND n.domainId IN ({",".join("?" * len(domain_ids))}) | |
| AND n.end IS NOT NULL | |
| """, | |
| [*event_types, *domain_ids], | |
| ) | |
| out: list[tuple[int, int, int, int | None]] = [] | |
| for row in cur: | |
| start, end, tid, lbl, payload = row | |
| if not lbl: | |
| continue | |
| text = str(lbl).strip() | |
| if KVIKIO_REMOTE_READ_SUBSTR not in text: | |
| continue | |
| nbytes = payload | |
| out.append((int(start), int(end), int(tid), text, nbytes)) | |
| return out | |
| def filter_kvikio_by_windows( | |
| rows: list[tuple[int, int, int, str, int | None]], | |
| windows: list[tuple[int, int]] | None, | |
| ) -> list[tuple[int, int, int, str, int | None]]: | |
| if not windows: | |
| return rows | |
| out: list[tuple[int, int, int, str, int | None]] = [] | |
| for start, end, tid, lbl, nb in rows: | |
| if any(intervals_overlap(start, end, w0, w1) for w0, w1 in windows): | |
| out.append((start, end, tid, lbl, nb)) | |
| return out | |
| def summarize_kvikio_reads( | |
| rows: list[tuple[int, int, int, str, int | None]], | |
| ) -> tuple[int, int, int, float]: | |
| """ | |
| Returns ``(count_with_payload, missing_payload, total_bytes, total_duration_ns)``. | |
| Duration is sum of per-range (end - start); overlapping ranges double-count time. | |
| """ | |
| total_b = 0 | |
| total_ns = 0 | |
| missing = 0 | |
| for start, end, _tid, _lbl, nb in rows: | |
| dur = end - start | |
| if dur <= 0: | |
| continue | |
| if nb is None: | |
| missing += 1 | |
| continue | |
| total_b += nb | |
| total_ns += dur | |
| have = len(rows) - missing | |
| return have, missing, total_b, total_ns | |
| def print_kvikio_read_throughput( | |
| console: Console, | |
| *, | |
| rows: list[tuple[int, int, int, str, int | None]], | |
| kvikio_domain: str, | |
| filtered: bool, | |
| ) -> None: | |
| have, missing, total_b, total_ns = summarize_kvikio_reads(rows) | |
| console.print() | |
| console.print( | |
| "[bold]libkvikio RemoteHandle::read[/bold] " | |
| f"(domain [cyan]{kvikio_domain!r}[/cyan]" | |
| + (", filtered to query/iteration windows" if filtered else "") | |
| + ")" | |
| ) | |
| if not rows: | |
| console.print(" No matching NVTX ranges (trace may lack libkvikio or read() was not used).") | |
| return | |
| console.print(f" Matching ranges: {len(rows)} (with payload bytes: {have}, missing payload: {missing})") | |
| if have == 0 or total_ns <= 0: | |
| console.print( | |
| " Cannot compute throughput (no UINT64 payload on these events, or NVTX_EVENTS " | |
| "export has no payload column — inspect with `sqlite3 … '.schema NVTX_EVENTS'`)." | |
| ) | |
| return | |
| gb = total_b / 1e9 | |
| s = total_ns / 1e9 | |
| gbs = gb / s if s > 0 else 0.0 | |
| console.print(f" Sum of read sizes: {total_b} bytes ({gb:.6f} GB)") | |
| console.print( | |
| f" Sum of range durations: {format_ns(total_ns)} " | |
| "(sum over events; overlaps on multiple threads inflate this vs wall time)" | |
| ) | |
| console.print(f" [bold]Aggregate bytes / sum(duration): {gbs:.4f} GB/s[/bold]") | |
| if len(rows) <= 20: | |
| console.print(" Per-event (bytes, duration, bytes/s):") | |
| for start, end, tid, lbl, nb in sorted(rows, key=lambda r: r[0]): | |
| dur = end - start | |
| if nb is None or dur <= 0: | |
| console.print(f" tid={tid} {format_ns(dur)} payload=n/a") | |
| else: | |
| # GB/s = (bytes/1e9) / (ns/1e9) = bytes/ns with bytes and ns as given. | |
| per_gbs = nb / dur | |
| console.print( | |
| f" tid={tid} bytes={nb} dur={format_ns(dur)} ~{per_gbs:.4f} GB/s" | |
| ) | |
| def load_ir_nvtx_rows( | |
| conn: sqlite3.Connection, | |
| *, | |
| domain_ids: list[int], | |
| event_types: list[int], | |
| ir_only: bool, | |
| ir_labels: frozenset[str], | |
| exclude_labels: frozenset[str], | |
| ) -> list[tuple[int, int, int, str]]: | |
| if not domain_ids: | |
| raise SystemExit("No NvtxDomainCreate events found for the requested domain name.") | |
| cur = conn.execute( | |
| """ | |
| SELECT n.start, n.end, n.globalTid, COALESCE(t.value, n.text) AS lbl | |
| FROM NVTX_EVENTS n | |
| LEFT JOIN StringIds t ON n.textId = t.id | |
| WHERE n.eventType IN ({}) | |
| AND n.domainId IN ({}) | |
| AND n.end IS NOT NULL | |
| """.format( | |
| ",".join("?" * len(event_types)), | |
| ",".join("?" * len(domain_ids)), | |
| ), | |
| [*event_types, *domain_ids], | |
| ) | |
| rows: list[tuple[int, int, int, str]] = [] | |
| for start, end, tid, lbl in cur: | |
| if lbl is None: | |
| lbl = "" | |
| lbl = lbl.strip() | |
| if not lbl: | |
| continue | |
| if lbl in exclude_labels: | |
| continue | |
| if ir_only and lbl not in ir_labels: | |
| continue | |
| rows.append((start, end, tid, lbl)) | |
| return rows | |
| def load_kernels_with_launch( | |
| conn: sqlite3.Connection, launch_name_ids: list[int] | |
| ) -> list[tuple[int, int, int, int]]: | |
| """(kernel_start, kernel_end, launch_start, globalTid) per kernel.""" | |
| if not launch_name_ids: | |
| return [] | |
| placeholders = ",".join("?" * len(launch_name_ids)) | |
| cur = conn.execute( | |
| f""" | |
| SELECT k.start, k.end, r.start, r.globalTid | |
| FROM CUPTI_ACTIVITY_KIND_KERNEL k | |
| INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME r | |
| ON k.correlationId = r.correlationId | |
| WHERE r.nameId IN ({placeholders}) | |
| AND r.globalTid IS NOT NULL | |
| """, | |
| launch_name_ids, | |
| ) | |
| return [(int(a), int(b), int(c), int(d)) for a, b, c, d in cur.fetchall()] | |
| def critical_path_innermost_ns( | |
| rows: list[tuple[int, int, int, str]], | |
| ) -> defaultdict[str, int]: | |
| """ | |
| Per ``globalTid``, wall time where each IR label is the innermost active | |
| range (NVTX stack top). Sums across threads. | |
| Intervals are sweep-processed with ends-before-starts at the same timestamp, | |
| inner ends / outer starts ordered first so LIFO nesting is respected. | |
| Overlapping non-nested intervals on one thread are healed by popping until | |
| the expected uid is on top (may distort rare malformed traces). | |
| """ | |
| by_tid: dict[int, list[tuple[int, int, str, int]]] = defaultdict(list) | |
| for i, (s, e, tid, lab) in enumerate(rows): | |
| if e <= s: | |
| continue | |
| by_tid[tid].append((s, e, lab, i)) | |
| out: defaultdict[str, int] = defaultdict(int) | |
| for _tid, ivs in by_tid.items(): | |
| evs: list[tuple[int, int, int, int, str]] = [] | |
| for s, e, lab, uid in ivs: | |
| evs.append((s, 1, -e, uid, lab)) | |
| evs.append((e, 0, -s, uid, lab)) | |
| evs.sort(key=lambda x: (x[0], x[1], x[2], x[3])) | |
| stack: list[tuple[int, str]] = [] | |
| prev_t: int | None = None | |
| for t, phase, _prio, uid, lab in evs: | |
| if prev_t is not None and t > prev_t and stack: | |
| out[stack[-1][1]] += t - prev_t | |
| if phase == 1: | |
| stack.append((uid, lab)) | |
| else: | |
| while stack and stack[-1][0] != uid: | |
| stack.pop() | |
| if stack: | |
| stack.pop() | |
| prev_t = t | |
| return out | |
| def attribute_device_by_innermost_launch( | |
| nvtx_rows: list[tuple[int, int, int, str]], | |
| kernels: list[tuple[int, int, int, int]], | |
| ) -> defaultdict[str, int]: | |
| """ | |
| For each kernel, pick the active NVTX range on the same thread with the | |
| largest ``start`` still <= launch_start (innermost at launch). Add full | |
| kernel duration to that label only. | |
| """ | |
| by_tid: dict[int, list[tuple[int, int, str]]] = defaultdict(list) | |
| for start, end, tid, label in nvtx_rows: | |
| by_tid[tid].append((start, end, label)) | |
| for tid in by_tid: | |
| by_tid[tid].sort(key=lambda x: x[0]) | |
| device_ns: defaultdict[str, int] = defaultdict(int) | |
| for ks, ke, launch_start, gtid in kernels: | |
| candidates = by_tid.get(gtid) | |
| if not candidates: | |
| continue | |
| best: tuple[int, int, str] | None = None | |
| for s, e, lab in candidates: | |
| if s <= launch_start < e: | |
| if best is None or s > best[0]: | |
| best = (s, e, lab) | |
| if best is not None: | |
| device_ns[best[2]] += ke - ks | |
| return device_ns | |
| def format_ns(ns: float) -> str: | |
| if ns >= 1e9: | |
| return f"{ns / 1e9:.3f} s" | |
| if ns >= 1e6: | |
| return f"{ns / 1e6:.3f} ms" | |
| if ns >= 1e3: | |
| return f"{ns / 1e3:.3f} µs" | |
| return f"{ns:.0f} ns" | |
| def _export_and_open_sqlite( | |
| nsys_rep: Path, | |
| sqlite_path: Path, | |
| *, | |
| no_export: bool, | |
| force_export: bool, | |
| ) -> int: | |
| """Run ``nsys export`` if needed; return 0 or a process exit code.""" | |
| if not no_export: | |
| try: | |
| export_sqlite(nsys_rep, sqlite_path, force=force_export) | |
| except subprocess.CalledProcessError as e: | |
| print( | |
| "nsys export failed. If the report is newer than your nsys build, " | |
| "upgrade Nsight Systems or use a matching ``nsys``.\n" | |
| f"Command exit code: {e.returncode}", | |
| file=sys.stderr, | |
| ) | |
| return e.returncode or 1 | |
| elif not sqlite_path.is_file(): | |
| print(f"SQLite not found (--no-export): {sqlite_path}", file=sys.stderr) | |
| return 1 | |
| return 0 | |
| def _add_export_args(p: argparse.ArgumentParser) -> None: | |
| p.add_argument( | |
| "report", | |
| type=Path, | |
| help="Path to .nsys-rep (SQLite is created beside it by default)", | |
| ) | |
| p.add_argument( | |
| "--sqlite", | |
| type=Path, | |
| default=None, | |
| help="SQLite output path (default: <report>.sqlite)", | |
| ) | |
| p.add_argument( | |
| "--no-export", | |
| action="store_true", | |
| help="Do not run nsys export; require existing SQLite", | |
| ) | |
| p.add_argument( | |
| "--force-export", | |
| action="store_true", | |
| help="Re-run nsys export even if SQLite exists", | |
| ) | |
| def run_summary( | |
| conn: sqlite3.Connection, | |
| console: Console, | |
| *, | |
| nsys_rep: Path, | |
| sqlite_path: Path, | |
| domain: str, | |
| query_id: int | None, | |
| iteration: int | None, | |
| all_types: bool, | |
| exclude: frozenset[str], | |
| ) -> int: | |
| """IR node host/device table. Returns 0 on success.""" | |
| event_types = _nvtx_range_event_types(conn) | |
| if not event_types: | |
| print("Could not resolve NVTX range event types from ENUM_NSYS_EVENT_TYPE.", file=sys.stderr) | |
| return 1 | |
| domain_ids = _domain_ids_for_name(conn, domain) | |
| if not domain_ids: | |
| print( | |
| f"No NvtxDomainCreate rows with text={domain!r}. " | |
| "Try --domain rapidsmpf or check the trace.", | |
| file=sys.stderr, | |
| ) | |
| return 1 | |
| qi_events = load_query_iteration_events_in_domain(conn, event_types, domain_ids) | |
| windows, wall_query_slices = filter_windows_for_cli( | |
| qi_events, query_id=query_id, iteration=iteration | |
| ) | |
| query_wall_ns = query_iteration_wall_span_ns(wall_query_slices) | |
| rows = load_ir_nvtx_rows( | |
| conn, | |
| domain_ids=domain_ids, | |
| event_types=event_types, | |
| ir_only=not all_types, | |
| ir_labels=DEFAULT_IR_LABELS, | |
| exclude_labels=exclude, | |
| ) | |
| rows = filter_by_windows(rows, windows) | |
| host_by_label: defaultdict[str, int] = defaultdict(int) | |
| for start, end, _tid, label in rows: | |
| host_by_label[label] += end - start | |
| launch_ids = _launch_name_ids(conn) | |
| kernels = load_kernels_with_launch(conn, launch_ids) | |
| device_by_label = attribute_device_by_innermost_launch(rows, kernels) | |
| cp_by_label = critical_path_innermost_ns(rows) | |
| labels = sorted(set(host_by_label) | set(device_by_label) | set(cp_by_label)) | |
| total_host = sum(host_by_label.values()) | |
| total_device = sum(device_by_label.values()) | |
| wall_start = min((r[0] for r in rows), default=None) | |
| wall_end = max((r[1] for r in rows), default=None) | |
| wall_ns = (wall_end - wall_start) if wall_start is not None else 0 | |
| print(f"Report: {nsys_rep}") | |
| print(f"SQLite: {sqlite_path}") | |
| print(f"Domain: {domain!r} (domainId(s) {domain_ids})") | |
| if windows is not None: | |
| print( | |
| f"Query filter: query={query_id!r} iteration={iteration!r} " | |
| f"({len(windows)} window(s))" | |
| ) | |
| print(f"IR ranges: {len(rows)}") | |
| if query_wall_ns is not None: | |
| n_qi = len(wall_query_slices) | |
| print( | |
| f"Query/iteration wall: {format_ns(query_wall_ns)} " | |
| f"({n_qi} top-level NVTX range(s) in domain {domain!r}; " | |
| "max(end) - min(start))" | |
| ) | |
| else: | |
| print( | |
| f"Query/iteration wall: n/a (no 'Query N - Iteration M' NVTX in domain {domain!r})" | |
| ) | |
| print(f"IR wall span: {format_ns(wall_ns)} (min start → max end of included IR NVTX)") | |
| print() | |
| print( | |
| "Host time = sum of (end - start) per NVTX range instance (inclusive; " | |
| "nested parent/child ranges double-count wall time on CPU)." | |
| ) | |
| print( | |
| "Device time = sum of kernel durations attributed to the innermost IR range " | |
| "active on the launching thread at cuda launch (avoids double-count across labels)." | |
| ) | |
| print( | |
| "Concurrent CUDA streams can run kernels in parallel; summed device time can " | |
| "exceed wall time. Host/Device % columns use totals over this table." | |
| ) | |
| print() | |
| table = Table(show_header=True, header_style="bold", padding=(0, 1)) | |
| table.add_column("Node Type", no_wrap=True) | |
| table.add_column("Host Time", justify="right") | |
| table.add_column("Host %", justify="right") | |
| table.add_column("Device Time", justify="right") | |
| table.add_column("Device %", justify="right") | |
| for lab in labels: | |
| h = host_by_label.get(lab, 0) | |
| d = device_by_label.get(lab, 0) | |
| hp = 100.0 * h / total_host if total_host else 0.0 | |
| dp = 100.0 * d / total_device if total_device else 0.0 | |
| table.add_row( | |
| lab, | |
| format_ns(h), | |
| f"{hp:.1f}%", | |
| format_ns(d), | |
| f"{dp:.1f}%", | |
| ) | |
| table.add_row( | |
| "TOTAL", | |
| format_ns(total_host), | |
| "100.0%", | |
| format_ns(total_device), | |
| "100.0%", | |
| style="bold", | |
| ) | |
| console.print(table) | |
| return 0 | |
| def run_io( | |
| conn: sqlite3.Connection, | |
| console: Console, | |
| *, | |
| nsys_rep: Path, | |
| sqlite_path: Path, | |
| kvikio_domain: str, | |
| query_domain: str, | |
| query_id: int | None, | |
| iteration: int | None, | |
| ) -> int: | |
| """libkvikio read throughput summary.""" | |
| print(f"Report: {nsys_rep}") | |
| print(f"SQLite: {sqlite_path}") | |
| event_types = _nvtx_range_event_types(conn) | |
| if not event_types: | |
| print("Could not resolve NVTX range event types from ENUM_NSYS_EVENT_TYPE.", file=sys.stderr) | |
| return 1 | |
| return run_kvikio_read_stats( | |
| conn, | |
| console, | |
| event_types=event_types, | |
| kvikio_domain=kvikio_domain, | |
| query_domain=query_domain, | |
| query_id=query_id, | |
| iteration=iteration, | |
| ) | |
| def run_kvikio_read_stats( | |
| conn: sqlite3.Connection, | |
| console: Console, | |
| *, | |
| event_types: list[int], | |
| kvikio_domain: str, | |
| query_domain: str, | |
| query_id: int | None, | |
| iteration: int | None, | |
| ) -> int: | |
| """ | |
| Print libkvikio read throughput summary. Returns 0, or 1 if query windows | |
| were requested but ``query_domain`` has no matching NVTX domain. | |
| """ | |
| windows: list[tuple[int, int]] | None = None | |
| filtered = False | |
| if query_id is not None or iteration is not None: | |
| domain_ids = _domain_ids_for_name(conn, query_domain) | |
| if not domain_ids: | |
| print( | |
| f"No NvtxDomainCreate rows with text={query_domain!r}; " | |
| "cannot apply --query/--iteration windows.", | |
| file=sys.stderr, | |
| ) | |
| return 1 | |
| qi_events = load_query_iteration_events_in_domain( | |
| conn, event_types, domain_ids | |
| ) | |
| windows, _ = filter_windows_for_cli( | |
| qi_events, query_id=query_id, iteration=iteration | |
| ) | |
| filtered = True | |
| rows = load_kvikio_remote_read_events( | |
| conn, event_types=event_types, kvikio_domain=kvikio_domain | |
| ) | |
| rows = filter_kvikio_by_windows(rows, windows) | |
| print_kvikio_read_throughput( | |
| console, | |
| rows=rows, | |
| kvikio_domain=kvikio_domain, | |
| filtered=filtered, | |
| ) | |
| return 0 | |
| def main(argv: list[str] | None = None) -> int: | |
| parser = argparse.ArgumentParser( | |
| description=__doc__, | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| ) | |
| sub = parser.add_subparsers(dest="command", required=True) | |
| p_sum = sub.add_parser( | |
| "summary", | |
| help="Per IR node: host NVTX time and attributed GPU kernel time.", | |
| ) | |
| _add_export_args(p_sum) | |
| p_sum.add_argument( | |
| "--domain", | |
| default="cudf_polars", | |
| help="NVTX domain for IR ranges and for 'Query N - Iteration M' windows", | |
| ) | |
| p_sum.add_argument( | |
| "--query", | |
| type=int, | |
| default=None, | |
| help="Only include IR ranges overlapping an NVTX range 'Query N - Iteration M'", | |
| ) | |
| p_sum.add_argument( | |
| "--iteration", | |
| type=int, | |
| default=None, | |
| help="With --query, filter by iteration index M (0-based as in NVTX text)", | |
| ) | |
| p_sum.add_argument( | |
| "--all-types", | |
| action="store_true", | |
| help="Include every range in the domain, not only known IR node labels", | |
| ) | |
| p_sum.add_argument( | |
| "--exclude", | |
| default=",".join(sorted(DEFAULT_EXCLUDE_LABELS)), | |
| help="Comma-separated NVTX labels to skip (default: ConvertIR,ExecuteIR)", | |
| ) | |
| p_io = sub.add_parser( | |
| "io", | |
| help="libkvikio RemoteHandle::read: bytes from NVTX payload and GB/s estimate.", | |
| ) | |
| _add_export_args(p_io) | |
| p_io.add_argument( | |
| "--domain", | |
| default="cudf_polars", | |
| help="NVTX domain containing 'Query N - Iteration M' ranges (for --query / --iteration)", | |
| ) | |
| p_io.add_argument( | |
| "--kvikio-domain", | |
| default=DEFAULT_KVIKIO_NVTX_DOMAIN, | |
| help=f"NVTX domain for kvikio reads (default: {DEFAULT_KVIKIO_NVTX_DOMAIN!r})", | |
| ) | |
| p_io.add_argument( | |
| "--query", | |
| type=int, | |
| default=None, | |
| help="Only include reads overlapping a 'Query N - Iteration M' range in --domain", | |
| ) | |
| p_io.add_argument( | |
| "--iteration", | |
| type=int, | |
| default=None, | |
| help="With --query, filter by iteration index M", | |
| ) | |
| args = parser.parse_args(argv) | |
| nsys_rep = args.report.expanduser().resolve() | |
| if not nsys_rep.is_file(): | |
| print(f"Report not found: {nsys_rep}", file=sys.stderr) | |
| return 1 | |
| sqlite_path = (args.sqlite or _default_sqlite_path(nsys_rep)).expanduser().resolve() | |
| rc = _export_and_open_sqlite( | |
| nsys_rep, | |
| sqlite_path, | |
| no_export=args.no_export, | |
| force_export=args.force_export, | |
| ) | |
| if rc != 0: | |
| return rc | |
| console = Console() | |
| conn = sqlite3.connect(f"file:{sqlite_path}?mode=ro", uri=True) | |
| try: | |
| if args.command == "summary": | |
| exclude = frozenset(x.strip() for x in args.exclude.split(",") if x.strip()) | |
| return run_summary( | |
| conn, | |
| console, | |
| nsys_rep=nsys_rep, | |
| sqlite_path=sqlite_path, | |
| domain=args.domain, | |
| query_id=args.query, | |
| iteration=args.iteration, | |
| all_types=args.all_types, | |
| exclude=exclude, | |
| ) | |
| if args.command == "io": | |
| return run_io( | |
| conn, | |
| console, | |
| nsys_rep=nsys_rep, | |
| sqlite_path=sqlite_path, | |
| kvikio_domain=args.kvikio_domain, | |
| query_domain=args.domain, | |
| query_id=args.query, | |
| iteration=args.iteration, | |
| ) | |
| raise AssertionError(f"unknown command {args.command!r}") | |
| finally: | |
| conn.close() | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment