Pytorch分类问题中的交叉熵损失函数使用

本文主要介绍一下分类问题中损失函数的使用,对于二分类、多分类、多标签这个三个不同的场景,在 Pytorch 中的损失函数使用稍有区别。


 

 

 

损失函数

Softmax

在介绍损失函数前,先介绍一下什么是 Softmax,通常在分类问题中会将 Softmax 搭配 Cross Entropy 一同使用。Softmax 的计算公式定义如下:

$$\mathtt{softmax(x_i)={exp(x_i) \over {\sum_{j} exp(x_j)}}}$$

例如,我们现在有一个数组 [1, 2, 3],这三个数的 Softmax 输出是:

$$\mathtt{softmax(1)={exp(1) \over exp(1)+exp(2)+exp(3)}=0.09}$$

$$\mathtt{softmax(2)={exp(2) \over exp(1)+exp(2)+exp(3)}=0.2447}$$

$$\mathtt{softmax(3)={exp(3) \over exp(3)+exp(2)+exp(3)}=0.6652}$$

所以 Softmax 直白来说就是将原来输出是 [1, 2, 3] 的数组,通过 Softmax 函数作用后映射成范围为 (0, 1) 的值,而这些值的累积和为 1(满足概率性质),那么我们就可以将它理解成概率,在最后选取输出结点的时候,我们就可以选取概率最大(也就是值对应最大的)结点,作为我们的预测目标。

 

Cross Entropy

对于 Cross Entropy,以下是我见过最喜欢的一个解释:

 

在机器学习中,P 往往用来表示样本的真实分布,比如 [1, 0, 0] 表示当前样本属于第一类;Q 往往用来表示模型所预测的分布,比如 [0.7, 0.2, 0.1]。这里直观的理解就是,如果用 P 来描述样本,那就非常完美,而用 Q 来描述样本,虽然可以大致描述,但是不是那么的完美,信息量不足,需要额外的一些信息增量才能达到和 P 一样完美的描述。如果我们的 Q 通过反复训练,也能趋近完美地描述样本,那么就不再需要额外的信息增量,这时 Q 等价于 P。

 

所以,如果按照真实分布 P 来衡量描述一个样本所需的编码长度的期望,即平均编码长度,或称信息熵:

$$\mathtt{H(p) = -\sum_{i=1}^C p(x_i)log(p(x_i))}$$

如果使用拟合分布 Q 来表示来自真实分布 P 的编码长度的期望,即平均编码长度,或称交叉熵:

$$\mathtt{H(p,q)=-\sum_{i=1}^C p(x_i)log(q(x_i))}$$

所以 H(p, q) >= H(q) 恒成立。我们把 Q 得到的平均编码长度比 P 得到的平均编码长度多出的 bit 数称为“相对熵”,也叫“KL散度”,用来衡量两个分布的差异:

$$\mathtt{D_{KL}(p||q)=H(p,q)-H(p)=\sum_{i=1}^C p(x_i)log({p(x_i) \over q(x_i)})}$$

在机器学习的分类问题中,我们希望通过训练来缩小模型预测和标签之间的差距,即“KL散度”越小越好,根据上面公式,“KL散度”中 H(p) 项不变,所以在优化过程中我们只需要关注“交叉熵”即可,这就是我们使用“交叉熵”作为损失函数的原因。

 

交叉熵如何计算?

在做完 Softmax 之后,再计算交叉熵,作为损失函数:

$$\mathtt{L(\widehat{y}, y) =- \sum_{i=1}^C y_i log(\widehat{y}\_{i})}$$

这里的 $\mathtt{\widehat{y}}$ 指的是预测值(Softmax 层的输出)。$\mathtt{y}$ 指的是真实值,是一个 one-hot 编码后的 C 维向量。什么是 One-hot Encoding?如果一个样例 x 是类别 i,则它标签 y 的第 i 维的值为 1,其余维的值为 0。例如,x 是类别 2,共 4 类,则其标签 y 的 值为 [0, 1, 0, 0]。

 

多分类和多标签:

1. 二分类:

表示分类任务中有两个类别。

[Pytorch Example]

2. 多分类:

一个样本属于且只属于多个类中的一个,一个样本只能属于一个类,不同类之间是互斥的。对于多分类问题,我们通常有几种处理方法:1)直接多分类;2)one vs one;3)one vs rest 训练 N 个分类器,测试的时候若仅有一个分类器预测为正的类别则对应的类别标记作为最终分类结果,若有多个分类器预测为正类,则选择置信度最大的类别作为最终分类结果。

3. 多标签:

一个样本可以属于多个类别(或标签),不同类之间是有关联的。

 

激活函数和损失函数搭配选择:

Problem Actv Func Loss Func
binary sigmoid BCE
Multiclass softmax CE
Multilabel sigmoid BCE

 



 

 

Pytorch 中的交叉熵损失函数使用

以下详细介绍在 Pytorch 如何使用交叉熵作为损失函数。

 

Pytorch 中有各种交叉熵的实现,总的来说包含两类:nn.xxx 和 F.xxx 两种。nn.xxx 是包装好的类,而 F.xxx 是可以直接调用的函数,其实 nn.xxx 里面就是包装的 F.xxx:

torch.nn (nn) torch.nn.functional (F)
nn.CrossEntropyLoss() F.cross_entropy()
nn.LogSoftmax()+nn.NLLLoss() F.log_softmax()+F.nll_loss()
nn.BCELoss() F.binary_cross_entropy()
nn.BCEWithLogitsLoss() F.binary_cross_entropy_with_logits()

参数说明:

1. weight:

  • CE 和 BCE 系列都有此参数,用于为每个类别的 loss 设置权重,常用于类别不均衡问题;
  • weight 必须是 float 类型的 1D tensor,长度和类别长度一致:weight = torch.from_numpy(np.array([0.6, 0.2, 0.2])).float().to(device)
  • 注意:weight 加起来未必一定要等于 1,类 c 对应的 weight 为 W_c = (N-N_c) / N,数目越多的类,weight 越小,weight 越大,此类得到的 loss 被放大;

2. ignore_index:

  • 其中 BCE 系列没有此参数,此参数用于指定忽略某些类别的 loss;

3. size_average:

  • 该参数指定 loss 是否在一个 batch 内平均,即是否除以 N,目前此参数已经被弃用

4. reduce:

  • 目前此参数已经被弃用

5. reduction:

  • 此参数在新版本中是为了取代 ”size_average“ 和 "reduce" 参数的;
  • mean (default):返回 N 个 loss 的平均值;
  • sum:返回 N 个 loss 的 sum;
  • None:直接返回一个 batch 中的 N 个 loss;

6. pos_weight:

  • 只有 BCEWithLogits 系列有次参数;
  • 与 weight 参数的区别是:WIP;

 

Binary Case

  • BCEWithLogits:无需手动做 Sigmoid (作用是将 pd 缩放到 0~1)
    • 图像分类:
      nn. & F. Shape Type

      criterion = nn.BCEWithLogitsLoss()
      loss = criterion(pd, gt)

      pd: [N]
      gt: [N]

      float32
      float32

      loss = F.binary_cross_entropy_with_logits(pd, gt)

      pd: [N]
      gt: [N]

      float32
      float32

    •  语义分割:
      nn. & F. Shape Type

      criterion = nn.BCEWithLogitsLoss()
      loss = criterion (pd, gt)

      pd: [N, 1, H, W]
      gt: [N, 1, H, W]

      float32
      float32

      loss = F.binary_cross_entropy_with_logits(pd, gt)

      pd: [N, 1, H, W]
      gt: [N, 1, H, W]

      float32
      float32

    • 代码:
      import torch
      import torch.nn as nn
      import torch.nn.functional as F
      
      
      # Binary Case Image Classification: BCEWithLogits vs Sigmoid + BCE
      print("Binary Case Image Classification: BCEWithLogits vs Sigmoid + BCE")
      N = 16
      C_in = 3
      H = 5
      W = 5
      C_out = 1
      
      input = torch.randn(N, C_in, H, W)
      cnn = nn.Conv2d(C_in, C_out, kernel_size=(3, 3), padding=1)
      pd = cnn(input)
      pd = pd.view(pd.size(0), C_out*H*W)
      fc = nn.Linear(C_out*H*W, C_out)
      pd = fc(pd)
      gt = torch.empty(N, C_out, dtype=torch.float).random_(0, 2)
      print("input shape: {} / input type: {}".format(input.shape, input.dtype))
      print("pd shape: {} / pd type: {}".format(pd.shape, pd.dtype))
      print("gt shape: {} / gt type: {}".format(gt.shape, gt.dtype))
      
      loss_func_BCEWithLogits = nn.BCEWithLogitsLoss()
      loss_BCEWithLogits_nn = loss_func_BCEWithLogits(pd, gt)
      loss_BCEWithLogits_F = F.binary_cross_entropy_with_logits(pd, gt)
      print("BCEWithLogits: nn({}) / F({})".format(loss_BCEWithLogits_nn, loss_BCEWithLogits_F))
      
      m = nn.Sigmoid()
      loss_func_BCE = nn.BCELoss()
      loss_BCE_nn = loss_func_BCE(m(pd), gt)
      loss_BCE_F = F.binary_cross_entropy(m(pd), gt)
      print("Sigmoid + BCE: nn({}) / F({})".format(loss_BCE_nn, loss_BCE_F))
      print("-----------------------------------------------------------------")
  • Sigmoid + BCE:
    • 图像分类:
      nn. & F. Shape Type

      act = nn.Sigmoid()
      criterion = nn.BCELoss()
      loss = criterion(act(pd), gt)

      pd: [N]
      gt: [N]

      float32
      float32

      loss = F.binary_cross_entropy(torch.sigmoid(pd), gt)

      pd: [N]
      gt: [N]

      float32
      float32

    • 语义分割:
      nn. & F. Shape Type

      act = nn.Sigmoid()
      criterion = nn.BCELoss()
      loss = criterion(act(pd), gt)

      pd: [N, 1, H, W]
      gt: [N, 1, H, W]

      float32
      float32

      loss = F.binary_cross_entropy(torch.sigmoid(pd), gt)

      pd: [N, 1, H, W]
      gt: [N, 1, H, W]

      float32
      float32

    • 代码:
      import torch
      import torch.nn as nn
      import torch.nn.functional as F
      
      
      
      # Binary Case Image Segmentation: BCEWithLogits vs Sigmoid + BCE
      print("Binary Case Image Segmentation: BCEWithLogits vs Sigmoid + BCE")
      N = 16
      C_in = 3
      H = 5
      W = 5
      C_out = 1
      
      input = torch.randn(N, C_in, H, W)
      cnn = nn.Conv2d(C_in, C_out, kernel_size=(3, 3), padding=1)
      pd = cnn(input)
      gt = torch.empty(N, C_out, H, W, dtype=torch.float).random_(0, 2)
      print("input shape: {} / input type: {}".format(input.shape, input.dtype))
      print("pd shape: {} / pd type: {}".format(pd.shape, pd.dtype))
      print("gt shape: {} / gt type: {}".format(gt.shape, gt.dtype))
      
      loss_func_BCEWithLogits = nn.BCEWithLogitsLoss()
      loss_BCEWithLogits_nn = loss_func_BCEWithLogits(pd, gt)
      loss_BCEWithLogits_F = F.binary_cross_entropy_with_logits(pd, gt)
      print("BCEWithLogits: nn({}) / F({})".format(loss_BCEWithLogits_nn, loss_BCEWithLogits_F))
      
      m = nn.Sigmoid()
      loss_func_BCE = nn.BCELoss()
      loss_BCE_nn = loss_func_BCE(m(pd), gt)
      loss_BCE_F = F.binary_cross_entropy(torch.sigmoid(pd), gt)
      print("Sigmoid + BCE: nn({}) / F({})".format(loss_BCE_nn, loss_BCE_F))
      print("-----------------------------------------------------------------")

       

       

 

 

Multi-Class Case

  • CE:无需手动做 Softmax
    • 图像分类:
      nn. & F. Shape Type

      criterion = nn.CrossEntropyLoss()
      loss = criterion(pd, gt)

      pd: [N, C]
      gt: [N]

      float32
      long

      loss = F.cross_entropy(pd, gt)

      pd: [N, C]
      gt: [N]

      float32
      long

    • 语义分割:
      nn. & F. Shape Type

      criterion = nn.CrossEntropyLoss()
      loss = criterion(pd, gt)

      pd: [N, C, H, W]
      gt: [N, H, W]

      float32
      long

      loss = F.cross_entropy(pd, gt)

      pd: [N, C, H, W]
      gt: [N, H, W]

      float32
      long

    • 代码:
      import torch
      import torch.nn as nn
      import torch.nn.functional as F
      
      
      
      # Multi-Class Case Image Classification: CE vs LogSoftmax + NLL
      print("Multi-Class Case Image Classification: CE vs LogSoftmax + NLL")
      N = 16
      C_in = 3
      H = 5
      W = 5
      C_out = 10
      
      input = torch.randn(N, C_in, H, W)
      cnn = nn.Conv2d(C_in, C_out, kernel_size=(3, 3), padding=1)
      pd = cnn(input)
      pd = pd.view(pd.size(0), C_out*H*W)
      fc = nn.Linear(C_out*H*W, C_out)
      pd = fc(pd)
      gt = torch.empty(N, dtype=torch.long).random_(0, C_out)    # should be long
      print("input shape: {} / input type: {}".format(input.shape, input.dtype))
      print("pd shape: {} / pd type: {}".format(pd.shape, pd.dtype))
      print("gt shape: {} / gt type: {}".format(gt.shape, gt.dtype))
      
      loss_func_CE = nn.CrossEntropyLoss()
      loss_CE_nn = loss_func_CE(pd, gt)
      loss_CE_F = F.cross_entropy(pd, gt)
      print("CE: nn({}) / F({})".format(loss_CE_nn, loss_CE_F))
      
      m = nn.LogSoftmax(dim=1)
      loss_func_NLL = nn.NLLLoss()
      loss_NLL_nn = loss_func_NLL(m(pd), gt)
      loss_NLL_F = F.nll_loss(m(pd), gt)
      print("NLL: nn({}) / F({})".format(loss_NLL_nn, loss_NLL_F))
      print("-----------------------------------------------------------------")
  • LogSoftmax + NLL:
    • 图像分类:
      nn. & F. Shape Type

      m = nn.LogSoftmax(dim=1)
      criterion = nn.NLLLoss()
      loss = criterion(m(pd), gt)

      pd: [N, C]
      gt: [N]

      float32
      long

      loss = F.nll_loss(F.log_softmax(pd, dim=1), gt)

      pd: [N, C]
      gt: [N]

      float32
      long

    • 语义分割:
      nn. & F. Shape Type

      m = nn.LogSoftmax(dim=1)
      criterion = nn.NLLLoss()
      loss = criterion(m(pd), gt)

      pd: [N, C, H, W]
      gt: [N, H, W]

      float32
      long

      loss = F.nll_loss(F.log_softmax(pd, dim=1), gt)

      pd: [N, C, H, W]
      gt: [N, H, W]

      float32
      long

    • 代码:
      import torch
      import torch.nn as nn
      import torch.nn.functional as F
      
      
      
      # Multi-Class Case Image Segmentation: CE vs LogSoftmax + NLL
      print("Multi-Class Case Image Segmentation: CE vs LogSoftmax + NLL")
      N = 16
      C_in = 1
      H = 5
      W = 5
      C_out = 10
      
      input = torch.randn(N, C_in, H, W)
      cnn = nn.Conv2d(C_in, C_out, kernel_size=(3, 3), padding=1)
      pd = cnn(input)
      # pd = pd.view(pd.size(0), C_out*H*W)
      # fc = nn.Linear(C_out*H*W, C_out)
      # pd = fc(pd)
      gt = torch.empty(N, H, W, dtype=torch.long).random_(0, C_out)
      print("input shape: {} / input type: {}".format(input.shape, input.dtype))
      print("pd shape: {} / pd type: {}".format(pd.shape, pd.dtype))
      print("gt shape: {} / gt type: {}".format(gt.shape, gt.dtype))
      
      loss_func_CE = nn.CrossEntropyLoss()
      loss_CE_nn = loss_func_CE(pd, gt)
      loss_CE_F = F.cross_entropy(pd, gt)
      print("CE: nn({}) / F({})".format(loss_CE_nn, loss_CE_F))
      
      m = nn.LogSoftmax(dim=1)
      loss_func_NLL = nn.NLLLoss()
      loss_NLL_nn = loss_func_NLL(m(pd), gt)
      loss_NLL_F = F.nll_loss(m(pd), gt)
      print("NLL: nn({}) / F({})".format(loss_NLL_nn, loss_NLL_F))
      print("-----------------------------------------------------------------")

 

 

 

Multi-Class Case

WIP

posted @ 2021-03-11 08:55  hmlovetech  阅读(2525)  评论(0编辑  收藏  举报