Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Created July 23, 2024 04:19
Show Gist options
  • Save pszemraj/1495755758da2a204fc34b66fa25d426 to your computer and use it in GitHub Desktop.
Save pszemraj/1495755758da2a204fc34b66fa25d426 to your computer and use it in GitHub Desktop.
print out a summary of a pytorch model
from typing import List, Tuple, Optional, Set
import torch.nn as nn
from transformers import PreTrainedModel
def model_summary(
model: PreTrainedModel, max_depth: int = 4, show_input_size: bool = False
) -> None:
"""
Prints an accurate summary of the model, avoiding double-counting of parameters.
:param PreTrainedModel model: torch model to summarize
:param int max_depth: maximum depth of the model to print, defaults to 4
:param bool show_input_size: whether to show input size for each layer, defaults to False
"""
def format_params(num_params: int) -> str:
return f"{num_params:,}" if num_params > 0 else "--"
def format_size(size: Optional[List[int]]) -> str:
return "x".join(str(x) for x in size) if size else "N/A"
def count_parameters(module: nn.Module) -> Tuple[int, int]:
total_params = sum(p.numel() for p in module.parameters())
trainable_params = sum(
p.numel() for p in module.parameters() if p.requires_grad
)
return total_params, trainable_params
def recursive_summarize(
module: nn.Module, depth: int, idx: List[int], prefix: str = ""
) -> List[Tuple[str, int, int, int, Optional[List[int]], nn.Module]]:
summary = []
total_params, trainable_params = count_parameters(module)
if depth <= max_depth:
layer_name = f"{prefix}{type(module).__name__}"
layer_index = ".".join(map(str, idx))
param_shape = next(
(p.shape for p in module.parameters(recurse=False) if p.requires_grad),
None,
)
summary.append(
(layer_name, depth, total_params, trainable_params, param_shape, module)
)
for i, (name, child) in enumerate(module.named_children(), 1):
child_summary = recursive_summarize(
child, depth + 1, idx + [i], prefix + " "
)
summary.extend(child_summary)
return summary
summary = recursive_summarize(model, 1, [1])
max_name_length = max(len(name) for name, _, _, _, _, _ in summary)
max_shape_length = max(len(format_size(shape)) for _, _, _, _, shape, _ in summary)
print("=" * (max_name_length + 50))
header = f"{'Layer (type:depth-idx)':<{max_name_length}} {'Output Shape':>{max_shape_length}} {'Param #':>12} {'Trainable':>10}"
print(header)
print("=" * (max_name_length + 50))
for name, depth, num_params, trainable_params, shape, _ in summary:
shape_str = format_size(shape) if show_input_size else ""
print(
f"{name:<{max_name_length}} {shape_str:>{max_shape_length}} {format_params(num_params):>12} {str(trainable_params > 0):>10}"
)
total_params, trainable_params = count_parameters(model)
print("=" * (max_name_length + 50))
print(f"Total params: {format_params(total_params)}")
print(f"Trainable params: {format_params(trainable_params)}")
print(f"Non-trainable params: {format_params(total_params - trainable_params)}")
print("=" * (max_name_length + 50))
# Example usage:
# from transformers import T5ForConditionalGeneration
# model = T5ForConditionalGeneration.from_pretrained("t5-small")
# model_summary(model, max_depth=4, show_input_size=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment