Skip to content

Instantly share code, notes, and snippets.

@hcho3
Last active August 5, 2025 03:35
Show Gist options
  • Save hcho3/ce7b52f9be8cb2bbf70f21e6db1970e2 to your computer and use it in GitHub Desktop.
Save hcho3/ce7b52f9be8cb2bbf70f21e6db1970e2 to your computer and use it in GitHub Desktop.
# Copyright (c) 2023-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Credit: Robert Maynard <[email protected]>
# Jinsol Park <[email protected]>
# Hyunsu Cho <[email protected]>
import argparse
import os
import os.path
import pprint
import re
import shutil
import subprocess
import sys
import tempfile
from collections import defaultdict
def tidy_kernel_name(x):
# Tokenize
tokens = []
current_token = ""
skip_space = False
for c in x:
if c in ("<", ">", "(", ")", ","):
if current_token:
tokens.append(current_token)
tokens.append(c)
current_token = ""
skip_space = False
elif c.isspace():
if not skip_space:
current_token += c
else:
current_token += c
skip_space = False
if c == ",":
# Skip space(s) after comma
skip_space = True
if current_token:
tokens.append(current_token)
# Print tokens with indents and line breaks
out = ""
indent = 0
for token in tokens:
if token in ("<", "("):
out += f"{token}\n"
indent += 2
out += " " * indent
elif token in (">", ")"):
indent -= 2
out += "\n"
out += " " * indent
out += token
elif token == ",":
out += ",\n"
out += " " * indent
else:
out += token
return out
def execute(command, args, **kwargs):
working_dir = kwargs.get("cwd", None)
try:
invoke = [command]
invoke += args
output = subprocess.run(
invoke, cwd=working_dir, check=True, capture_output=True
).stdout
output = output.splitlines()
except (OSError, subprocess.CalledProcessError) as err:
print(err)
output = []
return output
def is_elf(file_path: str) -> bool:
with open(file_path, "rb") as f:
first_byte = f.read(4)
return first_byte == b"\x7fELF"
def extract_cubins(elf_file: str, dump_location: str):
execute(
"cuobjdump",
["--extract-elf", "all", os.path.abspath(elf_file)],
cwd=dump_location,
)
files = os.listdir(dump_location)
return [os.path.join(dump_location, f) for f in files]
class Symbol:
def __init__(self, name, raw_type, str_size) -> None:
self.type = raw_type
self.name = name
self.size = int(str_size)
def __eq__(self, other):
return self.name == other.name and self.type == other.type
def extract_info_from_cubin(file: str):
# Each entry for a symbol has the format "type.<symbol>"
# we abuse the fact that the symbols will start with an `_`
# to construct an unique id that allows us to ignore
# info lines
regex = re.compile(r"\s+|\._")
output = execute("size", ["-A", file])
nice_symbols = []
for line in output:
raw_text = line.decode("utf8")
entry = regex.split(raw_text)
if len(entry) == 4:
nice_symbols.append(Symbol("_" + entry[1], entry[0], entry[2]))
return nice_symbols
def transform_to_demangled_names(symbols):
# Call cu++filt with a subset of entries to save time
# We can't pass all entries as cu++filt has a max
# input size
def chunk_iter(x):
for i in range(0, len(x), 128):
yield x[i : i + 128]
symbols_out = []
# Process 128 symbols at a time
for chunk in chunk_iter(symbols):
demangled = [
n.decode("utf8") for n in execute("llvm-cxxfilt", [s.name for s in chunk])
]
for i, n in enumerate(demangled):
symbols_out.append(Symbol(n, chunk[i].type, chunk[i].size))
return symbols_out
class SymbolCache:
def __init__(self, inclusions) -> None:
self.has_inclusions = False
if inclusions:
# build regex engines
self.inclusions = [re.compile(e) for e in inclusions]
self.has_inclusions = True
self.cubin_cache = {}
self.tmpdir = tempfile.mkdtemp()
def __del__(self):
if self.tmpdir:
shutil.rmtree(self.tmpdir)
def load(self, path) -> None:
if path not in self.cubin_cache and is_elf(path):
dump_loc = os.path.join(self.tmpdir, os.path.basename(path))
os.mkdir(dump_loc)
cubins = extract_cubins(path, dump_loc)
self.cubin_cache[path] = cubins
def display_sizes(self):
# Determine if a name matches any of the inclusions regex
def has_match(name):
if self.has_inclusions:
for regex in self.inclusions:
if regex.search(name):
return True
return False
return True
sizes = defaultdict(int)
counts = defaultdict(int)
symbols = []
# When counting the number of duplication for each kernel,
# ignore the presence of multiple code section types
# (.text, .constant, .nv.info etc).
for values in self.cubin_cache.values():
for cubin in values:
current_symbols = transform_to_demangled_names(
extract_info_from_cubin(cubin)
)
symbols.extend(current_symbols)
for s in current_symbols:
if has_match(s.name):
sizes[s.name] += s.size
counts[(s.name, s.type)] += 1
# counts_agg[symbol_name] <- max(counts[(symbol_name, *)])
counts_agg = defaultdict(int)
for k, v in counts.items():
counts_agg[k[0]] = max(counts_agg[k[0]], v)
total_size = sum(sizes.values())
entries = {k: (sizes[k], v) for k, v in counts_agg.items()}
for k, v in sorted(entries.items(), key=lambda kv: kv[1][0]):
if v[0] > 0: # Skip zero-byte entries
print(f"{tidy_kernel_name(k)}: {v[0]} bytes ({v[1]} instantiations)\n")
print("Total uncompressed size of CUDA kernels: ", total_size, "bytes")
def main():
parser = argparse.ArgumentParser(prog="report CUDA SASS Kernel sizes")
parser.add_argument(
"-i",
"--include",
type=str,
nargs="+",
help="only include symbols that match this pattern ( applied on demangled names)",
)
parser.add_argument(
"input", nargs="+", type=str, help="elf file ( .so, .exe, .o ) or directory"
)
args = parser.parse_args()
cache = SymbolCache(args.include)
# Transform any directory into files
items = []
for item in args.input:
if os.path.isdir(item):
for possible_item in os.listdir(item):
pitem = os.path.join(item, possible_item)
if os.path.isfile(pitem):
items.append(pitem)
else:
items.append(item)
for item in items:
if os.path.isfile(item):
cache.load(item)
if len(cache.cubin_cache) == 0:
print("Invalid input given")
parser.print_help()
sys.exit(1)
cache.display_sizes()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment