Skip to content

Instantly share code, notes, and snippets.

@gau-nernst
Created December 12, 2024 12:44
Show Gist options
  • Save gau-nernst/290e6e89a89ad3198fa9a11b69d734c4 to your computer and use it in GitHub Desktop.
Save gau-nernst/290e6e89a89ad3198fa9a11b69d734c4 to your computer and use it in GitHub Desktop.
Full CPU offload for single-GPU training
import torch
from torch import Tensor, nn
from tqdm import tqdm
class PerLayerOffloadWithBackwardGradient:
"This version also offloads gradients. To ensure proper synchronization, it will take control over the optimizer."
def __init__(
self,
model: nn.Module,
optim_cls: type[torch.optim.Optimizer],
optim_kwargs: dict | None = None,
enable: bool = True,
):
self.model = model
self.enable = enable
if not enable:
return
self.optim_cls = optim_cls
self.optim_kwargs = optim_kwargs
self.stream = torch.cuda.Stream()
self.disable_forward_hook = False
self.key2flat_gpu_buffer = dict()
self.key2flat_cpu_params = dict()
self.param2cpu_view = dict()
self.param2optim = dict()
self.param_queue = [] # we will run optimizer in this order
manual_params = set()
def traverse(module: nn.Module, key: tuple[str, ...] = ()):
if (
isinstance(module, (nn.ModuleList, nn.Sequential))
and len(module) > 1
and all(type(layer) == type(module[0]) for layer in module)
):
self._register_sequential(module, key)
else:
for p in module.parameters(recurse=False):
manual_params.add(p)
for name, child in module.named_children():
traverse(child, key + (name,))
traverse(model)
self.manual_tensors = list(manual_params) + list(self.model.buffers())
self.manual_optim = optim_cls(manual_params, **(optim_kwargs or dict()))
def cuda(self):
if not self.enable:
self.model.cuda()
else:
for p in self.manual_tensors:
p.data = p.data.cuda(non_blocking=True)
return self
def cpu(self):
if not self.enable:
self.model.cpu()
else:
for p in self.manual_tensors:
p.data = self.param2cpu.get(p, p.data.cpu())
return self
@staticmethod
def _get_flat_param(module: nn.Module):
return torch.cat([x.detach().view(-1) for x in module.parameters()], dim=0)
@staticmethod
@torch.compiler.disable()
def _view_into_flat_param(module: nn.Module, flat_param: Tensor):
offset = 0
for p in module.parameters():
p.data = flat_param[offset : offset + p.numel()].view(p.shape)
offset += p.numel()
def _register_sequential(self, module_list: nn.Sequential | nn.ModuleList, key: tuple[str, ...]):
self.key2flat_gpu_buffer[key] = [
self._get_flat_param(module_list[0]).cuda(),
self._get_flat_param(module_list[-1]).cuda(),
]
self.key2flat_cpu_params[key] = []
def create_pre_forward_hook(idx: int):
def pre_forward_hook(module: nn.Module, inputs: tuple):
# when there is activation checkpointing, .forward() is re-run in backward pass.
# we use this flag to disable forward hooks in this case. set it to True before
# calling loss.backward() and set it back to False after that.
if self.disable_forward_hook:
return
compute_buffer, transfer_buffer = self.key2flat_gpu_buffer[key]
self._view_into_flat_param(module, compute_buffer)
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(self.stream)
self.stream.wait_stream(current_stream)
with torch.cuda.stream(self.stream):
next_layer_cpu = self.key2flat_cpu_params[key][(idx + 1) % len(module_list)]
transfer_buffer.copy_(next_layer_cpu, non_blocking=True)
self.key2flat_gpu_buffer[key] = [transfer_buffer, compute_buffer]
return pre_forward_hook
def create_pre_backward_hook(idx: int):
def pre_backward_hook(module, grad_output):
transfer_buffer, compute_buffer = self.key2flat_gpu_buffer[key]
self._view_into_flat_param(module, compute_buffer)
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(self.stream)
self.stream.wait_stream(current_stream)
with torch.cuda.stream(self.stream):
next_layer_cpu = self.key2flat_cpu_params[key][(idx - 1) % len(module_list)]
transfer_buffer.copy_(next_layer_cpu, non_blocking=True)
self.key2flat_gpu_buffer[key] = [compute_buffer, transfer_buffer]
return pre_backward_hook
# NOTE: apparently when nn.Module.register_full_backward_hook() fires, param.grad
# is not guaranteed to be computed https://github.com/pytorch/pytorch/issues/86051
# hence, we have to use Tensor.register_post_accumulate_grad_hook() to offload grads.
def post_grad_hook(p: Tensor):
# make sure p.grad finished being computed
self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream):
self.param2cpu_view[p].grad.copy_(p.grad, non_blocking=True)
# we will execute optim step in this order
self.param_queue.append((p, self.stream.record_event()))
# free grad memory
p.grad.record_stream(self.stream)
p.grad = None
desc = f"Copying params to pinned memory {key}"
for i, curr_layer in enumerate(tqdm(module_list, desc=desc, dynamic_ncols=True)):
flat_param = self._get_flat_param(curr_layer).cpu().pin_memory()
self.key2flat_cpu_params[key].append(flat_param)
offset = 0
for p in curr_layer.parameters():
cpu_param = flat_param[offset : offset + p.numel()].view(p.shape)
offset += p.numel()
self.param2cpu_view[p] = cpu_param
# pre-allocate pinned memory for gradients, and install hooks to offload grads
if p.requires_grad:
cpu_param.grad = torch.empty(p.shape, dtype=p.dtype, device="cpu", pin_memory=True)
self.param2optim[p] = self.optim_cls([cpu_param], **(self.optim_kwargs or dict()))
p.register_post_accumulate_grad_hook(post_grad_hook)
curr_layer.register_forward_pre_hook(create_pre_forward_hook(i))
curr_layer.register_full_backward_pre_hook(create_pre_backward_hook(i))
@torch.no_grad()
def optim_step(self):
after_bwd_event = torch.cuda.current_stream().record_event()
self.manual_optim.step()
for p, sync_event in self.param_queue:
sync_event.synchronize() # wait for grad offload to finish
self.param2optim[p].step()
# manually prefetch 1st layer, since it won't be prefetched in pre-forward hook
# make sure backward finishes
self.stream.wait_event(after_bwd_event)
with torch.cuda.stream(self.stream):
for key in self.key2flat_cpu_params.keys():
self.key2flat_gpu_buffer[key][0].copy_(self.key2flat_cpu_params[key][0], non_blocking=True)
def optim_zero_grad(self):
self.manual_optim.zero_grad()
self.param_queue = []
import os
import time
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import datasets
import torch
import torch.nn.functional as F
import wandb
from torch.utils.data import DataLoader, IterableDataset
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from offload import PerLayerOffloadWithBackwardGradient
class TokenDataset(IterableDataset):
def __init__(self, dataset_id: str, model_id: str, seq_len: int):
self.ds = datasets.load_dataset(dataset_id, split="train", streaming=True)
self.model_id = model_id
self.seq_len = seq_len
def __iter__(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
tokens = []
for sample in self.ds:
tokens.extend(tokenizer(sample["text"])["input_ids"])
while len(tokens) >= self.seq_len + 1:
yield torch.tensor(tokens[: self.seq_len + 1])
tokens = tokens[self.seq_len + 1 :]
def get_loss(model, tokens):
logits = model(tokens[:, :-1])[0]
return cross_entropy(logits, tokens[:, 1:])
# wrap logits.float() and F.cross_entropy() in a compiled function to reduce memory
@torch.compile
def cross_entropy(logits, labels):
return F.cross_entropy(logits.float().view(-1, logits.shape[-1]), labels.flatten())
if __name__ == "__main__":
model_id = "meta-llama/Llama-3.2-1B"
dtype = torch.bfloat16
bsize = 4
seq_len = 2048
num_steps = 200
use_compile = True
offload = False
profile = False
torch.manual_seed(2024)
cfg = AutoConfig.from_pretrained(
model_id,
max_position_embeddings=seq_len,
use_cache=False,
)
model = AutoModelForCausalLM.from_config(cfg, torch_dtype=dtype)
model.gradient_checkpointing_enable()
# current there is a bug with model.compile() + module hooks
# hence, we will manually compile .forward() instead
# https://github.com/pytorch/pytorch/issues/142358
if use_compile:
for layer in model.model.layers:
layer.forward = torch.compile(layer.forward)
optim_cls = torch.optim.AdamW
optim_kwargs = dict(lr=3e-4, weight_decay=0.0, fused=True)
if offload:
offloader = PerLayerOffloadWithBackwardGradient(model, optim_cls, optim_kwargs)
offloader.cuda()
else:
model.cuda()
optim = optim_cls(model.parameters(), **optim_kwargs)
ds = TokenDataset("HuggingFaceFW/fineweb-edu", model_id, seq_len)
dloader = DataLoader(ds, bsize, num_workers=1, pin_memory=True)
dloader_iter = iter(dloader)
if profile:
torch._inductor.config.triton.unique_kernel_names = True
prof = torch.profiler.profile()
log_interval = 10
pbar = tqdm(total=num_steps, dynamic_ncols=True)
model.train()
step = 0
wandb.init(project="CPU offload", dir="/tmp", mode="disabled" if profile else None)
torch.cuda.reset_peak_memory_stats()
time0 = time.time()
while step < num_steps:
tokens = next(dloader_iter).cuda()
# torch.compile(get_loss)(model, tokens) is faster for baseline,
# but does not work for CPU offload (due to module hooks)
loss = get_loss(model, tokens)
if offload:
offloader.disable_forward_hook = True
loss.backward()
if offload:
offloader.disable_forward_hook = False
if step % log_interval == 0:
wandb.log(dict(loss=loss.item()), step=step)
if offload:
offloader.optim_step()
offloader.optim_zero_grad()
else:
optim.step()
optim.zero_grad()
step += 1
pbar.update()
if profile:
if step == 1:
prof.start()
elif step == 3:
break
if step % log_interval == 0:
time1 = time.time()
log_dict = dict(
max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9,
tokens_per_second=bsize * seq_len * log_interval / (time1 - time0),
)
time0 = time1
wandb.log(log_dict, step=step)
wandb.finish()
if profile:
prof.stop()
prof.export_chrome_trace("trace.json.gz")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment