Skip to content

Instantly share code, notes, and snippets.

View LiutongZhou's full-sized avatar
🏠
Working

Liutong Zhou LiutongZhou

🏠
Working
  • Apple
  • New York
View GitHub Profile
@LiutongZhou
LiutongZhou / Benchmark_Results.md
Last active February 27, 2026 16:09
Benchmarking Jax 0.7+ vs Pytorch 2.10+ Attention Speed

Benchmarking Jax 0.7+ vs Pytorch 2.10+ Attention Variants' Speed and Memory

torch.nn.attention.varlen.varlen_attn is the absolute winner on GPU

GPU: NVIDIA GeForce RTX 4090 Laptop GPU
PyTorch version: 2.10.0
JAX version: 0.7.2
Config: Config(batch_size=8, num_heads=32, head_dim=128, seq_min=128, seq_max=2048, dtype=<DType.BFLOAT16: 'bfloat16'>, is_causal=True, warmup=2, iters=30, seed=42, _seq_lens=(299, 1614, 1385, 971, 959, 1777, 293, 1467))
Seq lens: [299, 1614, 1385, 971, 959, 1777, 293, 1467]
@LiutongZhou
LiutongZhou / moe.py
Created February 4, 2026 03:32
MOE Parallel with Token Dropping in Jax
"""Mixture of Experts (MoE) Layer with token dropping
Using Ragged All-to-All Communication and Ragged Dot in JAX.
"""
__author__ = "Liutong Zhou"
from __future__ import annotations
from dataclasses import dataclass, field
@LiutongZhou
LiutongZhou / JAX Sharding Cheat Sheet (v0.8+).ipynb
Last active February 10, 2026 03:45
Jax Distributed Zero to Hero
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@LiutongZhou
LiutongZhou / flash_attention.py
Created November 21, 2025 21:18
FlashAttention
"""FlashAttention: reference implementation of the core logic"""
import math
import torch
from einops import einsum
from jaxtyping import Float, Int
from torch import nn, Tensor
@LiutongZhou
LiutongZhou / sdpa_moe_openai_oss.py
Last active February 4, 2026 02:18
sdpa and MOE of OpenAI OSS
"""OpenAI OSS sdpa and moe implementations that are suitable for both training and inference."""
from typing import Final
import torch
import torch.nn.functional as F
from einops import einsum, rearrange, repeat
from torch import Tensor, nn
__all__ = ["sdpa", "MOEBlock"]
"""Distributed Data Parallel Inference for Hugging Face Transformers."""
from typing import Union
import torch
from accelerate import Accelerator
from accelerate.utils import gather_object
from tqdm import tqdm
from transformers import (
PreTrainedModel,
@LiutongZhou
LiutongZhou / universal_decorator.py
Last active May 13, 2024 15:23
Universal Decorator
"""Universal Decorator
Universal decorators can decorate functions, classes, bound methods
(class method / instance method) referenced outside of class definition
and descriptors (class methods, static methods) defined inside class definition.
"""
from __future__ import annotations
import inspect
@LiutongZhou
LiutongZhou / memory_efficient_training.md
Last active July 11, 2023 15:37
Memory Efficient Training of LLMs
"""Data Strutures that extend OrderedDict"""
from collections import Counter, OrderedDict
from typing import Any, Hashable, Optional, Tuple, List
from hypothesis import given, strategies as st
__all__ = ["OrderedDefaultDict", "MinMaxCounter"]
class OrderedDefaultDict(OrderedDict):
@LiutongZhou
LiutongZhou / docker_tips.md
Last active December 27, 2025 03:34
Docker Tips

Docker Tips

Move default docker storage to another location

nano /etc/docker/daemon.json

## add this config
{
"data-root": "/newlocation"