Skip to content

Instantly share code, notes, and snippets.

@gau-nernst
Created April 8, 2025 02:32
Show Gist options
  • Save gau-nernst/74f3742148d547d344e474a9e0e8fc1d to your computer and use it in GitHub Desktop.
Save gau-nernst/74f3742148d547d344e474a9e0e8fc1d to your computer and use it in GitHub Desktop.
PyTorch int4mm_cpu
import torch
print(torch.__version__)
group_size = 32
w = torch.randn(512, 1024)
w_groups = w.unflatten(1, (-1, group_size))
min_val = w_groups.amin(2, keepdim=True)
max_val = w_groups.amax(2, keepdim=True)
scale = (max_val - min_val) / 15 # scale (max-min) to 15
zero_point = min_val + scale * 8
w_int = (w_groups - min_val) * scale.reciprocal()
w_int = w_int.clip(0, 15).to(torch.int32).reshape(w.shape)
w_packed = torch._convert_weight_to_int4pack_for_cpu(w_int, 1)
scales_zeros = torch.cat([scale, zero_point], dim=2).transpose(0, 1).contiguous()
x = torch.randn(1, w.shape[1])
out = torch._weight_int4pack_mm_for_cpu(x, w_packed, group_size, scales_zeros)
w_dequant = (w_int.unflatten(1, (-1, group_size)) * scale + min_val).reshape(w.shape)
diff_abs = (out - x @ w_dequant.T).abs()
print(diff_abs.max(), diff_abs.mean())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment