Skip to content

Instantly share code, notes, and snippets.

@LukasHaas
Last active December 10, 2021 07:41
Show Gist options
  • Save LukasHaas/8e0e6004f2f266f719bd3bf2d19de2ca to your computer and use it in GitHub Desktop.
Save LukasHaas/8e0e6004f2f266f719bd3bf2d19de2ca to your computer and use it in GitHub Desktop.
Modified LightGCN Model
class LightGCN(torch.nn.Module):
def __init__(self, train_data, num_layers, emb_size=16, initialize_with_words=False):
super(LightGCN, self).__init__()
self.convs = nn.ModuleList()
assert (num_layers >= 1), 'Number of layers is not >=1'
for l in range(num_layers):
self.convs.append(LightGCNConv(input_dim, input_dim))
# Initialize using custom embeddings if provided
num_nodes = train_data.node_label_index.size()[0]
self.embeddings = nn.Embedding(num_nodes, emb_size)
if initialize_with_words:
self.embeddings.weight.data.copy_(train_datanode_features)
self.loss_fn = nn.BCELoss()
self.num_layers = num_layers
self.emb_size = emb_size
self.num_modes = num_nodes
def forward(self, data):
edge_index, edge_label_index, node_label_index = data.edge_index, data.edge_label_index, data.node_label_index
layer_embeddings = []
x = self.embeddings(node_label_index)
mean_layer = x
# We take an average of ever layer's node embeddings
for i in range(self.num_layers):
x = self.convs[i](x, edge_index)
mean_layer += x
mean_layer /= 4
# Prediction head is simply dot product
nodes_first = torch.index_select(x, 0, edge_label_index[0,:].long())
nodes_second = torch.index_select(x, 0, edge_label_index[1,:].long())
# Since we don't want a rank output, we create a sigmoid of the dot product
out = torch.sum(nodes_first * nodes_second, dim=-1) # FOR RANKING
pred = torch.sigmoid(out)
return torch.flatten(pred)
def loss(self, pred, label):
return self.loss_fn(pred, label)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment