Tensorflow callbacks are very important to customize behaviour of Keras Tensorflow models in training or evaluation. We can either use predefined callbacks from tensorflow or can write our own callbacks to do some process. Callbacks helps to save model data, log stats during process, evaluate some data at some certain steps or perform different decisions based on model performance. List of callbacks method provided by Tensorflow-Keras are as follows.
- Base Callback class
We will use these callbacks in our tensorflow sequential model and check how it works.
Lets check how these callbacks works and how we can use them in our tensorflow model.
Base Callback class
Base callback class can be used to override existing methods for callbacks and use them with own code. Here is a basic example of callback using epoch end and training end. We can perform different operation using custom callbacks like get model results for validation or testing dataset and visualize them or store output (images, logs, text etc.) in a file or directory or write data to tensorboard.
import tensorflow as tf
# write a custom callback for epoch and and training finished
# we define all methods we want to use
def on_epoch_end(self, logs=None): # it provides logs for metrics we are using i.e. loss, accuracy, val_loss, val_acc
# perform some operation on epoch end
print("Epoch Finished: Logs are: ", logs)
def on_train_end(self, logs=None):
# perform some operation on training end
We can perform any operation we want inside custom callbacks. Now there are other callbacks predefined for certain tasks like writing logs, save checkpoints and model files etc.
Model checkpoint callback is used with
model.fit() method and writes checkpoint by parameters defined. We can define checkpoint save behaviour using arguments like save weights only or save model, metrics to listen like accuracy, loss, validation accuracy, save path and frequency. We can also pass arguments to save weights only if there is increase in model accuracy or decrease in loss based on metrics provided. A simple model checkpoint can be defined as follows.
# Model Checkpoint
checkpoint_filepath = "checkpoints/ckpt"
model_ckpt = tf.keras.callbacks.ModelCheckpoint(
checkpoint_filepath, # path to save checkpoint
monitor="val_loss", # monitor val loss, also can use (accuracy, loss, val_accuracy)
verbose=0, # verbosity mode, 0 or 1
save_best_only=True, # Save weights after each epoch or save best only
save_weights_only=True, # save only weights or save model
mode="auto", # model as auto, min or max
save_freq="epoch", # save frequency
Now, when we run training, it will save model checkpoints after each epoch is current validation loss is less than previous epoch validation loss so that we have model weights with least validation loss. We can modify
save_freq to batch or epoch depending on our choice.
Tensorboard is a tool provided by Tensorflow for visualization of model, logs and outputs. We can visualize model architecture like layers, graph etc and also visualize graphs based on model performance like accuracy, loss, validation data details and other custom logs that we can write and visualize. We can also visualize embeddings in lower dimensional space, histograms of weights and biases and visualiztion of data like text, images and audio. View more details on tensorboard on tensorflow page.
We can write output like logs, images and other data during training using callbacks which later on can be visualized using Tensorboard. Here is a simple callback for using Tensorboard in tensorflow models.
log_directory = "logs"
tensorboard_cb = tf.keras.callbacks.TensorBoard(
log_dir=log_directory, # path to logs directory
histogram_freq=0, # histogram frequency after specific epcohs (0 means no compute)
write_graph=True, # whethere write the graph or not
write_images=False, # visualize model weights as image
write_steps_per_second=False, # save steps per second time (both batch and epoch)
update_freq="epoch", # batch or epoch
profile_batch=2, # profile second batch (set 0 to disable)
embeddings_freq=0, # embedding layers visualization frequency
embeddings_metadata=None, # embedding metadata as dictionary
Once, we execute the script, it will start writing data to specified logs directory and we can open tensorboard and check outputs using following command.
# log dir in example is 'logs'
Now you can navigate to http://localhost:6006 and view visualization from model logs.
When training our model we want to reduce loss and increase accuracy of model but in some cases after a certain amount of time either loss remains constant and dont change or start increasing again. So we need to apply a check that if loss does not decrease or accuracy does not incrase or check on validation data, we can stop model training. EarlyStopping is a way to do that using keras callbacks which monitor given metrics and stops training based on given patience and other arguments so that we have best model. So, we write a custom callback which monitors validation loss and stops training if model validation loss does not decrease for a couple of give epochs.
monitor="val_loss", # monitor validation loss
min_delta=0, # delta to change in monitored quantity
patience=4, # How many epoch to wait if no improvment than stop training
verbose=0, # verbosity mode
mode="auto", # mode
baseline=None, # if model does not show improvement over baseline, stop training
restore_best_weights=False, # if set true, it will retrun weights for last best epoch
It will monitor validation loss and if it does not decrease for 4 epochs, it will stop training.
If we are training model remotely and want to monitor its performance remotely, we can use RemoteMonitor callback to send data to a specified endpoint which that we can write to file or save to any database. So, we need an api endpoint which we pass to this callback and it sends data to endpoint after each epoch or batch.
root="http://localhost:8080", # base url
path="/report/epoch-end/", # url path
field="data", # field in form
send_as_json=True, # in json format
This callback is very usefull, if we have a model training remotely can we can create a simple visualization tool which will receive data using this api and we can visualize results as we want to show.
Learning Rate Scheduler
When training model, we can change learning rate depending on the performance of model for which we can use learning rate scheduler callback from Tensorlfow. We can define a custom class which can modify learning rate after specific number of epochs or model stats.
def scheduler(epoch, lr):
if epoch < 10: # if epoch is less than 10, use predefined learning rate
return lr * tf.math.exp(-0.1) # else decrease it
# define learning rate scheduler callback
lr_callback = tf.keras.callbacks.LearningRateScheduler(
scheduler, # scheduler
verbose = 0 # verbose
We can also writer output of model as csv file which is easy to read for everyone so we can show results without any post processing of logs. It writes output as a csv file seperater by specifier delimiter provided and we can also save to existing or create now or overwrite using flags.
csv_filepath = "logs/output.csv" # path to csv
csv_filepath, # path to csv
separator=",", #comma seperated
append=False # if file exists it will overwrite, if true then if file exists it will save into that file
Backup and Restore
If we are training a model and it stops at some specific epoch because of power failure or any other issue, we can resume training from that epoch using BackupAndRestore callback. It creates a backup of model in a directory and if we start training again, it will check for backup and if that exists it will resume training from there.
backup_dir = "tmp/backup_tf"
Here are some limitations and details about BackupAndRestore from Tensorflow documenations.
- This callback is not compatible with disabling eager execution.
- A checkpoint is saved at the end of each epoch, when restoring we'll redo any partial work from an unfinished epoch in which the training got restarted (so the work done before a interruption doesn't affect the final model state).
- This works for both single worker and multi-worker mode, only MirroredStrategy and MultiWorkerMirroredStrategy are supported for now.
Now we can use these callbacks defined above in our tensorflow model. For example, if we want to use them in training, we can pass a list of these callbacks like this.
train_x, # training data
train_y, # training labels
epochs=10, # number of epochs
validation_data=(test_x, test_y), # validation data
# List of callbacks, we are currently using 3,
# custom, model checkpoint and tensorboard
callbacks=[MyCustomCallback(), model_ckpt,tensorboard_cb ]
There are some other Callbacks also provided by Tensorflow-Keras which are not discussed here.
You can view details for these callbacks on official documentation of Tensorflow. For complete details and documentation, you can view tensorflow callbacks page.