Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created June 19, 2025 13:20
Show Gist options
  • Save sayakpaul/9287f475d4d309a66878af6cdbd4911b to your computer and use it in GitHub Desktop.
Save sayakpaul/9287f475d4d309a66878af6cdbd4911b to your computer and use it in GitHub Desktop.
from diffusers import DiffusionPipeline
import torch.utils.benchmark as benchmark
import torch
import psutil
import os
import json
import argparse
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
def run_inference(pipe, pipe_kwargs):
_ = pipe(**pipe_kwargs)
def initialize_pipeline():
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
)
pipe.set_progress_bar_config(disable=True)
return pipe
def maybe_apply_offloading(pipe, args):
if not args.model_cpu_offload and not args.seq_cpu_offload and not args.group_offload:
pipe = pipe.to("cuda")
else:
if args.model_cpu_offload:
pipe.enable_model_cpu_offload()
elif args.seq_cpu_offload:
pipe.enable_sequential_cpu_offload()
elif args.group_offload:
pipe.transformer.enable_group_offload(
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="block_level",
num_blocks_per_group=1,
use_stream=True,
non_blocking=True,
offload_to_disk_path="." if args.offload_to_disk else None,
record_stream=True,
_enable_deepnvme_disk_offloading=args.nvme
)
if args.compile:
torch._dynamo.config.cache_size_limit = 10000
pipe.transformer.compile()
# For the rest of the components, just place on CUDA.
for name, component in pipe.components.items():
if name != "transformer" and isinstance(component, torch.nn.Module):
component.cuda()
return pipe
def main(args):
process = psutil.Process(os.getpid())
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
pipe = initialize_pipeline()
pipe = maybe_apply_offloading(pipe, args)
pipe_kwargs = {
"prompt": "A cat holding a sign that says hello world",
"height": 1024,
"width": 1024,
"guidance_scale": 3.5,
"num_inference_steps": 28,
"max_sequence_length": 512,
"generator": torch.manual_seed(0),
}
time = benchmark_fn(run_inference, pipe, pipe_kwargs)
inference_memory = torch.cuda.max_memory_allocated() / (1024 ** 3)
inference_memory = float(f"{inference_memory:.2f}")
ram_bytes = process.memory_info().rss
ram_gb = ram_bytes / (1024 ** 3)
# report
print(f"Peak GPU memory: {inference_memory} GB")
print(f"Resident CPU memory (RSS): {ram_gb:.2f} GB")
prefix = "base"
for key, value in vars(args).items():
prefix += f"_{key}@{value}"
image = pipe(**pipe_kwargs).images[0]
image.save(f"{prefix}.png")
artifact_dict = {"time": time, "memory": inference_memory, "ram": ram_gb}
artifact_dict.update(vars(args))
with open(f"{prefix}.json", "w") as f:
json.dump(artifact_dict, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_cpu_offload", action="store_true")
parser.add_argument("--seq_cpu_offload", action="store_true")
parser.add_argument("--group_offload", action="store_true")
parser.add_argument("--offload_to_disk", action="store_true")
parser.add_argument("--nvme", action="store_true")
parser.add_argument("--compile", action="store_true")
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment