Skip to content

Instantly share code, notes, and snippets.

@JanSchm
Created June 27, 2022 16:58
Show Gist options
  • Save JanSchm/c7858aa092201313a73b940a9bbcb3e2 to your computer and use it in GitHub Desktop.
Save JanSchm/c7858aa092201313a73b940a9bbcb3e2 to your computer and use it in GitHub Desktop.
callbacks = [tf.keras.callbacks.ModelCheckpoint('SiameseTriplet_AlbertBase_epoch{epoch:02d}_val-loss{val_loss:.6f}.hdf5', monitor='val_loss', save_best_only=True, verbose=1),]
# Train the network
history = model.fit(train_gen,
steps_per_epoch = len(df_train)//BATCH_SIZE+1,
batch_size=None,
verbose=1,
epochs=25,
shuffle=True,
validation_data=val_gen,
validation_steps=len(df_val)//BATCH_SIZE+1,
callbacks=callbacks,
max_queue_size=3,)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment