Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active June 19, 2025 18:22
Show Gist options
  • Save pszemraj/24e98af11455edc11a0ec02c9699129e to your computer and use it in GitHub Desktop.
Save pszemraj/24e98af11455edc11a0ec02c9699129e to your computer and use it in GitHub Desktop.
Prints an accurate summary of a pytorch model
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
@dataclass
class _LayerSummary:
"""A dataclass to hold summary information for a single layer."""
name: str
param_shape: Optional[torch.Size]
inclusive_total_params: int
inclusive_trainable_params: int
def model_summary(
model: nn.Module, max_depth: int = 4, show_param_shapes: bool = False
) -> None:
"""
Prints a hierarchical summary of a PyTorch model.
This function provides a detailed view of the model's layers, including
inclusive parameter counts and trainable status, in a hierarchical format.
:param model: The torch.nn.Module to summarize.
:param max_depth: The maximum depth of the model hierarchy to display.
A value of 0 will only show the total parameter counts.
A value of 1 will only show the main model layer.
Defaults to 4.
:param show_param_shapes: If True, displays the shape of the first direct
trainable parameter for each layer. Defaults to False.
"""
def _format_number(num: int) -> str:
"""Formats a number with commas or returns '--' if zero."""
return f"{num:,}" if num > 0 else "--"
def _format_shape(shape: Optional[torch.Size]) -> str:
"""Formats a torch.Size object into a string like '1x2x3'."""
return "x".join(map(str, shape)) if shape else "N/A"
def _count_module_params(module: nn.Module) -> Tuple[int, int]:
"""
Counts the total and trainable parameters directly in a module (non-recursive).
"""
total_params = 0
trainable_params = 0
for param in module.parameters(recurse=False):
num_elements = param.numel()
total_params += num_elements
if param.requires_grad:
trainable_params += num_elements
return total_params, trainable_params
def _create_summary_list(
model: nn.Module, max_depth: int
) -> Tuple[List[_LayerSummary], int, int]:
"""
Builds the hierarchical list of layer summaries using pre-order traversal.
"""
summary_list = []
def summarize_recursive(
module: nn.Module, depth: int, prefix: str
) -> Tuple[int, int]:
"""Recursively summarizes a module and its children."""
if depth > max_depth:
# If max depth is reached, just count all params below and return
total = 0
trainable = 0
for p in module.parameters(recurse=True):
total += p.numel()
if p.requires_grad:
trainable += p.numel()
return total, trainable
# 1. Add parent to the list first (pre-order traversal)
current_summary_index = len(summary_list)
param_shape = next(
(p.shape for p in module.parameters(recurse=False) if p.requires_grad),
None,
)
summary_list.append(
_LayerSummary(
name=f"{prefix}{type(module).__name__}",
param_shape=param_shape,
inclusive_total_params=0, # Placeholder, will be updated later
inclusive_trainable_params=0, # Placeholder
)
)
# 2. Recurse into children to get their counts and add them to the list
exclusive_total, exclusive_trainable = _count_module_params(module)
children_total = 0
children_trainable = 0
for child in module.children():
child_total, child_trainable = summarize_recursive(
child, depth + 1, prefix + " "
)
children_total += child_total
children_trainable += child_trainable
inclusive_total = exclusive_total + children_total
inclusive_trainable = exclusive_trainable + children_trainable
# 3. Update the parent's summary with the final inclusive counts
summary_list[current_summary_index].inclusive_total_params = inclusive_total
summary_list[
current_summary_index
].inclusive_trainable_params = inclusive_trainable
return inclusive_total, inclusive_trainable
if max_depth <= 0:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
return [], total_params, trainable_params
total_params, trainable_params = summarize_recursive(model, 1, "")
return summary_list, total_params, trainable_params
summary_entries, total_params, trainable_params = _create_summary_list(
model, max_depth
)
name_col_width = len("Layer (type)")
if summary_entries:
name_col_width = max(name_col_width, max(len(s.name) for s in summary_entries))
shape_col_width = 0
if show_param_shapes:
shape_col_width = len("Param Shape")
if summary_entries:
shape_strings = [_format_shape(s.param_shape) for s in summary_entries]
max_shape_str_len = (
max(len(s) for s in shape_strings) if shape_strings else 0
)
shape_col_width = max(shape_col_width, max_shape_str_len)
params_col_width = 12
trainable_col_width = 10
col_spacing = " "
# --- Printing ---
header_parts = [f"{'Layer (type)':<{name_col_width}}"]
if show_param_shapes:
header_parts.append(f"{'Param Shape':>{shape_col_width}}")
header_parts.append(f"{'Param #':>{params_col_width}}")
header_parts.append(f"{'Trainable':>{trainable_col_width}}")
header = col_spacing.join(header_parts)
separator = "=" * len(header)
print(separator)
print(header)
print(separator)
for entry in summary_entries:
line_parts = [f"{entry.name:<{name_col_width}}"]
if show_param_shapes:
shape_str = _format_shape(entry.param_shape)
line_parts.append(f"{shape_str:>{shape_col_width}}")
line_parts.append(
f"{_format_number(entry.inclusive_total_params):>{params_col_width}}"
)
is_trainable = entry.inclusive_trainable_params > 0
line_parts.append(f"{str(is_trainable):>{trainable_col_width}}")
print(col_spacing.join(line_parts))
print(separator)
print(f"Total params: {_format_number(total_params)}")
print(f"Trainable params: {_format_number(trainable_params)}")
non_trainable_params = total_params - trainable_params
print(f"Non-trainable params: {_format_number(non_trainable_params)}")
print(separator)
from transformers import AutoTokenizer, AutoModelForPreTraining
tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator")
model = AutoModelForPreTraining.from_pretrained("google/electra-base-discriminator")
# assuming model_summary() already defined
model_summary(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment