import sys import math from typing import Callable, Mapping import skimage.metrics as sk_metrics import torch import torch.nn.functional as F from ignite.engine import Engine from ignite.exceptions import NotComputableError from ignite.metrics import Metric from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce import pulp def median_scaling( depth_gt: torch.Tensor, depth_pred: torch.Tensor, ): # TODO: ensure this works for any batch size mask = depth_gt > 0 depth_gt[mask] = torch.nan depth_pred[mask] = torch.nan scaling = torch.nanmedian(depth_gt.flatten(-2, -1), dim=-1) / torch.nanmedian( depth_pred.flatten(-2, -1), dim=-1 ) depth_pred = scaling[..., None, None] * depth_pred return depth_pred def l2_scaling( depth_gt: torch.Tensor, depth_pred: torch.Tensor, ): # TODO: ensure this works for any batch size mask = depth_gt > 0 depth_pred = depth_pred depth_gt_ = depth_gt[mask] depth_pred_ = depth_pred[mask] depth_pred_ = torch.stack((depth_pred_, torch.ones_like(depth_pred_)), dim=-1) x = torch.linalg.lstsq( depth_pred_.to(torch.float32), depth_gt_.unsqueeze(-1).to(torch.float32) ).solution.squeeze() depth_pred = depth_pred * x[0] + x[1] return depth_pred def compute_depth_metrics( depth_gt: torch.Tensor, depth_pred: torch.Tensor, scaling_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None, ): # TODO: find out if dim -3 is dummy dimension or part of the batch # TODO: Test if works for batches of images if scaling_fn: depth_pred = scaling_fn(depth_gt, depth_pred) depth_pred = torch.clamp(depth_pred, 1e-3, 80) mask = depth_gt != 0 max_ratio = torch.maximum((depth_gt / depth_pred), (depth_pred / depth_gt)) a_scores = {} for name, thresh in {"a1": 1.25, "a2": 1.25**2, "a3": 1.25**3}.items(): within_thresh = (max_ratio < thresh).to(torch.float) within_thresh[~mask] = 0.0 a_scores[name] = within_thresh.flatten(-2, -1).sum(dim=-1) / mask.to( torch.float ).flatten(-2, -1).sum(dim=-1) square_error = (depth_gt - depth_pred) ** 2 square_error[~mask] = 0.0 log_square_error = (torch.log(depth_gt) - torch.log(depth_pred)) ** 2 log_square_error[~mask] = 0.0 abs_error = torch.abs(depth_gt - depth_pred) abs_error[~mask] = 0.0 rmse = ( square_error.flatten(-2, -1).sum(dim=-1) / mask.to(torch.float).flatten(-2, -1).sum(dim=-1) ) ** 0.5 rmse_log = ( log_square_error.flatten(-2, -1).sum(dim=-1) / mask.to(torch.float).flatten(-2, -1).sum(dim=-1) ) ** 0.5 abs_rel = abs_error / depth_gt abs_rel[~mask] = 0.0 abs_rel = ( abs_rel.flatten(-2, -1).sum(dim=-1) / mask.to(torch.float).flatten(-2, -1).sum(dim=-1) ) ** 0.5 sq_rel = square_error / depth_gt sq_rel[~mask] = 0.0 sq_rel = ( sq_rel.flatten(-2, -1).sum(dim=-1) / mask.to(torch.float).flatten(-2, -1).sum(dim=-1) ) ** 0.5 metrics_dict = { "abs_rel": abs_rel, "sq_rel": sq_rel, "rmse": rmse, "rmse_log": rmse_log, "a1": a_scores["a1"], "a2": a_scores["a2"], "a3": a_scores["a3"], } return metrics_dict def compute_occ_metrics( occupancy_pred: torch.Tensor, occupancy_gt: torch.Tensor, is_visible: torch.Tensor ): # Only not visible points can be occupied occupancy_gt &= ~is_visible is_occupied_acc = (occupancy_pred == occupancy_gt).float().mean().item() is_occupied_prec = occupancy_gt[occupancy_pred].float().mean().item() is_occupied_rec = occupancy_pred[occupancy_gt].float().mean().item() not_occupied_not_visible_ratio = ( ((~occupancy_gt) & (~is_visible)).float().mean().item() ) total_ie = ((~occupancy_gt) & (~is_visible)).float().sum().item() ie_acc = (occupancy_pred == occupancy_gt)[(~is_visible)].float().mean().item() ie_prec = (~occupancy_gt)[(~occupancy_pred) & (~is_visible)].float().mean() ie_rec = (~occupancy_pred)[(~occupancy_gt) & (~is_visible)].float().mean() total_no_nop_nv = ( ((~occupancy_gt) & (~occupancy_pred))[(~is_visible) & (~occupancy_gt)] .float() .sum() ) return { "o_acc": is_occupied_acc, "o_rec": is_occupied_rec, "o_prec": is_occupied_prec, "ie_acc": ie_acc, "ie_rec": ie_rec, "ie_prec": ie_prec, "ie_r": not_occupied_not_visible_ratio, "t_ie": total_ie, "t_no_nop_nv": total_no_nop_nv, } def compute_nvs_metrics(data, lpips): # TODO: This is only correct for batchsize 1! # Following tucker et al. and others, we crop 5% on all sides # idx of stereo frame (the target frame is always the "stereo" frame). sf_id = data["rgb_gt"].shape[1] // 2 imgs_gt = data["rgb_gt"][:1, sf_id : sf_id + 1] imgs_pred = data["fine"][0]["rgb"][:1, sf_id : sf_id + 1] imgs_gt = imgs_gt.squeeze(0).permute(0, 3, 1, 2) imgs_pred = imgs_pred.squeeze(0).squeeze(-2).permute(0, 3, 1, 2) n, c, h, w = imgs_gt.shape y0 = int(math.ceil(0.05 * h)) y1 = int(math.floor(0.95 * h)) x0 = int(math.ceil(0.05 * w)) x1 = int(math.floor(0.95 * w)) imgs_gt = imgs_gt[:, :, y0:y1, x0:x1] imgs_pred = imgs_pred[:, :, y0:y1, x0:x1] imgs_gt_np = imgs_gt.detach().squeeze().permute(1, 2, 0).cpu().numpy() imgs_pred_np = imgs_pred.detach().squeeze().permute(1, 2, 0).cpu().numpy() ssim_score = sk_metrics.structural_similarity( imgs_pred_np, imgs_gt_np, multichannel=True, data_range=1, channel_axis=-1 ) psnr_score = sk_metrics.peak_signal_noise_ratio( imgs_pred_np, imgs_gt_np, data_range=1 ) lpips_score = lpips(imgs_pred, imgs_gt, normalize=False).mean() metrics_dict = { "ssim": torch.tensor([ssim_score], device=imgs_gt.device), "psnr": torch.tensor([psnr_score], device=imgs_gt.device), "lpips": torch.tensor([lpips_score], device=imgs_gt.device), } return metrics_dict def compute_dino_metrics(data): dino_gt = data["dino_gt"] if "dino_features_downsampled" in data["coarse"][0]: dino_pred = data["coarse"][0]["dino_features_downsampled"].squeeze(-2) else: dino_pred = data["coarse"][0]["dino_features"].squeeze(-2) l1_loss = F.l1_loss(dino_pred, dino_gt, reduction="none").mean(dim=(0, 2, 3, 4)) l2_loss = F.mse_loss(dino_pred, dino_gt, reduction="none").mean(dim=(0, 2, 3, 4)) cos_sim = F.cosine_similarity(dino_pred, dino_gt, dim=-1).mean(dim=(0, 2, 3)) metrics_dict = { "l1": torch.tensor([l1_loss.mean()], device=dino_gt.device), "l2": torch.tensor([l2_loss.mean()], device=dino_gt.device), "cos_sim": torch.tensor([cos_sim.mean()], device=dino_gt.device) } for i in range(len(l1_loss)): metrics_dict[f"l1_{i}"] = torch.tensor([l1_loss[i]], device=dino_gt.device) metrics_dict[f"l2_{i}"] = torch.tensor([l2_loss[i]], device=dino_gt.device) metrics_dict[f"cos_sim_{i}"] = torch.tensor([cos_sim[i]], device=dino_gt.device) return metrics_dict def compute_stego_metrics(data): if "stego_corr" not in data["segmentation"]: return {} metrics_dict = { "stego_self_corr": data["segmentation"]["stego_corr"]["stego_self_corr"], "stego_nn_corr": data["segmentation"]["stego_corr"]["stego_nn_corr"], "stego_random_corr": data["segmentation"]["stego_corr"]["stego_random_corr"], } return metrics_dict def compute_seg_metrics(data, n_classes, gt_classes): segs_gt = data["segmentation"]["target"].flatten() valid_mask = segs_gt >= 0 segs_gt = segs_gt[valid_mask] metrics_dict = {} for result_key, result in data["segmentation"]["results"].items(): if "pseudo_segs_pred" in result: segs_pred = result["pseudo_segs_pred"][:, 0].flatten() else: segs_pred = result["segs_pred"][:, 0].flatten() segs_pred = segs_pred[valid_mask] confusion_matrix = torch.bincount(n_classes * segs_gt + segs_pred, minlength=n_classes * gt_classes).reshape(gt_classes, n_classes) metrics_dict[result_key] = confusion_matrix return metrics_dict class MeanMetric(Metric): def __init__(self, output_transform=lambda x: x["output"], device="cpu"): super(MeanMetric, self).__init__( output_transform=output_transform, device=device ) self._sum = torch.tensor(0, device=self._device, dtype=torch.float32) self._num_examples = 0 self.required_output_keys = () @reinit__is_reduced def reset(self): self._sum = torch.tensor(0, device=self._device, dtype=torch.float32) self._num_examples = 0 super(MeanMetric, self).reset() @reinit__is_reduced def update(self, value): if torch.any(torch.isnan(torch.tensor(value))): raise ValueError("NaN values present in metric!") self._sum += value self._num_examples += 1 @sync_all_reduce("_num_examples:SUM", "_sum:SUM") def compute(self): if self._num_examples == 0: raise NotComputableError( "CustomAccuracy must have at least one example before it can be computed." ) return self._sum.item() / self._num_examples @torch.no_grad() def iteration_completed(self, engine: Engine) -> None: output = self._output_transform( engine.state.output ) ## engine.state.output.keys() == dict_keys(['output', 'loss_dict', 'timings_dict', 'metrics_dict']) self.update(output) class DictMeanMetric(Metric): def __init__(self, name: str, output_transform=lambda x: x["output"], device="cpu"): self._name = name self._sums: dict[str, torch.Tensor] = {} self._num_examples = 0 self.required_output_keys = () super(DictMeanMetric, self).__init__( output_transform=output_transform, device=device ) @reinit__is_reduced def reset(self): self._sums = {} self._num_examples = 0 super(DictMeanMetric, self).reset() @reinit__is_reduced def update(self, value): num_examples = None for key, metric in value.items(): if not key in self._sums: self._sums[key] = torch.tensor( 0, device=self._device, dtype=torch.float32 ) if torch.any(torch.isnan(metric)): # TODO: integrate into logging print(f"Warining: Metric {self._name}/{key} has a nan value") continue self._sums[key] += metric.sum().to(self._device) # TODO: check if this works with batches if num_examples is None: num_examples = metric.shape[0] self._num_examples += 1 @sync_all_reduce("_num_examples:SUM", "_sum:SUM") def compute(self): if self._num_examples == 0: raise NotComputableError( "CustomAccuracy must have at least one example before it can be computed." ) return { f"{self._name}_{key}": metric.item() / self._num_examples for key, metric in self._sums.items() } @torch.no_grad() def iteration_completed(self, engine: Engine) -> None: output = self._output_transform(engine.state.output["output"]) self.update(output) def completed(self, engine: Engine, name: str) -> None: """Helper method to compute metric's value and put into the engine. It is automatically attached to the `engine` with :meth:`~ignite.metrics.metric.Metric.attach`. If metrics' value is torch tensor, it is explicitly sent to CPU device. Args: engine: the engine to which the metric must be attached name: the name of the metric used as key in dict `engine.state.metrics` .. changes from default implementation: don't add whole result dict to engine state, but only the values """ result = self.compute() if isinstance(result, Mapping): if name in result.keys(): raise ValueError( f"Argument name '{name}' is conflicting with mapping keys: {list(result.keys())}" ) for key, value in result.items(): engine.state.metrics[key] = value else: if isinstance(result, torch.Tensor): if len(result.size()) == 0: result = result.item() elif "cpu" not in result.device.type: result = result.cpu() engine.state.metrics[name] = result class SegmentationMetric(DictMeanMetric): def __init__(self, name: str, output_transform=lambda x: x["output"], device="cpu", assign_pseudo=True): super(SegmentationMetric, self).__init__( name, output_transform, device ) self.assign_pseudo = assign_pseudo # [road, sidewalk, building, wall, fence, pole, traffic light, traffic sign, vegetation, terrain, sky, person, rider, car, truck, bus, train, motorcycle, bicycle] self.weights = torch.Tensor([4, 2, 2, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 1, 1, 1, 1, 1]) self.weights = self.weights / self.weights.mean() @reinit__is_reduced def update(self, value): for key, metric in value.items(): if not key in self._sums: self._sums[key] = torch.zeros(metric.shape, device=self._device, dtype=torch.int32) if torch.any(torch.isnan(metric)): print(f"Warining: Metric {self._name}/{key} has a nan value") continue self._sums[key] += metric.to(self._device) self._num_examples += 1 @sync_all_reduce("_num_examples:SUM", "_sum:SUM") def compute(self): if self._num_examples == 0: raise NotComputableError( "CustomAccuracy must have at least one example before it can be computed." ) result = {} for key, _sum in self._sums.items(): if self.assign_pseudo: assignment = self._calculate_pseudo_label_assignment(_sum) gt_classes = _sum.size(0) confusion_matrix = torch.zeros((gt_classes, gt_classes), dtype=_sum.dtype) confusion_matrix.scatter_add_( 1, assignment.unsqueeze(0).expand(gt_classes, -1), _sum ) result[key + "_assignment"] = assignment else: confusion_matrix = _sum # confusion_matrix axes: (actual, prediction) true_positives = confusion_matrix.diag() false_negatives = torch.sum(confusion_matrix, dim=1) - true_positives false_positives = torch.sum(confusion_matrix, dim=0) - true_positives denominator = true_positives + false_positives + false_negatives per_class_iou = torch.where(denominator > 0, true_positives / denominator, torch.zeros_like(denominator)) result[key + "_per_class_iou"] = per_class_iou result[key + "_miou"] = per_class_iou.mean().item() result[key + "_weighted_miou"] = (per_class_iou * self.weights).mean().item() result[key + "_acc"] = confusion_matrix.diag().sum().item() / confusion_matrix.sum().item() result[key + "_confusion_matrix"] = confusion_matrix return result def _calculate_pseudo_label_assignment(self, metric_matrix): """Implemented this way to generalize to over-segmentation""" gt_classes, n_classes = metric_matrix.size() costs = metric_matrix.cpu().numpy() problem = pulp.LpProblem("CapacitatedAssignment", pulp.LpMaximize) x = [[pulp.LpVariable(f"x_{i}_{j}", cat="Binary") for j in range(n_classes)] for i in range(gt_classes)] problem += pulp.lpSum(costs[i][j] * x[i][j] for i in range(gt_classes) for j in range(n_classes)) for j in range(n_classes): problem += pulp.lpSum(x[i][j] for i in range(gt_classes)) == 1, f"AssignPseudoLabel_{j}" for i in range(gt_classes): problem += pulp.lpSum(x[i][j] for j in range(n_classes)) >= 1, f"MinAssignActualLabel_{i}" problem.solve() print("Status:", pulp.LpStatus[problem.status]) print("Objective:", pulp.value(problem.objective)) assignment = torch.zeros(n_classes, dtype=torch.int64) for j in range(n_classes): assignment[j] = next(i for i in range(gt_classes) if pulp.value(x[i][j]) == 1) return assignment class ConcatenateMetric(DictMeanMetric): @reinit__is_reduced def update(self, value, every_nth=100): n_bins = 50 for key, metric in value.items(): if not key in self._sums: self._sums[key] = torch.zeros((n_bins,), device=self._device, dtype=torch.int32) if torch.any(torch.isnan(metric)): print(f"Warning: Metric {self._name}/{key} has a nan value") continue metric_flat = metric.flatten().to(self._device)[::every_nth] if key in self._sums: self._sums[key] = torch.cat([self._sums[key], metric_flat]) else: self._sums[key] = metric_flat self._num_examples += 1 @sync_all_reduce("_num_examples:SUM", "_sum:SUM") def compute(self): return self._sums class FG_ARI(Metric): def __init__(self, output_transform=lambda x: x["output"], device="cpu"): self._sum_fg_aris = torch.tensor(0, device=self._device, dtype=torch.float32) self._num_examples = 0 self.required_output_keys = () super(FG_ARI, self).__init__(output_transform=output_transform, device=device) @reinit__is_reduced def reset(self): self._sum_fg_aris = torch.tensor(0, device=self._device, dtype=torch.float32) self._num_examples = 0 super(FG_ARI, self).reset() @reinit__is_reduced def update(self, data): true_masks = data["segs"] # fc [n, h, w] pred_masks = data["slot_masks"] # n, fc, sc, h, w n, fc, sc, h, w = pred_masks.shape true_masks = [ F.interpolate(tm.to(float).unsqueeze(1), (h, w), mode="nearest") .squeeze(1) .to(int) for tm in true_masks ] for i in range(n): for f in range(fc): true_mask = true_masks[f][i] pred_mask = pred_masks[i, f] true_mask = true_mask.view(-1) pred_mask = pred_mask.view(sc, -1) if torch.max(true_mask) == 0: continue foreground = true_mask > 0 true_mask = true_mask[foreground] pred_mask = pred_mask[:, foreground].permute(1, 0) true_mask = F.one_hot(true_mask) # Filter out empty true groups not_empty = torch.any(true_mask, dim=0) true_mask = true_mask[:, not_empty] # Filter out empty predicted groups not_empty = torch.any(pred_mask, dim=0) pred_mask = pred_mask[:, not_empty] true_mask.unsqueeze_(0) pred_mask.unsqueeze_(0) _, n_points, n_true_groups = true_mask.shape n_pred_groups = pred_mask.shape[-1] if n_points <= n_true_groups and n_points <= n_pred_groups: print( "adjusted_rand_index requires n_groups < n_points.", file=sys.stderr, ) continue true_group_ids = torch.argmax(true_mask, -1) pred_group_ids = torch.argmax(pred_mask, -1) true_mask_oh = true_mask.to(torch.float32) pred_mask_oh = F.one_hot(pred_group_ids, n_pred_groups).to( torch.float32 ) n_points = torch.sum(true_mask_oh, dim=[1, 2]).to(torch.float32) nij = torch.einsum("bji,bjk->bki", pred_mask_oh, true_mask_oh) a = torch.sum(nij, dim=1) b = torch.sum(nij, dim=2) rindex = torch.sum(nij * (nij - 1), dim=[1, 2]) aindex = torch.sum(a * (a - 1), dim=1) bindex = torch.sum(b * (b - 1), dim=1) expected_rindex = aindex * bindex / (n_points * (n_points - 1)) max_rindex = (aindex + bindex) / 2 ari = (rindex - expected_rindex) / ( max_rindex - expected_rindex + 0.000000000001 ) _all_equal = lambda values: torch.all( torch.eq(values, values[..., :1]), dim=-1 ) both_single_cluster = torch.logical_and( _all_equal(true_group_ids), _all_equal(pred_group_ids) ) self._sum_fg_aris += torch.where( both_single_cluster, torch.ones_like(ari), ari ).squeeze() self._num_examples += 1 @sync_all_reduce("_num_examples:SUM", "_sum_fg_aris:SUM") def compute(self): if self._num_examples == 0: raise NotComputableError( "CustomAccuracy must have at least one example before it can be computed." ) return self._sum_fg_aris.item() / self._num_examples @torch.no_grad() def iteration_completed(self, engine: Engine) -> None: output = self._output_transform(engine.state.output) self.update(output)