标签平滑实现
最近做二分类发现对正样本输出的概率特别高。。应该是过拟合了,然后有一个叫做标签平滑的操作,方便以后复用,记录在这里。。
看了一个原理写的不错的!https://blog.csdn.net/qq_41915623/article/details/124852409
import torch
import torch.nn as nn
class CrossEntropyWithLabelSmoothing(nn.Module):
def __init__(self, epsilon=0.1, reduction="mean", ignore_index=-100):
super(CrossEntropyWithLabelSmoothing, self).__init__()
self.epsilon = epsilon
self.reduction = reduction
self.ignore_index = ignore_index
self.logsoftmax = nn.LogSoftmax(dim=-1)
def forward(self, output, target, class_num):
"""
output: (batch_size, class_num)
target: (batch_size,)
class_num: k分类
"""
batch_size = output.shape[0]
y = torch.zeros((batch_size, class_num))
y[range(len(target)), target] = 1 # 得到 数字转one_hot的向量
y1 = (1 - self.epsilon) * y
y2 = self.epsilon * torch.ones(batch_size, class_num) / class_num
smooth_y = y1 + y2
loss = - smooth_y * self.logsoftmax(output)
if self.reduction == "sum":
return loss.sum() # sum
if self.reduction == "mean":
return loss.sum(-1).mean() # sum / k
if self.reduction == "none":
return loss.sum(-1) # (batch_size, )
output = torch.tensor([0.1, 0.1, 0.1, 0.36, 0.34])
criterion = CrossEntropyWithLabelSmoothing()
loss = criterion(torch.randn((5, 3)), torch.randint(3, (5, )), 3)

浙公网安备 33010602011771号