实例分割loss示例

@LOSSES.register_module()
class LaneCustomWeightedLoss(nn.Module):
    def __init__(self,
                 negative_ratio=3,
                 loss_weight=1.0,
                 type_num=3,  # Assuming there are 11 classes
                 special_classes=[4],
                 special_weight=5)
        super(LaneCustomWeightedLoss, self).__init__()
        self.type_num = type_num
        self.negative_ratio = negative_ratio
        self.loss_weight = loss_weight
        self.special_classes = special_classes
        self.special_weight = special_weight

    def forward(self, pred_seg, gt_seg_cur, quality_weight):   # pred_seg:(bs, num_cls, H, W)
        B, H, W = gt_seg_cur.shape
        gt_seg_cur = gt_seg_cur.reshape(-1)
        pred_seg = pred_seg.permute(0, 2, 3, 1).reshape(-1, self.type_num)

        pmask = (gt_seg_cur > 0)   # 创建正样本掩码,标记真实标签中大于0的位置(即有车道线的位置)
        fpmask = pmask.float()
        n_pos = torch.sum(fpmask)
        nmask = (gt_seg_cur == 0)  # 创建负样本掩码,标记真实标签中等于0的位置(即无车道线的背景位置)
        fnmask = nmask.float()
        max_neg_entries = torch.sum(fnmask)
        n_neg = n_pos * self.negative_ratio   # 计算需要选择的负样本数量,按照正负样本比例(默认为1:3)
        n_neg = n_neg if n_neg < max_neg_entries else max_neg_entries
        n_neg = n_neg if n_neg > 0 else torch.ones_like(n_neg)
        n_neg_int = n_neg.int()

        bg_predict = pred_seg[:, 0]   # 获取预测结果中背景类别的预测值
        val, indxes = torch.topk(bg_predict[nmask], k=n_neg_int, largest=False)  # 在负样本位置中找到预测值最小的 n_neg_int 个样本(最难的负样本)
        max_hard_pred = val[-1]
        nmask_ohem = ((bg_predict <= max_hard_pred) * nmask)   # 通过OHEM策略创建最终的负样本掩码,选择最难的负样本
        fnmask_ohem = nmask_ohem.float()

        log_p = F.log_softmax(pred_seg, dim=-1)
        q_log_p = -log_p.gather(dim=1, index=gt_seg_cur.unsqueeze(1).long()).squeeze(1)  # 提取真实类别对应的对数概率,并取负号得到交叉熵损失
        
        # Apply special weight to specific classes
        special_mask = torch.zeros_like(gt_seg_cur, dtype=torch.float)
        for cls in self.special_classes:
            special_mask += (gt_seg_cur == cls).float()
        weighted_q_log_p = q_log_p * (1 + self.special_weight * special_mask)   # 为特殊类别应用额外权重,普通类别权重为1,特殊类别权重为 1+special_weight

        # 根据质检结果改变loss权重
        assert len(quality_weight) == B
        quality_weight = torch.tensor(
            quality_weight, device=pred_seg.device)[:, None, None].expand(B, H, W)
        quality_weight = quality_weight.reshape(-1)
        weighted_q_log_p *= quality_weight

        total_cross_pos = torch.sum(weighted_q_log_p * fpmask)   # 计算加权后的正样本交叉熵损失总和
        total_cross_neg = torch.sum(weighted_q_log_p * fnmask_ohem)   # 计算加权后的负样本交叉熵损失总和(仅包含OHEM选中的困难负样本)
        
        total_cross_pos = total_cross_pos / n_pos if n_pos > 0 else torch.zeros_like(total_cross_pos)
        total_cross_neg = total_cross_neg / n_neg

        total_cross_pos = total_cross_pos * self.loss_weight
        total_cross_neg = total_cross_neg * self.loss_weight
        
        return total_cross_pos + total_cross_neg,  pmask.view(B,H,W)

  

 

posted @ 2026-01-22 15:19  Picassooo  阅读(0)  评论(0)    收藏  举报