A walkthrough of how the sac-autonaming-revamp branch applies name-based
selective activation checkpointing.
There are three pieces:
A walkthrough of how the sac-autonaming-revamp branch applies name-based
selective activation checkpointing.
There are three pieces:
Selective Activation Checkpointing (SAC) saves a subset of activations
during forward to avoid recomputing them during backward. The main branch
of TorchTitan uses a counter-based policy
that saves every other matmul output. The sac-autonaming-revamp branch
replaces this with a naming-based policy
where each matmul output gets a name like layers.0.attention.mm_2_0, and
an explicit save list specifies which tensors to keep.
| """Zero-copy gradient packing into a contiguous buffer. | |
| Demonstrates how a custom autograd Function can write gradients directly | |
| into a pre-allocated contiguous buffer, and have AccumulateGrad steal | |
| the views (Case 1.1) so that .grad points into the buffer with no copy. | |
| Key requirements for the steal path: | |
| 1. The gradient tensor must obey the layout contract (strides match the parameter) | |
| 2. The gradient tensor's refcount must be <= num_expected_refs (typically 1) | |
| - Clear ctx references before returning from backward |
super().apply() pathis_executable = True, the autograd graph is built_trace_pre_record / _trace_post_record are noopsdummy has requires_grad=True — it was created with torch.empty((0,), requires_grad=True) on line 1615, so any_variable_requires_grad returns True_checkpoint_hook is NOT yet active — this .apply() call is before the with _checkpoint_hook(new_frame) block (line 1623). So SavedTensorDefaultHooks::is_enabled() returns False at this point (the checkpoint's pack/unpack hooks aren't installed yet for this apply call).| import torch | |
| import functools | |
| import contextlib | |
| from torch.utils._python_dispatch import TorchDispatchMode | |
| from torch.utils._pytree import tree_map_only | |
| from torch.utils.weak import WeakTensorKeyDictionary | |
| from torch.utils.checkpoint import CheckpointPolicy, _policy_from_bool | |
| from collections import namedtuple | |
| import weakref |
| from torch.utils.weak import WeakTensorKeyDictionary | |
| import weakref | |
| from dataclasses import dataclass | |
| import dataclasses | |
| from typing import * | |
| import sys | |
| @dataclass | |
| class CacheEntry: | |
| one: Optional[Union[torch.Tensor, weakref.ReferenceType]] = None |
| import torch | |
| import functools | |
| from torch.utils._python_dispatch import TorchDispatchMode | |
| import torch.utils._pytree as pytree | |
| from torch.utils.weak import WeakTensorKeyDictionary | |
| class RecomputableTensor(torch.Tensor): | |
| @staticmethod | |
| def __new__(cls, t, func, args): |
| from torch.nested._internal.nested_tensor import jagged_from_list | |
| a = torch.randn(2, 7, 256, requires_grad=True, dtype=torch.float32) | |
| b = torch.randn(3, 7, 256, requires_grad=True, dtype=torch.float32) | |
| c = torch.randn(4, 7, 256, requires_grad=True, dtype=torch.float32) | |
| d = torch.randn(5, 7, 256, requires_grad=True, dtype=torch.float32) | |
| nt1 = jagged_from_list([a, b, c, d], None)[0] | |
| nt2 = jagged_from_list([a, b, c, d], None)[0] | |
| nt1_view = nt1.select(2, 1) |
| import torch | |
| class T(torch.Tensor): | |
| def __new__(cls, elem): | |
| return torch.Tensor._make_wrapper_subclass(cls, elem.shape, dtype=elem.dtype) | |
| def __init__(self, elem): | |
| self.elem = elem | |
| @classmethod |
| # Technically even in the "easy case" of t._base.requires_grad == t.requires_grad | |
| # I need to perform two views to recreate that view authentically. why? | |
| # There are actually two things I need to recreate, (1) the autograd | |
| # graph relationship and (2) the view relationship. | |
| # The reason we don't handle this today is because this autograd connectivity information | |
| # is not accessible during tracing and hence not relevant to compile in part because dynam | |
| # doesn't support grad_fn access. |