Search⌘ K

Transfer Learning using ResNet Model

Learn how to train, evaluate, and visualize the performance of a ResNet model.

We train the ResNet model by applying the train_one_epoch function for the desired number of epochs. This is a few epochs since we are fine-tuning the network.

Set up TensorBoard in Flax

To monitor model training via TensorBoard, we can write the training and validation metrics to TensorBoard.

Python 3.8
from torch.utils.tensorboard import SummaryWriter
logdir = "flax_logs"
writer = SummaryWriter(logdir)

In the code above:

  • Line 1: We import the SummaryWriter module from torch.utils.tensorboard to log in to TensorBoard.

  • Line 3: We define the logdir variable to store the path of the logging directory.

  • Line 4: We create an instance of SummaryWriter with logdir to store the TensorBoard logs.

Train model

We define a function to train and evaluate the model while writing the metrics to TensorBoard.

Python 3.8
(test_images, test_labels) = next(iter(validation_loader))
test_images = test_images / 255.0
training_loss = []
training_accuracy = []
testing_loss = []
testing_accuracy = []
def train_model(epochs):
for epoch in range(1, epochs + 1):
state, train_metrics = train_one_epoch(state, train_loader)
training_loss.append(train_metrics['loss'])
training_accuracy.append(train_metrics['accuracy'])
test_metrics = evaluate_model(state, test_images, test_labels)
testing_loss.append(test_metrics['loss'])
testing_accuracy.append(test_metrics['accuracy'])
writer.add_scalar('Loss/train', train_metrics['loss'], epoch)
writer.add_scalar('Loss/test', test_metrics['loss'], epoch)
writer.add_scalar('Accuracy/train', train_metrics['accuracy'], epoch)
writer.add_scalar('Accuracy/test', test_metrics['accuracy'], epoch)
print(f"Epoch: {epoch}, training loss: {train_metrics['loss']}, training accuracy: {train_metrics['accuracy'] * 100}, validation loss: {test_metrics['loss']}, validation accuracy: {test_metrics['accuracy'] * 100}")
return state

In the code above: ...