Skip to content

Instantly share code, notes, and snippets.

@karpathy
Created June 15, 2023 04:13
Show Gist options
  • Save karpathy/e5d58e83d9fb6ce0827f0f66b253e6fe to your computer and use it in GitHub Desktop.
Save karpathy/e5d58e83d9fb6ce0827f0f66b253e6fe to your computer and use it in GitHub Desktop.
pytorch strangeness
import torch
import torch.nn as nn
torch.manual_seed(42)
x = torch.randn(2, 768)
# matrix multiply "ignores" the second row when calculating the first row
w = torch.randn(768, 768)
z1 = x[0] @ w
z2 = (x @ w)[0]
print((z1-z2).abs().max().item()) # prints 0 (should be 0, OK)
# linear does not!
m = nn.Linear(768, 768, bias=False)
with torch.no_grad():
m.weight.copy_(w.T)
q1 = m(x[0])
q2 = m(x)[0]
print((q1-q2).abs().max().item()) # prints ~2e-5 ( should be 0?!)
# and z1 != q1
print((z1-q1).abs().max().item()) # prints ~9e-5 (should be 0?!)
@itsAnanth
Copy link

itsAnanth commented May 24, 2025

I looked at torch source code and found this

inline Tensor linear(
    const Tensor& input,
    const Tensor& weight,
    const Tensor& bias = {}) {
  if (input.dim() == 2 && bias.defined()) {
    // fused op is marginally faster
    return torch::addmm(bias, input, weight.t());
  } else {
    auto output = input.matmul(weight.t());
    if (bias.defined()) {
      output += bias;
    }
    return output;
  }
}

could it be caused by the difference in which fused operation is applied for batched inputs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment