Created
August 2, 2023 18:34
-
-
Save titu1994/e786fbd1efccd81f412bf76df5ff41c7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Benchmark script to measure Numba fp16 vs fp32 memory cost for RNNT loss. | |
Usage: | |
# Compute and evaluate the default benchmark configuration | |
python numba_memory_benchmark.py --results_dir "./results_fp16_vs_fp32" | |
# Modifying benchmark parameters | |
python numba_memory_benchmark.py --results_dir "./results_fp16_vs_fp32" \ | |
-B "1,2,4,8,16,32" \ | |
-T "200,400" \ | |
-U "100,200" \ | |
-V "28,1024 \ | |
-H "640" | |
# Only evaluate previously computed benchmark results (without re-computation) | |
# Simplified results | |
python numba_memory_benchmark.py --results_dir "./results_fp16_vs_fp32" --no-compute | |
# Breakdown of results into [data - loss - total] | |
python numba_memory_benchmark.py --results_dir "./results_fp16_vs_fp32" --no-compute --full-results | |
# Calculate benchmark without allocating memory of the gradient tensor | |
python numba_memory_benchmark.py --results_dir "./results_fp16_vs_fp32" --no-grads | |
""" | |
import argparse | |
import os | |
os.environ['NUMBA_CUDA_USE_NVIDIA_BINDING'] = "1" | |
import pickle | |
import subprocess | |
from typing import List, Union, Tuple | |
import torch | |
from pytorch_lightning import seed_everything | |
from nemo.collections.asr.modules import rnnt | |
from nemo.collections.asr.losses.rnnt import RNNTLoss | |
import numba | |
from nemo.core.utils import numba_utils | |
################################################################################### | |
# UTILITY FUNCTIONS | |
def log_system(): | |
""" | |
Log system information and whether Numba supports cuda and fp16 or not. | |
""" | |
print("Torch :", torch.__version__) | |
print("Numba :", numba.__version__) | |
# Print Numba FP16 supported | |
cuda_supported = numba_utils.numba_cuda_is_supported(numba.__version__) | |
fp16_supported, reason = numba_utils.is_numba_cuda_fp16_supported(return_reason=True) | |
print("Numba supports CUDA:", cuda_supported) | |
print("Numba supports CUDA FP16:", fp16_supported) | |
if not cuda_supported: | |
print("CUDA support not available. Exiting program...") | |
exit(1) | |
if not fp16_supported: | |
print("FP16 support not available. Reason:", reason) | |
print("Exiting program...") | |
exit(1) | |
print() | |
# Print CUDA environment | |
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, encoding='utf-8') | |
print(result.stdout) | |
print() | |
result = subprocess.run(['nvidia-smi', '-L'], capture_output=True, text=True, encoding='utf-8') | |
print(result.stdout) | |
print() | |
torch.cuda.empty_cache() | |
print("GPU Memory :", torch.cuda.memory_summary()) | |
print() | |
print() | |
def load_results(path): | |
"""Load results from a pickle file.""" | |
with open(path, 'rb') as f: | |
results = pickle.load(f) | |
return results | |
def save_results(results, path): | |
"""Save results to a pickle file.""" | |
with open(path, 'wb') as f: | |
pickle.dump(results, f) | |
def print_results(results_path, simplify_results=True): | |
"""Display the results from the preserved run. | |
Args: | |
results_path: Path to the pickle file containing the results. | |
simplify_results: If True, only display the total memory cost. Otherwise, display the breakdown of memory cost | |
into [data - loss - total]. | |
""" | |
results = load_results(results_path) | |
data_memory = [res['data_mem'] for res in results] # type: List[monitor_cuda_mem] | |
data_mem_fp16 = [data for data in data_memory if "float16" in data.scope] # type: List[monitor_cuda_mem] | |
data_mem_fp32 = [data for data in data_memory if "float32" in data.scope] # type: List[monitor_cuda_mem] | |
rnnt_memory = [res['rnnt_mem'] for res in results] # type: List[monitor_cuda_mem] | |
rnnt_mem_fp16 = [rnnt for rnnt in rnnt_memory if "float16" in rnnt.scope] # type: List[monitor_cuda_mem] | |
rnnt_mem_fp32 = [rnnt for rnnt in rnnt_memory if "float32" in rnnt.scope] # type: List[monitor_cuda_mem] | |
loss_memory = [res['loss_mem'] for res in results] # type: List[monitor_cuda_mem] | |
loss_mem_fp16 = [loss for loss in loss_memory if "float16" in loss.scope] # type: List[monitor_cuda_mem] | |
loss_mem_fp32 = [loss for loss in loss_memory if "float32" in loss.scope] # type: List[monitor_cuda_mem] | |
for data16, rnnt16, loss16, data32, rnnt32, loss32 in zip( | |
data_mem_fp16, rnnt_mem_fp16, loss_mem_fp16, data_mem_fp32, rnnt_mem_fp32, loss_mem_fp32 | |
): | |
config = loss16.scope.replace("FP torch.float16", "").replace("FP torch.float32", "").strip() | |
if simplify_results: | |
fmt_str = ( | |
f"{config.upper():36} | " | |
f"FP32 = {HumanBytes.format(max(data32.final_memory, rnnt32.memory_diff, loss32.final_memory)):10} | " | |
f"FP16 = {HumanBytes.format(max(data16.final_memory, rnnt16.final_memory, loss16.final_memory)):10}" | |
) | |
else: | |
fmt_str = ( | |
f"{config.upper():36} | " | |
f"Data + RNNT Dec+Joint 32 = {rnnt32.memory_diff_human:10} | Data + RNNT Dec+Joint 16 = {rnnt16.memory_diff_human:10} |||| " | |
f"Loss Memory 32 = {loss32.memory_diff_human:10} | Loss Memory 16 = {loss16.memory_diff_human:10} |||| " | |
f"FP32 Total Memory = {HumanBytes.format(max(data32.final_memory, rnnt32.memory_diff, loss32.final_memory)):10} | " | |
f"FP16 Total Memory = {HumanBytes.format(max(data16.final_memory, rnnt16.final_memory, loss16.final_memory)):10}" | |
) | |
print(fmt_str) | |
""" | |
Utility context manager to monitor CUDA memory usage using PyTorch CUDA API. | |
""" | |
class monitor_cuda_mem: | |
""" | |
Context manager to monitor CUDA memory usage using PyTorch CUDA API. | |
""" | |
_CONTEXT_DEPTH = 0 | |
ENABLED: bool = True # Globally enables or disabls the context manager | |
EMPTY: bool = False # If True, will perform torch.cuda.empty_cache() before and after the context | |
DEVICE: int = 0 # CUDA device to monitor | |
VERBOSE: bool = True # If true will print the scope of memory that was allocated and freed | |
PRECISION: int = 4 # Number of decimal places to print for memory usage | |
def __init__( | |
self, scope, empty=None, enabled: bool = None, device: int = None, verbose: bool = None, precision: int = None | |
): | |
self.scope = scope | |
self.empty = empty if empty is not None else monitor_cuda_mem.EMPTY | |
self.enabled = enabled if enabled is not None else monitor_cuda_mem.ENABLED | |
self.device = device if device is not None else monitor_cuda_mem.DEVICE | |
self.verbose = verbose if verbose is not None else monitor_cuda_mem.VERBOSE | |
self.precision = precision if precision is not None else monitor_cuda_mem.PRECISION | |
self.reset() | |
def reset(self): | |
self.memory_diff = None | |
self.memory_diff_human = None | |
self.initial_memory = None | |
self.final_memory = None | |
def __enter__(self): | |
monitor_cuda_mem._CONTEXT_DEPTH += 1 | |
if self.enabled: | |
if self.verbose: | |
self.print_pad() | |
print(f"|> {self.scope}") | |
self.initial_memory = torch.cuda.max_memory_allocated(self.device) | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
if self.enabled: | |
if self.empty: | |
torch.cuda.empty_cache() | |
self.final_memory = torch.cuda.max_memory_allocated(self.device) | |
self.memory_diff = self.final_memory - self.initial_memory | |
self.memory_diff_human = HumanBytes.format(self.memory_diff, precision=self.precision) | |
if self.verbose: | |
self.print_pad() | |
print(f"{self.scope} |> {self.memory_diff_human}") | |
monitor_cuda_mem._CONTEXT_DEPTH -= 1 | |
@classmethod | |
def print_pad(cls): | |
print('\t' * (cls._CONTEXT_DEPTH - 1), end='') | |
# Shortened form of the answer from https://stackoverflow.com/a/63839503 | |
# Used to format bytes into human-readable format. | |
class HumanBytes: | |
# fmt: off | |
METRIC_LABELS: List[str] = ["B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"] | |
BINARY_LABELS: List[str] = ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"] | |
PRECISION_OFFSETS: List[float] = [5 * (0.1 ** x) for x in range(1, 22)] # PREDEFINED FOR SPEED. | |
PRECISION_FORMATS: List[str] = [("{}{:." + str(ratio) + "f} {}") for ratio in range(len(PRECISION_OFFSETS))] # PREDEFINED FOR SPEED. | |
# fmt: on | |
@staticmethod | |
def format(num: Union[int, float], metric: bool = False, precision: int = 1) -> str: | |
assert isinstance(num, (int, float)), "num must be an int or float" | |
assert isinstance(metric, bool), "metric must be a bool" | |
assert ( | |
isinstance(precision, int) and precision >= 0 and precision <= len(HumanBytes.PRECISION_OFFSETS) | |
), "precision must be an int (range 0-20)" | |
unit_labels = HumanBytes.METRIC_LABELS if metric else HumanBytes.BINARY_LABELS | |
last_label = unit_labels[-1] | |
unit_step = 1000 if metric else 1024 | |
unit_step_thresh = unit_step - HumanBytes.PRECISION_OFFSETS[precision] | |
is_negative = num < 0 | |
if is_negative: # Faster than ternary assignment or always running abs(). | |
num = abs(num) | |
for unit in unit_labels: | |
if num < unit_step_thresh: | |
break | |
if unit != last_label: | |
num /= unit_step | |
return HumanBytes.PRECISION_FORMATS[precision].format("-" if is_negative else "", num, unit) | |
################################################################################### | |
# DATA UTILS | |
# Global input variables. Used to store data for the benchmark. | |
global x, x_len, y, y_len | |
DEVICE = "cuda" | |
def data_gen(bs, t=200, u=100, v=1024, h=640, dtype=torch.float32): | |
""" | |
Generate seeded data for the benchmark. Every call to this function will generate the same data for a given | |
set of input parameters. | |
Args: | |
bs: Batch Size | |
t: Audio Timesteps | |
u: Text Tokens | |
v: Vocabulary Size | |
h: RNNT Hidden size | |
dtype: torch.dtype | |
Returns: | |
x: Audio data of shape [bs, h, t] | |
x_len: Audio length of shape [bs] | |
y: Text data of shape [bs, u - 1] | |
y_len: Text length of shape [bs] | |
""" | |
# utilize global variables for input to loss | |
torch.cuda.empty_cache() | |
torch.manual_seed(0) | |
x = torch.randn(bs, h, t, dtype=dtype, device=DEVICE, requires_grad=False) | |
x_len = torch.randint(t, size=[bs], device=DEVICE, dtype=torch.int64) | |
y = torch.randint(v, size=[bs, u - 1], device=DEVICE, dtype=torch.int64) | |
y_len = torch.randint(u, size=[bs], device=DEVICE, dtype=torch.int64) | |
# enforce some RNNT input constraints | |
rand_idx = torch.randint(bs, size=[1]) | |
x_len[rand_idx] = t | |
y_len[rand_idx] = u - 1 | |
return x, x_len, y, y_len | |
def str_to_int_list(string: str) -> List[int]: | |
return [int(x) for x in string.split(',')] if string else [] | |
################################################################################### | |
# MODEL UTILS | |
def rnnt_decoder_joint(v=1024, h=640, dtype=torch.float32, requires_grad=False) -> Tuple[rnnt.RNNTDecoder, rnnt.RNNTJoint]: | |
"""Build a RNNTDecoder and RNNTJoint with the given parameters.""" | |
seed_everything(0) | |
prednet = {'pred_hidden': h, 'pred_rnn_layers': 1} | |
rnnt_decoder = rnnt.RNNTDecoder(prednet, vocab_size=v) | |
jointnet = {'joint_hidden': h, 'encoder_hidden': h, 'pred_hidden': h, 'activation': 'relu'} | |
rnnt_joint = rnnt.RNNTJoint(jointnet, num_classes=v, fuse_loss_wer=False) | |
rnnt_decoder.to(dtype=dtype, device=DEVICE) | |
rnnt_joint.to(dtype=dtype, device=DEVICE) | |
# Setup zero grad of params if needed | |
if requires_grad: | |
with torch.no_grad(): | |
for p in rnnt_decoder.parameters(): | |
p.requires_grad = True | |
p.grad = torch.zeros_like(p) | |
for p in rnnt_joint.parameters(): | |
p.requires_grad = True | |
p.grad = torch.zeros_like(p) | |
return rnnt_decoder, rnnt_joint | |
def rnnt_forward(x, x_len, y, y_len, rnnt_decoder, rnnt_joint): | |
"""Run the forward pass of the RNNTDecoder and RNNTJoint.""" | |
g, target_length, states = rnnt_decoder(targets=y, target_length=y_len) | |
acts = rnnt_joint( | |
encoder_outputs=x, decoder_outputs=g, encoder_lengths=x_len, transcripts=y, transcript_lengths=y_len | |
) | |
return acts | |
def check_memory_numba(rnnt_loss, x, x_len, y, y_len, requires_grad=False): | |
"""Compute the RNNT Loss on the activations check the memory consumed by the Numba kernel.""" | |
loss = rnnt_loss(log_probs=x, targets=y, input_lengths=x_len, target_lengths=y_len) | |
if requires_grad: | |
loss.sum().backward() # compute gradients | |
return loss | |
def exec_closure(args): | |
""" | |
Closure function for the benchmark. This function is called by the benchmarking script and is responsible for | |
running the benchmark and returning the results. | |
Returns: | |
results: Path to a pickle file, containing List of measurements from the benchmark. | |
""" | |
# Compare takes a list of measurements which we'll save in results. | |
results = [] | |
torch.cuda.empty_cache() | |
basedir = args.results_dir | |
if not os.path.exists(basedir): | |
os.makedirs(basedir, exist_ok=True) | |
results_path = os.path.join(basedir, f'rnnt_results_requires_grad_{str(args.requires_grad)}.pkl') | |
# Parse benchmark arguments | |
batchsizes = str_to_int_list(args.B) | |
audio_lens = str_to_int_list(args.T) | |
target_lens = str_to_int_list(args.U) | |
vocab_sizes = str_to_int_list(args.V) | |
hidden_sizes = str_to_int_list(args.H) | |
print() | |
print("*" * 100) | |
print() | |
print("Gradients computed :", args.requires_grad) | |
print() | |
# If we're not computing the benchmark, just return the results path | |
if not args.compute: | |
return results_path | |
# Save the empty result list, thereby resetting results | |
save_results(results, results_path) | |
del results | |
for b in batchsizes: # 1, 4, 8, 16, 32, 64 (on 48 GB GPUs) | |
for t in audio_lens: # 200, 400, 600 (LibriSpeech with 4x and 8x stride, on 32 GB GPUs) | |
for u in target_lens: # 100, 200 # (char enc, subword enc) | |
for v in vocab_sizes: # 28, 1024 # (char encoding, Conformer RNNT Vocab Size) | |
for h in hidden_sizes: # 640, 1024 # (common hidden size of encoder, decoder, joint) | |
for dtype in [ | |
torch.float32, | |
torch.float16, | |
]: | |
# Access global dataset and reset | |
global x, x_len, y, y_len | |
x = None | |
x_len = None | |
y = None | |
y_len = None | |
# Reset CUDA memory stats | |
torch.cuda.empty_cache() | |
torch.cuda.reset_peak_memory_stats() | |
torch.cuda.reset_accumulated_memory_stats() | |
torch.cuda.reset_max_memory_cached() | |
torch.cuda.reset_max_memory_allocated() | |
# Setup CUDA monitor flags | |
monitor_cuda_mem.DEVICE = 0 | |
monitor_cuda_mem.EMPTY = False | |
monitor_cuda_mem.PRECISION = 2 | |
monitor_cuda_mem.VERBOSE = False | |
# sub_label are the rows | |
# description is the column | |
sub_label = f'[b={b}, t={t}, u={u}, v={v}, h={h}]' | |
print("Computing :", sub_label) | |
# Numba FP 32 or 16 depending on `dtype` | |
env = f'FP {dtype}' | |
# Build batch of samples with seed set | |
with monitor_cuda_mem(f'Data {dtype}', empty=False) as datagen_mem: | |
x, x_len, y, y_len = data_gen(b, t, u, v, h, dtype=dtype) | |
if args.requires_grad: | |
x.requires_grad_(True) | |
with torch.no_grad(): | |
x.grad = torch.zeros_like(x, dtype=dtype) | |
print("Batch data memory", datagen_mem.memory_diff_human) | |
# Build new RNNT decoder and joint | |
with monitor_cuda_mem(f'{env} {sub_label}', empty=False) as rnnt_mem: | |
dec, joint = rnnt_decoder_joint(v, h=h, dtype=dtype, requires_grad=args.requires_grad) | |
# Compute the joint activations | |
# (we don't need to perform log-softmax due to fused kernel) | |
acts = rnnt_forward(x, x_len, y, y_len, dec, joint) | |
print("RNNT Decoder+Joint memory", rnnt_mem.memory_diff_human) | |
# Compute the loss and memory cost | |
with monitor_cuda_mem(f'{env} {sub_label}', empty=False) as loss_mem: | |
blank = x.shape[-1] - 1 | |
rnnt_loss = RNNTLoss( | |
num_classes=blank, | |
reduction='sum', | |
loss_name='warprnnt_numba', | |
loss_kwargs=dict(fastemit_lambda=0.0, clamp=-1.0), | |
) | |
# Compute the loss and check memory | |
# Note: We are not measuring speed, and therefore Numba JIT compile time is not measured | |
# Therefore we skip performing a warmup run of the loss function | |
unused_value_ = check_memory_numba( | |
rnnt_loss, acts, x_len, y, y_len, requires_grad=args.requires_grad | |
) | |
result = { | |
'data_mem': datagen_mem, | |
'rnnt_mem': rnnt_mem, | |
'loss_mem': loss_mem, | |
} | |
print(f"Loss memory ({dtype})", loss_mem.memory_diff_human) | |
print(f"Peak memory allocated : {HumanBytes.format(torch.cuda.max_memory_allocated())}") | |
print() | |
# Save results to disk | |
results = load_results(results_path) | |
results.append(result) | |
save_results(results, results_path) | |
# Clean up memory for next benchmark | |
del results, unused_value_ | |
del dec, joint, acts | |
del rnnt_loss, blank | |
del x, x_len, y, y_len | |
torch.cuda.empty_cache() | |
return results_path | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="RNNT Loss Benchmark") | |
parser.add_argument( | |
'-B', '--batch-size', dest='B', type=str, default='1,4,8,16,32', help="Batch sizes to benchmark" | |
) | |
parser.add_argument( | |
'-T', '--audio-len', dest='T', type=str, default='200,400', help="Max audio lengths to benchmark" | |
) | |
parser.add_argument( | |
'-U', '--text-len', dest='U', type=str, default='100,200', help="Max text sequence lengths to benchmark" | |
) | |
parser.add_argument('-V', '--vocab-size', dest='V', type=str, default='28,1024', help="Vocab sizes to benchmark") | |
parser.add_argument('-H', '--hidden-size', dest='H', type=str, default='640', help="Hidden size to RNNT Joint") | |
parser.add_argument("--results_dir", type=str, default='./numba_fp32_vs_fp16', help="Name of results directory") | |
parser.add_argument( | |
"--no-compute", | |
dest='compute', | |
action='store_false', | |
help="Whether to avoid computing results or not. " | |
"Used when only printing the results without recomputation.", | |
) | |
parser.add_argument( | |
'--no-grads', | |
dest='requires_grad', | |
action='store_false', | |
help="Whether to avoid calculating gradients or not to compare memory usage", | |
) | |
parser.add_argument( | |
'--full-results', | |
dest='simplify_results', | |
action='store_false', | |
help="Print full results - including breakdown of data storage size and activations", | |
) | |
parser.set_defaults(compute=True, require_grad=True) | |
args = parser.parse_args() | |
return args | |
def main(args): | |
log_system() | |
results_path = exec_closure(args) | |
print("\n\n") | |
print("Results::") | |
print_results(results_path, simplify_results=args.simplify_results) | |
if __name__ == '__main__': | |
args = parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment