Last active
December 10, 2021 07:41
-
-
Save LukasHaas/8e0e6004f2f266f719bd3bf2d19de2ca to your computer and use it in GitHub Desktop.
Modified LightGCN Model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LightGCN(torch.nn.Module): | |
def __init__(self, train_data, num_layers, emb_size=16, initialize_with_words=False): | |
super(LightGCN, self).__init__() | |
self.convs = nn.ModuleList() | |
assert (num_layers >= 1), 'Number of layers is not >=1' | |
for l in range(num_layers): | |
self.convs.append(LightGCNConv(input_dim, input_dim)) | |
# Initialize using custom embeddings if provided | |
num_nodes = train_data.node_label_index.size()[0] | |
self.embeddings = nn.Embedding(num_nodes, emb_size) | |
if initialize_with_words: | |
self.embeddings.weight.data.copy_(train_datanode_features) | |
self.loss_fn = nn.BCELoss() | |
self.num_layers = num_layers | |
self.emb_size = emb_size | |
self.num_modes = num_nodes | |
def forward(self, data): | |
edge_index, edge_label_index, node_label_index = data.edge_index, data.edge_label_index, data.node_label_index | |
layer_embeddings = [] | |
x = self.embeddings(node_label_index) | |
mean_layer = x | |
# We take an average of ever layer's node embeddings | |
for i in range(self.num_layers): | |
x = self.convs[i](x, edge_index) | |
mean_layer += x | |
mean_layer /= 4 | |
# Prediction head is simply dot product | |
nodes_first = torch.index_select(x, 0, edge_label_index[0,:].long()) | |
nodes_second = torch.index_select(x, 0, edge_label_index[1,:].long()) | |
# Since we don't want a rank output, we create a sigmoid of the dot product | |
out = torch.sum(nodes_first * nodes_second, dim=-1) # FOR RANKING | |
pred = torch.sigmoid(out) | |
return torch.flatten(pred) | |
def loss(self, pred, label): | |
return self.loss_fn(pred, label) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment