Search⌘ K
AI Features

Saving and Loading Methods

Explore how to implement save and load methods within PyTorch model classes to manage checkpoints effectively. Understand how to switch between training and evaluation modes while making predictions and learn to handle data conversion between NumPy arrays and PyTorch tensors seamlessly.

We'll cover the following...

Saving and loading

Most of the code here is the same as the code we had in the chapter, Rethinking the Training Loop. The only difference is that we will be using the class attributes instead of the local variables.

The updated method for saving checkpoints should look like this now:

Python 3.5
def save_checkpoint(self, filename):
# Builds dictionary with all elements for resuming training
checkpoint = {
'epoch': self.total_epochs,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': self.losses,
'val_loss': self.val_losses
}
torch.save(checkpoint, filename)
setattr(StepByStep, 'save_checkpoint', save_checkpoint)

In addition, the loading checkpoint method should look like the following:

Python 3.5
def load_checkpoint(self, filename):
# Loads dictionary
checkpoint = torch.load(filename)
# Restore state for model and optimizer
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(
checkpoint['optimizer_state_dict']
)
self.total_epochs = checkpoint['epoch']
self.losses = checkpoint['loss']
self.val_losses = checkpoint['val_loss']
self.model.train() # always use TRAIN for resuming training
setattr(StepByStep, 'load_checkpoint', load_checkpoint)
...
...