Skip to content

Instantly share code, notes, and snippets.

@athena15
Last active April 18, 2019 17:11
Show Gist options
  • Save athena15/5a5ff2a3d247cdefeafd97cbf5aba174 to your computer and use it in GitHub Desktop.
Save athena15/5a5ff2a3d247cdefeafd97cbf5aba174 to your computer and use it in GitHub Desktop.
# change in tables_custom_flavor.py
class FlavorWrapper(object):
"""
Wrapper class that creates a predict function such that
predict(data: pd.DataFrame) -> model's output as pd.DataFrame (pandas DataFrame)
"""
# the intent is to allow state variables to be passed into the class here
def __init__(self, function, **kwargs):
self.function = function
for key, value in kwargs.items():
setattr(self, key, value)
def predict(self, data):
return self.function(data)
# example of how to save and load a model in a Databricks Notebook
import tables_custom_flavor as custom_flavor
def my_func(y):
return y ** 2
# log the model by passing the function into FlavorWrapper class
func_model = custom_flavor.FlavorWrapper(my_func, stretch_factor=0.5)
custom_flavor.log_model(func_model, 'model_path')
# load the model
loaded_model = custom_flavor.load_model(path='model_path', run_id='1af665df4f4d4a418dd9d913a5b18a74')
loaded_model.function(3) # returns 9
loaded_model.predict(25) # returns 625
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment