This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Helper Method to get data into shape to pass to GridSearchClassifier | |
def get_sciki_xy(X,y): | |
X_sciki = np.column_stack([X.reshape((y.shape[0], np.prod(X.shape[1:]))), y]) | |
y_sciki = np.column_stack([y,X.reshape((y.shape[0], np.prod(X.shape[1:])))]) | |
return X_sciki,y_sciki | |
def do_cross_val(): | |
(x_train_, y_train_), (x_test_, y_test_) = load_mnist() #load the dataset | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def input_reshaper(X): | |
return [X[:,:-10].reshape(X.shape[0],28,28,1), X[:,-10:]] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
from sklearn.base import BaseEstimator, TransformerMixin | |
from sklearn.preprocessing import LabelEncoder, FunctionTransformer, OneHotEncoder | |
class MultiOutputTransformer(BaseEstimator, TransformerMixin): | |
def fit(self, y): | |
# Separate the two different 'y's into two arrays |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.base import BaseEstimator, TransformerMixin | |
from scikeras.wrappers import KerasClassifier | |
### Multi-Output Classifier | |
class MultiOutputTransformer(BaseEstimator, TransformerMixin): | |
#define your transformer | |
class MultiOutputClassifier(KerasClassifier): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn.datasets import make_classification | |
from tensorflow import keras | |
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier | |
# Make a dummy dataset | |
X, y = make_classification(1000, 20, n_informative=10, random_state=0) | |
X = X.astype(np.float32) | |
y = y.astype(np.int64) |