Skip to content

Instantly share code, notes, and snippets.

@vlasenkoalexey
Last active March 13, 2020 18:13

Revisions

  1. vlasenkoalexey revised this gist Mar 13, 2020. 1 changed file with 3 additions and 1 deletion.
    4 changes: 3 additions & 1 deletion train_estimator_linear.py
    Original file line number Diff line number Diff line change
    @@ -23,6 +23,8 @@ def train_estimator_linear(model_dir):
    logging.info('training and evaluating linear estimator model')
    tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(input_fn=lambda: get_dataset('train'), max_steps=get_max_steps(), hooks=hooks),
    train_spec=tf.estimator.TrainSpec(input_fn=lambda: get_dataset('train'),
    max_steps=get_max_steps(),
    hooks=hooks),
    eval_spec=tf.estimator.EvalSpec(input_fn=lambda: get_dataset('test')))
    logging.info('done evaluating estimator model')
  2. vlasenkoalexey created this gist Mar 13, 2020.
    28 changes: 28 additions & 0 deletions train_estimator_linear.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,28 @@
    def train_estimator_linear(model_dir):
    global ARGS

    logging.info('training for {} steps'.format(get_max_steps()))
    config = tf.estimator.RunConfig().replace(save_summary_steps=10)

    hooks = []
    if ARGS.profiler:
    profiler_hook = tf.estimator.ProfilerHook(
    save_steps=get_training_steps_per_epoch(),
    output_dir=os.path.join(model_dir, "profiler"),
    show_dataflow=True,
    show_memory=True)
    hooks.append(profiler_hook)

    feature_columns = create_feature_columns()
    estimator = tf.estimator.LinearClassifier(
    feature_columns=feature_columns,
    optimizer=GradientDescentOptimizer(learning_rate=0.001),
    model_dir=model_dir,
    config=config
    )
    logging.info('training and evaluating linear estimator model')
    tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(input_fn=lambda: get_dataset('train'), max_steps=get_max_steps(), hooks=hooks),
    eval_spec=tf.estimator.EvalSpec(input_fn=lambda: get_dataset('test')))
    logging.info('done evaluating estimator model')