Skip to content

Instantly share code, notes, and snippets.

@LukasHaas
LukasHaas / RuleAugmentedEstimator.py
Last active June 24, 2022 17:54
Augment Sklearn Models with Rule-Based Logic
import numpy as np
import pandas as pd
import sklearn
from typing import Dict, Tuple
from sklearn.base import BaseEstimator
class RuleAugmentedEstimator(BaseEstimator):
"""Augments sklearn estimators with rule-based logic.
torch.manual_seed(42)
x_tensor = torch.from_numpy(x).float()
y_tensor = torch.from_numpy(y).float()
# Builds dataset with ALL data
dataset = TensorDataset(x_tensor, y_tensor)
# Splits randomly into train and validation datasets
train_dataset, val_dataset = random_split(dataset, [80, 20])