Skip to content

Instantly share code, notes, and snippets.

View a-r-r-o-w's full-sized avatar
:octocat:
wandering on a rock

Aryan a-r-r-o-w

:octocat:
wandering on a rock
View GitHub Profile
@a-r-r-o-w
a-r-r-o-w / triton_autotune_wtf.py
Created August 31, 2025 12:35
triton autotuning somehow reports slower times
import torch
import torch._dynamo.config
import torch._inductor.config
import triton
import triton.language as tl
torch._dynamo.config.cache_size_limit = 10000
torch._inductor.config.triton.cudagraphs = False
torch._inductor.config.triton.cudagraph_trees = False
@a-r-r-o-w
a-r-r-o-w / benchmark.sh
Created August 29, 2025 12:57
Wan 2.2 5B T2V benchmarks
#!/bin/bash
set -xe
export TORCH_LOGS="recompiles,inductor"
export CUDA_VISIBLE_DEVICES="3,2,1,0"
set_fa_op() {
COMPUTE_CAPABILITY=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n 1 | tr -d '.')
@a-r-r-o-w
a-r-r-o-w / flux_cuda_stream.py
Created August 18, 2025 09:40
Flux with cuda stream
import argparse
import contextlib
import functools
import pathlib
import math
from dataclasses import dataclass
from typing import Callable, List, Literal, Optional, Tuple
import numpy as np
import torch
@a-r-r-o-w
a-r-r-o-w / wan_t2v.py
Last active August 19, 2025 17:18
Lossless system-level only optimizations benchmark for Wan
import argparse
import contextlib
import math
from dataclasses import dataclass
from typing import Callable, Literal, Optional, Tuple
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
import torch.profiler._utils
import math
import time
import torch
import triton
import triton.language as tl
import triton.runtime as runtime
import triton.tools.experimental_descriptor
import torch
import torch.nn as nn
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
class Model(nn.Module):
"""
Simple model that performs a single square matrix multiplication (C = A * B)
@a-r-r-o-w
a-r-r-o-w / attention_free_transformer.py
Created July 11, 2025 11:29
Attention-free transformer
"""
Implementation of "An Attention-Free Transformer": https://arxiv.org/abs/2105.14103
"""
import contextlib
import functools
import torch
import triton
import triton.language as tl
@a-r-r-o-w
a-r-r-o-w / fused_adaln_zero_triton.py
Created July 8, 2025 08:16
Can be faster than torch.compile if you don't use masks! Almost always possible in common transformer scenarios with aligned block sizes
import torch
import triton
import triton.language as tl
torch._dynamo.config.cache_size_limit = 10000
ENABLE_TRITON = True
ENABLE_DEEP_AUTOTUNE = True
@a-r-r-o-w
a-r-r-o-w / sequential_ring.py
Created June 30, 2025 11:36
sequential and templated ring/ulysses/unified attention implementation
import torch
torch.manual_seed(42)
def torch_sdpa(query, key, value):
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_cudnn_attention(
query=query,
key=key,
import argparse
import contextlib
import math
import pathlib
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn