Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created October 30, 2025 01:59
Show Gist options
  • Save ezyang/dc297e0234d0f61a0ada7c6aa61216c1 to your computer and use it in GitHub Desktop.
Save ezyang/dc297e0234d0f61a0ada7c6aa61216c1 to your computer and use it in GitHub Desktop.
x = DTensor.from_local(arange_nd(15), mesh["m", "n", "k"], [R, R, R])
# Eliminate M
x = DTensor.from_local(x.redistribute(placements=[R, R, S(0)]).to_local(), mesh["m", "n"]) # shard K
x = DTensor.from_local(x.redistribute(placements=[R, S(0)]).to_local(), mesh["m"]) # shard N
x = x.redistribute(placements=[S(0)]).to_local() # shard M
x = DTensor.from_local(x, mesh["n"], [S(0)]).redistribute(placements=[R]) # unshard N
x = DTensor.from_local(x.to_local(), mesh["n", "k"], [R, S(0)]).redistribute(placements=[R, R]) # unshard K
# Eliminate N
x = DTensor.from_local(x.redistribute(placements=[R, S(0)]).to_local(), mesh["n"]) # shard K
x = x.redistribute(placements=[S(0)]).to_local() # shard N
x = DTensor.from_local(x, mesh["k"], [S(0)]).redistribute(placements=[R]) # unshard K
# Eliminate K
x = x.redistribute(placements=[S(0)]).to_local() # shard K
print(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment