This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import mlflow | |
logger = logging.getLogger(__name__) | |
class BaseMLLogger: | |
""" | |
Base class for tracking experiments. | |
This class can be extended to implement custom logging backends like MLFlow, Tensorboard, or Sacred. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_model(run_id, model_name): | |
model = mlflow.get_run(run_id).info.artifact_uri+"/model_name/sparkml" | |
return PipelineModel.load(model) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def save_model(model,model_name): | |
mlflow.spark.log_model(model,model_name) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_parameter(run_id, param_name): | |
return ast.literal_eval(mlflow.get_run(run_id).data.params[param_name]) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def update_learning(recall,recall_live): | |
s3_client = boto3.Session(profile_name=None).client('s3') | |
s3_resource = boto3.resource('s3') | |
artifact_bucket = 'YOUR ARTIFACT BUCKET ON S3' | |
if recall>recall_live: | |
# Push live champion to history | |
try: | |
object = s3_client.get_object(Bucket=artifact_bucket, Key='mlflow/'+proj_id+'/live_model_run_history') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_run(proj_id): | |
''' | |
get active and live runs for this model | |
''' | |
s3_client = boto3.Session(profile_name=None).client('s3') | |
artifact_bucket = 'YOUR ARTIFACT BUCKET ON S3' | |
s3_object = s3_client.get_object(Bucket=artifact_bucket, Key='mlflow/'+proj_id+'/active_model_run') | |
active_run = s3_object['Body'].read().decode("utf-8") | |
s3_object = s3_client.get_object(Bucket=artifact_bucket, Key='mlflow/'+proj_id+'/live_model_run') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def set_run(proj_id): | |
''' | |
+ This creates a file called "active_run" in S3 and writes current run_id into it. | |
+ If a file named "live_run" does not exist, it creates one and throws active_run into it | |
''' | |
s3_client = boto3.Session(profile_name=None).client('s3') | |
s3_resource = boto3.resource('s3') | |
artifact_bucket = 'YOUR ARTIFACT BUCKET ON S3' | |
active_run_id = mlflow.active_run().info.run_id |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def create_experiment(proj_id,project_description): | |
for i in [1]: | |
try: | |
mlflow.create_experiment(proj_id) | |
except: | |
continue | |
client.set_experiment_tag(mlflow.get_experiment_by_name(proj_id).experiment_id | |
,"mlflow.note.content" | |
,project_description) |