Skip to content

Instantly share code, notes, and snippets.

@nihalpasham
Last active June 30, 2026 03:34
Show Gist options
  • Select an option

  • Save nihalpasham/7c6b8b2c9e19790c218416b6ba4d9767 to your computer and use it in GitHub Desktop.

Select an option

Save nihalpasham/7c6b8b2c9e19790c218416b6ba4d9767 to your computer and use it in GitHub Desktop.
cuda-oxide: a speed-of-light GEMM in pure Rust (companion notes for the stream)
image

Image Credits:

image 2

One Kernel, Eight Ideas: A Speed-of-Light GEMM in Pure Rust

This walkthrough reads one GPU kernel: gemm_sol_clc_multicast_4_stage_pipeline, a matrix-multiply written in pure Rust that hits 58% of NVIDIA's hand-tuned library on a Blackwell GPU. It is a few hundred lines, and packed into it are about eight distinct ideas, each one solving a specific bottleneck.

The plan: understand the problem, meet the one piece of silicon that does the actual math (the tensor core), look at the kernel from the top, then walk it part by part. Every part teaches one idea: what it is, the bottleneck it removes, how it shows up in this exact kernel, and the Rust that expresses it.

The method is the same throughout: find what is stalling, fix exactly that, repeat.

Table of Contents

  1. The Problem: Multiplying Two Big Matrices
  2. A Kernel Is Just a Rust Function
  3. The Hardware: Threads, Warps, Blocks, SMs, Clusters
  4. The Engine: Tensor Cores and the 8x8 Brick
  5. The Kernel in One Picture
  6. Walking the Kernel, One Idea at a Time
  7. The Epilogue: Getting the Answer Out
  8. The Rust Toolbox
  9. The Payoff
  10. Key Takeaways
  11. Reproducing
  12. Source Material

The Problem: Multiplying Two Big Matrices

Everything in this kernel is one operation: multiply two matrices. Take two 4096-by-4096 grids of numbers, A and B, and produce a third, C = A times B.

        K (=4096)            N (=4096)              N (=4096)
     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”      β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”      β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
   M β”‚     A     β”‚  Γ—  Kβ”‚       B       β”‚  =  Mβ”‚       C       β”‚
(4096β”‚  (M Γ— K)  β”‚      β”‚    (K Γ— N)    β”‚ (4096β”‚    (M Γ— N)    β”‚
)    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜      β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ )    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

  A is M rows by K columns.   B is K rows by N columns.
  C is M rows by N columns.   For us, M = K = N = 4096.

M, N, and K are the three sizes. K is the shared one: it lines A's columns up with B's rows, and it is the dimension that disappears in the product.

How is one cell of C computed? A single number C[i, j] is the dot product of row i of A with column j of B: multiply element by element, add it all up.

   C[i, j]  =  (row i of A)  Β·  (column j of B)

            =  A[i,0]Β·B[0,j] + A[i,1]Β·B[1,j] + ... + A[i,4095]Β·B[4095,j]
               └──────────────── sum over all K = 4096 terms β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

The scale, so you feel why this is a GPU job. C has M Γ— N = about 16.7 million cells. Each cell is a sum of K = 4096 multiply-adds. That is roughly 137 billion multiply-adds for one product, and real workloads do thousands of these per second.

Two facts from this picture drive every decision in the kernel:

  • C is enormous (16.7M cells), so the work must be split up. That is grid tiling.
  • Each cell sums over all of K (4096 terms), and K is far too long to hold in fast memory at once, so it must be walked. That is the K-loop.

A Kernel Is Just a Rust Function

Groundwork before the big kernel: what GPU code looks like in cuda-oxide. Here is the simplest possible kernel, adding two vectors.

#[kernel]
pub fn vecadd(a: &[f32], b: &[f32], mut c: DisjointSlice<f32>) {
    let idx = thread::index_1d();
    let idx_raw = idx.get();
    if let Some(c_elem) = c.get_mut(idx) {
        *c_elem = a[idx_raw] + b[idx_raw];
    }
}

A normal Rust function with a #[kernel] attribute. The backend turns it into GPU assembly. Inputs are real slices. The output is a DisjointSlice, and that one type is what makes the write safe:

  CUDA C++ raw `float* c`             cuda-oxide `DisjointSlice<f32>`
  ─────────────────────               ──────────────────────────────
  thread 3 ─┐                         thread 3 ─► [ ][ ][βœ“][ ][ ]
  thread 5 ─┴─► [ ][ ][βœ—][ ]          each thread gets a unique,
            two threads, one cell     bounds-checked cell, proven
            (the race compiles)       by the type system

To write through a DisjointSlice you need a ThreadIndex, which only the hardware-register function index_1d() can mint, and which is !Copy. You cannot photocopy your proof of uniqueness, so two threads cannot get a write ticket to the same cell. The data race and the out-of-bounds write are unrepresentable in safe code. The big kernel uses this same type as its output.


The Hardware: Threads, Warps, Blocks, SMs, Clusters

The kernel talks in terms of "warps," "blocks," and "clusters." Five words of GPU vocabulary make it readable. They nest, smallest to largest, and each level up adds one new power: first more workers, then a shared scratchpad, then a scratchpad shared across blocks.

   thread     one worker, with its own registers
     β”‚
     β”‚  32 of them run the SAME instruction in lockstep
     β–Ό
   WARP       32 threads; the GPU hands out work one warp at a time
     β”‚
     β”‚  several warps placed on ONE physical core
     β–Ό
   BLOCK      a team of warps on one SM, sharing a fast on-chip scratchpad
   (a "CTA")
     β”‚
     β”‚  a few blocks on neighbouring SMs, allowed to see each other's memory
     β–Ό
   CLUSTER    blocks that can read and write each other's scratchpad
     β”‚
     β–Ό
   GRID       every block of the whole kernel launch
  • Thread: the smallest worker, one lane with its own registers.
  • Warp = 32 threads. They run one instruction together, in lockstep (the GPU's "SIMT" model). Issue "load" on a warp and all 32 lanes load at once. Because work is handed out one warp at a time, this kernel will assign jobs per warp, not per thread.
  • Block (also called a CTA, cooperative thread array): a group of warps that runs on one SM and shares a small, fast on-chip scratchpad called shared memory (up to 228 KB on Blackwell). Threads in a block can synchronize and pass data through it. A block is the unit that owns one 128x128 output tile in the matmul.
  • SM (streaming multiprocessor): the physical core that runs a block. Each SM holds the CUDA cores, the tensor core, and the copy engine (all introduced next), plus the shared-memory scratchpad. A Blackwell B200 has 148 of them.
  • Cluster: a few blocks placed on neighbouring SMs that are allowed to read and write each other's shared memory (called distributed shared memory). This is new on Hopper and Blackwell, and it is the reason this kernel can make two blocks cooperate as one.
  • Grid: every block of one kernel launch, spread across all the SMs.

Memory follows the same ladder: each thread has registers, each block has shared memory, and the whole grid shares global memory (the big, slow HBM where A, B, and C actually live). Fast and tiny at the top, huge and slow at the bottom. Most of the kernel's cleverness is keeping data high on that ladder.


The Engine: Tensor Cores and the 8x8 Brick

Before reading the kernel, you have to know the one piece of silicon doing the multiply, because every optimization in it is about feeding that piece correctly. A modern GPU has two kinds of math unit:

  • CUDA cores: the ordinary, general-purpose lanes. They do the loads, the stores, the index math, the format conversions. Everything except the big matrix multiply.
  • Tensor cores: a dedicated matrix-multiply engine. This is the only part fast enough to make 137 billion multiply-adds tractable. Each SM has one.

The fundamental unit the tensor core works in is an 8x8 tile, the "core matrix." Eight rows by eight columns of fp16 numbers: that is 64 numbers, and in fp16 it is exactly 128 bytes. Think of it as a single brick. Every larger multiply the tensor core does is built out of these bricks.

   A 128Γ—128 tile, the way the tensor core sees it: a wall of 8Γ—8 bricks.

      8    8    8           8
    β”Œβ”€β”€β”€β”€β”¬β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”¬ Β·Β·Β· ┬────┐
  8 β”‚ β–¦  β”‚ β–¦  β”‚ β–¦  β”‚     β”‚ β–¦  β”‚   each β–¦ is one 8Γ—8 core matrix:
    β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Ό Β·Β·Β· ┼─────     64 fp16 numbers = 128 bytes
  8 β”‚ β–¦  β”‚ β–¦  β”‚ β–¦  β”‚     β”‚ β–¦  β”‚     = the tensor core's atom
    β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Ό Β·Β·Β· ┼─────
     ...                            16 bricks across Γ— 16 bricks down
    └────┴────┴────┴ Β·Β·Β· β”΄β”€β”€β”€β”€β”˜     = 256 bricks make one 128Γ—128 tile

You do not place bricks by hand. You hand the tensor core a small 32-bit descriptor that says "operands are fp16, accumulate in fp32, the shape is this big," and it does the rest. In CUDA you assemble that descriptor's bits with magic shifts. In cuda-oxide it is a const fn builder that folds to one constant at compile time, with compile-checked enums instead of shifts:

let idesc = Tcgen05InstructionDescriptor::builder()
    .shape(Tcgen05MmaShape::M256_N128)          // output shape (more on 256 later)
    .element_type(Tcgen05ElementType::F16)      // inputs are fp16
    .accumulator_type(Tcgen05AccumulatorType::F32)  // sum in fp32
    .build()
    .raw();

One detail that matters for Idea 3. The tensor core does not read a brick in one gulp. It reads it row by row, and the row stride it expects is hardwired to 16 bytes (8 fp16 numbers). It assumes the 8 rows of a brick sit 16 bytes apart. Remember that number. A whole optimization exists just to satisfy it.


The Kernel in One Picture

With that vocabulary in hand, here is the whole kernel from the top before we zoom in. Three things to take away from this picture: it runs blocks in pairs, each block is a crew of 6 warps with fixed jobs, and data flows down a fixed assembly line.

   #[cluster_launch(2, 1, 1)]   β†’   blocks run in pairs ("CTA pairs"),
                                     one block per SM, two SMs cooperating.

   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ one CTA pair: leader (SM 0) + follower (SM 1) ─────────────────┐
   β”‚                                                                                  β”‚
   β”‚  Each block = 6 warps (192 threads). Jobs are fixed:                             β”‚
   β”‚                                                                                  β”‚
   β”‚   Warp 4  LOADER     ─ streams A and B from global memory into shared memory     β”‚
   β”‚                        (TMA copies, broadcast to both blocks).   [CUDA cores]    β”‚
   β”‚                              β”‚ "buffer ready"      β–² "buffer free"               β”‚
   β”‚                              β–Ό                      β”‚                            β”‚
   β”‚   Warp 5  MULTIPLIER ─ leader issues ONE paired MMA spanning both blocks;        β”‚
   β”‚                        follower issues none.      [drives the TENSOR core]       β”‚
   β”‚                              β”‚ "tile done"                                       β”‚
   β”‚                              β–Ό                                                   β”‚
   β”‚   Warps 0-3 STORE    ─ read the result, convert f32β†’bf16, write to global.       β”‚
   β”‚                                                                  [CUDA cores]    β”‚
   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

   The assembly line for the data:

     global memory  ──TMA, multicast──►  shared memory (4-stage ring buffer)
                                              β”‚  tensor core reads bricks
                                              β–Ό
                                          tensor memory (the fp32 accumulator)
                                              β”‚  epilogue drains it
                                              β–Ό
                                          shared β†’ global  (the answer)

Now we walk it. Each idea below is one piece of this picture.


Walking the Kernel, One Idea at a Time

Idea 1: Grid Tiling (split the giant answer)

The bottleneck. C has 16.7 million cells. No single group of threads can compute all of them.

The idea. Cut C into small 128-by-128 tiles and give each tile to one block of threads (a "CTA"). For a 4096 matrix that is 32 tiles across and 32 down: 1024 tiles, spread across the whole GPU.

        N ─────────────────────────►
     β”Œβ”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€ ... ──┐
   M β”‚ tile β”‚ tile β”‚ tile β”‚          β”‚   Each block owns ONE 128Γ—128
   β”‚ β”‚(0,0) β”‚(0,1) β”‚(0,2) β”‚          β”‚   tile of C, and reads:
   β”‚ β”œβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€          β”‚     β€’ a 128-row BAND of A
   β–Ό β”‚ tile β”‚ tile β”‚      β”‚          β”‚     β€’ a 128-col BAND of B
     β”‚(1,0) β”‚ Β·Β·Β·  β”‚      β”‚          β”‚
     └──────┴──────┴──────┴── ... β”€β”€β”€β”˜
                 C (M Γ— N)

A block computes its tile coordinates from its block id, reads the matching band of A and band of B, and produces one 128x128 chunk of C. Two payoffs: parallelism (many tiles keep every SM busy) and reuse (a whole tile hits the same A and B bands repeatedly, so you load them into fast memory once).

In the kernel. The output is a DisjointSlice<u32> (each u32 packs two bf16 results). Each block writes a different, non-overlapping region of it, so the writes provably cannot race:

pub unsafe fn gemm_sol_clc_multicast_4_stage_pipeline(
    a_tma: *const TmaDescriptor,    // A and B come pre-described for the
    b_tma: *const TmaDescriptor,    // copy engine (see Idea 3)
    mut out: DisjointSlice<u32>,    // the output tile surface
    n: i32, k: i32, tiles_m: u32, _tiles_n: u32,
) {
    let first_tile_m = tile_idx % tiles_m;   // which tile this block owns
    let first_tile_n = tile_idx / tiles_m;

Idea 2: The K-Loop (build a tile by walking K)

The bottleneck. A block owns a 128x128 tile, but each cell still sums over all K = 4096 terms. The A and B bands are 4096 long and do not fit in fast memory at once.

The idea. Walk K in chunks. Load a 64-wide slice of the bands, multiply, add the partial result into the tile, repeat until K is exhausted.

   The band, walked in chunks of 64 (4096 / 64 = 64 iterations):

   A band:  [ chunk0 ][ chunk1 ][ chunk2 ] ... [ chunk63 ]
   B band:  [ chunk0 ][ chunk1 ][ chunk2 ] ... [ chunk63 ]
                 β”‚         β”‚         β”‚              β”‚
   tile  =   A0Γ—B0  +   A1Γ—B1  +   A2Γ—B2  + ... +  A63Γ—B63

   First chunk:  tile  = A0 Γ— B0      (overwrite)
   Every later:  tile += Ak Γ— Bk      (accumulate)
   After chunk 63: the tile is the finished answer.

The partial sums never leave the tensor core's private fp32 accumulator (tensor memory) until the tile is complete. That is the inner loop of the whole algorithm.

In the kernel. k_iters = k / 64, and accumulate is false only on the very first multiply, true forever after:

let k_iters = k / 64;            // 4096 / 64 = 64 chunks
while k_idx < k_iters {
    // ... get this chunk's buffers, wait for data ...
    let accumulate = k_idx > 0 || j > 0;   // overwrite once, then add
    tcgen05_mma_f16_cg2(tmem_addr + offset, a_desc, b_desc, idesc, accumulate);
    k_idx += 1;
}

(Each 64-wide chunk is actually 4 back-to-back MMA instructions, since one instruction digests 16 of the 64 K-elements. The brick wall from earlier is being multiplied 16 columns at a time.)

Idea 3: Swizzling (lay the bricks down right, in one pass)

The bottleneck. The loader copies each K-chunk into shared memory, and the tensor core reads its bricks from there. But a plain, natural row-by-row copy lays the data out wrong for the brick reader.

This is the picture that explains swizzling. A natural copy of a 64-wide chunk puts each row 128 bytes after the last (64 fp16 numbers per row). But the tensor core's brick reader steps in 16-byte rows (it wants 8-wide bricks, remember). So after the first row of each brick, it reads garbage.

   What a plain copy produces        What the tensor core's brick
   (rows 128 bytes apart):           reader expects (rows 16 bytes apart):

   row 0 ─► byte 0                    row 0 ─► byte 0
   row 1 ─► byte 128                  row 1 ─► byte 16
   row 2 ─► byte 256                  row 2 ─► byte 32
        the reader steps 16 bytes,         every row lands exactly where
        so it lands inside row 0's         the hardwired 16-byte stride
        data and multiplies garbage        points: correct bricks

How do you get the bricks laid out the way the reader wants? Two options:

  1. Chop the copy into eight skinny strips, each one already brick-shaped. This works, but it is eight transfers where you wanted one. Slow.
  2. Swizzle. Ask the copy engine (TMA) to scramble the bytes into brick order as it transfers them. One big copy lands already brick-aligned. That is "swizzling": a fixed, hardware-understood reshuffle applied in flight.
   Option 1: 8 skinny strips per chunk     Option 2: 1 swizzled copy
   β”Œβ”€β”β”Œβ”€β”β”Œβ”€β”β”Œβ”€β”β”Œβ”€β”β”Œβ”€β”β”Œβ”€β”β”Œβ”€β”                 β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
   β”‚ β”‚β”‚ β”‚β”‚ β”‚β”‚ β”‚β”‚ β”‚β”‚ β”‚β”‚ β”‚β”‚ β”‚   ──────────►   β”‚ scrambled into brick   β”‚
   β””β”€β”˜β””β”€β”˜β””β”€β”˜β””β”€β”˜β””β”€β”˜β””β”€β”˜β””β”€β”˜β””β”€β”˜                 β”‚ order by the TMA engineβ”‚
   8Γ— transfer overhead                     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                                             1Γ— transfer, same layout

There is a second reason the scramble is shaped the way it is: bank conflicts. Shared memory is split into 32 "banks," and they work like 32 supermarket checkout lanes. 32 threads reading 32 different banks all get served in one cycle. But if many threads hit the same bank, they queue and get served one at a time.

   No swizzle: a brick's rows all          Swizzle: the reshuffle spreads
   land in the SAME bank β†’ threads         those rows across different banks
   queue at one checkout lane              β†’ all served in 1 cycle

      lane β–’ β–’ β–’ β–’ β–’ β–’ β–’ β–’                    lane β–’ . . . . . . .
           (one bank, serialized)                  . β–’ . . . . . .
                                                    . . β–’ . . . . .

The swizzle does both jobs at once: it lands the bricks in the layout the tensor core's 16-byte stride expects, and it spreads each brick's bytes across the 32 banks so reads never collide.

In the kernel. Swizzling is one named mode, threaded into the shared-memory descriptor that tells the tensor core where its bricks are. No change to the loop, the math, or the epilogue:

const SWIZZLE_128B: u8 = 2;     // a hardware-known reshuffle pattern
let a_desc = build_smem_descriptor(smem_a_base + off, LBO_BYTES, SBO_BYTES, SWIZZLE_128B);

Idea 4: The 4-Stage Pipeline (stop waiting on loads)

The bottleneck. Inside the K-loop, the naive rhythm is: load a chunk, wait, multiply, wait. The tensor core sits idle during every load, and the copy engine sits idle during every multiply. They take turns when they could overlap.

The idea. Keep several shared-memory buffers (a ring buffer of "stages") so the loader can fill the next buffers while the tensor core chews on the current one. This kernel keeps four stages, deep enough that the loader runs three chunks ahead.

   1 buffer (serial):                 4 buffers (overlap):
     [load 0]                           [load 0][load 1][load 2][load 3]
            [calc 0]                            [calc 0][calc 1][calc 2]…
            [load 1]                     loader stays 3 chunks ahead;
                   [calc 1]              the tensor core never waits for a load

Why four and not two? Because of the next two ideas: clusters add a per-chunk handshake between blocks, and a deep buffer is what gives the loader enough runway to hide that handshake. Shallow pipelines stall on it.

In the kernel. The stage for each K-iteration is just the low two bits of a running counter, and which buffers/barriers to use is one exhaustive match:

let stage = global_k & 3;            // 0,1,2,3 rotating
let (smem_a, smem_b, tma_bar, mma_bar) = match stage {
    0 => ( &raw mut SMEM_A0, &raw mut SMEM_B0, .. ),
    1 => ( &raw mut SMEM_A1, &raw mut SMEM_B1, .. ),
    2 => ( &raw mut SMEM_A2, &raw mut SMEM_B2, .. ),
    _ => ( &raw mut SMEM_A3, &raw mut SMEM_B3, .. ),   // wildcard makes it total
};

Idea 5: Warp Specialization (a crew, each with one job)

The bottleneck. Here is the subtle one. If a single team of threads issues "load, then multiply, then load," it is serial no matter how many buffers you give it. The 4-stage pipeline from Idea 4 does nothing on its own. The buffers were never the problem. Having one worker do both jobs was.

The idea. Warps are the lever here (recall: 32 threads, scheduled as one unit). This kernel runs 6 warps per block (192 threads) and gives each warp a single dedicated job, like a kitchen with a prep cook, a line cook, and a plating crew instead of one person doing all three.

   The block as a specialized crew: 6 warps, 32 threads each.

   Warp 0 ─┐
   Warp 1  β”‚   STORE CREW (epilogue, 128 threads): once a tile is
   Warp 2  β”‚   finished, read it out of tensor memory, convert
   Warp 3 β”€β”˜   f32 β†’ bf16, write it to global.        ← CUDA cores (STORE)

   Warp 4      LOADER (producer): does nothing but issue TMA copies,
               streaming A and B chunks into the buffers.   ← CUDA cores (LOAD)

   Warp 5      MULTIPLIER (consumer): does nothing but issue MMA
               instructions to the tensor core.   ← drives the TENSOR core (MATH)

This is the answer to "how many warps, and which ones do the CUDA-core work?" Six warps. The tensor core does only the multiply-add, and only warp 5 talks to it. Everything else is CUDA-core work: warp 4 does the loads, warps 0-3 do the format conversion and the stores. Load and store are the CUDA cores; multiply is the tensor core; and now they all run at once.

The loader (warp 4) and the multiplier (warp 5) hand buffers back and forth through lightweight signals (mbarriers): "stage ready" from loader to multiplier, "stage free" back the other way. No lockstep, no sync_threads in the K-loop.

   One team (serial):                 Specialized crew (overlapped):
   load, compute, load, compute…      Warp 4: loadβ†’loadβ†’load→…  (only loads)
                                       β”‚ "ready"   β–² "free"
                                       β–Ό           β”‚
                                      Warp 5: wait→MMA→wait→MMA (only MMAs)

In the kernel. The roles are named constants, and the body splits into blocks gated on warp_id:

const TMA_WARP: u32 = 4;
const MMA_WARP: u32 = 5;
if warp_id == TMA_WARP { /* producer: only issues loads          */ }
if warp_id == MMA_WARP { /* consumer: only issues paired MMAs     */ }
if warp_id < 4         { /* epilogue: tensor memory β†’ bf16 β†’ global */ }

Pipelining (Idea 4) and warp specialization (Idea 5) are one optimization with two halves: the buffers are the racetrack, the specialized warps are the cars. Neither does anything without the other, which is why this kernel has both.

Idea 6: Persistent + Work Stealing (stop relaunching)

The bottleneck. If you launch one fresh block per tile, you pay launch overhead 1024 times, and the final step of each tile (writing it out) leaves the tensor core idle between tiles.

The idea. Launch one long-lived ("persistent") block per SM that loops: finish a tile, then grab the next tile to compute. How do blocks agree on who gets which tile? A shared counter in global memory works but becomes a traffic jam (every block hammering one address). This kernel instead lets the hardware scheduler hand out tiles (Cluster Launch Control, CLC), with no global-memory traffic at all.

   Global atomic counter (jam):       On-chip scheduler (CLC):
   blk ─┐    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”         blk ─► "give me a tile"
   blk ─┼───►│ counter [73] β”‚              ◄─ "take tile 73"
   blk β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜         tens of cycles, no global traffic
   contention grows with tile count

In the kernel. The loader warp's body is two loops, one nested in the other. The block first computes its home tile (the one its blockIdx points at) by running the K-loop of Idea 2: 64 iterations (k_iters = K/64), each streaming one 64-wide K-slab into the pipeline (a 128x64 piece of A and a 64x64 piece of B, per block). That inner loop is the loader's entire contribution to one output tile. Then an outer loop asks the scheduler for another tile and runs the same 64-step K-loop again, until CLC reports no work is left. One block, many tiles.

// HOME tile: stream all 64 K-chunks (this inner while IS the K-loop of Idea 2)
while k_idx < k_iters { /* TMA-load one 64-wide K-slab of A and B */ }

loop {                                                      // then steal more tiles
    clc_try_cancel_multicast(resp_ptr, &raw mut CLC_BAR);   // ask for a tile
    if clc_query_is_canceled(resp_lo, resp_hi) == 0 { break; }  // none left β†’ done
    let tile_idx = clc_query_get_first_ctaid_x(resp_lo, resp_hi) / 2;
    while k_idx < k_iters { /* the SAME 64-step K-loop, for the stolen tile */ }
}

The loader does not preload all 64 chunks at once: with only four pipeline stages, the mma_bar wait makes it pause until the multiplier frees a stage, so it stays about four chunks ahead and no further. A running tile counter keeps the stage rotation and barrier parity continuous across tile boundaries, so the pipeline does not hiccup at the seam between one tile and the next.

The block also runs two accumulator slots in tensor memory, so the store crew can be draining tile N while the multiplier already starts tile N+1. The tensor core never idles waiting for a write-out.

Bonus lever (large matrices only). Persistent blocks also let you choose the order tiles are computed. Sweeping a small band of columns before advancing rows keeps the A and B strips that neighbors share resident in the on-chip L2 cache, turning slow DRAM re-reads into cache hits. At 4096-cubed the data already fits and it is a no-op; at 16384-cubed this single reordering is worth nearly +90%. Same math, same output, just a better visit order.

Idea 7: Clusters + Multicast (load shared data once)

The bottleneck. Notice the #[cluster_launch(2, 1, 1)] on the function: it groups blocks into clusters of two that run together and can see each other's shared memory. The two blocks in a cluster work on neighboring output tiles that need overlapping data. Each one loading that shared data separately is redundant memory traffic.

The idea. Have one block load the shared operand once and let the hardware broadcast (multicast) it to both blocks in the cluster. One trip to memory, fanned out on-chip.

   Without multicast:                 With multicast:
   block 0 ─► loads its own copy       GMEM ─► block 0 ─► block 1
   block 1 ─► loads its own copy       one load, hardware fans it to both

A warning from history: naively widening the cluster and broadcasting added overhead, because the broadcast forces a per-chunk handshake across the cluster, and a shallow pipeline had no slack to hide it. That is exactly why this kernel pairs a narrow cluster (just two blocks) with the deep 4-stage pipeline of Idea 4. The depth pays for the handshake.

In the kernel. The load is a multicast bulk-copy; the cluster size lives in the #[cluster_launch(...)] attribute:

#[cluster_launch(2, 1, 1)]       // clusters of 2 blocks
// ...
cp_async_bulk_tensor_2d_g2s_multicast_cg2(smem_a_ptr, a_tma, k, m, aliased_bar, self_mask);
cp_async_bulk_tensor_2d_g2s_multicast_cg2(smem_b_ptr, b_tma, k, n, aliased_bar, self_mask);

Idea 8: CTA Pairs (two blocks, one instruction)

The bottleneck. Two blocks in a cluster, each driving its own tensor core, each issuing its own MMA and its own synchronization. Can the pair share more?

The idea. Pair the two blocks so that a single tensor-core instruction spans both. With cta_group::2, the leader block issues one MMA whose shape is M256_N128: it reads operands from both blocks' shared memory and writes results into both blocks' tensor memory. One instruction computes a 256-row output tile (two stacked 128-row tiles). The follower block just loads and stores; it never issues an MMA.

   Two 128-row tiles stacked into one 256-row tile, computed by ONE MMA:

   block 0 (leader)   ─ owns rows   0..127 ─┐
                                            β”œβ”€β–Ί one M256_N128 paired MMA
   block 1 (follower) ─ owns rows 128..255 β”€β”˜    reads both blocks' SMEM,
                                                 writes both blocks' TMEM

For that paired MMA to fire, both blocks' "load done" signals have to land on the same barrier. The trick: mask one bit of the barrier address so both blocks point at the leader's barrier. That makes the cross-block handshake nearly free.

   block 0's barrier addr ─┐
                           β”œβ”€β–Ί (addr & 0xFEFFFFF8) ─► leader's barrier
   block 1's barrier addr β”€β”˜
   The mask clears the bit that distinguishes the two blocks, so their
   separate "I'm loaded" signals merge into one the leader waits on.

In the kernel. The dangerous pointer arithmetic is isolated behind one named constant and one clearly named local, instead of a magic number buried in the hot loop; the paired MMA is a single cta_group::2 call:

const PEER_BIT_MASK: u32 = 0xFEFFFFF8;
let aliased_bar = ((tma_bar_mut as u32) & PEER_BIT_MASK) as *mut Barrier;
// ... leader only:
tcgen05_mma_f16_cg2(tmem_addr + offset, a_desc, b_desc, idesc, accumulate);  // spans both blocks

This is what turns the cluster from a cost into a win: the pair shares both the load (multicast, Idea 7) and the multiply (paired MMA, Idea 8).


The Epilogue: Getting the Answer Out

When the K-loop finishes, the answer for the tile lives in the tensor core's private fp32 accumulator (tensor memory). It is not in a form you can write to global memory yet. The store crew (warps 0-3) drains it:

   tensor memory (fp32)  ──read──►  registers  ──convert──►  bf16
        (CuSimd group)                              β”‚
                                                    β–Ό
                              shared memory  ──coalesced store──►  global memory
                              (staged with stmatrix)               (the DisjointSlice)
  1. Read the accumulator out of tensor memory into a register group (CuSimd).
  2. Convert each fp32 pair to a packed bf16 pair (cvt_f32x2_bf16x2).
  3. Stage the bf16 results into shared memory in a tidy layout (stmatrix).
  4. Store that shared buffer to global memory, the only writable surface, the DisjointSlice output.

Because the block runs two accumulator slots, this drain happens for tile N while the multiplier is already grinding tile N+1. The store latency hides behind the next tile's math.

let regs_a = tcgen05_ld_16x256b_pure(tmem_addr + ...);  // read fp32 from TMEM
tcgen05_load_wait();
let p0 = cvt_f32x2_bf16x2(regs_a[0], regs_a[1]);        // f32 β†’ bf16
stmatrix_m8n8_x2(smem_addr, p0, p1);                    // stage in shared
// ... then a coalesced copy from SMEM_OUT to the global `out` slice

The Rust Toolbox

Three abstractions that are not GPU ideas, but are how Rust keeps the kernel's resource use correct.

Type-state: a lifecycle the compiler checks. Tensor memory and barriers must be allocated, used, then freed in order, or you get a leak or silent corruption. cuda-oxide encodes that order in the type:

   TmemUninit ──alloc()──► TmemReady ──dealloc()──► TmemDeallocated
                              β”‚                          β”‚
                         address()                  (no methods)
                                                  address() here?
                                                  βœ— COMPILE ERROR

alloc consumes the Uninit handle and returns Ready. address only exists on Ready. Use-before-alloc and use-after-free do not compile. Barriers get the same treatment, with a "kind" so the load barrier and the compute barrier are different types.

CuSimd: register groups as one type. When the tensor core hands back its result, it arrives as a group of registers. cuda-oxide models that as one indexable value, not a pile of named fields and a giant switch:

   CUDA C++:  float r0,r1,…r31;        cuda-oxide:  CuSimd<f32, 32>
              switch (i) { … }                      regs[i]   // Index trait

Newtypes: no same-width mix-ups. A tensor-map is a TmaDescriptor, not a void*. A barrier token is a BarrierToken, not a u64. Zero-cost, but a wrong argument becomes a compile error instead of a silent swap.


The Payoff

This one kernel runs at about 868 TFLOPS, 58% of a live cublasLtMatmul FP16 baseline at 4096-cubed on a B200. A few hundred lines of Rust, reaching well past half of a library NVIDIA has tuned for years.

The interesting part is how the ideas combine. Two of them do nothing on their own: the 4-stage pipeline (Idea 4) is dead weight without warp specialization (Idea 5) to put a second worker on it, and clustering + multicast (Idea 7) actually loses performance until the deep pipeline and CTA pairs (Idea 8) hide and amortize its handshake. That is why a speed-of-light kernel is not eight independent tricks bolted together. It is eight ideas chosen so each one pays for the next one's cost.


Key Takeaways

  1. The problem is simple to state. A (MΓ—K) times B (KΓ—N) equals C (MΓ—N). Every cell of C is a dot product over K. C is huge and K is long; that is the entire source of the difficulty.
  2. The tensor core works in 8x8 bricks. Everything else in the kernel is about feeding that brick engine: laying bricks out right (swizzle), keeping it fed (pipeline + warp specialization), and sharing its inputs (multicast, pairs).
  3. CUDA cores load and store; the tensor core multiplies. In this kernel that is six warps: warp 4 loads, warps 0-3 store, warp 5 drives the tensor core, all at once.
  4. Clusters turn two blocks into one machine. Multicast shares the load, the paired MMA shares the multiply, and barrier aliasing makes the handshake between them nearly free.
  5. Ideas pay for each other. The pipeline needs specialized warps; the cluster needs the deep pipeline. The wins are in the combinations, not the parts.
  6. The Rust you know carries over. std::sync::atomic, exhaustive match, loop { ... break }, the Index trait, type-state, RAII, and const fn builders that fold to one constant so the readable form is the fast form.

Reproducing

# Run the kernels: correctness tests + benchmarks
cargo oxide run gemm_sol

# Watch the device-code pipeline (MIR -> dialect-mir -> LLVM -> PTX)
cargo oxide pipeline gemm_sol

The kernel lives in crates/rustc-codegen-cuda/examples/gemm_sol/src/main.rs as gemm_sol_clc_multicast_4_stage_pipeline, a pub unsafe fn inside the #[cuda_module] mod kernels block. Generated gemm_sol.ll and gemm_sol.ptx land next to the example. Requires sm_100+ (Blackwell) to execute; on older GPUs only PTX generation is verified.


Source Material

  • gemm_sol example: crates/rustc-codegen-cuda/examples/gemm_sol/
  • The abstractions live in the cuda-device crate (DisjointSlice, SharedArray, TmemGuard, ManagedBarrier, CuSimd, the tensor-core builders) and cuda-host (#[cuda_module], LaunchConfig).
  • Conceptual grounding: Modern GPU Programming for MLSys (https://mlc.ai/modern-gpu-programming-for-mlsys/index.html), especially the tensor-core, data-layout, and advanced-GEMM chapters.
  • The repo: https://github.com/NVlabs/cuda-oxide
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment