Computer Science/구글 BERT의 정석

Tinybert

  • -
728x90
반응형

Bert에서 모델 경량화를 시킨 모델이라고 생각해주면 된다.

모델 경량화 방법에는 다음과 같은 3가지 방법이 존재한다.

1. Quantization

2. Weight Pruning

3. Knowledge Distilation

 

위 3가지 중 본 논문은 3번째 방법인 Knowledge Distilation을 활용하여 모델을 경량화 하는 방법을 제안한다.

해당 방식을 사전 학습과 fine-tuning 단계 모두 진쟁하게 된다. 따라서 위 모델은 사전학습을 통해 얻을 수 있는 general domain에 대한 지식과 fine-tuning 단계에서 얻을 수 있는 task-specific한 지식까지 얻을 수 있게 되는 것이다.

 

위 모델의 핵심적인 내용은 3가지의 loss를 사용했다는 점과, 2단계의 distillation을 진행하였다는 것이다.

 

먼저 3가지 loss를 사용한 것에 대해서 살펴보도록 하자.

loss가 어떤 것인지 살펴보면 다음과 같다.

1. embedding layer의 output

2. Transformer layer에 있는 hidden vector와 attention matrix

3. prediction layer의 output

 

2단계의 distillation은 다음과 같다.

1. general distillation

2. task-specific distillation(1번에서 얻은 general한 모델을 시작점으로 하여, augmentation을 시키고 그것을 fine-tuning하는 단계를 거치게 된다.)

 

Knowledge distaillation은 비유하자면, 선생님 네트워크 T의 지식을 학생 네트워크 S에게 전달하는 것이다.

즉, 학생 네트워크는 선생님 네트워크의 행동을 비슷하게 수행하게끔 학습하는 것이다.

(단, $f_T$는 선생님의 행동함수, $f_S$는 학생의 행동 함수)

 

이때 행동함수란 Transformer의 multi-head attention, FFN layer, representation layer등이 해당한다.

즉, 위 두 네트워크 사이의 차이를 작게끔 훈련되게 되는 것이다.

기본적으로 선생님에 해당하는 layer가 더 많으므로 학생 모델의 layer에 해당하는 개수만큼을 선택해주어야 한다.

예를 들어, 선생님에 해당하는 layer의 개수를 N, 학생에 해당하는 layer의 개수를 M이라고 하자.

그러면, N개 중에 M개를 선택해주어야 한다. 이후에는 각 선생님 layer와 학생 layer 사이에 매핑을 진행해주면 된다.

그 이후, 다음 손실함수를 최소화하는 방향으로 학습을 진행하게 된다.

 

Transformer layer에서는 attention과 hidden state 각각에 대해 지식을 distillation하게 되는데

첫번째로 학생 네트워크는 multi-attention의 결과 나오는 attention matrix와 유사하게끔 학습을 진행하게 된다.

두번째로 transformer layer의 output값과 유가하게끔 학습을 진행하게 된다.

embedding layer distillation의 경우 hidden state distillation과 유사하게끔 진행하게 된다.

최종 laye의 지식을 학생 모델의 distillation을 하게 되고, 이는 soft cross entropy loss를 활용하여 진행하게 된다.

위에서 나온 loss들을 결합하게 되면 다음과 같은 결과가 된다.

반응형
Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.