标签平滑实现

最近做二分类发现对正样本输出的概率特别高。。应该是过拟合了,然后有一个叫做标签平滑的操作,方便以后复用,记录在这里。。

看了一个原理写的不错的!https://blog.csdn.net/qq_41915623/article/details/124852409

import torch
import torch.nn as nn


class CrossEntropyWithLabelSmoothing(nn.Module):
    def __init__(self, epsilon=0.1, reduction="mean", ignore_index=-100):
        super(CrossEntropyWithLabelSmoothing, self).__init__()
        self.epsilon = epsilon
        self.reduction = reduction
        self.ignore_index = ignore_index
        self.logsoftmax = nn.LogSoftmax(dim=-1)

    def forward(self, output, target, class_num):
        """
        output: (batch_size, class_num)
        target: (batch_size,)
        class_num: k分类
        """
        batch_size = output.shape[0]
        y = torch.zeros((batch_size, class_num))
        y[range(len(target)), target] = 1  # 得到 数字转one_hot的向量
        y1 = (1 - self.epsilon) * y
        y2 = self.epsilon * torch.ones(batch_size, class_num) / class_num
        smooth_y = y1 + y2
        loss = - smooth_y * self.logsoftmax(output)
        if self.reduction == "sum":
            return loss.sum() # sum
        if self.reduction == "mean":
            return loss.sum(-1).mean() # sum / k
        if self.reduction == "none":
            return loss.sum(-1) # (batch_size, )

output = torch.tensor([0.1, 0.1, 0.1, 0.36, 0.34])
criterion = CrossEntropyWithLabelSmoothing()
loss = criterion(torch.randn((5, 3)), torch.randint(3, (5, )), 3)
posted @ 2022-08-13 17:25  アイラ  阅读(57)  评论(0)    收藏  举报