多分类交叉熵损失函数

理解交叉熵

关于样本集的两个概率分布\(p\)\(q\),设\(p\)为真实的分布,比如\([1,0,0]\)表示当前样本属于第一类,\(q\)为拟合的分布,比如\([0.7, 0.2, 0.1]\)

按照真实分布\(p\)来衡量识别一个样本所需的编码长度的期望,即平均编码长度(信息熵):

\[H(p)=-\sum^C_{i=1} p(x_i) \log (p(x_i)) \]

如果使用拟合分布\(q\)来表示来自真实分布\(p\)的编码长度的期望,即平均编码长度(交叉熵):

直观上,用\(p\)来描述样本是最完美的,用\(q\)描述样本就不那么完美,根据吉布斯不等式,\(H(p,q) \geq H(p)\)恒成立,当\(q\)为真实分布时取等,我们将由\(q\)得到的平均编码长度比由\(p\)得到的平均编码长度多出的bit数称为相对熵,也叫KL散度:

\[D(p||q) = H(p, q) - H(p) = \sum^C_{i=1} p(x_i) \log (\frac{p(x_i)}{q(x_i)}) \]

在机器学习的分类问题中,我们希望缩小模型预测和标签之间的差距,即KL散度越小越好,在这里由于KL散度中的 \(H(p)\) 项不变(在其他问题中未必),故在优化过程中只需要关注交叉熵就可以了,因此一般使用交叉熵作为损失函数。

多分类任务中的交叉熵损失函数

\[Loss = -\sum^{C-1}_{i=0}y_i \log(p_i) = -\log(p_c) \]

其中 \(p=[p_0, \dots ,p_{C−1}]\) 是一个概率分布,每个元素 pip_ip_i 表示样本属于第i类的概率; \(y=[y_0, \dots ,y_{C−1}]\) 是样本标签的onehot表示,当样本属于第类别\(i\)\(y_i=1\),否则\(y_i=0\); \(c\)是样本标签。

PyTorch中的交叉熵损失函数实现

PyTorch提供了两个类来计算交叉熵,分别是CrossEntropyLoss() 和NLLLoss()。

  • torch.nn.CrossEntropyLoss()

类定义如下

torch.nn.CrossEntropyLoss(
    weight=None,
    ignore_index=-100,
    reduction="mean",
)

z=[z_0, \dots, z_{C-1}]表示一个样本的非softmax输出,c表示该样本的标签,则损失函数公式描述如下,

\[loss(z,c) = -\log( \frac{\exp(z[c])}{\sum^{C-1}_{j=0} \exp(z[j])}) = -z[c] + \log(\sum^{C-1}_{j=0} \exp(z[j])) \]

如果weight被指定,

\[loss(z,c) = w \cdot(-z[c] + \log(\sum^{C-1}_{j=0} \exp(z[j]))) \]

其中,\(w=weight[c] \cdot 1\{c \neq ignore\_index\}\)

import torch
import torch.nn as nn

model = nn.Linear(10, 3)
criterion = nn.CrossEntropyLoss()

x = torch.randn(16, 10)
y = torch.randint(0, 3, size=(16,))  # (16, )
logits = model(x)  # (16, 3)

loss = criterion(logits, y)
  • torch.nn.NLLLoss()

类定义如下

torch.nn.NLLLoss(
    weight=None,
    ignore_index=-100,
    reduction="mean",
)

a=[a_0, \dots, a_{C-1}]表示一个样本对每个类别的对数似然(log-probabilities),\(c\) 表示该样本的标签,损失函数公式描述如下,

\[loss(a,c) = -w \cdot a[c] = -w \cdot \log(p_c) \]

其中, \(w=weight[c] \cdot 1\{c \neq ignore\_index\}\)

import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(10, 3),
    nn.LogSoftmax()
)
criterion = nn.NLLLoss()

x = torch.randn(16, 10)
y = torch.randint(0, 3, size=(16,))  # (16, )
out = model(x)  # (16, 3)

loss = criterion(out, y)

总结

torch.nn.CrossEntropyLoss在一个类中组合了nn.LogSoftmax和nn.NLLLoss,

posted @ 2024-10-23 20:28  X1OO  阅读(147)  评论(0)    收藏  举报