Skip to content

Instantly share code, notes, and snippets.

@g023
Created November 27, 2025 06:16
Show Gist options
  • Select an option

  • Save g023/891d45858a2a80062d9acf7dbdc7b49e to your computer and use it in GitHub Desktop.

Select an option

Save g023/891d45858a2a80062d9acf7dbdc7b49e to your computer and use it in GitHub Desktop.
Memory-Efficient Backpropagation: Chunked Linear Layer
# Author g023 - https://x.com/g023dev - https://github.com/g023
import torch
import torch.nn as nn
import gc
import math
import tracemalloc
# Optional psutil for CPU memory readings; if missing we'll fall back to CUDA
try:
import psutil
_ps = psutil.Process()
except Exception:
_ps = None
class MemoryProfiler:
"""Lightweight memory profiler combining tracemalloc (Python allocations),
optional psutil RSS, and torch.cuda peak stats.
Usage:
with MemoryProfiler('label') as mp:
... code to profile ...
# then read mp attributes like mp.rss_before/mp.rss_after/mp.tracemalloc_peak/mp.cuda_peak
"""
def __init__(self, name=None):
self.name = name or "mp"
self.tracemalloc_enabled = hasattr(tracemalloc, "start")
self.rss_before = None
self.rss_after = None
self.tracemalloc_before = None
self.tracemalloc_peak = None
self.cuda_before = None
self.cuda_peak = None
def __enter__(self):
gc.collect()
if _ps is not None:
try:
self.rss_before = _ps.memory_info().rss
except Exception:
self.rss_before = None
if torch.cuda.is_available():
# reset CUDA peak counters then snapshot pre-run
try:
torch.cuda.reset_peak_memory_stats()
self.cuda_before = torch.cuda.memory_allocated()
except Exception:
self.cuda_before = None
if self.tracemalloc_enabled:
tracemalloc.start()
self.tracemalloc_before = tracemalloc.get_traced_memory()[0]
return self
def __exit__(self, exc_type, exc, tb):
gc.collect()
if _ps is not None:
try:
self.rss_after = _ps.memory_info().rss
except Exception:
self.rss_after = None
if self.tracemalloc_enabled:
current, peak = tracemalloc.get_traced_memory()
self.tracemalloc_peak = peak
tracemalloc.stop()
if torch.cuda.is_available():
try:
self.cuda_peak = torch.cuda.max_memory_allocated()
except Exception:
self.cuda_peak = None
return False
def _fmt_bytes(b):
if b is None:
return "n/a"
return f"{b/1024/1024:.3f} MB"
class MemoryEfficientLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w, chunk_size=1024):
"""
Forward pass: Computes output without storing the huge logits Z.
"""
# Save only the input tensor in saved_tensors (lighter for PyTorch internals)
# Keep a plain reference to weights on ctx so they don't get put into
# saved_tensors (this reduces memory pressure from the autograd internals).
ctx.save_for_backward(x)
ctx.w_ref = w
ctx.chunk_size = int(chunk_size)
# Validate shapes/types
if ctx.chunk_size <= 0:
raise ValueError(f"chunk_size must be > 0, got {chunk_size}")
# Example Function f: Sum(Sigmoid(XW)) computed in chunks to avoid
# materializing the full [N, V] logits matrix.
total_loss = x.new_zeros(())
vocab_size = w.shape[1]
for i in range(0, vocab_size, ctx.chunk_size):
w_chunk = w[:, i:i+ctx.chunk_size]
z_chunk = x @ w_chunk
# accumulate scalar loss; z_chunk is not saved
total_loss = total_loss + torch.sigmoid(z_chunk).sum()
return total_loss
@staticmethod
def backward(ctx, grad_output):
"""
Backward pass: Recomputes Z chunks on-the-fly to calculate gradients.
"""
saved = ctx.saved_tensors
if len(saved) < 1:
raise RuntimeError("Expected saved input tensor(s) in ctx")
x = saved[0]
w = getattr(ctx, "w_ref", None)
chunk_size = int(ctx.chunk_size)
if w is None:
raise RuntimeError("Weights reference missing on ctx; cannot recompute logits")
# This implementation currently assumes the forward returned a scalar
# loss (as the example does). If grad_output is not a scalar tensor,
# the user must adapt the forward/backward logic.
if grad_output.numel() != 1:
raise NotImplementedError(
"MemoryEfficientLinear.backward currently supports scalar output only"
)
# Preallocate gradient tensors on the correct device/dtype
dx = torch.zeros_like(x)
dw = torch.zeros_like(w)
vocab_size = w.shape[1]
# We'll recompute z chunks without tracking autograd to avoid building a
# secondary graph during backward.
with torch.no_grad():
for i in range(0, vocab_size, chunk_size):
w_chunk = w[:, i:i+chunk_size]
# Rematerialize small logit block
z_chunk = x @ w_chunk
# Compute sigmoid and its derivative (in no_grad context)
sig_z = torch.sigmoid(z_chunk)
local_grad = grad_output * (sig_z * (1 - sig_z))
# Accumulate dW and dX. Use in-place ops to reduce allocations.
# dW[:, i:i+chunk_size] = x.t() @ local_grad
dw[:, i:i+chunk_size].add_(x.t() @ local_grad)
dx.add_(local_grad @ w_chunk.t())
# free temporaries explicitly
del z_chunk, sig_z, local_grad
return dx, dw, None
# --- VERIFICATION ---
# Setup Data
BS, SEQ, HD, VOCAB = 2, 128, 64, 4096 # Small dims for demo
chunk_size = 512
x = torch.randn(BS*SEQ, HD, requires_grad=True)
w = torch.randn(HD, VOCAB, requires_grad=True)
# 1. Naive Implementation (High Memory)
# Materializes full [256, 4096] matrix
x_naive = x.clone().detach().requires_grad_(True)
w_naive = nn.Parameter(w.clone().detach())
# prepare and profile naive run
with MemoryProfiler("naive") as mp_naive:
logits = x_naive @ w_naive
loss_naive = torch.sigmoid(logits).sum()
loss_naive.backward()
# capture profiler outputs
mem_before_naive = mp_naive.rss_before
mem_after_naive = mp_naive.rss_after
cuda_before_naive = mp_naive.cuda_before
cuda_peak_naive = mp_naive.cuda_peak
tracemalloc_peak_naive = mp_naive.tracemalloc_peak
# 2. Efficient Implementation (Low Memory)
# 2. Efficient Implementation (Low Memory)
# Never materializes full matrix
x_eff = x.clone().detach().requires_grad_(True)
w_eff = nn.Parameter(w.clone().detach())
# cleanup before efficient run to reduce carryover and profile
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
with MemoryProfiler("eff") as mp_eff:
loss_eff = MemoryEfficientLinear.apply(x_eff, w_eff, chunk_size)
loss_eff.backward()
mem_before_eff = mp_eff.rss_before
mem_after_eff = mp_eff.rss_after
cuda_before_eff = mp_eff.cuda_before
cuda_peak_eff = mp_eff.cuda_peak
tracemalloc_peak_eff = mp_eff.tracemalloc_peak
# Check Results (more detailed diagnostics)
loss_match = torch.allclose(loss_naive, loss_eff)
dx_close = torch.allclose(x_naive.grad, x_eff.grad, atol=1e-5)
dw_close = torch.allclose(w_naive.grad, w_eff.grad, atol=1e-5)
# numeric diffs
loss_diff = (loss_naive - loss_eff).abs().item()
dx_diff = (x_naive.grad - x_eff.grad)
dw_diff = (w_naive.grad - w_eff.grad)
dx_max = dx_diff.abs().max().item()
dx_mean = dx_diff.abs().mean().item()
dw_max = dw_diff.abs().max().item()
dw_mean = dw_diff.abs().mean().item()
dx_rel = dx_diff.norm().item() / (x_naive.grad.norm().item() + 1e-12)
dw_rel = dw_diff.norm().item() / (w_naive.grad.norm().item() + 1e-12)
print(f"Loss naive: {loss_naive.item():.6e}, eff: {loss_eff.item():.6e}, abs diff: {loss_diff:.3e}, equal: {loss_match}")
print(f"dX equal: {dx_close}, max_abs_diff: {dx_max:.3e}, mean_abs_diff: {dx_mean:.3e}, rel_norm_diff: {dx_rel:.3e}")
print(f"dW equal: {dw_close}, max_abs_diff: {dw_max:.3e}, mean_abs_diff: {dw_mean:.3e}, rel_norm_diff: {dw_rel:.3e}")
print(f"||x_naive.grad||_2: {x_naive.grad.norm().item():.6e}, ||w_naive.grad||_2: {w_naive.grad.norm().item():.6e}")
print(f"dX contains NaN: {torch.isnan(x_naive.grad).any().item() or torch.isnan(x_eff.grad).any().item()}")
print(f"dW contains NaN: {torch.isnan(w_naive.grad).any().item() or torch.isnan(w_eff.grad).any().item()}")
print("\nIf these diagnostics look small and equalities are True, the algorithm is correct!")
# --- MEMORY SUMMARY ---
print('\n--- Memory summary ---')
if _ps is not None:
naive_rss_diff = None if (mem_before_naive is None or mem_after_naive is None) else (mem_after_naive - mem_before_naive)
eff_rss_diff = None if (mem_before_eff is None or mem_after_eff is None) else (mem_after_eff - mem_before_eff)
print(f"Naive RSS before: {_fmt_bytes(mem_before_naive)}, after: {_fmt_bytes(mem_after_naive)}, delta: {_fmt_bytes(naive_rss_diff)}")
print(f"Eff RSS before: {_fmt_bytes(mem_before_eff)}, after: {_fmt_bytes(mem_after_eff)}, delta: {_fmt_bytes(eff_rss_diff)}")
if naive_rss_diff is not None and eff_rss_diff is not None:
try:
saved = naive_rss_diff - eff_rss_diff
perc = 100.0 * (saved / (naive_rss_diff + 1e-12))
print(f"RSS saved (naive - eff): {_fmt_bytes(saved)} ({perc:.2f}% of naive delta)")
except Exception:
pass
else:
print("psutil not available: CPU RSS measurements not shown. Install psutil to enable CPU memory reporting: pip install psutil")
if torch.cuda.is_available():
print(f"CUDA allocated before naive: {_fmt_bytes(cuda_before_naive)}, peak naive: {_fmt_bytes(cuda_peak_naive)}")
print(f"CUDA allocated before eff: {_fmt_bytes(cuda_before_eff)}, peak eff: {_fmt_bytes(cuda_peak_eff)}")
if cuda_peak_naive is not None and cuda_peak_eff is not None:
saved_cuda = cuda_peak_naive - cuda_peak_eff
perc_cuda = 100.0 * (saved_cuda / (cuda_peak_naive + 1e-12))
print(f"CUDA peak saved (naive - eff): {_fmt_bytes(saved_cuda)} ({perc_cuda:.2f}% of naive peak)")
else:
print("CUDA not available or not used; CUDA memory stats not shown.")
# --- SPECIFICATIONS ---
#
# Purpose
# - MemoryEfficientLinear computes the scalar value sum(sigmoid(x @ w)) without
# materializing the full [N, V] logits matrix Z. It does this by processing
# the vocabulary dimension in chunks.
#
# API
# - Forward: MemoryEfficientLinear.apply(x, w, chunk_size)
# - x: Tensor of shape [N, D], requires_grad may be True if you want dx
# - w: Tensor of shape [D, V], typically a model parameter
# - chunk_size: positive int controlling the max width of each weight slice
# - Returns: scalar tensor (0-dim) containing sum(sigmoid(x @ w)).
#
# Output / Gradients
# - Output is a scalar. Backward returns gradients (dx, dw, None) matching
# shapes of x and w respectively.
# - Current implementation assumes scalar output (loss reduced to a single
# value). To support per-token or per-sample outputs, the backward logic
# must be adapted so grad_output has the same shape as the forward output.
#
# Device / dtype
# - All work is performed on the devices/dtypes of the provided tensors. x and
# w must be on the same device. The implementation stores x in
# saved_tensors; w is kept as a plain reference on ctx (not in saved_tensors).
#
# Complexity and tradeoffs
# - Memory: avoids storing the full [N, V] logits (Z). Still stores x and a
# reference to w. Peak memory for Z is O(N * chunk_size) instead of O(N * V).
# - Compute: extra matmuls are performed during backward (rematerialization).
# Roughly an extra forward-cost matmul for each chunk during backward,
# so expect ~2x matmul compute compared to naive if forward also computed
# the full Z once.
#
# Robustness and checks
# - Validates chunk_size > 0 and raises clear errors if ctx.w_ref is missing or
# if the caller supplies a non-scalar grad_output (not supported by this
# example implementation).
# - Uses torch.no_grad() while rematerializing chunks in backward to avoid
# building a secondary autograd graph and to reduce temporary allocations.
#
# Recommendations / Extensions
# - For non-scalar losses, adapt the local_grad computation so shapes align
# and per-element grads are handled correctly.
# - Consider using torch.autograd.gradcheck (finite differences) with small
# inputs to validate correctness in edge cases.
# - If w is also very large and you want to avoid keeping a full reference,
# consider sharding w on the filesystem or a parameter server and streaming
# chunks into the backward pass (advanced use-case).
#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment