Typesetting math: 2%

softmax求导的过程


(图出自李宏毅老师的PPT

对机器学习/深度学习有所了解的同学肯定不会对 softmax 陌生,它时而出现在多分类中用于得到每个类别的概率,时而出现在二分类中用于得到正样本的概率(当然,这个时候 softmaxsigmoid 的形式出现)。

1. 从 sigmoidsoftmax

sigmoid 出现的频率在机器学习/深度学习中不可谓不高,从 logistic 回归到深度学习中的激活函数。先来看一下它的形式:

sigmoid(z)=σ(z)=11+ez,zRsigmoid(z)=σ(z)=11+ez,zR

我们把它的图像画一下:

import numpy as np
imort matplotlib.pyplot as plt

x = np.arange(-5, -5, 0.05)
z = x / (1 + np.exp(-x))

plt.plot(x, z)
plt.show()

sigmoid
显然,sigmoid 将实数域的值压缩到了 (0,1) 之间。那么在二分类中又是怎么用的呢?以 logistic 回归为例。在逻辑回归中通常假设样本的类别呈现为伯努利分布,即:

P(y=1)=p1P(y=0)=p0

且有 p1+p0=1。用逻辑回归解决二分类问题时,我们用建模的时样本为正样本的概率(根据伯努利分布,为负样本的概率显而易见):将逻辑回归得到的回归值输入到 sigmoid 中就得到了 p1,即:

z=f(z)p1=11+ez

其中 f 即是一个回归函数。那么 sigmoid 又与 softmax 有什么关系呢?先来看一下 softmax 的定义:

softmax(z)=ezniezi,zRn

显然,softmax 是一个向量的函数,正如本文开头的图示一样。softmax 将一个向量中的值进行了归一化,我们在多分类中常将其视作样本属于不同类别的概率值。我们将设有一个二分类模型,其输出是一个向量 zR2,其中 z0,z1 分别是样本属于类别 0, 1 的未归一化的值。如果我们用 softmax 得到类别概率,则:

p0=ez0ez0+ez1p1=ez1ez0+ez1

p0,p1 即是样本分别属于 0, 1 的概率,我们对其进行以下变形:

p1=ez1ez0+ez1=ez1ez0(ez0z0+ez1z0)=ez1z0ez0z0+ez1z0=eΔe0+eΔ=eΔ1+eΔ=11+eΔ

其中 Δ=z1z0,同理可得 p0=1eΔ+1。回想到逻辑回归是对样本属于 1 概率进行建模,那么 Δ 就是逻辑回归进入 sigmoid 之前的预测值,这里我们来看看 Δ 到底是什么:

p1p0=1eΔ=eΔlogp1p0=Δ

由上可知,逻辑回归是对样本为 1 和 为 0 的概率的比值(几率)的对数(对数几率)进行回归。所以说逻辑回归也是一种回归,只不过回归的是样本的对数几率,得到了对数几率后再来得到样本属于类别 1 的概率。

再说回 sigmoidsoftmax 的关系,其实从上面的的世子我们也看出来了其实 sigmoid 只是 softmax 的一种情况,sigmoid 隐式地包含了另一个元素地 softmax 值。在对分类任务进行建模时,我们通常将二分类任务中的一个类别进行建模,假设其服从伯努利分布;或者建模为二项分布,分别建模样本属于每个类别地概率(即 z 中的每一位表示样本为对应类别的对数几率,softmax(z) 中的每一位表示样本为对应类别的概率)。

2. softmax 损失的求导

在多分类任务中,我们通常使用对数损失(在二分类中就是交叉熵损失):

L=1NNi=1Cj=1yijlogˆyij

其中,N,C 分别为样本数和类别数,yij{0,1} 表示样本 xi 是否属于类别 jˆyij 表示对应的预测概率。在多分类中,概率值通常通过 softmax 获得,即:

ˆyij=softmax(zi)j=ezijCk=1ezik

这里我们只考虑一个样本的损失,即:

l=Cj=1yjlogˆyjˆyj=softmax(z)j=ezjCk=1ezk

好了,开始重头戏,求多分类对数损失对 zk 的偏导:

lzk=zk(Cj=1yjlogˆyj)=Cj=1yjlogˆyjzk=Cj=1yj1ˆyjˆyjzk

到这里其实以及很简单了,只需要算出 ˆyjzk 就行了,那就来吧:

ˆyjzk=zk(ezjCc=1ezc)=ezjzkezjzk()2=ezjzkezjezk()2

其中 =Cc=1ezc,其中 ezjzk 需要分情况讨论一下:

ezjzk={ezjk=j0kj

因此,

ˆyjzk={ezj  (ezj)2()2k=j0ezjezk()2kj

看起来有点复杂,我们来化简以下:

ˆyjzk={ˆyj(1ˆyj)k=jˆyjˆykkj

收工了吗?不!我们的目的是求 lzk

lzk=[yk1ˆykˆykzk+jkyj1ˆyjˆyjzk]=[yk1ˆykˆyk(1ˆyk)+jkyj1ˆyj(ˆyjˆyk)]=[yk(1ˆyk)jkyjˆyk]=[ykykˆykjkyjˆyk]=[ykjyjˆyk]=jyjˆykyk=ˆykjyjyk=ˆykyk

别看求起来有一点复杂,但是最后的结果还是很优雅的嘛:预测值 - 真实值

这里有几个要注意的点:

  • 上式第二部中依据 j 是否等于 k 将求和分成了两部分;
  • 上式的倒数第二步中,利用了多分类的目标中只有一个为 1,即 jyj=1.

以上!

posted @ 2022-02-25 21:21  Milkha  阅读(1370)  评论(0)    收藏  举报
编辑推荐:
· 35+程序员的转型之路:经济寒冬中的希望与策略
· JavaScript中如何遍历对象?
· 领域模型应用
· 记一次 ADL 导致的 C++ 代码编译错误
· MySQL查询执行顺序:一张图看懂SQL是如何工作的
阅读排行:
· 35+程序员的转型之路:经济寒冬中的希望与策略
· 全球首位 AI 程序员 Devin 诞生了,对于程序员的影响到底多大?
· 使用 OpenAuth.Net 快速搭建 .NET 企业级权限工作流系统
· 做这么个免费在线拼图工具,如何赚钱呢
· 我在厂里搞wine的日子
点击右上角即可分享
微信分享提示