RadeGS——depth_order_loss/ranking_loss
def sample_pixel_pairs(mask, num_pairs): """ 从掩码中随机采样像素对 mask: H×W bool tensor return: (N, 2) index pairs in flattened index """ # 展平,找到所有为true的像素的索引 idx = torch.nonzero(mask.flatten(), as_tuple=False).squeeze(1) if idx.numel() < 2: return None # 随机采样 perm = torch.randint(0, idx.numel(), (num_pairs * 2,), device=mask.device) u = idx[perm[:num_pairs]] v = idx[perm[num_pairs:]] return u, v
`
def depth_order_loss(
pred_depth, # rendered_expected_depth (H×W)
gt_depth, # MoGe depth (H×W)
mask, # valid mask (H×W)
num_pairs=8192,
tau=0.02
):
"""
tau: 相对深度阈值(后面我会解释怎么定)
"""
device = pred_depth.device
H, W = pred_depth.shape
pred = pred_depth.flatten()
gt = gt_depth.flatten()
pairs = sample_pixel_pairs(mask, num_pairs)
if pairs is None:
return torch.tensor(0.0, device=device)
u, v = pairs
# MoGe depth difference
d_gt = gt[u] - gt[v]
# ordinal label
label = torch.zeros_like(d_gt)
label[d_gt > tau] = 1.0
label[d_gt < -tau] = -1.0
valid = label != 0
if valid.sum() < 16:
return torch.tensor(0.0, device=device)
d_pred = pred[u] - pred[v]
# logistic ranking loss
loss = torch.log1p(torch.exp(-label[valid] * d_pred[valid]))
return loss.mean()
`

浙公网安备 33010602011771号