多分类交叉熵损失函数
理解交叉熵
关于样本集的两个概率分布\(p\)和\(q\),设\(p\)为真实的分布,比如\([1,0,0]\)表示当前样本属于第一类,\(q\)为拟合的分布,比如\([0.7, 0.2, 0.1]\)。
按照真实分布\(p\)来衡量识别一个样本所需的编码长度的期望,即平均编码长度(信息熵):
如果使用拟合分布\(q\)来表示来自真实分布\(p\)的编码长度的期望,即平均编码长度(交叉熵):
直观上,用\(p\)来描述样本是最完美的,用\(q\)描述样本就不那么完美,根据吉布斯不等式,\(H(p,q) \geq H(p)\)恒成立,当\(q\)为真实分布时取等,我们将由\(q\)得到的平均编码长度比由\(p\)得到的平均编码长度多出的bit数称为相对熵,也叫KL散度:
在机器学习的分类问题中,我们希望缩小模型预测和标签之间的差距,即KL散度越小越好,在这里由于KL散度中的 \(H(p)\) 项不变(在其他问题中未必),故在优化过程中只需要关注交叉熵就可以了,因此一般使用交叉熵作为损失函数。
多分类任务中的交叉熵损失函数
其中 \(p=[p_0, \dots ,p_{C−1}]\) 是一个概率分布,每个元素 pip_ip_i 表示样本属于第i类的概率; \(y=[y_0, \dots ,y_{C−1}]\) 是样本标签的onehot表示,当样本属于第类别\(i\)时\(y_i=1\),否则\(y_i=0\); \(c\)是样本标签。
PyTorch中的交叉熵损失函数实现
PyTorch提供了两个类来计算交叉熵,分别是CrossEntropyLoss() 和NLLLoss()。
- torch.nn.CrossEntropyLoss()
类定义如下
torch.nn.CrossEntropyLoss(
weight=None,
ignore_index=-100,
reduction="mean",
)
z=[z_0, \dots, z_{C-1}]
表示一个样本的非softmax输出,c表示该样本的标签,则损失函数公式描述如下,
如果weight被指定,
其中,\(w=weight[c] \cdot 1\{c \neq ignore\_index\}\)
import torch
import torch.nn as nn
model = nn.Linear(10, 3)
criterion = nn.CrossEntropyLoss()
x = torch.randn(16, 10)
y = torch.randint(0, 3, size=(16,)) # (16, )
logits = model(x) # (16, 3)
loss = criterion(logits, y)
- torch.nn.NLLLoss()
类定义如下
torch.nn.NLLLoss(
weight=None,
ignore_index=-100,
reduction="mean",
)
a=[a_0, \dots, a_{C-1}]
表示一个样本对每个类别的对数似然(log-probabilities),\(c\) 表示该样本的标签,损失函数公式描述如下,
其中, \(w=weight[c] \cdot 1\{c \neq ignore\_index\}\)
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(10, 3),
nn.LogSoftmax()
)
criterion = nn.NLLLoss()
x = torch.randn(16, 10)
y = torch.randint(0, 3, size=(16,)) # (16, )
out = model(x) # (16, 3)
loss = criterion(out, y)
总结
torch.nn.CrossEntropyLoss在一个类中组合了nn.LogSoftmax和nn.NLLLoss,