Last active
June 19, 2025 18:22
-
-
Save pszemraj/24e98af11455edc11a0ec02c9699129e to your computer and use it in GitHub Desktop.
Prints an accurate summary of a pytorch model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from typing import List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
@dataclass | |
class _LayerSummary: | |
"""A dataclass to hold summary information for a single layer.""" | |
name: str | |
param_shape: Optional[torch.Size] | |
inclusive_total_params: int | |
inclusive_trainable_params: int | |
def model_summary( | |
model: nn.Module, max_depth: int = 4, show_param_shapes: bool = False | |
) -> None: | |
""" | |
Prints a hierarchical summary of a PyTorch model. | |
This function provides a detailed view of the model's layers, including | |
inclusive parameter counts and trainable status, in a hierarchical format. | |
:param model: The torch.nn.Module to summarize. | |
:param max_depth: The maximum depth of the model hierarchy to display. | |
A value of 0 will only show the total parameter counts. | |
A value of 1 will only show the main model layer. | |
Defaults to 4. | |
:param show_param_shapes: If True, displays the shape of the first direct | |
trainable parameter for each layer. Defaults to False. | |
""" | |
def _format_number(num: int) -> str: | |
"""Formats a number with commas or returns '--' if zero.""" | |
return f"{num:,}" if num > 0 else "--" | |
def _format_shape(shape: Optional[torch.Size]) -> str: | |
"""Formats a torch.Size object into a string like '1x2x3'.""" | |
return "x".join(map(str, shape)) if shape else "N/A" | |
def _count_module_params(module: nn.Module) -> Tuple[int, int]: | |
""" | |
Counts the total and trainable parameters directly in a module (non-recursive). | |
""" | |
total_params = 0 | |
trainable_params = 0 | |
for param in module.parameters(recurse=False): | |
num_elements = param.numel() | |
total_params += num_elements | |
if param.requires_grad: | |
trainable_params += num_elements | |
return total_params, trainable_params | |
def _create_summary_list( | |
model: nn.Module, max_depth: int | |
) -> Tuple[List[_LayerSummary], int, int]: | |
""" | |
Builds the hierarchical list of layer summaries using pre-order traversal. | |
""" | |
summary_list = [] | |
def summarize_recursive( | |
module: nn.Module, depth: int, prefix: str | |
) -> Tuple[int, int]: | |
"""Recursively summarizes a module and its children.""" | |
if depth > max_depth: | |
# If max depth is reached, just count all params below and return | |
total = 0 | |
trainable = 0 | |
for p in module.parameters(recurse=True): | |
total += p.numel() | |
if p.requires_grad: | |
trainable += p.numel() | |
return total, trainable | |
# 1. Add parent to the list first (pre-order traversal) | |
current_summary_index = len(summary_list) | |
param_shape = next( | |
(p.shape for p in module.parameters(recurse=False) if p.requires_grad), | |
None, | |
) | |
summary_list.append( | |
_LayerSummary( | |
name=f"{prefix}{type(module).__name__}", | |
param_shape=param_shape, | |
inclusive_total_params=0, # Placeholder, will be updated later | |
inclusive_trainable_params=0, # Placeholder | |
) | |
) | |
# 2. Recurse into children to get their counts and add them to the list | |
exclusive_total, exclusive_trainable = _count_module_params(module) | |
children_total = 0 | |
children_trainable = 0 | |
for child in module.children(): | |
child_total, child_trainable = summarize_recursive( | |
child, depth + 1, prefix + " " | |
) | |
children_total += child_total | |
children_trainable += child_trainable | |
inclusive_total = exclusive_total + children_total | |
inclusive_trainable = exclusive_trainable + children_trainable | |
# 3. Update the parent's summary with the final inclusive counts | |
summary_list[current_summary_index].inclusive_total_params = inclusive_total | |
summary_list[ | |
current_summary_index | |
].inclusive_trainable_params = inclusive_trainable | |
return inclusive_total, inclusive_trainable | |
if max_depth <= 0: | |
total_params = sum(p.numel() for p in model.parameters()) | |
trainable_params = sum( | |
p.numel() for p in model.parameters() if p.requires_grad | |
) | |
return [], total_params, trainable_params | |
total_params, trainable_params = summarize_recursive(model, 1, "") | |
return summary_list, total_params, trainable_params | |
summary_entries, total_params, trainable_params = _create_summary_list( | |
model, max_depth | |
) | |
name_col_width = len("Layer (type)") | |
if summary_entries: | |
name_col_width = max(name_col_width, max(len(s.name) for s in summary_entries)) | |
shape_col_width = 0 | |
if show_param_shapes: | |
shape_col_width = len("Param Shape") | |
if summary_entries: | |
shape_strings = [_format_shape(s.param_shape) for s in summary_entries] | |
max_shape_str_len = ( | |
max(len(s) for s in shape_strings) if shape_strings else 0 | |
) | |
shape_col_width = max(shape_col_width, max_shape_str_len) | |
params_col_width = 12 | |
trainable_col_width = 10 | |
col_spacing = " " | |
# --- Printing --- | |
header_parts = [f"{'Layer (type)':<{name_col_width}}"] | |
if show_param_shapes: | |
header_parts.append(f"{'Param Shape':>{shape_col_width}}") | |
header_parts.append(f"{'Param #':>{params_col_width}}") | |
header_parts.append(f"{'Trainable':>{trainable_col_width}}") | |
header = col_spacing.join(header_parts) | |
separator = "=" * len(header) | |
print(separator) | |
print(header) | |
print(separator) | |
for entry in summary_entries: | |
line_parts = [f"{entry.name:<{name_col_width}}"] | |
if show_param_shapes: | |
shape_str = _format_shape(entry.param_shape) | |
line_parts.append(f"{shape_str:>{shape_col_width}}") | |
line_parts.append( | |
f"{_format_number(entry.inclusive_total_params):>{params_col_width}}" | |
) | |
is_trainable = entry.inclusive_trainable_params > 0 | |
line_parts.append(f"{str(is_trainable):>{trainable_col_width}}") | |
print(col_spacing.join(line_parts)) | |
print(separator) | |
print(f"Total params: {_format_number(total_params)}") | |
print(f"Trainable params: {_format_number(trainable_params)}") | |
non_trainable_params = total_params - trainable_params | |
print(f"Non-trainable params: {_format_number(non_trainable_params)}") | |
print(separator) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer, AutoModelForPreTraining | |
tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator") | |
model = AutoModelForPreTraining.from_pretrained("google/electra-base-discriminator") | |
# assuming model_summary() already defined | |
model_summary(model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment