Last active
May 13, 2025 12:00
-
-
Save a-r-r-o-w/93b467ddf64bfe9df47fc12fc2ae4fac to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.distributed as dist | |
from diffusers import AutoencoderKLWan, WanPipeline | |
from diffusers.utils import export_to_video | |
from finetrainers._metadata import ParamId, CPInput, CPOutput | |
from finetrainers.parallel.ptd import apply_context_parallel | |
from finetrainers.models.attention_dispatch import attention_provider, attention_dispatch | |
torch.nn.functional.scaled_dot_product_attention = attention_dispatch | |
def apply_compile(model: torch.nn.Module, compile_scope: str) -> torch.nn.Module: | |
r"""Apply torch.compile to a model or its submodules if not already compiled.""" | |
if getattr(model, "_torch_compiled", False): | |
return model # Already compiled | |
if compile_scope == "full": | |
model = torch.compile(model) | |
setattr(model, "_torch_compiled", True) | |
elif compile_scope == "regional": | |
if isinstance(model, torch.nn.ModuleList): | |
for name, module in model.named_children(): | |
if not getattr(module, "_torch_compiled", False): | |
compiled_module = torch.compile(module, mode="max-autotune-no-cudagraphs", fullgraph=False, dynamic=False) | |
setattr(compiled_module, "_torch_compiled", True) | |
model.register_module(name, compiled_module) | |
else: | |
for name, module in model.named_children(): | |
apply_compile(module, compile_scope) | |
else: | |
raise ValueError(f"Unknown compile mode: {compile_scope}. Use 'full' or 'regional'.") | |
return model | |
torch.manual_seed(0) | |
dist.init_process_group("nccl") | |
rank, world_size = dist.get_rank(), dist.get_world_size() | |
torch.cuda.set_device(rank) | |
cp_mesh = dist.device_mesh.init_device_mesh("cuda", [world_size], mesh_dim_names=["cp"]) | |
cp_plan = { | |
"rope": { | |
ParamId(index=0): CPInput(2, 4, split_output=True), | |
}, | |
"blocks.*": { | |
ParamId("encoder_hidden_states", 1): CPInput(1, 3), | |
}, | |
"blocks.0": { | |
ParamId("hidden_states", 0): CPInput(1, 3), | |
}, | |
"proj_out": [CPOutput(1, 3)], | |
} | |
try: | |
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" | |
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) | |
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) | |
pipe.to("cuda") | |
apply_context_parallel(pipe.transformer, mesh=cp_mesh, plan=cp_plan) | |
apply_compile(pipe.transformer, compile_scope="regional") | |
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." | |
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" | |
with torch.no_grad(): | |
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( | |
prompt=prompt, negative_prompt=negative_prompt, device="cuda", | |
) | |
attention_backend = "_native_flash" | |
generator = torch.Generator().manual_seed(0) | |
# Warmup | |
with attention_provider(attention_backend, mesh=cp_mesh, convert_to_fp32=True, rotate_method="alltoall"): | |
latents = pipe( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
height=480, | |
width=832, | |
num_frames=81, | |
num_inference_steps=2, | |
guidance_scale=5.0, | |
output_type="latent", | |
generator=generator, | |
).frames[0] | |
# Inference | |
with attention_provider(attention_backend, mesh=cp_mesh, convert_to_fp32=True, rotate_method="allgather"): | |
latents = pipe( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
height=480, | |
width=832, | |
num_frames=81, | |
guidance_scale=5.0, | |
num_inference_steps=30, | |
output_type="latent", | |
generator=generator, | |
).frames[0] | |
with torch.no_grad(): | |
latents = latents.to(pipe.vae.dtype) | |
latents_mean = ( | |
torch.tensor(pipe.vae.config.latents_mean) | |
.view(1, pipe.vae.config.z_dim, 1, 1, 1) | |
.to(latents.device, latents.dtype) | |
) | |
latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1).to( | |
latents.device, latents.dtype | |
) | |
latents = latents / latents_std + latents_mean | |
video = pipe.vae.decode(latents, return_dict=False)[0] | |
video = pipe.video_processor.postprocess_video(video, output_type="pil")[0] | |
if rank == 0: | |
export_to_video(video, "output.mp4", fps=16) | |
finally: | |
dist.destroy_process_group() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment