Last active
August 5, 2025 03:35
-
-
Save hcho3/ce7b52f9be8cb2bbf70f21e6db1970e2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2023-2025, NVIDIA CORPORATION. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# Credit: Robert Maynard <[email protected]> | |
# Jinsol Park <[email protected]> | |
# Hyunsu Cho <[email protected]> | |
import argparse | |
import os | |
import os.path | |
import pprint | |
import re | |
import shutil | |
import subprocess | |
import sys | |
import tempfile | |
from collections import defaultdict | |
def tidy_kernel_name(x): | |
# Tokenize | |
tokens = [] | |
current_token = "" | |
skip_space = False | |
for c in x: | |
if c in ("<", ">", "(", ")", ","): | |
if current_token: | |
tokens.append(current_token) | |
tokens.append(c) | |
current_token = "" | |
skip_space = False | |
elif c.isspace(): | |
if not skip_space: | |
current_token += c | |
else: | |
current_token += c | |
skip_space = False | |
if c == ",": | |
# Skip space(s) after comma | |
skip_space = True | |
if current_token: | |
tokens.append(current_token) | |
# Print tokens with indents and line breaks | |
out = "" | |
indent = 0 | |
for token in tokens: | |
if token in ("<", "("): | |
out += f"{token}\n" | |
indent += 2 | |
out += " " * indent | |
elif token in (">", ")"): | |
indent -= 2 | |
out += "\n" | |
out += " " * indent | |
out += token | |
elif token == ",": | |
out += ",\n" | |
out += " " * indent | |
else: | |
out += token | |
return out | |
def execute(command, args, **kwargs): | |
working_dir = kwargs.get("cwd", None) | |
try: | |
invoke = [command] | |
invoke += args | |
output = subprocess.run( | |
invoke, cwd=working_dir, check=True, capture_output=True | |
).stdout | |
output = output.splitlines() | |
except (OSError, subprocess.CalledProcessError) as err: | |
print(err) | |
output = [] | |
return output | |
def is_elf(file_path: str) -> bool: | |
with open(file_path, "rb") as f: | |
first_byte = f.read(4) | |
return first_byte == b"\x7fELF" | |
def extract_cubins(elf_file: str, dump_location: str): | |
execute( | |
"cuobjdump", | |
["--extract-elf", "all", os.path.abspath(elf_file)], | |
cwd=dump_location, | |
) | |
files = os.listdir(dump_location) | |
return [os.path.join(dump_location, f) for f in files] | |
class Symbol: | |
def __init__(self, name, raw_type, str_size) -> None: | |
self.type = raw_type | |
self.name = name | |
self.size = int(str_size) | |
def __eq__(self, other): | |
return self.name == other.name and self.type == other.type | |
def extract_info_from_cubin(file: str): | |
# Each entry for a symbol has the format "type.<symbol>" | |
# we abuse the fact that the symbols will start with an `_` | |
# to construct an unique id that allows us to ignore | |
# info lines | |
regex = re.compile(r"\s+|\._") | |
output = execute("size", ["-A", file]) | |
nice_symbols = [] | |
for line in output: | |
raw_text = line.decode("utf8") | |
entry = regex.split(raw_text) | |
if len(entry) == 4: | |
nice_symbols.append(Symbol("_" + entry[1], entry[0], entry[2])) | |
return nice_symbols | |
def transform_to_demangled_names(symbols): | |
# Call cu++filt with a subset of entries to save time | |
# We can't pass all entries as cu++filt has a max | |
# input size | |
def chunk_iter(x): | |
for i in range(0, len(x), 128): | |
yield x[i : i + 128] | |
symbols_out = [] | |
# Process 128 symbols at a time | |
for chunk in chunk_iter(symbols): | |
demangled = [ | |
n.decode("utf8") for n in execute("llvm-cxxfilt", [s.name for s in chunk]) | |
] | |
for i, n in enumerate(demangled): | |
symbols_out.append(Symbol(n, chunk[i].type, chunk[i].size)) | |
return symbols_out | |
class SymbolCache: | |
def __init__(self, inclusions) -> None: | |
self.has_inclusions = False | |
if inclusions: | |
# build regex engines | |
self.inclusions = [re.compile(e) for e in inclusions] | |
self.has_inclusions = True | |
self.cubin_cache = {} | |
self.tmpdir = tempfile.mkdtemp() | |
def __del__(self): | |
if self.tmpdir: | |
shutil.rmtree(self.tmpdir) | |
def load(self, path) -> None: | |
if path not in self.cubin_cache and is_elf(path): | |
dump_loc = os.path.join(self.tmpdir, os.path.basename(path)) | |
os.mkdir(dump_loc) | |
cubins = extract_cubins(path, dump_loc) | |
self.cubin_cache[path] = cubins | |
def display_sizes(self): | |
# Determine if a name matches any of the inclusions regex | |
def has_match(name): | |
if self.has_inclusions: | |
for regex in self.inclusions: | |
if regex.search(name): | |
return True | |
return False | |
return True | |
sizes = defaultdict(int) | |
counts = defaultdict(int) | |
symbols = [] | |
# When counting the number of duplication for each kernel, | |
# ignore the presence of multiple code section types | |
# (.text, .constant, .nv.info etc). | |
for values in self.cubin_cache.values(): | |
for cubin in values: | |
current_symbols = transform_to_demangled_names( | |
extract_info_from_cubin(cubin) | |
) | |
symbols.extend(current_symbols) | |
for s in current_symbols: | |
if has_match(s.name): | |
sizes[s.name] += s.size | |
counts[(s.name, s.type)] += 1 | |
# counts_agg[symbol_name] <- max(counts[(symbol_name, *)]) | |
counts_agg = defaultdict(int) | |
for k, v in counts.items(): | |
counts_agg[k[0]] = max(counts_agg[k[0]], v) | |
total_size = sum(sizes.values()) | |
entries = {k: (sizes[k], v) for k, v in counts_agg.items()} | |
for k, v in sorted(entries.items(), key=lambda kv: kv[1][0]): | |
if v[0] > 0: # Skip zero-byte entries | |
print(f"{tidy_kernel_name(k)}: {v[0]} bytes ({v[1]} instantiations)\n") | |
print("Total uncompressed size of CUDA kernels: ", total_size, "bytes") | |
def main(): | |
parser = argparse.ArgumentParser(prog="report CUDA SASS Kernel sizes") | |
parser.add_argument( | |
"-i", | |
"--include", | |
type=str, | |
nargs="+", | |
help="only include symbols that match this pattern ( applied on demangled names)", | |
) | |
parser.add_argument( | |
"input", nargs="+", type=str, help="elf file ( .so, .exe, .o ) or directory" | |
) | |
args = parser.parse_args() | |
cache = SymbolCache(args.include) | |
# Transform any directory into files | |
items = [] | |
for item in args.input: | |
if os.path.isdir(item): | |
for possible_item in os.listdir(item): | |
pitem = os.path.join(item, possible_item) | |
if os.path.isfile(pitem): | |
items.append(pitem) | |
else: | |
items.append(item) | |
for item in items: | |
if os.path.isfile(item): | |
cache.load(item) | |
if len(cache.cubin_cache) == 0: | |
print("Invalid input given") | |
parser.print_help() | |
sys.exit(1) | |
cache.display_sizes() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment