Learn to checkpoint DL models in Keras and use the ModelCheckpoint callback class to checkpoint model weights under the specified conditions.

Checkpointing is the process of saving the computational state of an application to recover it in case of a system failure. Checkpointing takes a snapshot of the system state that we can use as a starting point for a new run of our application. During the DL model training phase, we can checkpoint model weights. These weights can be further used for making predictions or training a new model. Let’s discuss the use of the ModelCheckpoint callback class of Keras to save model weights.

The ModelCheckpoint callback class

The ModelCheckpoint callback class facilitates us to checkpoint model weights under the conditions we specify. This class takes the accuracy and/or loss metrics to monitor either the training or validation dataset. Using this class, we can check whether a given metric improves over time. Moreover, we can store variable information such as the metric name or the epoch number while we store the model weights.

We pass the ModelCheckpoint class to the model-training process in the fit() method. We can save network weights in various formats, such as the HDF5 HDF5format.

Loading a checkpointed model

We can load a checkpointed model to either continue training on it or to make predictions using this model. The model checkpoint contains the network weights; it assumes that we know the network structure. The checkpoint stores the model weights in HDF5 format. The command model.load_weights('path') loads the model weights.

The following code checkpoints model weights when validation accuracy improves. The code also loads the checkpointed model to evaluate its accuracy.

Get hands-on with 1200+ tech skills courses.