Tensorflow callbacks are very important to customize behaviour of Keras Tensorflow models in training or evaluation. We can either use predefined callbacks from tensorflow or can write our own callbacks to do some process. Callbacks helps to save model data, log stats during process, evaluate some data at some certain steps or perform different decisions based on model performance. List of callbacks method provided by Tensorflow-Keras are as follows.

  • Base Callback class
  • ModelCheckpoint
  • TensorBoard
  • EarlyStopping
  • RemoteMonitor
  • LearningRateScheduler
  • CSVLogger
  • BackupAndRestore

We will use these callbacks in our tensorflow sequential model and check how it works.

Callbacks

Lets check how these callbacks works and how we can use them in our tensorflow model.

Base Callback class

Base callback class can be used to override existing methods for callbacks and use them with own code. Here is a basic example of callback using epoch end and training end. We can perform different operation using custom callbacks like get model results for validation or testing dataset and visualize them or store output (images, logs, text etc.) in a file or directory or write data to tensorboard.

import tensorflow as tf
# write a custom callback for epoch and and training finished
class MyCustomCallback(tf.keras.callbacks.Callback):
    # we define all methods we want to use
    def on_epoch_end(self, logs=None): # it provides logs for metrics we are using i.e. loss, accuracy, val_loss, val_acc
        # perform some operation on epoch end
        print("Epoch Finished: Logs are: ", logs)
    def on_train_end(self, logs=None):
        # perform some operation on training end
        print("Training Finished")

We can perform any operation we want inside custom callbacks. Now there are other callbacks predefined for certain tasks like writing logs, save checkpoints and model files etc.

Model Checkpoint

Model checkpoint callback is used with model.fit() method and writes checkpoint by parameters defined. We can define checkpoint save behaviour using arguments like save weights only or save model, metrics to listen like accuracy, loss, validation accuracy, save path and frequency. We can also pass arguments to save weights only if there is increase in model accuracy or decrease in loss based on metrics provided. A simple model checkpoint can be defined as follows.

# Model Checkpoint
checkpoint_filepath = "checkpoints/ckpt"
model_ckpt = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_filepath, # path to save checkpoint
    monitor="val_loss", # monitor val loss, also can use (accuracy, loss, val_accuracy)
    verbose=0, # verbosity mode, 0 or 1
    save_best_only=True, # Save weights after each epoch or save best only
    save_weights_only=True, # save only weights or save model
    mode="auto", # model as auto, min or max
    save_freq="epoch", # save frequency
)

Now, when we run training, it will save model checkpoints after each epoch is current validation loss is less than previous epoch validation loss so that we have model weights with least validation loss. We can modify save_freq to batch or epoch depending on our choice.

Tensorboard

Tensorboard is a tool provided by Tensorflow for visualization of model, logs and outputs. We can visualize model architecture like layers, graph etc and also visualize graphs based on model performance like accuracy, loss, validation data details and other custom logs that we can write and visualize. We can also visualize embeddings in lower dimensional space, histograms of weights and biases and visualiztion of data like text, images and audio. View more details on tensorboard on tensorflow page.

https://www.tensorflow.org/tensorboard

We can write output like logs, images and other data during training using callbacks which later on can be visualized using Tensorboard. Here is a simple callback for using Tensorboard in tensorflow models.

log_directory = "logs"
tensorboard_cb = tf.keras.callbacks.TensorBoard(
    log_dir=log_directory, # path to logs directory
    histogram_freq=0, # histogram frequency after specific epcohs (0 means no compute)
    write_graph=True, # whethere write the graph or not
    write_images=False, # visualize model weights as image
    write_steps_per_second=False, # save steps per second time (both batch and epoch)
    update_freq="epoch", # batch or epoch
    profile_batch=2, # profile second batch (set 0 to disable)
    embeddings_freq=0, # embedding layers visualization frequency
    embeddings_metadata=None, # embedding metadata as dictionary
)

Once, we execute the script, it will start writing data to specified logs directory and we can open tensorboard and check outputs using following command.

# log dir in example is 'logs'
tensorboard --logdir=logs

Now you can navigate to http://localhost:6006 and view visualization from model logs.

Early Stopping

When training our model we want to reduce loss and increase accuracy of model but in some cases after a certain amount of time either loss remains constant and dont change or start increasing again. So we need to apply a check that if loss does not decrease or accuracy does not incrase or check on validation data, we can stop model training. EarlyStopping is a way to do that using keras callbacks which monitor given metrics and stops training based on given patience and other arguments so that we have best model. So, we write a custom callback which monitors validation loss and stops training if model validation loss does not decrease for a couple of give epochs.

tf.keras.callbacks.EarlyStopping(
    monitor="val_loss", # monitor validation loss
    min_delta=0, # delta to change in monitored quantity
    patience=4, # How many epoch to wait if no improvment than stop training
    verbose=0, # verbosity mode
    mode="auto", # mode
    baseline=None, # if model does not show improvement over baseline, stop training
    restore_best_weights=False, # if set true, it will retrun weights for last best epoch
)

It will monitor validation loss and if it does not decrease for 4 epochs, it will stop training.

Remote Monitor

If we are training model remotely and want to monitor its performance remotely, we can use RemoteMonitor callback to send data to a specified endpoint which that we can write to file or save to any database. So, we need an api endpoint which we pass to this callback and it sends data to endpoint after each epoch or batch.

tf.keras.callbacks.RemoteMonitor(
    root="http://localhost:8080", # base url
    path="/report/epoch-end/", # url path
    field="data", # field in form
    headers=None,
    send_as_json=True, # in json format
)

This callback is very usefull, if we have a model training remotely can we can create a simple visualization tool which will receive data using this api and we can visualize results as we want to show.

Learning Rate Scheduler

When training model, we can change learning rate depending on the performance of model for which we can use learning rate scheduler callback from Tensorlfow. We can define a custom class which can modify learning rate after specific number of epochs or model stats.

def scheduler(epoch, lr):
    if epoch < 10: # if epoch is less than 10, use predefined learning rate
        return lr
    return lr * tf.math.exp(-0.1) # else decrease it

# define learning rate scheduler callback
lr_callback = tf.keras.callbacks.LearningRateScheduler(
        scheduler, # scheduler
        verbose = 0 # verbose
)

CSV Logger

We can also writer output of model as csv file which is easy to read for everyone so we can show results without any post processing of logs. It writes output as a csv file seperater by specifier delimiter provided and we can also save to existing or create now or overwrite using flags.

csv_filepath = "logs/output.csv" # path to csv
tf.keras.callbacks.CSVLogger(
    csv_filepath, # path to csv
    separator=",", #comma seperated
    append=False # if file exists it will overwrite, if true then if file exists it will save into that file
)

Backup and Restore

If we are training a model and it stops at some specific epoch because of power failure or any other issue, we can resume training from that epoch using BackupAndRestore callback. It creates a backup of model in a directory and if we start training again, it will check for backup and if that exists it will resume training from there.

backup_dir = "tmp/backup_tf"
tf.keras.callbacks.experimental.BackupAndRestore(backup_dir)

Here are some limitations and details about BackupAndRestore from Tensorflow documenations.

  Note
  1. This callback is not compatible with disabling eager execution.
  2. A checkpoint is saved at the end of each epoch, when restoring we'll redo any partial work from an unfinished epoch in which the training got restarted (so the work done before a interruption doesn't affect the final model state).
  3. This works for both single worker and multi-worker mode, only MirroredStrategy and MultiWorkerMirroredStrategy are supported for now.

Usage

Now we can use these callbacks defined above in our tensorflow model. For example, if we want to use them in training, we can pass a list of these callbacks like this.

model.fit(
    train_x, # training data
    train_y, # training labels
    epochs=10, # number of epochs
    validation_data=(test_x, test_y), # validation data
    # List of callbacks, we are currently using 3, 
    # custom, model checkpoint and tensorboard
    callbacks=[MyCustomCallback(), model_ckpt,tensorboard_cb ]
)

There are some other Callbacks also provided by Tensorflow-Keras which are not discussed here.

 

You can view details for these callbacks on official documentation of Tensorflow. For complete details and documentation, you can view tensorflow callbacks page.