Last active
January 22, 2021 17:55
-
-
Save data-hound/b94ee5b9157ae547be72a8b722090524 to your computer and use it in GitHub Desktop.
Scikeras Tutorial - 5: Wrapping the MIMO Estimator and giving a CV run
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Helper Method to get data into shape to pass to GridSearchClassifier | |
def get_sciki_xy(X,y): | |
X_sciki = np.column_stack([X.reshape((y.shape[0], np.prod(X.shape[1:]))), y]) | |
y_sciki = np.column_stack([y,X.reshape((y.shape[0], np.prod(X.shape[1:])))]) | |
return X_sciki,y_sciki | |
def do_cross_val(): | |
(x_train_, y_train_), (x_test_, y_test_) = load_mnist() #load the dataset | |
# Optional - trim the data size for faster epochs | |
x_train, y_train = get_sciki_xy(x_train_[:1000], y_train_[:1000]) | |
x_test, y_test = get_sciki_xy(x_test_[:100],y_test_[:100]) | |
# Create a MIMOEstimator with get_model function | |
# Parameters that need to be passed to get_model function are prefixed with model__ | |
clf = MIMOEstimator(model = get_model, | |
model__input_shape=x_train_.shape[1:], | |
model__n_class=len(np.unique(np.argmax(y_train_, 1))), | |
model__routings=args.routings, | |
model__batch_size = args.batch_size, | |
model__n_filters_c1=256, | |
# epochs=args.epochs, | |
# callbacks=[log, checkpoint, lr_decay], | |
model__model_type = 'train') | |
# Print the shapes of X and Y | |
print("X input shape = ", x_train.shape) | |
print("Y input shape = ", y_train.shape) | |
# Define the parameter grid to perform Grid-Search | |
params = {'model__n_filters_c1': [128,256], | |
'model__routings': [4,5]} | |
# no. of examples/cv should be completely divisible by batch_size | |
gs = GridSearchCV(estimator=clf, param_grid=params, cv=5 verbose=True) | |
gs_res = gs.fit(X=x_train, | |
y=y_train) | |
print("Grid Search Results: ") | |
print(gs_res) | |
best_est = gs_res.best_estimator_ | |
best_score = gs_res.best_score_ | |
best_params = gs_res.best_params_ | |
print('Best score obtained after GridSearchCV: ', best_score) | |
return best_est, best_params | |
est,params = do_cross_val() # call the function to begin training |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment