Created
December 12, 2024 12:44
-
-
Save gau-nernst/290e6e89a89ad3198fa9a11b69d734c4 to your computer and use it in GitHub Desktop.
Full CPU offload for single-GPU training
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import Tensor, nn | |
from tqdm import tqdm | |
class PerLayerOffloadWithBackwardGradient: | |
"This version also offloads gradients. To ensure proper synchronization, it will take control over the optimizer." | |
def __init__( | |
self, | |
model: nn.Module, | |
optim_cls: type[torch.optim.Optimizer], | |
optim_kwargs: dict | None = None, | |
enable: bool = True, | |
): | |
self.model = model | |
self.enable = enable | |
if not enable: | |
return | |
self.optim_cls = optim_cls | |
self.optim_kwargs = optim_kwargs | |
self.stream = torch.cuda.Stream() | |
self.disable_forward_hook = False | |
self.key2flat_gpu_buffer = dict() | |
self.key2flat_cpu_params = dict() | |
self.param2cpu_view = dict() | |
self.param2optim = dict() | |
self.param_queue = [] # we will run optimizer in this order | |
manual_params = set() | |
def traverse(module: nn.Module, key: tuple[str, ...] = ()): | |
if ( | |
isinstance(module, (nn.ModuleList, nn.Sequential)) | |
and len(module) > 1 | |
and all(type(layer) == type(module[0]) for layer in module) | |
): | |
self._register_sequential(module, key) | |
else: | |
for p in module.parameters(recurse=False): | |
manual_params.add(p) | |
for name, child in module.named_children(): | |
traverse(child, key + (name,)) | |
traverse(model) | |
self.manual_tensors = list(manual_params) + list(self.model.buffers()) | |
self.manual_optim = optim_cls(manual_params, **(optim_kwargs or dict())) | |
def cuda(self): | |
if not self.enable: | |
self.model.cuda() | |
else: | |
for p in self.manual_tensors: | |
p.data = p.data.cuda(non_blocking=True) | |
return self | |
def cpu(self): | |
if not self.enable: | |
self.model.cpu() | |
else: | |
for p in self.manual_tensors: | |
p.data = self.param2cpu.get(p, p.data.cpu()) | |
return self | |
@staticmethod | |
def _get_flat_param(module: nn.Module): | |
return torch.cat([x.detach().view(-1) for x in module.parameters()], dim=0) | |
@staticmethod | |
@torch.compiler.disable() | |
def _view_into_flat_param(module: nn.Module, flat_param: Tensor): | |
offset = 0 | |
for p in module.parameters(): | |
p.data = flat_param[offset : offset + p.numel()].view(p.shape) | |
offset += p.numel() | |
def _register_sequential(self, module_list: nn.Sequential | nn.ModuleList, key: tuple[str, ...]): | |
self.key2flat_gpu_buffer[key] = [ | |
self._get_flat_param(module_list[0]).cuda(), | |
self._get_flat_param(module_list[-1]).cuda(), | |
] | |
self.key2flat_cpu_params[key] = [] | |
def create_pre_forward_hook(idx: int): | |
def pre_forward_hook(module: nn.Module, inputs: tuple): | |
# when there is activation checkpointing, .forward() is re-run in backward pass. | |
# we use this flag to disable forward hooks in this case. set it to True before | |
# calling loss.backward() and set it back to False after that. | |
if self.disable_forward_hook: | |
return | |
compute_buffer, transfer_buffer = self.key2flat_gpu_buffer[key] | |
self._view_into_flat_param(module, compute_buffer) | |
current_stream = torch.cuda.current_stream() | |
current_stream.wait_stream(self.stream) | |
self.stream.wait_stream(current_stream) | |
with torch.cuda.stream(self.stream): | |
next_layer_cpu = self.key2flat_cpu_params[key][(idx + 1) % len(module_list)] | |
transfer_buffer.copy_(next_layer_cpu, non_blocking=True) | |
self.key2flat_gpu_buffer[key] = [transfer_buffer, compute_buffer] | |
return pre_forward_hook | |
def create_pre_backward_hook(idx: int): | |
def pre_backward_hook(module, grad_output): | |
transfer_buffer, compute_buffer = self.key2flat_gpu_buffer[key] | |
self._view_into_flat_param(module, compute_buffer) | |
current_stream = torch.cuda.current_stream() | |
current_stream.wait_stream(self.stream) | |
self.stream.wait_stream(current_stream) | |
with torch.cuda.stream(self.stream): | |
next_layer_cpu = self.key2flat_cpu_params[key][(idx - 1) % len(module_list)] | |
transfer_buffer.copy_(next_layer_cpu, non_blocking=True) | |
self.key2flat_gpu_buffer[key] = [compute_buffer, transfer_buffer] | |
return pre_backward_hook | |
# NOTE: apparently when nn.Module.register_full_backward_hook() fires, param.grad | |
# is not guaranteed to be computed https://github.com/pytorch/pytorch/issues/86051 | |
# hence, we have to use Tensor.register_post_accumulate_grad_hook() to offload grads. | |
def post_grad_hook(p: Tensor): | |
# make sure p.grad finished being computed | |
self.stream.wait_stream(torch.cuda.current_stream()) | |
with torch.cuda.stream(self.stream): | |
self.param2cpu_view[p].grad.copy_(p.grad, non_blocking=True) | |
# we will execute optim step in this order | |
self.param_queue.append((p, self.stream.record_event())) | |
# free grad memory | |
p.grad.record_stream(self.stream) | |
p.grad = None | |
desc = f"Copying params to pinned memory {key}" | |
for i, curr_layer in enumerate(tqdm(module_list, desc=desc, dynamic_ncols=True)): | |
flat_param = self._get_flat_param(curr_layer).cpu().pin_memory() | |
self.key2flat_cpu_params[key].append(flat_param) | |
offset = 0 | |
for p in curr_layer.parameters(): | |
cpu_param = flat_param[offset : offset + p.numel()].view(p.shape) | |
offset += p.numel() | |
self.param2cpu_view[p] = cpu_param | |
# pre-allocate pinned memory for gradients, and install hooks to offload grads | |
if p.requires_grad: | |
cpu_param.grad = torch.empty(p.shape, dtype=p.dtype, device="cpu", pin_memory=True) | |
self.param2optim[p] = self.optim_cls([cpu_param], **(self.optim_kwargs or dict())) | |
p.register_post_accumulate_grad_hook(post_grad_hook) | |
curr_layer.register_forward_pre_hook(create_pre_forward_hook(i)) | |
curr_layer.register_full_backward_pre_hook(create_pre_backward_hook(i)) | |
@torch.no_grad() | |
def optim_step(self): | |
after_bwd_event = torch.cuda.current_stream().record_event() | |
self.manual_optim.step() | |
for p, sync_event in self.param_queue: | |
sync_event.synchronize() # wait for grad offload to finish | |
self.param2optim[p].step() | |
# manually prefetch 1st layer, since it won't be prefetched in pre-forward hook | |
# make sure backward finishes | |
self.stream.wait_event(after_bwd_event) | |
with torch.cuda.stream(self.stream): | |
for key in self.key2flat_cpu_params.keys(): | |
self.key2flat_gpu_buffer[key][0].copy_(self.key2flat_cpu_params[key][0], non_blocking=True) | |
def optim_zero_grad(self): | |
self.manual_optim.zero_grad() | |
self.param_queue = [] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
import datasets | |
import torch | |
import torch.nn.functional as F | |
import wandb | |
from torch.utils.data import DataLoader, IterableDataset | |
from tqdm import tqdm | |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
from offload import PerLayerOffloadWithBackwardGradient | |
class TokenDataset(IterableDataset): | |
def __init__(self, dataset_id: str, model_id: str, seq_len: int): | |
self.ds = datasets.load_dataset(dataset_id, split="train", streaming=True) | |
self.model_id = model_id | |
self.seq_len = seq_len | |
def __iter__(self): | |
tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
tokens = [] | |
for sample in self.ds: | |
tokens.extend(tokenizer(sample["text"])["input_ids"]) | |
while len(tokens) >= self.seq_len + 1: | |
yield torch.tensor(tokens[: self.seq_len + 1]) | |
tokens = tokens[self.seq_len + 1 :] | |
def get_loss(model, tokens): | |
logits = model(tokens[:, :-1])[0] | |
return cross_entropy(logits, tokens[:, 1:]) | |
# wrap logits.float() and F.cross_entropy() in a compiled function to reduce memory | |
@torch.compile | |
def cross_entropy(logits, labels): | |
return F.cross_entropy(logits.float().view(-1, logits.shape[-1]), labels.flatten()) | |
if __name__ == "__main__": | |
model_id = "meta-llama/Llama-3.2-1B" | |
dtype = torch.bfloat16 | |
bsize = 4 | |
seq_len = 2048 | |
num_steps = 200 | |
use_compile = True | |
offload = False | |
profile = False | |
torch.manual_seed(2024) | |
cfg = AutoConfig.from_pretrained( | |
model_id, | |
max_position_embeddings=seq_len, | |
use_cache=False, | |
) | |
model = AutoModelForCausalLM.from_config(cfg, torch_dtype=dtype) | |
model.gradient_checkpointing_enable() | |
# current there is a bug with model.compile() + module hooks | |
# hence, we will manually compile .forward() instead | |
# https://github.com/pytorch/pytorch/issues/142358 | |
if use_compile: | |
for layer in model.model.layers: | |
layer.forward = torch.compile(layer.forward) | |
optim_cls = torch.optim.AdamW | |
optim_kwargs = dict(lr=3e-4, weight_decay=0.0, fused=True) | |
if offload: | |
offloader = PerLayerOffloadWithBackwardGradient(model, optim_cls, optim_kwargs) | |
offloader.cuda() | |
else: | |
model.cuda() | |
optim = optim_cls(model.parameters(), **optim_kwargs) | |
ds = TokenDataset("HuggingFaceFW/fineweb-edu", model_id, seq_len) | |
dloader = DataLoader(ds, bsize, num_workers=1, pin_memory=True) | |
dloader_iter = iter(dloader) | |
if profile: | |
torch._inductor.config.triton.unique_kernel_names = True | |
prof = torch.profiler.profile() | |
log_interval = 10 | |
pbar = tqdm(total=num_steps, dynamic_ncols=True) | |
model.train() | |
step = 0 | |
wandb.init(project="CPU offload", dir="/tmp", mode="disabled" if profile else None) | |
torch.cuda.reset_peak_memory_stats() | |
time0 = time.time() | |
while step < num_steps: | |
tokens = next(dloader_iter).cuda() | |
# torch.compile(get_loss)(model, tokens) is faster for baseline, | |
# but does not work for CPU offload (due to module hooks) | |
loss = get_loss(model, tokens) | |
if offload: | |
offloader.disable_forward_hook = True | |
loss.backward() | |
if offload: | |
offloader.disable_forward_hook = False | |
if step % log_interval == 0: | |
wandb.log(dict(loss=loss.item()), step=step) | |
if offload: | |
offloader.optim_step() | |
offloader.optim_zero_grad() | |
else: | |
optim.step() | |
optim.zero_grad() | |
step += 1 | |
pbar.update() | |
if profile: | |
if step == 1: | |
prof.start() | |
elif step == 3: | |
break | |
if step % log_interval == 0: | |
time1 = time.time() | |
log_dict = dict( | |
max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9, | |
tokens_per_second=bsize * seq_len * log_interval / (time1 - time0), | |
) | |
time0 = time1 | |
wandb.log(log_dict, step=step) | |
wandb.finish() | |
if profile: | |
prof.stop() | |
prof.export_chrome_trace("trace.json.gz") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment