torch.nn.attention.varlen.varlen_attn is the absolute winner on GPU
GPU: NVIDIA GeForce RTX 4090 Laptop GPU
PyTorch version: 2.10.0
JAX version: 0.7.2
Config: Config(batch_size=8, num_heads=32, head_dim=128, seq_min=128, seq_max=2048, dtype=<DType.BFLOAT16: 'bfloat16'>, is_causal=True, warmup=2, iters=30, seed=42, _seq_lens=(299, 1614, 1385, 971, 959, 1777, 293, 1467))
Seq lens: [299, 1614, 1385, 971, 959, 1777, 293, 1467]