Transferring Knowledge from BERT to Neural Networks

Learn about the task-specific teacher and student BERT and transfering the task-specific knowledge to the neural networks.

Let's look at an interesting paper, In this paper, the researchers have explained how to perform knowledge distillation and transfer task-specific knowledge from BERT to a simple neural network. Let's get into the details and learn how exactly this works.

To understand how exactly we transfer task-specific knowledge from BERT to a neural network, first, let's take a look at the teacher BERT and student network in detail.

The teacher BERT

We use the pre-trained BERT as the teacher BERT. Here, we use the pre-trained BERT-large as the teacher BERT. Note that we are transferring task-specific knowledge from the teacher to the student. So, first, we take the pre-trained BERT-large model, fine-tune it for a specific task, and then use it as the teacher.

Example: Sentiment analysis task

Suppose we want to train our student network for the sentiment analysis task. In that case, we take the pre-trained BERT-large model, fine-tune it for a sentiment analysis task, and then use it as the teacher. Thus, our teacher is the pre-trained BERT, fine-tuned for a specific task of our interest.

The student network

The student network is a simple bidirectional LSTM (BiLSTM). The architecture of the student network changes based on the task. Let's look at the architecture of the student network for a single-sentence classification task.

Example: Sentiment analysis task

Suppose we are performing sentiment analysis. Say we have a sentence: 'I love Paris'. First, we get the embeddings of the sentence, and then we feed the input embeddings to the bidirectional LSTM. The bidirectional LSTM reads the sentence in both directions (that is, forward and backward). So, we obtain the forward and backward hidden states from bidirectional LSTM.

Next, we feed the forward and backward hidden states to the fully connected layer with ReLU activation, which then returns the logits as an output. We take the logits and feed them to the softmax function and obtain the probabilities of the sentence belonging to the positive and negative class, as shown in the diagram:

Get hands-on with 1200+ tech skills courses.