lightweight openpose和hrnet的loss的区别

lightweight openpose的loss都用的是平方损失。

def l2_loss(input, target, mask, batch_size):
loss = (input - target) * mask
loss = (loss * loss) / 2 / batch_size

return loss.sum()

计算各个输出都用的是l2平方损失。

hrnet的loss都用的是关键点是平方损失而用的损失函数l1损失。并乘以权重

class HeatmapLoss(nn.Module):
def __init__(self):
super().__init__()

def forward(self, pred, gt, mask):
assert pred.size() == gt.size()
loss = ((pred - gt)**2) * mask
loss = loss.mean(dim=3).mean(dim=2).mean(dim=1).mean(dim=0)
return loss


class OffsetsLoss(nn.Module):
def __init__(self):
super().__init__()

def smooth_l1_loss(self, pred, gt, beta=1. / 9):
l1_loss = torch.abs(pred - gt)
cond = l1_loss < beta
loss = torch.where(cond, 0.5*l1_loss**2/beta, l1_loss-0.5*beta)
return loss

def forward(self, pred, gt, weights):
assert pred.size() == gt.size()
num_pos = torch.nonzero(weights > 0).size()[0]
loss = self.smooth_l1_loss(pred, gt) * weights
if num_pos == 0:
num_pos = 1.
loss = loss.sum() / num_pos
return loss

posted @ 2023-02-05 17:05  祥瑞哈哈哈  阅读(125)  评论(0)    收藏  举报