Trusted answers to developer questions

What is transfer learning and why is it needed?

Get Started With Machine Learning

Learn the fundamentals of Machine Learning with this free course. Future-proof your career by adding ML skills to your toolkit — or prepare to land a job in AI or Data Science.

In this article, we will be discussing a state-of-the-art technique for building complex deep learning models using pre-trained models.

What is transfer learning?

We humans have an inherent ability to transfer our knowledge across different tasks. We utilize the knowledge that we acquire from one task to solve other similar tasks. The more related the task, the easier it is for us to transfer or cross-utilize our knowledge. Let’s understand this with some examples:

  • If you know how to ride a motorbike, then you can learn how to drive a car
  • If you know math and statistics, then you can learn machine learning
  • If you know how to play classical piano, then you can learn how to play jazz piano

Conventional machine learning and deep learning algorithms are designed to work in isolation. These algorithms are trained to solve some kind of specific task. The models have to be rebuilt from scratch once the feature-space distribution changes.

Transfer learning is the idea of overcoming the isolated learning paradigm and utilizing knowledge acquired for one task to solve related ones.

In fact, Andrew Ng, a renowned professor and data scientist who has been associated with Google Brain, Baidu, Stanford, and Coursera gave an amazing tutorial in NIPS 2016 called, ‘Nuts and bolts of building AI applications using Deep Learning’ where he mentioned,

“After supervised learning, transfer learning will be the next driver of ML commercial success.”

Have a look at the image below to better understand how transfer learning differs from traditional ML techniques.

How Traditional ML differs from Transfer Learning
How Traditional ML differs from Transfer Learning

Some formal definitions of transfer learning are:

Transfer learning and domain adaptation refer to the situation where what has been learned in one setting is exploited to improve generalization in another setting.

Transfer learning is the improvement of learning a new task through the transfer of knowledge from a related task that has already been learned.

Examples of transfer learning

Transfer learning in vision (image data)

In computer vision, you deal with image or video data. It is very common to use Deep Learning models when working with such data. Generally, people do not train the networks from scratch for their use case; instead, they use a deep learning model that is pre-trained for a large and challenging image classification task like the ImageNet - 1000-class photograph classification competition.

The researchers who develop the models for this type of competition release the final model weights so that they can be used under a permissive license.

Some of the models are:

  • Oxford VGG Model
  • Google Inception Model
  • Microsoft ResNet Model

Transfer learning in NLP (text data)

In natural language processing, you deal with text data as input or output. In these types of problems, embedding is used. Embedding is the mapping of words to a high-dimensional continuous vector space where different words with similar meanings have similar vector representations.

The researchers who develop models, and train them on a very large corpus of text, release the model under a permissive license.

Some popular pre-trained models are:

  • Google’s word2vec Model
  • Stanford’s GloVe Model

Why is transfer learning a better choice?

Let’s take an example.

Suppose you want to build an image classifier, and you have a dataset that contains around 2000 images. If you wish that your model must learn complex features and you want to train a complex model with many convolutional layers, then you might end up training around 10 million parameters (for example).

Now, if you know about neural networks, then the above architecture of your model will eventually overfit in just 1 or 2 epochs. There will be no generalization in your model, and eventually, the model test results will be very poor.

How to solve the above problem

This is where transfer learning comes. The idea behind transfer learning is to use some pre-trained networks.

Some very famous pre-trained networks are:

  1. Alexnet
  2. VGG19
  3. VGG16
  4. MobileNet
  5. ResNet
  6. Word2vec
  7. Glove and many more…

RELATED TAGS

transfer learning
deep learning
machine learning
nlp
Did you find this helpful?