Spaces:
Running
Running
| from einops.einops import rearrange | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from romatch.utils.utils import get_gt_warp | |
| import wandb | |
| import romatch | |
| import math | |
| class RobustLosses(nn.Module): | |
| def __init__( | |
| self, | |
| robust=False, | |
| center_coords=False, | |
| scale_normalize=False, | |
| ce_weight=0.01, | |
| local_loss=True, | |
| local_dist=4.0, | |
| local_largest_scale=8, | |
| smooth_mask = False, | |
| depth_interpolation_mode = "bilinear", | |
| mask_depth_loss = False, | |
| relative_depth_error_threshold = 0.05, | |
| alpha = 1., | |
| c = 1e-3, | |
| ): | |
| super().__init__() | |
| self.robust = robust # measured in pixels | |
| self.center_coords = center_coords | |
| self.scale_normalize = scale_normalize | |
| self.ce_weight = ce_weight | |
| self.local_loss = local_loss | |
| self.local_dist = local_dist | |
| self.local_largest_scale = local_largest_scale | |
| self.smooth_mask = smooth_mask | |
| self.depth_interpolation_mode = depth_interpolation_mode | |
| self.mask_depth_loss = mask_depth_loss | |
| self.relative_depth_error_threshold = relative_depth_error_threshold | |
| self.avg_overlap = dict() | |
| self.alpha = alpha | |
| self.c = c | |
| def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale): | |
| with torch.no_grad(): | |
| B, C, H, W = scale_gm_cls.shape | |
| device = x2.device | |
| cls_res = round(math.sqrt(C)) | |
| G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)], indexing='ij') | |
| G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) | |
| GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices | |
| cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99] | |
| certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob) | |
| if not torch.any(cls_loss): | |
| cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere | |
| losses = { | |
| f"gm_certainty_loss_{scale}": certainty_loss.mean(), | |
| f"gm_cls_loss_{scale}": cls_loss.mean(), | |
| } | |
| wandb.log(losses, step = romatch.GLOBAL_STEP) | |
| return losses | |
| def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale): | |
| with torch.no_grad(): | |
| B, C, H, W = delta_cls.shape | |
| device = x2.device | |
| cls_res = round(math.sqrt(C)) | |
| G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)]) | |
| G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale | |
| GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices | |
| cls_loss = F.cross_entropy(delta_cls, GT, reduction = 'none')[prob > 0.99] | |
| certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob) | |
| if not torch.any(cls_loss): | |
| cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere | |
| losses = { | |
| f"delta_certainty_loss_{scale}": certainty_loss.mean(), | |
| f"delta_cls_loss_{scale}": cls_loss.mean(), | |
| } | |
| wandb.log(losses, step = romatch.GLOBAL_STEP) | |
| return losses | |
| def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"): | |
| epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1) | |
| if scale == 1: | |
| pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean() | |
| wandb.log({"train_pck_05": pck_05}, step = romatch.GLOBAL_STEP) | |
| ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob) | |
| a = self.alpha[scale] if isinstance(self.alpha, dict) else self.alpha | |
| cs = self.c * scale | |
| x = epe[prob > 0.99] | |
| reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2) | |
| if not torch.any(reg_loss): | |
| reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere | |
| losses = { | |
| f"{mode}_certainty_loss_{scale}": ce_loss.mean(), | |
| f"{mode}_regression_loss_{scale}": reg_loss.mean(), | |
| } | |
| wandb.log(losses, step = romatch.GLOBAL_STEP) | |
| return losses | |
| def forward(self, corresps, batch): | |
| scales = list(corresps.keys()) | |
| tot_loss = 0.0 | |
| # scale_weights due to differences in scale for regression gradients and classification gradients | |
| scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1} | |
| for scale in scales: | |
| scale_corresps = corresps[scale] | |
| scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = ( | |
| scale_corresps["certainty"], | |
| scale_corresps.get("flow_pre_delta"), | |
| scale_corresps.get("delta_cls"), | |
| scale_corresps.get("offset_scale"), | |
| scale_corresps.get("gm_cls"), | |
| scale_corresps.get("gm_certainty"), | |
| scale_corresps["flow"], | |
| scale_corresps.get("gm_flow"), | |
| ) | |
| if flow_pre_delta is not None: | |
| flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d") | |
| b, h, w, d = flow_pre_delta.shape | |
| else: | |
| # _ = 1 | |
| b, _, h, w = scale_certainty.shape | |
| gt_warp, gt_prob = get_gt_warp( | |
| batch["im_A_depth"], | |
| batch["im_B_depth"], | |
| batch["T_1to2"], | |
| batch["K1"], | |
| batch["K2"], | |
| H=h, | |
| W=w, | |
| ) | |
| x2 = gt_warp.float() | |
| prob = gt_prob | |
| if self.local_largest_scale >= scale: | |
| prob = prob * ( | |
| F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0] | |
| < (2 / 512) * (self.local_dist[scale] * scale)) | |
| if scale_gm_cls is not None: | |
| gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale) | |
| gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"] | |
| tot_loss = tot_loss + scale_weights[scale] * gm_loss | |
| elif scale_gm_flow is not None: | |
| gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm") | |
| gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"] | |
| tot_loss = tot_loss + scale_weights[scale] * gm_loss | |
| if delta_cls is not None: | |
| delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale) | |
| delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"] | |
| tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss | |
| else: | |
| delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale) | |
| reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"] | |
| tot_loss = tot_loss + scale_weights[scale] * reg_loss | |
| prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach() | |
| return tot_loss | |