Distilling the Knowledge in a Neural Network

 

url: https://arxiv.org/abs/1503.02531
year: NIPS 2014

69DCE989-A8B0-423F-A36D-877BCDDD231F

简介

将大模型的泛化能力转移到小模型的一种显而易见的方法是使用由大模型产生的类概率作为训练小模型的“软目标”

其中, T(temperature, 蒸馏温度), 通常设置为1的。使用较高的T值可以产生更软的类别概率分布。 也就是, 较高的 T 值, 让学生的概率分布可以更加的接近与老师的概率分布,

下面通过一个直观的例子来感受下

def softmax_with_T(logits, temperature):

    for t in temperature:
        total = 0
        prob = []
        for logit in logits:
            total += np.exp(logit/t)
        for logit in logits:
            prob.append(np.exp(logit/t) / total)
        print('T={:<4d}'.format(t), end='  ')
        for p in prob:
            print('{:0.3f}'.format(p), end='  ')
        print()

可以看出, softmax 输出的项比例与 logits原始比例之间的关系与 logits 本身的模长以及 T 值大小相关, 感觉 T 值需要仔细调整下, 至少能反应 logits 之间的大致关系, 而且可以看出, softmax_with_T 受两个变量的影响, 直接来比较的话, 比较难分析. 当 T 远大于 logits 的模长时, softmax 的输出尺度在相同的数量级下(如logits=[6,3,1], T=25), 这样看的话, 即使老师和学生的 logit 相差很远, 经过具有很大 T 的 softamx 之后, 数量级几乎相同, 这样是不合理的. 但是, 下面的公式推导结果加上实验结果表明, 认真看梯度才是王道, 看输出的话, 完全找不到感觉, 对于软标签交叉熵损失

梯度推导

softmax+cross entropy梯度求导

 

Czi=1T(qipi)=1T(ezi/Tjezj/Tevi/Tjevj/T)∂C∂zi=1T(qi−pi)=1T(ezi/T∑jezj/T−evi/T∑jevj/T)
<script id="MathJax-Element-1" type="math/tex; mode=display">// </script>

 

exex <script id="MathJax-Element-2" type="math/tex">// </script> 泰勒展开

 

ex1+x+x22!+x33!++xnn!x0,ex1+xex≈1+x+x22!+x33!+⋯+xnn!x→0,ex≈1+x
<script id="MathJax-Element-3" type="math/tex; mode=display">// </script>

 

TT→∞ <script id="MathJax-Element-4" type="math/tex">// </script> 时, ZiT0ZiT→0 <script id="MathJax-Element-5" type="math/tex">// </script>

Czi1T(1+zi/TN+zj/T1+vi/TN+vj/T)∂C∂zi≈1T(1+zi/TN+∑zj/T−1+vi/TN+∑vj/T)
<script id="MathJax-Element-6" type="math/tex; mode=display">// </script>

 

假设logits已经单独进行了zero-center中心化处理,那么,

 

jzj=jvj=0Czi1NT2(zivi)∑jzj=∑jvj=0⇓∂C∂zi≈1NT2(zi−vi)
<script id="MathJax-Element-7" type="math/tex; mode=display">// </script>

 

这样的话, 当T值最够大, 方法就变为求老师和学生的 logits 的 L2 距离了.

术语说明
qsoftqsoft <script id="MathJax-Element-8" type="math/tex">// </script> 老师模型的 softmax 输出软标签
qhardqhard <script id="MathJax-Element-9" type="math/tex">// </script> 训练集 one-hot 硬标签
psoftpsoft <script id="MathJax-Element-10" type="math/tex">// </script> 学生模型的 softmax 输出软标签
phardphard <script id="MathJax-Element-11" type="math/tex">// </script> 学生模型的 softmax 输出硬标签(T=1)

 

loss_cross_entpopy=αT2qsoftln(psoft)+(1α)qhardln(phard)loss_cross_entpopy=α⋅T2⋅qsoft⋅ln⁡(psoft)+(1−α)⋅qhard⋅ln⁡(phard)
<script id="MathJax-Element-12" type="math/tex; mode=display">// </script>

 

论文中发现通常给予硬标签损失函数 可忽略不计的较低权重 <script id="MathJax-Element-13" type="math/tex">// </script> 可以获得最佳结果。 由于软目标产生的梯度的大小为 1T21T2 <script id="MathJax-Element-14" type="math/tex">// </script> ,因此当使用硬目标和软目标时,将它们乘以 T2T2 <script id="MathJax-Element-15" type="math/tex">// </script> 是很重要的, 这确保软硬标签对梯度相对贡献在一个数量级。

实验结果

思考

软标签交叉熵函数与 KL 散度的联系
5FEBBF21-BEAC-4102-AC36-6A4FDE89D5E9
86CD9883-393B-48BA-89A7-6BFB9CD7A787

上式中, 由于 p 为老师的预测结果, 模型蒸馏时候, 老师模型被冻结, 从梯度反传来看, 软标签交叉熵函数 等价于 KL 散度.

对于我而言, 这篇论文相对于 Do Deep Nets Really Need to be Deep? 贡献就在于, 将 L2距离 和 KL 散度统一到一个公式中了, 由于到 T 足够大, KL 散度的梯度与 L2 距离的一样. 这篇论文中其他部分没有读懂, 没有看到其他想要的东西. 后面知识积累了有机会在看看有没有新感受吧.

蒸馏入门的话, 推荐 Do Deep Nets Really Need to be Deep? 这篇论文. 从实验分析来说, 各种分析都很到位, 分析的方式也是易读的, 容易理解. 就工程效果来看, 实际上Distilling the Knowledge in a Neural Network 这篇论文有效时候, T一般都挺大的, 那么KL 散度的实际的效果就是 L2 距离, 不如直接用 L2 距离, 理解上简单, 调节超参少, 效果也非常好.

posted on 2019-09-23 15:47  Hello_zhengXinTang  阅读(215)  评论(0)    收藏  举报