Created
November 27, 2025 06:16
-
-
Save g023/891d45858a2a80062d9acf7dbdc7b49e to your computer and use it in GitHub Desktop.
Memory-Efficient Backpropagation: Chunked Linear Layer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Author g023 - https://x.com/g023dev - https://github.com/g023 | |
| import torch | |
| import torch.nn as nn | |
| import gc | |
| import math | |
| import tracemalloc | |
| # Optional psutil for CPU memory readings; if missing we'll fall back to CUDA | |
| try: | |
| import psutil | |
| _ps = psutil.Process() | |
| except Exception: | |
| _ps = None | |
| class MemoryProfiler: | |
| """Lightweight memory profiler combining tracemalloc (Python allocations), | |
| optional psutil RSS, and torch.cuda peak stats. | |
| Usage: | |
| with MemoryProfiler('label') as mp: | |
| ... code to profile ... | |
| # then read mp attributes like mp.rss_before/mp.rss_after/mp.tracemalloc_peak/mp.cuda_peak | |
| """ | |
| def __init__(self, name=None): | |
| self.name = name or "mp" | |
| self.tracemalloc_enabled = hasattr(tracemalloc, "start") | |
| self.rss_before = None | |
| self.rss_after = None | |
| self.tracemalloc_before = None | |
| self.tracemalloc_peak = None | |
| self.cuda_before = None | |
| self.cuda_peak = None | |
| def __enter__(self): | |
| gc.collect() | |
| if _ps is not None: | |
| try: | |
| self.rss_before = _ps.memory_info().rss | |
| except Exception: | |
| self.rss_before = None | |
| if torch.cuda.is_available(): | |
| # reset CUDA peak counters then snapshot pre-run | |
| try: | |
| torch.cuda.reset_peak_memory_stats() | |
| self.cuda_before = torch.cuda.memory_allocated() | |
| except Exception: | |
| self.cuda_before = None | |
| if self.tracemalloc_enabled: | |
| tracemalloc.start() | |
| self.tracemalloc_before = tracemalloc.get_traced_memory()[0] | |
| return self | |
| def __exit__(self, exc_type, exc, tb): | |
| gc.collect() | |
| if _ps is not None: | |
| try: | |
| self.rss_after = _ps.memory_info().rss | |
| except Exception: | |
| self.rss_after = None | |
| if self.tracemalloc_enabled: | |
| current, peak = tracemalloc.get_traced_memory() | |
| self.tracemalloc_peak = peak | |
| tracemalloc.stop() | |
| if torch.cuda.is_available(): | |
| try: | |
| self.cuda_peak = torch.cuda.max_memory_allocated() | |
| except Exception: | |
| self.cuda_peak = None | |
| return False | |
| def _fmt_bytes(b): | |
| if b is None: | |
| return "n/a" | |
| return f"{b/1024/1024:.3f} MB" | |
| class MemoryEfficientLinear(torch.autograd.Function): | |
| @staticmethod | |
| def forward(ctx, x, w, chunk_size=1024): | |
| """ | |
| Forward pass: Computes output without storing the huge logits Z. | |
| """ | |
| # Save only the input tensor in saved_tensors (lighter for PyTorch internals) | |
| # Keep a plain reference to weights on ctx so they don't get put into | |
| # saved_tensors (this reduces memory pressure from the autograd internals). | |
| ctx.save_for_backward(x) | |
| ctx.w_ref = w | |
| ctx.chunk_size = int(chunk_size) | |
| # Validate shapes/types | |
| if ctx.chunk_size <= 0: | |
| raise ValueError(f"chunk_size must be > 0, got {chunk_size}") | |
| # Example Function f: Sum(Sigmoid(XW)) computed in chunks to avoid | |
| # materializing the full [N, V] logits matrix. | |
| total_loss = x.new_zeros(()) | |
| vocab_size = w.shape[1] | |
| for i in range(0, vocab_size, ctx.chunk_size): | |
| w_chunk = w[:, i:i+ctx.chunk_size] | |
| z_chunk = x @ w_chunk | |
| # accumulate scalar loss; z_chunk is not saved | |
| total_loss = total_loss + torch.sigmoid(z_chunk).sum() | |
| return total_loss | |
| @staticmethod | |
| def backward(ctx, grad_output): | |
| """ | |
| Backward pass: Recomputes Z chunks on-the-fly to calculate gradients. | |
| """ | |
| saved = ctx.saved_tensors | |
| if len(saved) < 1: | |
| raise RuntimeError("Expected saved input tensor(s) in ctx") | |
| x = saved[0] | |
| w = getattr(ctx, "w_ref", None) | |
| chunk_size = int(ctx.chunk_size) | |
| if w is None: | |
| raise RuntimeError("Weights reference missing on ctx; cannot recompute logits") | |
| # This implementation currently assumes the forward returned a scalar | |
| # loss (as the example does). If grad_output is not a scalar tensor, | |
| # the user must adapt the forward/backward logic. | |
| if grad_output.numel() != 1: | |
| raise NotImplementedError( | |
| "MemoryEfficientLinear.backward currently supports scalar output only" | |
| ) | |
| # Preallocate gradient tensors on the correct device/dtype | |
| dx = torch.zeros_like(x) | |
| dw = torch.zeros_like(w) | |
| vocab_size = w.shape[1] | |
| # We'll recompute z chunks without tracking autograd to avoid building a | |
| # secondary graph during backward. | |
| with torch.no_grad(): | |
| for i in range(0, vocab_size, chunk_size): | |
| w_chunk = w[:, i:i+chunk_size] | |
| # Rematerialize small logit block | |
| z_chunk = x @ w_chunk | |
| # Compute sigmoid and its derivative (in no_grad context) | |
| sig_z = torch.sigmoid(z_chunk) | |
| local_grad = grad_output * (sig_z * (1 - sig_z)) | |
| # Accumulate dW and dX. Use in-place ops to reduce allocations. | |
| # dW[:, i:i+chunk_size] = x.t() @ local_grad | |
| dw[:, i:i+chunk_size].add_(x.t() @ local_grad) | |
| dx.add_(local_grad @ w_chunk.t()) | |
| # free temporaries explicitly | |
| del z_chunk, sig_z, local_grad | |
| return dx, dw, None | |
| # --- VERIFICATION --- | |
| # Setup Data | |
| BS, SEQ, HD, VOCAB = 2, 128, 64, 4096 # Small dims for demo | |
| chunk_size = 512 | |
| x = torch.randn(BS*SEQ, HD, requires_grad=True) | |
| w = torch.randn(HD, VOCAB, requires_grad=True) | |
| # 1. Naive Implementation (High Memory) | |
| # Materializes full [256, 4096] matrix | |
| x_naive = x.clone().detach().requires_grad_(True) | |
| w_naive = nn.Parameter(w.clone().detach()) | |
| # prepare and profile naive run | |
| with MemoryProfiler("naive") as mp_naive: | |
| logits = x_naive @ w_naive | |
| loss_naive = torch.sigmoid(logits).sum() | |
| loss_naive.backward() | |
| # capture profiler outputs | |
| mem_before_naive = mp_naive.rss_before | |
| mem_after_naive = mp_naive.rss_after | |
| cuda_before_naive = mp_naive.cuda_before | |
| cuda_peak_naive = mp_naive.cuda_peak | |
| tracemalloc_peak_naive = mp_naive.tracemalloc_peak | |
| # 2. Efficient Implementation (Low Memory) | |
| # 2. Efficient Implementation (Low Memory) | |
| # Never materializes full matrix | |
| x_eff = x.clone().detach().requires_grad_(True) | |
| w_eff = nn.Parameter(w.clone().detach()) | |
| # cleanup before efficient run to reduce carryover and profile | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| with MemoryProfiler("eff") as mp_eff: | |
| loss_eff = MemoryEfficientLinear.apply(x_eff, w_eff, chunk_size) | |
| loss_eff.backward() | |
| mem_before_eff = mp_eff.rss_before | |
| mem_after_eff = mp_eff.rss_after | |
| cuda_before_eff = mp_eff.cuda_before | |
| cuda_peak_eff = mp_eff.cuda_peak | |
| tracemalloc_peak_eff = mp_eff.tracemalloc_peak | |
| # Check Results (more detailed diagnostics) | |
| loss_match = torch.allclose(loss_naive, loss_eff) | |
| dx_close = torch.allclose(x_naive.grad, x_eff.grad, atol=1e-5) | |
| dw_close = torch.allclose(w_naive.grad, w_eff.grad, atol=1e-5) | |
| # numeric diffs | |
| loss_diff = (loss_naive - loss_eff).abs().item() | |
| dx_diff = (x_naive.grad - x_eff.grad) | |
| dw_diff = (w_naive.grad - w_eff.grad) | |
| dx_max = dx_diff.abs().max().item() | |
| dx_mean = dx_diff.abs().mean().item() | |
| dw_max = dw_diff.abs().max().item() | |
| dw_mean = dw_diff.abs().mean().item() | |
| dx_rel = dx_diff.norm().item() / (x_naive.grad.norm().item() + 1e-12) | |
| dw_rel = dw_diff.norm().item() / (w_naive.grad.norm().item() + 1e-12) | |
| print(f"Loss naive: {loss_naive.item():.6e}, eff: {loss_eff.item():.6e}, abs diff: {loss_diff:.3e}, equal: {loss_match}") | |
| print(f"dX equal: {dx_close}, max_abs_diff: {dx_max:.3e}, mean_abs_diff: {dx_mean:.3e}, rel_norm_diff: {dx_rel:.3e}") | |
| print(f"dW equal: {dw_close}, max_abs_diff: {dw_max:.3e}, mean_abs_diff: {dw_mean:.3e}, rel_norm_diff: {dw_rel:.3e}") | |
| print(f"||x_naive.grad||_2: {x_naive.grad.norm().item():.6e}, ||w_naive.grad||_2: {w_naive.grad.norm().item():.6e}") | |
| print(f"dX contains NaN: {torch.isnan(x_naive.grad).any().item() or torch.isnan(x_eff.grad).any().item()}") | |
| print(f"dW contains NaN: {torch.isnan(w_naive.grad).any().item() or torch.isnan(w_eff.grad).any().item()}") | |
| print("\nIf these diagnostics look small and equalities are True, the algorithm is correct!") | |
| # --- MEMORY SUMMARY --- | |
| print('\n--- Memory summary ---') | |
| if _ps is not None: | |
| naive_rss_diff = None if (mem_before_naive is None or mem_after_naive is None) else (mem_after_naive - mem_before_naive) | |
| eff_rss_diff = None if (mem_before_eff is None or mem_after_eff is None) else (mem_after_eff - mem_before_eff) | |
| print(f"Naive RSS before: {_fmt_bytes(mem_before_naive)}, after: {_fmt_bytes(mem_after_naive)}, delta: {_fmt_bytes(naive_rss_diff)}") | |
| print(f"Eff RSS before: {_fmt_bytes(mem_before_eff)}, after: {_fmt_bytes(mem_after_eff)}, delta: {_fmt_bytes(eff_rss_diff)}") | |
| if naive_rss_diff is not None and eff_rss_diff is not None: | |
| try: | |
| saved = naive_rss_diff - eff_rss_diff | |
| perc = 100.0 * (saved / (naive_rss_diff + 1e-12)) | |
| print(f"RSS saved (naive - eff): {_fmt_bytes(saved)} ({perc:.2f}% of naive delta)") | |
| except Exception: | |
| pass | |
| else: | |
| print("psutil not available: CPU RSS measurements not shown. Install psutil to enable CPU memory reporting: pip install psutil") | |
| if torch.cuda.is_available(): | |
| print(f"CUDA allocated before naive: {_fmt_bytes(cuda_before_naive)}, peak naive: {_fmt_bytes(cuda_peak_naive)}") | |
| print(f"CUDA allocated before eff: {_fmt_bytes(cuda_before_eff)}, peak eff: {_fmt_bytes(cuda_peak_eff)}") | |
| if cuda_peak_naive is not None and cuda_peak_eff is not None: | |
| saved_cuda = cuda_peak_naive - cuda_peak_eff | |
| perc_cuda = 100.0 * (saved_cuda / (cuda_peak_naive + 1e-12)) | |
| print(f"CUDA peak saved (naive - eff): {_fmt_bytes(saved_cuda)} ({perc_cuda:.2f}% of naive peak)") | |
| else: | |
| print("CUDA not available or not used; CUDA memory stats not shown.") | |
| # --- SPECIFICATIONS --- | |
| # | |
| # Purpose | |
| # - MemoryEfficientLinear computes the scalar value sum(sigmoid(x @ w)) without | |
| # materializing the full [N, V] logits matrix Z. It does this by processing | |
| # the vocabulary dimension in chunks. | |
| # | |
| # API | |
| # - Forward: MemoryEfficientLinear.apply(x, w, chunk_size) | |
| # - x: Tensor of shape [N, D], requires_grad may be True if you want dx | |
| # - w: Tensor of shape [D, V], typically a model parameter | |
| # - chunk_size: positive int controlling the max width of each weight slice | |
| # - Returns: scalar tensor (0-dim) containing sum(sigmoid(x @ w)). | |
| # | |
| # Output / Gradients | |
| # - Output is a scalar. Backward returns gradients (dx, dw, None) matching | |
| # shapes of x and w respectively. | |
| # - Current implementation assumes scalar output (loss reduced to a single | |
| # value). To support per-token or per-sample outputs, the backward logic | |
| # must be adapted so grad_output has the same shape as the forward output. | |
| # | |
| # Device / dtype | |
| # - All work is performed on the devices/dtypes of the provided tensors. x and | |
| # w must be on the same device. The implementation stores x in | |
| # saved_tensors; w is kept as a plain reference on ctx (not in saved_tensors). | |
| # | |
| # Complexity and tradeoffs | |
| # - Memory: avoids storing the full [N, V] logits (Z). Still stores x and a | |
| # reference to w. Peak memory for Z is O(N * chunk_size) instead of O(N * V). | |
| # - Compute: extra matmuls are performed during backward (rematerialization). | |
| # Roughly an extra forward-cost matmul for each chunk during backward, | |
| # so expect ~2x matmul compute compared to naive if forward also computed | |
| # the full Z once. | |
| # | |
| # Robustness and checks | |
| # - Validates chunk_size > 0 and raises clear errors if ctx.w_ref is missing or | |
| # if the caller supplies a non-scalar grad_output (not supported by this | |
| # example implementation). | |
| # - Uses torch.no_grad() while rematerializing chunks in backward to avoid | |
| # building a secondary autograd graph and to reduce temporary allocations. | |
| # | |
| # Recommendations / Extensions | |
| # - For non-scalar losses, adapt the local_grad computation so shapes align | |
| # and per-element grads are handled correctly. | |
| # - Consider using torch.autograd.gradcheck (finite differences) with small | |
| # inputs to validate correctness in edge cases. | |
| # - If w is also very large and you want to avoid keeping a full reference, | |
| # consider sharding w on the filesystem or a parameter server and streaming | |
| # chunks into the backward pass (advanced use-case). | |
| # |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment