RadeGS——MILO中的compute_depth_order_loss

目标:让预测深度 depth 的相对前后顺序尽量与先验深度 prior_depth 一致(不要求先验尺度准确)。

像素对采样:为每个像素随机采一个邻近像素(在 max_pixel_shift_ratio 控制的偏移范围内),并把偏移后的坐标 clamp 到图像边界。

取邻居深度:在随机邻居位置分别取 shifted_depth 和 shifted_prior_depth。

计算差值:计算深度差 diff = (depth - shifted_depth)/scene_extent 和先验差 prior_diff = (prior_depth - shifted_prior_depth)/scene_extent。

可选符号化先验:若 normalize_loss=True,把 prior_diff 除以其绝对值(detach 后)以近似保留符号、去掉幅值影响。

顺序惩罚规则:计算 diff * prior_diff,若为正(顺序一致)则 loss=0;若为负(顺序相反)则产生正惩罚:-(diff*prior_diff).clamp(max=0)。

可选鲁棒化:若 log_space=True,对惩罚做 log(1 + log_scale * loss) 压缩,减弱极大误差的影响。

输出聚合:按 reduction 选择返回均值(mean)、总和(sum)或逐像素图(none);debug=True 额外返回中间变量。 

def compute_depth_order_loss(
        depth: torch.Tensor,
        prior_depth: torch.Tensor,
        scene_extent: float = 1.,
        max_pixel_shift_ratio: float = 0.05,
        normalize_loss: bool = True,
        log_space: bool = False,
        log_scale: float = 20.,
        reduction: str = "mean",
        debug: bool = False,
):
    """Compute a loss encouraging pixels in 'depth' to have the same relative depth order as in 'prior_depth'.
    This loss does not require prior depth maps to be multi-view consistent nor to have accurate relative scale.

    Args:
        depth (torch.Tensor): A tensor of shape (H, W), (H, W, 1) or (1, H, W) containing the depth values.
        prior_depth (torch.Tensor): A tensor of shape (H, W), (H, W, 1) or (1, H, W) containing the prior depth values.
        scene_extent (float): The extent of the scene used to normalize the loss and make the loss invariant to the scene scale.
        max_pixel_shift_ratio (float, optional): The maximum pixel shift ratio. Defaults to 0.05, i.e. 5% of the image size.
        normalize_loss (bool, optional): Whether to normalize the loss. Defaults to True.
        reduction (str, optional): The reduction to apply to the loss. Can be "mean", "sum" or "none". Defaults to "mean".

    Returns:
        torch.Tensor: A scalar tensor.
            If reduction is "none", returns a tensor with same shape as depth containing the pixel-wise depth order loss.
    """
    height, width = depth.squeeze().shape
    pixel_coords = torch.stack(torch.meshgrid(
        torch.linspace(0, height - 1, height, dtype=torch.long, device=depth.device),
        torch.linspace(0, width - 1, width, dtype=torch.long, device=depth.device),
        indexing='ij'
    ), dim=-1).view(-1, 2)

    # Get random pixel shifts
    # TODO: Change the sampling so that shifts of (0, 0) are not possible
    max_pixel_shift = max(round(max_pixel_shift_ratio * max(height, width)), 1)
    pixel_shifts = torch.randint(-max_pixel_shift, max_pixel_shift + 1, pixel_coords.shape, device=depth.device)

    # Apply pixel shifts to pixel coordinates and clamp to image boundaries
    shifted_pixel_coords = (pixel_coords + pixel_shifts).clamp(
        min=torch.tensor([0, 0], device=depth.device),
        max=torch.tensor([height - 1, width - 1], device=depth.device)
    )

    # Get depth values at shifted pixel coordinates
    shifted_depth = depth.squeeze()[
        shifted_pixel_coords[:, 0],
        shifted_pixel_coords[:, 1]
    ].reshape(depth.shape)
    shifted_prior_depth = prior_depth.squeeze()[
        shifted_pixel_coords[:, 0],
        shifted_pixel_coords[:, 1]
    ].reshape(depth.shape)

    # Compute pixel-wise depth order loss
    diff = (depth - shifted_depth) / scene_extent
    prior_diff = (prior_depth - shifted_prior_depth) / scene_extent
    if normalize_loss:
        prior_diff = prior_diff / prior_diff.detach().abs().clamp(min=1e-8)
    depth_order_loss = - (diff * prior_diff).clamp(max=0)
    if log_space:
        depth_order_loss = torch.log(1. + log_scale * depth_order_loss)

    # Reduce the loss
    if reduction == "mean":
        depth_order_loss = depth_order_loss.mean()
    elif reduction == "sum":
        depth_order_loss = depth_order_loss.sum()
    elif reduction == "none":
        pass
    else:
        raise ValueError(f"Invalid reduction: {reduction}")

    if debug:
        return {
            "depth_order_loss": depth_order_loss,
            "diff": diff,
            "prior_diff": prior_diff,
            "shifted_depth": shifted_depth,
            "shifted_prior_depth": shifted_prior_depth,
        }
    return depth_order_loss

 

posted @ 2026-01-04 15:19  蘑菇味的花魂  阅读(5)  评论(0)    收藏  举报