Class-Balanced Loss Based on Effective Number of Samples - 2 - 代码学习

参考:https://github.com/vandit15/Class-balanced-loss-pytorch

其中的class_balanced_loss.py:

import numpy as np
import torch
import torch.nn.functional as F



def focal_loss(labels, logits, alpha, gamma):
    """Compute the focal loss between `logits` and the ground truth `labels`.
    Focal loss = -alpha_t * (1-pt)^gamma * log(pt)
    where pt is the probability of being classified to the true class.
    pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit).
    Args:
      labels: A float tensor of size [batch, num_classes].
      logits: A float tensor of size [batch, num_classes].
      alpha: A float tensor of size [batch_size]
        specifying per-example weight for balanced cross entropy.
      gamma: A float scalar modulating loss from hard and easy examples.
    Returns:
      focal_loss: A float32 scalar representing normalized total loss.
    """    
    BCLoss = F.binary_cross_entropy_with_logits(input = logits, target = labels,reduction = "none")

    if gamma == 0.0:
        modulator = 1.0
    else:
        modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 + 
            torch.exp(-1.0 * logits)))

    loss = modulator * BCLoss

    weighted_loss = alpha * loss
    focal_loss = torch.sum(weighted_loss)

    focal_loss /= torch.sum(labels)
    return focal_loss



def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma):
    """Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
    Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
    where Loss is one of the standard losses used for Neural Networks.
    Args:
      labels: A int tensor of size [batch].
      logits: A float tensor of size [batch, no_of_classes].
      samples_per_cls: A python list of size [no_of_classes].
      no_of_classes: total number of classes. int
      loss_type: string. One of "sigmoid", "focal", "softmax".
      beta: float. Hyperparameter for Class balanced loss.
      gamma: float. Hyperparameter for Focal loss.
    Returns:
      cb_loss: A float tensor representing class balanced loss
    """
    effective_num = 1.0 - np.power(beta, samples_per_cls)
    weights = (1.0 - beta) / np.array(effective_num)
    weights = weights / np.sum(weights) * no_of_classes

    labels_one_hot = F.one_hot(labels, no_of_classes).float()

    weights = torch.tensor(weights).float()
    weights = weights.unsqueeze(0)
    weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot
    weights = weights.sum(1)
    weights = weights.unsqueeze(1)
    weights = weights.repeat(1,no_of_classes)

    if loss_type == "focal":
        cb_loss = focal_loss(labels_one_hot, logits, weights, gamma)
    elif loss_type == "sigmoid":
        cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights)
    elif loss_type == "softmax":
        pred = logits.softmax(dim = 1)
        cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights)
    return cb_loss



if __name__ == '__main__':
    no_of_classes = 5
    logits = torch.rand(10,no_of_classes).float()
    labels = torch.randint(0,no_of_classes, size = (10,))
    beta = 0.9999
    gamma = 2.0
    samples_per_cls = [2,3,1,2,2]
    loss_type = "focal"
    cb_loss = CB_loss(labels, logits, samples_per_cls, no_of_classes,loss_type, beta, gamma)
    print(cb_loss)
View Code

 

添加注释和输出的版本:

#coding:utf-8
import numpy as np
import torch
import torch.nn.functional as F



def focal_loss(labels, logits, alpha, gamma):
    """Compute the focal loss between `logits` and the ground truth `labels`.
    Focal loss = -alpha_t * (1-pt)^gamma * log(pt)
    where pt is the probability of being classified to the true class.
    pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit).
    Args:
      labels: A float tensor of size [batch, num_classes].
      logits: A float tensor of size [batch, num_classes].
      alpha: A float tensor of size [batch_size]
        specifying per-example weight for balanced cross entropy.
      gamma: A float scalar modulating loss from hard and easy examples.
    Returns:
      focal_loss: A float32 scalar representing normalized total loss.
    """    
    BCLoss = F.binary_cross_entropy_with_logits(input = logits, target = labels,reduction = "none")

    if gamma == 0.0:
        modulator = 1.0
    else:
        modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 + 
            torch.exp(-1.0 * logits)))

    loss = modulator * BCLoss

    weighted_loss = alpha * loss
    
    # 然后求损失的均值mean()
    focal_loss = torch.sum(weighted_loss)
    focal_loss /= torch.sum(labels)
    return focal_loss



def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma):
    """Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
    Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
    where Loss is one of the standard losses used for Neural Networks.
    Args:
      labels: A int tensor of size [batch].
      logits: A float tensor of size [batch, no_of_classes].
      samples_per_cls: A python list of size [no_of_classes].
      no_of_classes: total number of classes. int
      loss_type: string. One of "sigmoid", "focal", "softmax".
      beta: float. Hyperparameter for Class balanced loss.
      gamma: float. Hyperparameter for Focal loss.
    Returns:
      cb_loss: A float tensor representing class balanced loss
    """

    # 下面的操作用来计算((1-beta)/(1-beta^n)),即使用在损失函数中的weight
    effective_num = 1.0 - np.power(beta, samples_per_cls)
    print('effective_num shape: ', effective_num.shape)
    print(effective_num)
    weights = (1.0 - beta) / np.array(effective_num)
    print('weights shape : ', weights.shape)
    print(weights)
    weights = weights / np.sum(weights) * no_of_classes #归一化
    print('weights shape : ', weights.shape)
    print(weights)

    labels_one_hot = F.one_hot(labels, no_of_classes).float()
    print('labels_one_hot shape: ', labels_one_hot.shape)
    print(labels_one_hot)

    print('-'*50)
    weights = torch.tensor(weights).float()
    weights = weights.unsqueeze(0)
    print('unsqueeze weights shape : ', weights.shape)
    print(weights)
    #labels_one_hot.shape[0]得到样本数量,weight.repeat(,1)函数中的1表示weight对应位置的大小不变,所以是对行repeat labels_one_hot.shape[0]变
    print(weights.repeat(labels_one_hot.shape[0],1)) #([1,5])变成([10,5])
    weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot 
    print('repeat weights shape : ', weights.shape)
    print(weights)
    weights = weights.sum(1) #按dim=1相加,只留下dim=0,即得到每个样本的weight
    print('sum weights shape : ', weights.shape)
    print(weights)
    weights = weights.unsqueeze(1)
    print('unsqueeze weights shape : ', weights.shape)
    print(weights)
    weights = weights.repeat(1,no_of_classes) #这个就是按列相乘,([10,1])变成([10,5])
    print('repeat weights shape : ', weights.shape)
    print(weights)
    print('-'*50)

    if loss_type == "focal":
        cb_loss = focal_loss(labels_one_hot, logits, weights, gamma)
    elif loss_type == "sigmoid":
        cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights)
    elif loss_type == "softmax":
        pred = logits.softmax(dim = 1)
        cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights)
    return cb_loss



if __name__ == '__main__':
    no_of_classes = 5 #10个样本,5个类别
    logits = torch.rand(10,no_of_classes).float() # 预测10个样本分别是5个类别的概率
    print('logits shape : ', logits.shape)
    print(logits)
    labels = torch.randint(0,no_of_classes, size = (10,)) #10个样本的实际类别
    print('labels shape : ', labels.shape)
    print(labels)

    beta = 0.9999 #参数设置
    gamma = 2.0 #参数设置
    samples_per_cls = [2,3,1,2,2] # 每个类别的样本数
    loss_type = "focal"
    cb_loss = CB_loss(labels, logits, samples_per_cls, no_of_classes,loss_type, beta, gamma)
    print(cb_loss)
View Code

返回:

(deeplearning) bogon:work_gender_age wanghui$ python test_delete.py 
logits shape :  torch.Size([10, 5])
tensor([[0.1505, 0.9621, 0.8622, 0.0237, 0.0270],
        [0.6218, 0.2745, 0.3015, 0.1501, 0.1728],
        [0.3590, 0.1760, 0.0807, 0.7440, 0.6973],
        [0.9401, 0.7118, 0.1725, 0.1843, 0.3226],
        [0.4655, 0.8319, 0.4336, 0.8718, 0.5842],
        [0.9423, 0.3339, 0.1081, 0.4718, 0.4329],
        [0.5122, 0.7010, 0.1736, 0.5903, 0.0712],
        [0.6442, 0.1365, 0.1391, 0.8278, 0.5986],
        [0.1245, 0.5662, 0.9571, 0.8515, 0.9883],
        [0.4654, 0.8924, 0.0224, 0.9056, 0.4517]])
labels shape :  torch.Size([10])
tensor([0, 3, 4, 2, 2, 1, 4, 0, 1, 2])
effective_num shape:  (5,)
[1.99990000e-04 2.99970001e-04 1.00000000e-04 1.99990000e-04
 1.99990000e-04]
weights shape :  (5,)
[0.500025   0.33336667 1.         0.500025   0.500025  ]
weights shape :  (5,)
[0.88236332 0.58827163 1.76463841 0.88236332 0.88236332]
labels_one_hot shape:  torch.Size([10, 5])
tensor([[1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.]])
--------------------------------------------------
unsqueeze weights shape :  torch.Size([1, 5])
tensor([[0.8824, 0.5883, 1.7646, 0.8824, 0.8824]])
tensor([[0.8824, 0.5883, 1.7646, 0.8824, 0.8824],
        [0.8824, 0.5883, 1.7646, 0.8824, 0.8824],
        [0.8824, 0.5883, 1.7646, 0.8824, 0.8824],
        [0.8824, 0.5883, 1.7646, 0.8824, 0.8824],
        [0.8824, 0.5883, 1.7646, 0.8824, 0.8824],
        [0.8824, 0.5883, 1.7646, 0.8824, 0.8824],
        [0.8824, 0.5883, 1.7646, 0.8824, 0.8824],
        [0.8824, 0.5883, 1.7646, 0.8824, 0.8824],
        [0.8824, 0.5883, 1.7646, 0.8824, 0.8824],
        [0.8824, 0.5883, 1.7646, 0.8824, 0.8824]])
repeat weights shape :  torch.Size([10, 5])
tensor([[0.8824, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.8824, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.8824],
        [0.0000, 0.0000, 1.7646, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.7646, 0.0000, 0.0000],
        [0.0000, 0.5883, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.8824],
        [0.8824, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5883, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.7646, 0.0000, 0.0000]])
sum weights shape :  torch.Size([10])
tensor([0.8824, 0.8824, 0.8824, 1.7646, 1.7646, 0.5883, 0.8824, 0.8824, 0.5883,
        1.7646])
unsqueeze weights shape :  torch.Size([10, 1])
tensor([[0.8824],
        [0.8824],
        [0.8824],
        [1.7646],
        [1.7646],
        [0.5883],
        [0.8824],
        [0.8824],
        [0.5883],
        [1.7646]])
repeat weights shape :  torch.Size([10, 5])
tensor([[0.8824, 0.8824, 0.8824, 0.8824, 0.8824],
        [0.8824, 0.8824, 0.8824, 0.8824, 0.8824],
        [0.8824, 0.8824, 0.8824, 0.8824, 0.8824],
        [1.7646, 1.7646, 1.7646, 1.7646, 1.7646],
        [1.7646, 1.7646, 1.7646, 1.7646, 1.7646],
        [0.5883, 0.5883, 0.5883, 0.5883, 0.5883],
        [0.8824, 0.8824, 0.8824, 0.8824, 0.8824],
        [0.8824, 0.8824, 0.8824, 0.8824, 0.8824],
        [0.5883, 0.5883, 0.5883, 0.5883, 0.5883],
        [1.7646, 1.7646, 1.7646, 1.7646, 1.7646]])
--------------------------------------------------
tensor(1.9583)
View Code

可见在代码中能够使用二分类求损失主要是因为将labels转换成了ont-hot格式

labels_one_hot = F.one_hot(labels, no_of_classes).float()

 

主要比较复杂的就是focal loss的实现:

1)BCLoss = F.binary_cross_entropy_with_logits()

 

 

2) modulator

 

 

3)weight 即传进来的参数alpha

posted @ 2020-01-15 18:29  慢行厚积  阅读(1847)  评论(3编辑  收藏  举报