Last active
April 12, 2021 21:50
-
-
Save gautierdag/cfbebbbc4897dac2f81882e5b64b5b09 to your computer and use it in GitHub Desktop.
Pytorch NCE Loss
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import pytorch_lightning as pl | |
class NCE(pl.LightningModule): | |
""" | |
This implementation is taken from https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/nt_xent.py | |
The mask_correlated_samples funtion has been modified to be much faster to compute | |
and therefore be able to be called at train time without a predifined batch size. | |
""" | |
def __init__(self, temperature=0.1): | |
super(NCE, self).__init__() | |
self.temperature = temperature | |
self.criterion = nn.CrossEntropyLoss(reduction="sum") | |
self.similarity_f = nn.CosineSimilarity(dim=2) | |
def mask_correlated_samples(self, batch_size): | |
N = 2 * batch_size | |
mask = torch.ones((N, N), dtype=bool, device=self.device) | |
mask = mask.fill_diagonal_(0) | |
mask[:batch_size, batch_size:] = mask[:batch_size, :batch_size] | |
mask[batch_size:, :batch_size] = mask[:batch_size, :batch_size] | |
return mask | |
def forward(self, z_i, z_j): | |
""" | |
We do not sample negative examples explicitly. | |
Instead, given a positive pair, similar to (Chen et al., 2017), | |
we treat the other 2(N − 1) augmented examples within a minibatch as negative examples. | |
""" | |
batch_size = z_i.shape[0] | |
mask = self.mask_correlated_samples(batch_size) | |
N = 2 * batch_size | |
z = torch.cat((z_i, z_j), dim=0) | |
sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature | |
# get similarity between i and j and the reverse as well | |
sim_i_j = torch.diag(sim, batch_size) | |
sim_j_i = torch.diag(sim, -batch_size) | |
# We have 2N samples, resulting in: 2xN | |
positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) | |
negative_samples = sim[mask].reshape(N, -1) | |
labels = torch.zeros(N, device=positive_samples.device).long() | |
logits = torch.cat((positive_samples, negative_samples), dim=1) | |
loss = self.criterion(logits, labels) | |
loss /= N | |
return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment