Created
July 23, 2024 04:19
-
-
Save pszemraj/1495755758da2a204fc34b66fa25d426 to your computer and use it in GitHub Desktop.
print out a summary of a pytorch model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Tuple, Optional, Set | |
import torch.nn as nn | |
from transformers import PreTrainedModel | |
def model_summary( | |
model: PreTrainedModel, max_depth: int = 4, show_input_size: bool = False | |
) -> None: | |
""" | |
Prints an accurate summary of the model, avoiding double-counting of parameters. | |
:param PreTrainedModel model: torch model to summarize | |
:param int max_depth: maximum depth of the model to print, defaults to 4 | |
:param bool show_input_size: whether to show input size for each layer, defaults to False | |
""" | |
def format_params(num_params: int) -> str: | |
return f"{num_params:,}" if num_params > 0 else "--" | |
def format_size(size: Optional[List[int]]) -> str: | |
return "x".join(str(x) for x in size) if size else "N/A" | |
def count_parameters(module: nn.Module) -> Tuple[int, int]: | |
total_params = sum(p.numel() for p in module.parameters()) | |
trainable_params = sum( | |
p.numel() for p in module.parameters() if p.requires_grad | |
) | |
return total_params, trainable_params | |
def recursive_summarize( | |
module: nn.Module, depth: int, idx: List[int], prefix: str = "" | |
) -> List[Tuple[str, int, int, int, Optional[List[int]], nn.Module]]: | |
summary = [] | |
total_params, trainable_params = count_parameters(module) | |
if depth <= max_depth: | |
layer_name = f"{prefix}{type(module).__name__}" | |
layer_index = ".".join(map(str, idx)) | |
param_shape = next( | |
(p.shape for p in module.parameters(recurse=False) if p.requires_grad), | |
None, | |
) | |
summary.append( | |
(layer_name, depth, total_params, trainable_params, param_shape, module) | |
) | |
for i, (name, child) in enumerate(module.named_children(), 1): | |
child_summary = recursive_summarize( | |
child, depth + 1, idx + [i], prefix + " " | |
) | |
summary.extend(child_summary) | |
return summary | |
summary = recursive_summarize(model, 1, [1]) | |
max_name_length = max(len(name) for name, _, _, _, _, _ in summary) | |
max_shape_length = max(len(format_size(shape)) for _, _, _, _, shape, _ in summary) | |
print("=" * (max_name_length + 50)) | |
header = f"{'Layer (type:depth-idx)':<{max_name_length}} {'Output Shape':>{max_shape_length}} {'Param #':>12} {'Trainable':>10}" | |
print(header) | |
print("=" * (max_name_length + 50)) | |
for name, depth, num_params, trainable_params, shape, _ in summary: | |
shape_str = format_size(shape) if show_input_size else "" | |
print( | |
f"{name:<{max_name_length}} {shape_str:>{max_shape_length}} {format_params(num_params):>12} {str(trainable_params > 0):>10}" | |
) | |
total_params, trainable_params = count_parameters(model) | |
print("=" * (max_name_length + 50)) | |
print(f"Total params: {format_params(total_params)}") | |
print(f"Trainable params: {format_params(trainable_params)}") | |
print(f"Non-trainable params: {format_params(total_params - trainable_params)}") | |
print("=" * (max_name_length + 50)) | |
# Example usage: | |
# from transformers import T5ForConditionalGeneration | |
# model = T5ForConditionalGeneration.from_pretrained("t5-small") | |
# model_summary(model, max_depth=4, show_input_size=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment