Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Last active May 13, 2025 12:00
Show Gist options
  • Save a-r-r-o-w/93b467ddf64bfe9df47fc12fc2ae4fac to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/93b467ddf64bfe9df47fc12fc2ae4fac to your computer and use it in GitHub Desktop.
import torch
import torch.distributed as dist
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.utils import export_to_video
from finetrainers._metadata import ParamId, CPInput, CPOutput
from finetrainers.parallel.ptd import apply_context_parallel
from finetrainers.models.attention_dispatch import attention_provider, attention_dispatch
torch.nn.functional.scaled_dot_product_attention = attention_dispatch
def apply_compile(model: torch.nn.Module, compile_scope: str) -> torch.nn.Module:
r"""Apply torch.compile to a model or its submodules if not already compiled."""
if getattr(model, "_torch_compiled", False):
return model # Already compiled
if compile_scope == "full":
model = torch.compile(model)
setattr(model, "_torch_compiled", True)
elif compile_scope == "regional":
if isinstance(model, torch.nn.ModuleList):
for name, module in model.named_children():
if not getattr(module, "_torch_compiled", False):
compiled_module = torch.compile(module, mode="max-autotune-no-cudagraphs", fullgraph=False, dynamic=False)
setattr(compiled_module, "_torch_compiled", True)
model.register_module(name, compiled_module)
else:
for name, module in model.named_children():
apply_compile(module, compile_scope)
else:
raise ValueError(f"Unknown compile mode: {compile_scope}. Use 'full' or 'regional'.")
return model
torch.manual_seed(0)
dist.init_process_group("nccl")
rank, world_size = dist.get_rank(), dist.get_world_size()
torch.cuda.set_device(rank)
cp_mesh = dist.device_mesh.init_device_mesh("cuda", [world_size], mesh_dim_names=["cp"])
cp_plan = {
"rope": {
ParamId(index=0): CPInput(2, 4, split_output=True),
},
"blocks.*": {
ParamId("encoder_hidden_states", 1): CPInput(1, 3),
},
"blocks.0": {
ParamId("hidden_states", 0): CPInput(1, 3),
},
"proj_out": [CPOutput(1, 3)],
}
try:
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
apply_context_parallel(pipe.transformer, mesh=cp_mesh, plan=cp_plan)
apply_compile(pipe.transformer, compile_scope="regional")
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
with torch.no_grad():
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt=prompt, negative_prompt=negative_prompt, device="cuda",
)
attention_backend = "_native_flash"
generator = torch.Generator().manual_seed(0)
# Warmup
with attention_provider(attention_backend, mesh=cp_mesh, convert_to_fp32=True, rotate_method="alltoall"):
latents = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
height=480,
width=832,
num_frames=81,
num_inference_steps=2,
guidance_scale=5.0,
output_type="latent",
generator=generator,
).frames[0]
# Inference
with attention_provider(attention_backend, mesh=cp_mesh, convert_to_fp32=True, rotate_method="allgather"):
latents = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
height=480,
width=832,
num_frames=81,
guidance_scale=5.0,
num_inference_steps=30,
output_type="latent",
generator=generator,
).frames[0]
with torch.no_grad():
latents = latents.to(pipe.vae.dtype)
latents_mean = (
torch.tensor(pipe.vae.config.latents_mean)
.view(1, pipe.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
video = pipe.vae.decode(latents, return_dict=False)[0]
video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
if rank == 0:
export_to_video(video, "output.mp4", fps=16)
finally:
dist.destroy_process_group()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment