Trusted answers to developer questions

What are the strategies for using transfer learning?

Free System Design Interview Course

Many candidates are rejected or down-leveled due to poor performance in their System Design Interview. Stand out in System Design Interviews and get hired in 2024 with this popular free course.

In this article, we will discuss the strategies to use transfer learning. Here, we’ll see the example of an image classification problem using CNN, but this concept applies to most problem types.

Before moving on to how we can leverage the transfer learning paradigm, we must know that almost all the networks (either pre-trained or your own) have a common structure in them.

This is illustrated below:

Common Structure of any CNN Network
Common Structure of any CNN Network

Let’s understand the architecture shown above.

  • The Convolutional Base is a feature extractor that learns and extracts the features of any image. It contains the convolutions and pooling layers.
  • The Multi-layer Perceptron (MLP) is a network of dense layers that are used as a full connection in CNNs. It gives the output by applying a softmax activation function and getting a list of probabilities for each class.

That means that even pre-trained models, which we may use, have this same structure. So, we can leverage those pre-trained models using the transfer learning paradigm in 33 ways:

  • As pre-trained models: Modern ConvNets can take more than 2-3 weeks to train across multiple GPUs on your dataset. If your problem lies within the scope of an already pre-trained model, then a common approach would be to just take that model as it is and predict the labels for the images. This is the simplest way to use transfer learning. However, it would only provide accurate results for data similar to that on which the model was​ trained.
  • Fixed feature extractor: In this approach, you can take a ConvNet that was pre-trained on the ImageNet dataset, and remove the last fully-connected layer (i.e., the MLP layer in our architecture discussed above). Then, treat the rest of the ConvNet as a fixed feature extractor for the new dataset. The benefit of this approach would be that, even if you have less data, this would not over-fit as you would only train the MLP layer which you added to the Convolutional Base (i.e., retrain the classifier only and freeze the Convolutional Base).
  • Fine-tuning: The last strategy is to not only replace and retrain the classifier on top of the ConvNet on the new dataset, but to also fine-tune the weights of the pre-trained network by continuing the back-propagation. It is possible to fine-tune all the layers of the ConvNet, however, that might be a time-consuming process. It’s also possible to keep some of the earlier layers fixed (due to overfitting concerns) and only fine-tune some higher-level portions of the network.

RELATED TAGS

transfer learning
deep learning
cnn
machine learning
Did you find this helpful?