Skip to content

Instantly share code, notes, and snippets.

@prajwalsingh
Last active September 7, 2023 15:55
Show Gist options
  • Save prajwalsingh/248b438cbda4c7cbedbd982c185af976 to your computer and use it in GitHub Desktop.
Save prajwalsingh/248b438cbda4c7cbedbd982c185af976 to your computer and use it in GitHub Desktop.
Multi-Positive Contrastive Loss
# f: encoder: backbone + proj mlp
# tau: temperature
# minibatch x: [n, m, 3, h, w]
# n captions, m images per caption
# As per paper:
# n*m = 8192, and m=6 then n=1366
def H(p, q): # cross-entropy
return - (p * log(q)).sum(1).mean()
for x in loader:
x = augment(x)
x = cat(unbind(x, dim=1)) # [n*m, 3, h, w]
h = f(x)
# compute ground-truth distribution p
p = torch.ones(size=(m, m)).to(rank)
p.fill_diagonal_(0)
p = torch.kron(torch.eye(n, dtype=int).to(rank), p)
p /= p.sum(1)
# compute contrastive distribution q
logits = h @ h.T / tau
logits.fill_diagonal(-1e9) # self masking
q = softmax(logits, dim=1)
H(p, q).backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment