Skip to content

Instantly share code, notes, and snippets.

@patil-suraj
Last active February 9, 2023 14:34
Show Gist options
  • Save patil-suraj/ea4ddb77ce9a8908560c5032dd1fc238 to your computer and use it in GitHub Desktop.
Save patil-suraj/ea4ddb77ce9a8908560c5032dd1fc238 to your computer and use it in GitHub Desktop.
import torch
import torch.utils.benchmark as benchmark
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.models.cross_attention import TorchAttentionProcessor
def benchmark_torch_function(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return round(t0.blocked_autorange(min_run_time=1).mean, 2)
# benchmark code
model_id = "CompVis/stable-diffusion-v1-4"
prompt = "A photo of an astronaut riding a horse on mars."
steps = 50
batch_size = 10
dtype = torch.float16
# load model
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype, safety_checker=None).to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.set_progress_bar_config(disable=True)
# Vanilla Cross Attention
print("Running benchmark for vanilla cross attention...")
f = lambda : pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images
time_vanilla = benchmark_torch_function(f)
# PyTorch sdpa
print("Running benchmark for PyTorch SDPA...")
pipe.unet.set_attn_processor(TorchAttentionProcessor())
f = lambda : pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images
time_sdpa = benchmark_torch_function(f)
# PyTorch sdpa with torch.compile
print("Running benchmark for PyTorch SDPA with torch.compile...")
pipe.unet = torch.compile(pipe.unet)
# warmup
pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images
f = lambda : pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images
time_sdpa_torch_compile = benchmark_torch_function(f)
# print results with nice formatting
print(f"Model: {model_id}, dtype: {dtype}, steps: {steps}, batch_size: {batch_size}")
print(f"Vanilla Cross Attention: {time_vanilla} s")
print(f"PyTorch SDPA: {time_sdpa} s")
print(f"PyTorch SDPA with torch.compile: {time_sdpa_torch_compile} s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment