Sharpness-Aware Minimization for Efficiently Improving Generalization

Foret P., Kleiner A., Mobahi H., Neyshabur B. Sharpness-aware minimization for efficiently improving generalization. In International Conference on Learning Representations.

在训练的时候对权重加扰动能增强泛化性.

主要内容

如上图所示, 一般的训练方法虽然能够收敛到一个不错的局部最优点, 但是往往这个局部最优点附近是非常不光滑的, 即对权重\(w\)添加微小的扰动\(w+\epsilon\) 可能就会导致不好的结果, 作者认为这与模型的泛化性有很大关系(实际上已有别的文章提出这一观点).

作者给出如下的理论分析:

在满足一定条件下有

\[L_{\mathscr{D}} (w) \le \max_{\|\epsilon \|_2 \le \rho} L_{\mathcal{S}} (w + \epsilon) + h(\|w\|_2^2/\rho^2). \]

其中\(h\)是一个严格单调递增函数, \(L_{\mathcal{S}}\)是在训练集\(\mathcal{S}\)上的损失,

\[L_{\mathscr{D}}(w) = \mathbb{E}_{(x, y) \sim \mathscr{D}} [l(x, y;w)]. \]

如果把\(h(\|w\|_2^2/\rho^2)\)看成\(\lambda \|w\|_2^2\)(即常用的weight decay), 我们的目标函数可以认为是

\[\min_w L_{\mathcal{S}}^{SAM} (w) + \lambda \|w\|_2^2, \]

\[L_{\mathcal{S}}^{SAM}(w) := \max_{\|\epsilon \|_p \le \rho} L_{\mathcal{S}} (w + \epsilon), \]

注: 这里\(\|\cdot \|_p\)而并不仅限于\(\|\cdot \|_2\).

采用近似的方法求解上面的问题(就和对抗样本一样):

\[\epsilon^* (w) := \mathop{\arg \max} \limits_{\|\epsilon\|_p\le \rho} L_{\mathcal{S}}(w + \epsilon) \approx \mathop{\arg \max} \limits_{\|\epsilon\|_p\le \rho} L_{\mathcal{S}}(w) + \epsilon^T \nabla_w L_{\mathcal{S}}(w) = \mathop{\arg \max} \limits_{\|\epsilon\|_p\le \rho} \epsilon^T \nabla_w L_{\mathcal{S}}(w). \]

就是一个对偶范数的问题.

虽然\(\epsilon^*(w)\)实际上是和\(w\)有关的, 但是在实际中只是当初普通的量带入, 这样就不用计算二阶导数了, 即

\[\nabla_w L_{\mathcal{S}}^{SAM}(w) \approx \nabla_w L_{\mathcal{S}}(w) |_{w + \hat{\epsilon}(w)}. \]

实验结果非常好, 不仅能够提高普通的正确率, 在标签受到污染的情况下也能有很好的鲁棒性.

代码

原文代码

posted @ 2021-06-30 17:17  馒头and花卷  阅读(554)  评论(0)    收藏  举报