This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LightGCN(torch.nn.Module): | |
def __init__(self, train_data, num_layers, emb_size=16, initialize_with_words=False): | |
super(LightGCN, self).__init__() | |
self.convs = nn.ModuleList() | |
assert (num_layers >= 1), 'Number of layers is not >=1' | |
for l in range(num_layers): | |
self.convs.append(LightGCNConv(input_dim, input_dim)) | |
# Initialize using custom embeddings if provided | |
num_nodes = train_data.node_label_index.size()[0] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LightGCNConv(MessagePassing): | |
def __init__(self, aggregation='mean', **kwargs): | |
super(LightGCNConv, self).__init__(**kwargs) | |
self.aggregation = aggregation | |
def forward(self, x, edge_index, size = None): | |
out = self.propagate(edge_index, x=(x, x)) | |
return out | |
def message(self, x_j): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
from sklearn.base import BaseEstimator, clone | |
from sklearn.metrics import r2_score | |
from decimal import Decimal | |
from typing import Tuple, Union, Dict | |
class QuantileRegressor(BaseEstimator): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import sklearn | |
from typing import Dict, Tuple | |
from sklearn.base import BaseEstimator | |
class RuleAugmentedEstimator(BaseEstimator): | |
"""Augments sklearn estimators with rule-based logic. |