Distillation of Transformer Layer

The transformer layer is basically the encoder layer. In the encoder layer, we compute the attention matrix using multi-head attention, and the encoder layer returns the hidden state representation as output. In transformer distillation, we transfer the knowledge from the attention matrix of the teacher to the student, and we also transfer knowledge from the hidden state of the teacher to the student.

Methods of transformer layer distillation

Thus, transformer layer distillation includes two distillations, as given here:

  • Attention-based distillation

  • Hidden state-based distillation

First, let's take a look at how attention-based distillation works, and then we'll look at hidden state-based distillation.

Attention-based distillation

In attention-based distillation, we transfer the knowledge of the attention matrix from the teacher BERT to the student BERT. But what is the use of this? Why do we have to perform attention-based distillation? The attention matrix holds useful information such as language syntax, coreference information, and more, which is very useful in understanding more about the language in general. Therefore, transferring the knowledge of the attention matrix from the teacher to the student is very useful.

Performing attention-based distillation

To perform attention-based distillation, we train the student network by minimizing the mean squared error between the attention matrix of the student and the teacher BERT. The attention-based distillation loss LattnL_\text{attn} is expressed as follows:

Get hands-on with 1200+ tech skills courses.