雲計算

BERT 蒸餾在垃圾輿情識別中的探索

image.png
近來 BERT等大規模預訓練模型在 NLP 領域各項子任務中取得了不凡的結果,但是模型海量參數,導致上線困難,不能滿足生產需求。輿情審核業務中包含大量的垃圾輿情,會耗費大量的人力。本文在垃圾輿情識別任務中嘗試 BERT 蒸餾技術,提升 textCNN 分類器性能,利用其小而快的優點,成功落地。

風險樣本如下:

image.png

一 傳統蒸餾方案

目前,對模型壓縮和加速的技術主要分為四種:

  • 參數剪枝和共享
  • 低秩因子分解
  • 轉移/緊湊卷積濾波器
  • 知識蒸餾

知識蒸餾就是將教師網絡的知識遷移到學生網絡上,使得學生網絡的性能表現如教師網絡一般。本文主要集中講解知識蒸餾的應用。

1 soft label

知識蒸餾最早是 2014 年 Caruana 等人提出方法。通過引入 teacher network(複雜網絡,效果好,但預測耗時久) 相關的軟標籤作為總體 loss 的一部分,來引導 student network(簡單網絡,效果稍差,但預測耗時低) 進行學習,來達到知識的遷移目的。這是一個通用而簡單的、不同的模型壓縮技術。

  • 大規模神經網絡 (teacher network)得到的類別預測包含了數據結構間的相似性。
  • 有了先驗的小規模神經網絡(student network)只需要很少的新場景數據就能夠收斂。
  • Softmax函數隨著溫度變量(temperature)的升高分佈更均勻。

Loss公式如下:

image.png

其中,

image.png

由此我們可以看出蒸餾有以下優點:

  • 學習到大模型的特徵表徵能力,也能學習到one-hot label中不存在的類別間信息。
  • 具有抗噪聲能力,如下圖,當有噪聲時,教師模型的梯度對學生模型梯度有一定的修正性。
  • 一定的程度上,加強了模型的泛化性。

image.png

2 using hints

(ICLR 2015) FitNets Romero等人的工作不僅利用教師網絡的最後輸出logits,還利用了中間隱層參數值,訓練學生網絡。獲得又深又細的FitNets。

image.png

中間層學習loss如下:

image.png

作者通過添加中間層loss的方式,通過teacher network 的參數限制student network的解空間的方式,使得參數的最優解更加靠近到teacher network,從而學習到teacher network的高階表徵,減少網絡參數的冗餘。

3 co-training

(arXiv 2019) Route Constrained Optimization (RCO) Jin和Peng等人的工作受課程學習(curriculum learning)啟發,並且知道學生和老師之間的gap很大導致蒸餾失敗,導致認知偏差,提出路由約束提示學習(Route Constrained Hint Learning),把學習路徑更改為每訓練一次teacher network,並把結果輸出給student network進行訓練。student network可以一步一步地根據這些中間模型慢慢學習,from easy-to-hard。

訓練路徑如下圖:
image.png

二 Bert2TextCNN蒸餾方案

為了提高模型的準確率,並且保障時效性,應對GPU資源緊缺,我們開始構建bert模型蒸餾至textcnn模型的方案。

方案1:離線logit textcnn 蒸餾

使用的是Caruana的傳統方法進行蒸餾。

image.png

方案2:聯合訓練 bert textcnn 蒸餾

參數隔離:teacher model 訓練一次,並把logit傳給student。teacher 的參數更新至受到label的影響,student 參數更新受到teacher loigt的soft label loss 和label 的 hard label loss 的影響。

image.png

方案3:聯合訓練 bert textcnn 蒸餾

參數不隔離: 與方案2類似,主要區別在於前一次迭代的student 的 soft label 的梯度會用於teacher參數的更新。

image.png

方案4:聯合訓練 bert textcnn loss 相加

teacher 和student 同時訓練,使用mutil-task的方式。

image.png

方案5:多teacher

大部分模型,在更新時候需要覆蓋線上歷史模型的樣本,使用線上歷史模型作為teacher,讓模型學習原有歷史模型的知識,保障對原有模型有較高的覆蓋。

image.png

實驗結果如下:

image.png

從以上的實驗,可以發現很有趣的現象。

1)方案2和方案3均使用先訓練teacher,再訓練student的方式,但是由於梯度返回更新是否隔離的差異,導致方案2低於方案3。是由於方案3中,每次訓練一次teacher,在訓練一次student,student學習完了的soft loss 會再反饋給teacher,讓teacher知道指如何導student是合適的,並且還提升了teacher的性能。

2)方案4採用共同更新的,同時反饋梯度的方式。反而textcnn 的性能迅速下降,雖然bert的性能基本沒有衰減,但是bert難以對textcnn每一步的反饋有個正確性的引導。

3)方案5中使用了歷史textcnn 的logit,主要是為了用替換線上模型時候,並保持對原有模型有較高的覆蓋率,雖然召回下降,但是整體的覆蓋率相比於單textcnn 提高了5%的召回率。

Reference

1.Dean, J. (n.d.). Distilling the Knowledge in a Neural Network. 1–9.
2.Romero A , Ballas N , Kahou S E , et al. FitNets: Hints for Thin Deep Nets[J].
3.Jin X , Peng B , Wu Y , et al. Knowledge Distillation via Route Constrained Optimization[J].

歡迎各位技術同路人加入螞蟻集團大安全機器智能團隊,我們專注於面向海量輿情藉助大數據技術和自然語言理解技術挖掘存在的金融風險、平颱風險,為用戶資金安全護航、提高用戶在螞蟻生態下的用戶體驗。內推直達 [email protected],有信必回。

Leave a Reply

Your email address will not be published. Required fields are marked *