Skip to content

Instantly share code, notes, and snippets.

@LukasHaas
Last active December 10, 2021 07:44
Show Gist options
  • Save LukasHaas/adc0b40ab5b3f2c83c928170bac01d13 to your computer and use it in GitHub Desktop.
Save LukasHaas/adc0b40ab5b3f2c83c928170bac01d13 to your computer and use it in GitHub Desktop.
Modified LightGCN Convolutional Layer
class LightGCNConv(MessagePassing):
def __init__(self, aggregation='mean', **kwargs):
super(LightGCNConv, self).__init__(**kwargs)
self.aggregation = aggregation
def forward(self, x, edge_index, size = None):
out = self.propagate(edge_index, x=(x, x))
return out
def message(self, x_j):
out = x_j
return out
def aggregate(self, inputs, index, dim_size = None):
node_dim = self.node_dim
out = torch_scatter.scatter(inputs, index, dim=node_dim, reduce=self.aggregation)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment