Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created February 21, 2025 03:46
Show Gist options
  • Save a-r-r-o-w/f5c9fb5c515d24f9a06001adb5c6cf18 to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/f5c9fb5c515d24f9a06001adb5c6cf18 to your computer and use it in GitHub Desktop.
Tests multiple offloading mechanisms and gathers there CPU and CUDA memory/time usage on a single A100 GPU for Flux
import argparse
import functools
import json
import os
import pathlib
import psutil
import time
import torch
from diffusers import FluxPipeline
from diffusers.hooks import apply_group_offloading
from memory_profiler import profile
def get_memory_usage():
process = psutil.Process(os.getpid())
mem_bytes = process.memory_info().rss
return mem_bytes
@profile(precision=2)
def apply_offload(pipe: FluxPipeline, method: str) -> None:
if method == "full_cuda":
pipe.to("cuda")
elif method == "model_offload":
pipe.enable_model_cpu_offload()
elif method == "sequential_offload":
pipe.enable_sequential_cpu_offload()
elif method == "group_offload_block_1":
offloader_fn = functools.partial(
apply_group_offloading,
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="block_level",
num_blocks_per_group=1,
use_stream=False,
)
list(map(offloader_fn, [pipe.text_encoder, pipe.text_encoder_2, pipe.transformer]))
pipe.vae.to("cuda")
elif method == "group_offload_leaf":
offloader_fn = functools.partial(
apply_group_offloading,
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="leaf_level",
use_stream=False,
)
list(map(offloader_fn, [pipe.text_encoder, pipe.text_encoder_2, pipe.transformer]))
pipe.vae.to("cuda")
elif method == "group_offload_block_1_stream":
offloader_fn = functools.partial(
apply_group_offloading,
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="block_level",
num_blocks_per_group=1,
use_stream=True,
)
list(map(offloader_fn, [pipe.text_encoder, pipe.text_encoder_2, pipe.transformer]))
pipe.vae.to("cuda")
elif method == "group_offload_leaf_stream":
offloader_fn = functools.partial(
apply_group_offloading,
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="leaf_level",
use_stream=True,
)
list(map(offloader_fn, [pipe.text_encoder, pipe.text_encoder_2, pipe.transformer]))
pipe.vae.to("cuda")
@profile(precision=2)
def load_pipeline():
cache_dir = "/raid/.cache/huggingface"
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-Dev", torch_dtype=torch.bfloat16, cache_dir=cache_dir)
return pipe
@torch.no_grad()
def main(args):
pipe = load_pipeline()
apply_offload(pipe, args.method)
apply_offload_memory_usage = get_memory_usage()
torch.cuda.reset_peak_memory_stats()
cuda_model_memory = torch.cuda.max_memory_reserved()
output_dir = pathlib.Path(args.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
run_inference_memory_usage_list = []
def cpu_mem_callback():
nonlocal run_inference_memory_usage_list
run_inference_memory_usage_list.append(get_memory_usage())
@profile(precision=2)
def run_inference():
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=6.0,
generator=torch.Generator().manual_seed(42),
callback_on_step_end=lambda *args, **kwargs: [cpu_mem_callback(), kwargs][1],
).images[0]
image.save(output_dir / f"output_{args.method}.png")
t1 = time.time()
run_inference()
torch.cuda.synchronize()
t2 = time.time()
cuda_inference_memory = torch.cuda.max_memory_reserved()
time_required = t2 - t1
run_inference_memory_usage = sum(run_inference_memory_usage_list) / len(run_inference_memory_usage_list)
print(f"Run inference memory usage list: {run_inference_memory_usage_list}")
info = {
"time": round(time_required, 2),
"cuda_model_memory": round(cuda_model_memory / 1024**3, 2),
"cuda_inference_memory": round(cuda_inference_memory / 1024**3, 2),
"cpu_offload_memory": round(apply_offload_memory_usage / 1024**3, 2),
"cpu_inference_memory": round(run_inference_memory_usage / 1024**3, 2),
}
with open(output_dir / f"memory_usage_{args.method}.json", "w") as f:
json.dump(info, f, indent=4)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--method", type=str, default="full_cuda", choices=["full_cuda", "model_offload", "sequential_offload", "group_offload_block_1", "group_offload_leaf", "group_offload_block_1_stream", "group_offload_leaf_stream"])
parser.add_argument("--output_dir", type=str, default="offload_profiling")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
main(args)
@a-r-r-o-w
Copy link
Author

Configuration Time (s) CUDA Model Memory (GB) CUDA Inference Memory (GB) CPU Offload Memory (GB) CPU Inference Memory (GB)
full_cuda 25.77 31.45 36.07 0.86 1.41
model_offload 230.95 0.0 23.22 0.8 10.51
sequential_offload 2660.6 0.0 2.4 0.92 32.69
group_offload_block_1 306.01 0.17 13.41 0.91 37.55
group_offload_leaf 375.29 0.17 4.56 0.92 37.41
group_offload_block_1_stream 58.79 0.17 14.49 47.99 57.84
group_offload_leaf_stream 55.26 0.17 5.52 47.99 48.54

@a-r-r-o-w
Copy link
Author

For group offloading, only offloads text encoders and transformer. VAE is always on the GPU

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment