Last active
April 18, 2019 17:11
-
-
Save athena15/5a5ff2a3d247cdefeafd97cbf5aba174 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# change in tables_custom_flavor.py | |
class FlavorWrapper(object): | |
""" | |
Wrapper class that creates a predict function such that | |
predict(data: pd.DataFrame) -> model's output as pd.DataFrame (pandas DataFrame) | |
""" | |
# the intent is to allow state variables to be passed into the class here | |
def __init__(self, function, **kwargs): | |
self.function = function | |
for key, value in kwargs.items(): | |
setattr(self, key, value) | |
def predict(self, data): | |
return self.function(data) | |
# example of how to save and load a model in a Databricks Notebook | |
import tables_custom_flavor as custom_flavor | |
def my_func(y): | |
return y ** 2 | |
# log the model by passing the function into FlavorWrapper class | |
func_model = custom_flavor.FlavorWrapper(my_func, stretch_factor=0.5) | |
custom_flavor.log_model(func_model, 'model_path') | |
# load the model | |
loaded_model = custom_flavor.load_model(path='model_path', run_id='1af665df4f4d4a418dd9d913a5b18a74') | |
loaded_model.function(3) # returns 9 | |
loaded_model.predict(25) # returns 625 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment