Skip to content

Instantly share code, notes, and snippets.

@m3hrdadfi
Last active June 2, 2022 09:10
Show Gist options
  • Save m3hrdadfi/5bfeb1eb88122c89879524b0bae1b9a2 to your computer and use it in GitHub Desktop.
Save m3hrdadfi/5bfeb1eb88122c89879524b0bae1b9a2 to your computer and use it in GitHub Desktop.
CE HF Transformer
import torch
n_cls = 5
a = torch.rand((1, 3, n_cls))
b = torch.tensor([[0, 1, 2]])
print(a.shape)
print(b.shape)
# > torch.Size([1, 3, 5])
# > torch.Size([1, 3])
fc = torch.nn.CrossEntropyLoss()
# HF
print(fc(a.view(-1, n_cls), b.view(-1)))
# > tensor(1.8672)
# 3D
print(fc(a.transpose(2, 1), b))
# > tensor(1.8672)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment