#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2020/02/04 20:08
# @Author : dangxusheng
# @Email : dangxusheng163@163.com
# @File : isLand_loss.py
'''
岛屿损失旨在减少类内变化,同时扩大类间差异
目的是在center loss的基础上, 进一步优化类间距离
https://blog.csdn.net/heruili/article/details/88912074
L_island = L_center + lamda1 * penalty
Loss = L_softmax + lamda * L_island
'''
from myToolsPkgs.pytorch_helper import *
from torch.autograd import Function
class IslandLoss(nn.Module):
"""
paper: https://arxiv.org/pdf/1710.03144.pdf
url: https://blog.csdn.net/u013841196/article/details/89920441
"""
def __init__(self, features_dim, num_class=10, lamda=1., lamda1=10., scale=1.0, batch_size=64):
"""
初始化
:param features_dim: 特征维度 = c*h*w
:param num_class: 类别数量
:param lamda: island loss的权重系数
:param lamda1: island loss内部 特征中心距离惩罚项的权重系数
:param scale: 特征中心梯度的缩放因子
:param batch_size: 批次大小
"""
super(IslandLoss, self).__init__()
self.lamda = lamda
self.lamda1 = lamda1
self.num_class = num_class
self.scale = scale
self.batch_size = batch_size
self.feat_dim = features_dim
# store the center of each class , should be ( num_class, features_dim)
self.feature_centers = nn.Parameter(torch.randn([num_class, features_dim]))
# self.lossfunc = IslandLossFunc.apply
def forward(self, output_features, y_truth):
"""
损失计算
:param output_features: conv层输出的特征, [b,c,h,w]
:param y_truth: 标签值 [b,]
:return:
"""
batch_size = y_truth.size(0)
num_class = self.num_class
output_features = output_features.view(batch_size, -1)
assert output_features.size(-1) == self.feat_dim
factor = self.scale / batch_size
# # # 第一种: 使用自己重写的backward
# return self.lossfunc(output_features, y_truth, self.feature_centers,
# torch.Tensor([self.alpha, self.lamda, self.lamda1, self.scale]))
# 第二种: 使用pytorch默认的
centers_batch = self.feature_centers.index_select(0, y_truth.long()) # [b,features_dim]
diff = output_features - centers_batch
# 1 先求 center loss
loss_center = 1 / 2.0 * (diff.pow(2).sum()) * factor
# 2 再求 类心余弦距离
# 每个类心求余弦距离,+1 使得范围为0-2,越接近0表示类别差异越大,从而优化Loss即使得类间距离变大。
centers = self.feature_centers
# 求出向量模长矩阵 ||Ci||
centers_mod = torch.sum(centers * centers, dim=1, keepdim=True).sqrt() # [num_class, 1]
# ====================== method 1 =======================
item1_sum = 0
for j in range(num_class):
dis_sum_j_others = 0
for k in range(j + 1, num_class):
dot_kj = torch.sum(centers[j] * centers[k])
fenmu = centers_mod[j] * centers_mod[k] + 1e-9
cos_dis = dot_kj / fenmu
dis_sum_j_others += cos_dis + 1.
# print(dis_sum_j_others)
item1_sum += dis_sum_j_others
loss_island = self.lamda * (loss_center + self.lamda1 * item1_sum)
# ====================== method 2 =======================
# # Ci X Ci.T
# centers_mm = torch.matmul(centers,centers.t()) # [num_class, num_class]
# centers_mod_mm = centers_mod.mm(centers_mod.t()) # [num_class,num_class]
# # 求出 cos距离 矩阵, 这是一个对称矩阵
# centers_cos_dis = centers_mm / centers_mod_mm
# centers_cos_dis += 1.
# # 只获取上三角, 代表同一个类别的距离不考虑
# centers_cos_dis_1 = torch.triu(centers_cos_dis,diagonal=1)
# print(centers_cos_dis_1)
# sum_centers_cos_dis = torch.sum(centers_cos_dis_1)
# loss_island = self.lamda * (loss_center + self.lamda1 * sum_centers_cos_dis)
return loss_island
torch.manual_seed(1000)
if __name__ == '__main__':
import random
# test 1
num_class = 10
batch_size = 10
feat_dim = 2
ct = IslandLoss(feat_dim, num_class, 0.1, 1., 1., batch_size)
y = torch.Tensor([random.choice(range(num_class)) for i in range(batch_size)])
feat = torch.randn(num_class, feat_dim).requires_grad_()
print(feat)
out = ct(feat, y)
out.backward()
print(ct.feature_centers.grad)
print(feat.grad)