RadeGS——MILO中的compute_depth_order_loss
目标:让预测深度 depth 的相对前后顺序尽量与先验深度 prior_depth 一致(不要求先验尺度准确)。
像素对采样:为每个像素随机采一个邻近像素(在 max_pixel_shift_ratio 控制的偏移范围内),并把偏移后的坐标 clamp 到图像边界。
取邻居深度:在随机邻居位置分别取 shifted_depth 和 shifted_prior_depth。
计算差值:计算深度差 diff = (depth - shifted_depth)/scene_extent 和先验差 prior_diff = (prior_depth - shifted_prior_depth)/scene_extent。
可选符号化先验:若 normalize_loss=True,把 prior_diff 除以其绝对值(detach 后)以近似保留符号、去掉幅值影响。
顺序惩罚规则:计算 diff * prior_diff,若为正(顺序一致)则 loss=0;若为负(顺序相反)则产生正惩罚:-(diff*prior_diff).clamp(max=0)。
可选鲁棒化:若 log_space=True,对惩罚做 log(1 + log_scale * loss) 压缩,减弱极大误差的影响。
输出聚合:按 reduction 选择返回均值(mean)、总和(sum)或逐像素图(none);debug=True 额外返回中间变量。
def compute_depth_order_loss( depth: torch.Tensor, prior_depth: torch.Tensor, scene_extent: float = 1., max_pixel_shift_ratio: float = 0.05, normalize_loss: bool = True, log_space: bool = False, log_scale: float = 20., reduction: str = "mean", debug: bool = False, ): """Compute a loss encouraging pixels in 'depth' to have the same relative depth order as in 'prior_depth'. This loss does not require prior depth maps to be multi-view consistent nor to have accurate relative scale. Args: depth (torch.Tensor): A tensor of shape (H, W), (H, W, 1) or (1, H, W) containing the depth values. prior_depth (torch.Tensor): A tensor of shape (H, W), (H, W, 1) or (1, H, W) containing the prior depth values. scene_extent (float): The extent of the scene used to normalize the loss and make the loss invariant to the scene scale. max_pixel_shift_ratio (float, optional): The maximum pixel shift ratio. Defaults to 0.05, i.e. 5% of the image size. normalize_loss (bool, optional): Whether to normalize the loss. Defaults to True. reduction (str, optional): The reduction to apply to the loss. Can be "mean", "sum" or "none". Defaults to "mean". Returns: torch.Tensor: A scalar tensor. If reduction is "none", returns a tensor with same shape as depth containing the pixel-wise depth order loss. """ height, width = depth.squeeze().shape pixel_coords = torch.stack(torch.meshgrid( torch.linspace(0, height - 1, height, dtype=torch.long, device=depth.device), torch.linspace(0, width - 1, width, dtype=torch.long, device=depth.device), indexing='ij' ), dim=-1).view(-1, 2) # Get random pixel shifts # TODO: Change the sampling so that shifts of (0, 0) are not possible max_pixel_shift = max(round(max_pixel_shift_ratio * max(height, width)), 1) pixel_shifts = torch.randint(-max_pixel_shift, max_pixel_shift + 1, pixel_coords.shape, device=depth.device) # Apply pixel shifts to pixel coordinates and clamp to image boundaries shifted_pixel_coords = (pixel_coords + pixel_shifts).clamp( min=torch.tensor([0, 0], device=depth.device), max=torch.tensor([height - 1, width - 1], device=depth.device) ) # Get depth values at shifted pixel coordinates shifted_depth = depth.squeeze()[ shifted_pixel_coords[:, 0], shifted_pixel_coords[:, 1] ].reshape(depth.shape) shifted_prior_depth = prior_depth.squeeze()[ shifted_pixel_coords[:, 0], shifted_pixel_coords[:, 1] ].reshape(depth.shape) # Compute pixel-wise depth order loss diff = (depth - shifted_depth) / scene_extent prior_diff = (prior_depth - shifted_prior_depth) / scene_extent if normalize_loss: prior_diff = prior_diff / prior_diff.detach().abs().clamp(min=1e-8) depth_order_loss = - (diff * prior_diff).clamp(max=0) if log_space: depth_order_loss = torch.log(1. + log_scale * depth_order_loss) # Reduce the loss if reduction == "mean": depth_order_loss = depth_order_loss.mean() elif reduction == "sum": depth_order_loss = depth_order_loss.sum() elif reduction == "none": pass else: raise ValueError(f"Invalid reduction: {reduction}") if debug: return { "depth_order_loss": depth_order_loss, "diff": diff, "prior_diff": prior_diff, "shifted_depth": shifted_depth, "shifted_prior_depth": shifted_prior_depth, } return depth_order_loss

浙公网安备 33010602011771号