Distilling the Knowledge in a Neural Network
概
\[q_1 = \frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)}.
\]
主要内容
这篇文章或许重点是在迁移学习上, 一个重点就是其认为soft labels (即概率向量)比hard target (one-hot向量)含有更多的信息. 比如, 数字模型判别数字\(2\)为\(3\)和\(7\)的概率分别是0.1, 0.01, 这说明这个数字\(2\)很有可能和\(3\)长的比较像, 这是one-hot无法带来的信息.
于是乎, 现在的情况是:
-
以及有一个训练好的且往往效果比较好但是计量大的模型\(t\);
-
我们打算用一个小的模型\(s\)去近似这个已有的模型;
-
策略是每个样本\(x\), 先根据\(t(x)\)获得soft logits \(z \in \mathbb{R}^K\), 其中\(K\)是类别数, 且\(z\)未经softmax.
-
最后我们希望根据下面的损失函数来训练\(s\):
\[\mathcal{L(x, y)} = T^2 \cdot \mathcal{L}_{soft}(x, y) + \lambda \cdot\mathcal{L}_{hard}(x, y) \]
其中
\[\mathcal{L}_{soft}(x, y) = -\sum_{i=1}^K p_i(x) \log q_i (x) = -\sum_{i=1}^K
\frac{\exp(v_i(x)/T)}{\sum_j \exp(v_j(x)/T)}
\log \frac{\exp(z_i(x)/T)}{\sum_j \exp(z_j(x)/T)}
\]
\[\mathcal{L}_{hard}(x, y) = -\log
\frac{\exp(z_y(x))}{\sum_j \exp(z_j(x))}
\]
至于\(T^2\)是怎么来的, 这是为了配平梯度的magnitude.
\[\begin{array}{ll}
\frac{\partial \mathcal{L}_{soft}}{\partial z_k}
&= -\sum_{i=1}^K \frac{p_i}{q_i} \frac{\partial q_i}{\partial z_k}
= -\frac{1}{T}p_k -\sum_{i=1}^K \frac{p_i}{q_i} \cdot (-\frac{1}{T}q_i q_k) \\
&= -\frac{1}{T} (p_k -\sum_{i=1}^K p_iq_k) = \frac{1}{T}(q_k-p_k) \\
&= \frac{1}{T} (\frac{e^{z_i/T}}{\sum_j e^{z_j/T}} - \frac{e^{v_i/T}}{\sum_j e^{v_j/T}}) .
\end{array}
\]
当\(T\)足够大的时候, 并假设\(\sum_j z_j=0 = \sum_j v_j =0\), 有
\[\frac{\partial \mathcal{L}_{soft}}{\partial z_k} \approx \frac{1}{KT^2} (z_k - v_k).
\]
故需要加个\(T^2\)取抵消这部分的影响.
代码
其实一直很好奇的一点是这部分代码在pytorch里是怎么实现的, 毕竟pytorch里的交叉熵是
\[-\log p_y(x)
\]
另外很恶心的一点是, 我看大家都用的是 KLDivLOSS, 但是其实现居然是:
\[\mathcal{L}(x, y) = y \cdot \log y - y \cdot x,
\]
注: 这里的\(\cdot\)是逐项的.
def kl_div(x, y):
return y * (torch.log(y) - x)
x = torch.randn(2, 3)
y = torch.randn(2, 3).abs() + 1
loss1 = F.kl_div(x, y, reduction="none")
loss2 = kl_div(x, y)
这时, 出来的结果长这样
tensor([[-1.5965, 2.2040, -0.8753],
[ 3.9795, 0.0910, 1.0761]])
tensor([[-1.5965, 2.2040, -0.8753],
[ 3.9795, 0.0910, 1.0761]])
又或者:
def kl_div(x, y):
return (y * (torch.log(y) - x)).sum(dim=1).mean()
torch.manual_seed(10086)
x = torch.randn(2, 3)
y = torch.randn(2, 3).abs() + 1
loss1 = F.kl_div(x, y, reduction="batchmean")
loss2 = kl_div(x, y)
print(loss1)
print(loss2)
tensor(2.4394)
tensor(2.4394)
所以如果真要弄, 应该要
def soft_loss(z, v, T=10.):
# z: logits
# v: targets
z = F.log_softmax(z / T, dim=1)
v = F.softmax(v / T, dim=1)
return F.kl_div(z, v, reduction="batchmean")

浙公网安备 33010602011771号