Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
import math | |
from typing import Callable, Mapping | |
import skimage.metrics as sk_metrics | |
import torch | |
import torch.nn.functional as F | |
from ignite.engine import Engine | |
from ignite.exceptions import NotComputableError | |
from ignite.metrics import Metric | |
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce | |
import pulp | |
def median_scaling( | |
depth_gt: torch.Tensor, | |
depth_pred: torch.Tensor, | |
): | |
# TODO: ensure this works for any batch size | |
mask = depth_gt > 0 | |
depth_gt[mask] = torch.nan | |
depth_pred[mask] = torch.nan | |
scaling = torch.nanmedian(depth_gt.flatten(-2, -1), dim=-1) / torch.nanmedian( | |
depth_pred.flatten(-2, -1), dim=-1 | |
) | |
depth_pred = scaling[..., None, None] * depth_pred | |
return depth_pred | |
def l2_scaling( | |
depth_gt: torch.Tensor, | |
depth_pred: torch.Tensor, | |
): | |
# TODO: ensure this works for any batch size | |
mask = depth_gt > 0 | |
depth_pred = depth_pred | |
depth_gt_ = depth_gt[mask] | |
depth_pred_ = depth_pred[mask] | |
depth_pred_ = torch.stack((depth_pred_, torch.ones_like(depth_pred_)), dim=-1) | |
x = torch.linalg.lstsq( | |
depth_pred_.to(torch.float32), depth_gt_.unsqueeze(-1).to(torch.float32) | |
).solution.squeeze() | |
depth_pred = depth_pred * x[0] + x[1] | |
return depth_pred | |
def compute_depth_metrics( | |
depth_gt: torch.Tensor, | |
depth_pred: torch.Tensor, | |
scaling_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None, | |
): | |
# TODO: find out if dim -3 is dummy dimension or part of the batch | |
# TODO: Test if works for batches of images | |
if scaling_fn: | |
depth_pred = scaling_fn(depth_gt, depth_pred) | |
depth_pred = torch.clamp(depth_pred, 1e-3, 80) | |
mask = depth_gt != 0 | |
max_ratio = torch.maximum((depth_gt / depth_pred), (depth_pred / depth_gt)) | |
a_scores = {} | |
for name, thresh in {"a1": 1.25, "a2": 1.25**2, "a3": 1.25**3}.items(): | |
within_thresh = (max_ratio < thresh).to(torch.float) | |
within_thresh[~mask] = 0.0 | |
a_scores[name] = within_thresh.flatten(-2, -1).sum(dim=-1) / mask.to( | |
torch.float | |
).flatten(-2, -1).sum(dim=-1) | |
square_error = (depth_gt - depth_pred) ** 2 | |
square_error[~mask] = 0.0 | |
log_square_error = (torch.log(depth_gt) - torch.log(depth_pred)) ** 2 | |
log_square_error[~mask] = 0.0 | |
abs_error = torch.abs(depth_gt - depth_pred) | |
abs_error[~mask] = 0.0 | |
rmse = ( | |
square_error.flatten(-2, -1).sum(dim=-1) | |
/ mask.to(torch.float).flatten(-2, -1).sum(dim=-1) | |
) ** 0.5 | |
rmse_log = ( | |
log_square_error.flatten(-2, -1).sum(dim=-1) | |
/ mask.to(torch.float).flatten(-2, -1).sum(dim=-1) | |
) ** 0.5 | |
abs_rel = abs_error / depth_gt | |
abs_rel[~mask] = 0.0 | |
abs_rel = ( | |
abs_rel.flatten(-2, -1).sum(dim=-1) | |
/ mask.to(torch.float).flatten(-2, -1).sum(dim=-1) | |
) ** 0.5 | |
sq_rel = square_error / depth_gt | |
sq_rel[~mask] = 0.0 | |
sq_rel = ( | |
sq_rel.flatten(-2, -1).sum(dim=-1) | |
/ mask.to(torch.float).flatten(-2, -1).sum(dim=-1) | |
) ** 0.5 | |
metrics_dict = { | |
"abs_rel": abs_rel, | |
"sq_rel": sq_rel, | |
"rmse": rmse, | |
"rmse_log": rmse_log, | |
"a1": a_scores["a1"], | |
"a2": a_scores["a2"], | |
"a3": a_scores["a3"], | |
} | |
return metrics_dict | |
def compute_occ_metrics( | |
occupancy_pred: torch.Tensor, occupancy_gt: torch.Tensor, is_visible: torch.Tensor | |
): | |
# Only not visible points can be occupied | |
occupancy_gt &= ~is_visible | |
is_occupied_acc = (occupancy_pred == occupancy_gt).float().mean().item() | |
is_occupied_prec = occupancy_gt[occupancy_pred].float().mean().item() | |
is_occupied_rec = occupancy_pred[occupancy_gt].float().mean().item() | |
not_occupied_not_visible_ratio = ( | |
((~occupancy_gt) & (~is_visible)).float().mean().item() | |
) | |
total_ie = ((~occupancy_gt) & (~is_visible)).float().sum().item() | |
ie_acc = (occupancy_pred == occupancy_gt)[(~is_visible)].float().mean().item() | |
ie_prec = (~occupancy_gt)[(~occupancy_pred) & (~is_visible)].float().mean() | |
ie_rec = (~occupancy_pred)[(~occupancy_gt) & (~is_visible)].float().mean() | |
total_no_nop_nv = ( | |
((~occupancy_gt) & (~occupancy_pred))[(~is_visible) & (~occupancy_gt)] | |
.float() | |
.sum() | |
) | |
return { | |
"o_acc": is_occupied_acc, | |
"o_rec": is_occupied_rec, | |
"o_prec": is_occupied_prec, | |
"ie_acc": ie_acc, | |
"ie_rec": ie_rec, | |
"ie_prec": ie_prec, | |
"ie_r": not_occupied_not_visible_ratio, | |
"t_ie": total_ie, | |
"t_no_nop_nv": total_no_nop_nv, | |
} | |
def compute_nvs_metrics(data, lpips): | |
# TODO: This is only correct for batchsize 1! | |
# Following tucker et al. and others, we crop 5% on all sides | |
# idx of stereo frame (the target frame is always the "stereo" frame). | |
sf_id = data["rgb_gt"].shape[1] // 2 | |
imgs_gt = data["rgb_gt"][:1, sf_id : sf_id + 1] | |
imgs_pred = data["fine"][0]["rgb"][:1, sf_id : sf_id + 1] | |
imgs_gt = imgs_gt.squeeze(0).permute(0, 3, 1, 2) | |
imgs_pred = imgs_pred.squeeze(0).squeeze(-2).permute(0, 3, 1, 2) | |
n, c, h, w = imgs_gt.shape | |
y0 = int(math.ceil(0.05 * h)) | |
y1 = int(math.floor(0.95 * h)) | |
x0 = int(math.ceil(0.05 * w)) | |
x1 = int(math.floor(0.95 * w)) | |
imgs_gt = imgs_gt[:, :, y0:y1, x0:x1] | |
imgs_pred = imgs_pred[:, :, y0:y1, x0:x1] | |
imgs_gt_np = imgs_gt.detach().squeeze().permute(1, 2, 0).cpu().numpy() | |
imgs_pred_np = imgs_pred.detach().squeeze().permute(1, 2, 0).cpu().numpy() | |
ssim_score = sk_metrics.structural_similarity( | |
imgs_pred_np, imgs_gt_np, multichannel=True, data_range=1, channel_axis=-1 | |
) | |
psnr_score = sk_metrics.peak_signal_noise_ratio( | |
imgs_pred_np, imgs_gt_np, data_range=1 | |
) | |
lpips_score = lpips(imgs_pred, imgs_gt, normalize=False).mean() | |
metrics_dict = { | |
"ssim": torch.tensor([ssim_score], device=imgs_gt.device), | |
"psnr": torch.tensor([psnr_score], device=imgs_gt.device), | |
"lpips": torch.tensor([lpips_score], device=imgs_gt.device), | |
} | |
return metrics_dict | |
def compute_dino_metrics(data): | |
dino_gt = data["dino_gt"] | |
if "dino_features_downsampled" in data["coarse"][0]: | |
dino_pred = data["coarse"][0]["dino_features_downsampled"].squeeze(-2) | |
else: | |
dino_pred = data["coarse"][0]["dino_features"].squeeze(-2) | |
l1_loss = F.l1_loss(dino_pred, dino_gt, reduction="none").mean(dim=(0, 2, 3, 4)) | |
l2_loss = F.mse_loss(dino_pred, dino_gt, reduction="none").mean(dim=(0, 2, 3, 4)) | |
cos_sim = F.cosine_similarity(dino_pred, dino_gt, dim=-1).mean(dim=(0, 2, 3)) | |
metrics_dict = { | |
"l1": torch.tensor([l1_loss.mean()], device=dino_gt.device), | |
"l2": torch.tensor([l2_loss.mean()], device=dino_gt.device), | |
"cos_sim": torch.tensor([cos_sim.mean()], device=dino_gt.device) | |
} | |
for i in range(len(l1_loss)): | |
metrics_dict[f"l1_{i}"] = torch.tensor([l1_loss[i]], device=dino_gt.device) | |
metrics_dict[f"l2_{i}"] = torch.tensor([l2_loss[i]], device=dino_gt.device) | |
metrics_dict[f"cos_sim_{i}"] = torch.tensor([cos_sim[i]], device=dino_gt.device) | |
return metrics_dict | |
def compute_stego_metrics(data): | |
if "stego_corr" not in data["segmentation"]: | |
return {} | |
metrics_dict = { | |
"stego_self_corr": data["segmentation"]["stego_corr"]["stego_self_corr"], | |
"stego_nn_corr": data["segmentation"]["stego_corr"]["stego_nn_corr"], | |
"stego_random_corr": data["segmentation"]["stego_corr"]["stego_random_corr"], | |
} | |
return metrics_dict | |
def compute_seg_metrics(data, n_classes, gt_classes): | |
segs_gt = data["segmentation"]["target"].flatten() | |
valid_mask = segs_gt >= 0 | |
segs_gt = segs_gt[valid_mask] | |
metrics_dict = {} | |
for result_key, result in data["segmentation"]["results"].items(): | |
if "pseudo_segs_pred" in result: | |
segs_pred = result["pseudo_segs_pred"][:, 0].flatten() | |
else: | |
segs_pred = result["segs_pred"][:, 0].flatten() | |
segs_pred = segs_pred[valid_mask] | |
confusion_matrix = torch.bincount(n_classes * segs_gt + segs_pred, | |
minlength=n_classes * gt_classes).reshape(gt_classes, n_classes) | |
metrics_dict[result_key] = confusion_matrix | |
return metrics_dict | |
class MeanMetric(Metric): | |
def __init__(self, output_transform=lambda x: x["output"], device="cpu"): | |
super(MeanMetric, self).__init__( | |
output_transform=output_transform, device=device | |
) | |
self._sum = torch.tensor(0, device=self._device, dtype=torch.float32) | |
self._num_examples = 0 | |
self.required_output_keys = () | |
def reset(self): | |
self._sum = torch.tensor(0, device=self._device, dtype=torch.float32) | |
self._num_examples = 0 | |
super(MeanMetric, self).reset() | |
def update(self, value): | |
if torch.any(torch.isnan(torch.tensor(value))): | |
raise ValueError("NaN values present in metric!") | |
self._sum += value | |
self._num_examples += 1 | |
def compute(self): | |
if self._num_examples == 0: | |
raise NotComputableError( | |
"CustomAccuracy must have at least one example before it can be computed." | |
) | |
return self._sum.item() / self._num_examples | |
def iteration_completed(self, engine: Engine) -> None: | |
output = self._output_transform( | |
engine.state.output | |
) ## engine.state.output.keys() == dict_keys(['output', 'loss_dict', 'timings_dict', 'metrics_dict']) | |
self.update(output) | |
class DictMeanMetric(Metric): | |
def __init__(self, name: str, output_transform=lambda x: x["output"], device="cpu"): | |
self._name = name | |
self._sums: dict[str, torch.Tensor] = {} | |
self._num_examples = 0 | |
self.required_output_keys = () | |
super(DictMeanMetric, self).__init__( | |
output_transform=output_transform, device=device | |
) | |
def reset(self): | |
self._sums = {} | |
self._num_examples = 0 | |
super(DictMeanMetric, self).reset() | |
def update(self, value): | |
num_examples = None | |
for key, metric in value.items(): | |
if not key in self._sums: | |
self._sums[key] = torch.tensor( | |
0, device=self._device, dtype=torch.float32 | |
) | |
if torch.any(torch.isnan(metric)): | |
# TODO: integrate into logging | |
print(f"Warining: Metric {self._name}/{key} has a nan value") | |
continue | |
self._sums[key] += metric.sum().to(self._device) | |
# TODO: check if this works with batches | |
if num_examples is None: | |
num_examples = metric.shape[0] | |
self._num_examples += 1 | |
def compute(self): | |
if self._num_examples == 0: | |
raise NotComputableError( | |
"CustomAccuracy must have at least one example before it can be computed." | |
) | |
return { | |
f"{self._name}_{key}": metric.item() / self._num_examples | |
for key, metric in self._sums.items() | |
} | |
def iteration_completed(self, engine: Engine) -> None: | |
output = self._output_transform(engine.state.output["output"]) | |
self.update(output) | |
def completed(self, engine: Engine, name: str) -> None: | |
"""Helper method to compute metric's value and put into the engine. It is automatically attached to the | |
`engine` with :meth:`~ignite.metrics.metric.Metric.attach`. If metrics' value is torch tensor, it is | |
explicitly sent to CPU device. | |
Args: | |
engine: the engine to which the metric must be attached | |
name: the name of the metric used as key in dict `engine.state.metrics` | |
.. changes from default implementation: | |
don't add whole result dict to engine state, but only the values | |
""" | |
result = self.compute() | |
if isinstance(result, Mapping): | |
if name in result.keys(): | |
raise ValueError( | |
f"Argument name '{name}' is conflicting with mapping keys: {list(result.keys())}" | |
) | |
for key, value in result.items(): | |
engine.state.metrics[key] = value | |
else: | |
if isinstance(result, torch.Tensor): | |
if len(result.size()) == 0: | |
result = result.item() | |
elif "cpu" not in result.device.type: | |
result = result.cpu() | |
engine.state.metrics[name] = result | |
class SegmentationMetric(DictMeanMetric): | |
def __init__(self, name: str, output_transform=lambda x: x["output"], device="cpu", assign_pseudo=True): | |
super(SegmentationMetric, self).__init__( | |
name, output_transform, device | |
) | |
self.assign_pseudo = assign_pseudo | |
# [road, sidewalk, building, wall, fence, pole, traffic light, traffic sign, vegetation, terrain, sky, person, rider, car, truck, bus, train, motorcycle, bicycle] | |
self.weights = torch.Tensor([4, 2, 2, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 1, 1, 1, 1, 1]) | |
self.weights = self.weights / self.weights.mean() | |
def update(self, value): | |
for key, metric in value.items(): | |
if not key in self._sums: | |
self._sums[key] = torch.zeros(metric.shape, device=self._device, dtype=torch.int32) | |
if torch.any(torch.isnan(metric)): | |
print(f"Warining: Metric {self._name}/{key} has a nan value") | |
continue | |
self._sums[key] += metric.to(self._device) | |
self._num_examples += 1 | |
def compute(self): | |
if self._num_examples == 0: | |
raise NotComputableError( | |
"CustomAccuracy must have at least one example before it can be computed." | |
) | |
result = {} | |
for key, _sum in self._sums.items(): | |
if self.assign_pseudo: | |
assignment = self._calculate_pseudo_label_assignment(_sum) | |
gt_classes = _sum.size(0) | |
confusion_matrix = torch.zeros((gt_classes, gt_classes), dtype=_sum.dtype) | |
confusion_matrix.scatter_add_( | |
1, | |
assignment.unsqueeze(0).expand(gt_classes, -1), | |
_sum | |
) | |
result[key + "_assignment"] = assignment | |
else: | |
confusion_matrix = _sum | |
# confusion_matrix axes: (actual, prediction) | |
true_positives = confusion_matrix.diag() | |
false_negatives = torch.sum(confusion_matrix, dim=1) - true_positives | |
false_positives = torch.sum(confusion_matrix, dim=0) - true_positives | |
denominator = true_positives + false_positives + false_negatives | |
per_class_iou = torch.where(denominator > 0, true_positives / denominator, | |
torch.zeros_like(denominator)) | |
result[key + "_per_class_iou"] = per_class_iou | |
result[key + "_miou"] = per_class_iou.mean().item() | |
result[key + "_weighted_miou"] = (per_class_iou * self.weights).mean().item() | |
result[key + "_acc"] = confusion_matrix.diag().sum().item() / confusion_matrix.sum().item() | |
result[key + "_confusion_matrix"] = confusion_matrix | |
return result | |
def _calculate_pseudo_label_assignment(self, metric_matrix): | |
"""Implemented this way to generalize to over-segmentation""" | |
gt_classes, n_classes = metric_matrix.size() | |
costs = metric_matrix.cpu().numpy() | |
problem = pulp.LpProblem("CapacitatedAssignment", pulp.LpMaximize) | |
x = [[pulp.LpVariable(f"x_{i}_{j}", cat="Binary") for j in range(n_classes)] for i in | |
range(gt_classes)] | |
problem += pulp.lpSum(costs[i][j] * x[i][j] for i in range(gt_classes) for j in range(n_classes)) | |
for j in range(n_classes): | |
problem += pulp.lpSum(x[i][j] for i in range(gt_classes)) == 1, f"AssignPseudoLabel_{j}" | |
for i in range(gt_classes): | |
problem += pulp.lpSum(x[i][j] for j in range(n_classes)) >= 1, f"MinAssignActualLabel_{i}" | |
problem.solve() | |
print("Status:", pulp.LpStatus[problem.status]) | |
print("Objective:", pulp.value(problem.objective)) | |
assignment = torch.zeros(n_classes, dtype=torch.int64) | |
for j in range(n_classes): | |
assignment[j] = next(i for i in range(gt_classes) if pulp.value(x[i][j]) == 1) | |
return assignment | |
class ConcatenateMetric(DictMeanMetric): | |
def update(self, value, every_nth=100): | |
n_bins = 50 | |
for key, metric in value.items(): | |
if not key in self._sums: | |
self._sums[key] = torch.zeros((n_bins,), device=self._device, dtype=torch.int32) | |
if torch.any(torch.isnan(metric)): | |
print(f"Warning: Metric {self._name}/{key} has a nan value") | |
continue | |
metric_flat = metric.flatten().to(self._device)[::every_nth] | |
if key in self._sums: | |
self._sums[key] = torch.cat([self._sums[key], metric_flat]) | |
else: | |
self._sums[key] = metric_flat | |
self._num_examples += 1 | |
def compute(self): | |
return self._sums | |
class FG_ARI(Metric): | |
def __init__(self, output_transform=lambda x: x["output"], device="cpu"): | |
self._sum_fg_aris = torch.tensor(0, device=self._device, dtype=torch.float32) | |
self._num_examples = 0 | |
self.required_output_keys = () | |
super(FG_ARI, self).__init__(output_transform=output_transform, device=device) | |
def reset(self): | |
self._sum_fg_aris = torch.tensor(0, device=self._device, dtype=torch.float32) | |
self._num_examples = 0 | |
super(FG_ARI, self).reset() | |
def update(self, data): | |
true_masks = data["segs"] # fc [n, h, w] | |
pred_masks = data["slot_masks"] # n, fc, sc, h, w | |
n, fc, sc, h, w = pred_masks.shape | |
true_masks = [ | |
F.interpolate(tm.to(float).unsqueeze(1), (h, w), mode="nearest") | |
.squeeze(1) | |
.to(int) | |
for tm in true_masks | |
] | |
for i in range(n): | |
for f in range(fc): | |
true_mask = true_masks[f][i] | |
pred_mask = pred_masks[i, f] | |
true_mask = true_mask.view(-1) | |
pred_mask = pred_mask.view(sc, -1) | |
if torch.max(true_mask) == 0: | |
continue | |
foreground = true_mask > 0 | |
true_mask = true_mask[foreground] | |
pred_mask = pred_mask[:, foreground].permute(1, 0) | |
true_mask = F.one_hot(true_mask) | |
# Filter out empty true groups | |
not_empty = torch.any(true_mask, dim=0) | |
true_mask = true_mask[:, not_empty] | |
# Filter out empty predicted groups | |
not_empty = torch.any(pred_mask, dim=0) | |
pred_mask = pred_mask[:, not_empty] | |
true_mask.unsqueeze_(0) | |
pred_mask.unsqueeze_(0) | |
_, n_points, n_true_groups = true_mask.shape | |
n_pred_groups = pred_mask.shape[-1] | |
if n_points <= n_true_groups and n_points <= n_pred_groups: | |
print( | |
"adjusted_rand_index requires n_groups < n_points.", | |
file=sys.stderr, | |
) | |
continue | |
true_group_ids = torch.argmax(true_mask, -1) | |
pred_group_ids = torch.argmax(pred_mask, -1) | |
true_mask_oh = true_mask.to(torch.float32) | |
pred_mask_oh = F.one_hot(pred_group_ids, n_pred_groups).to( | |
torch.float32 | |
) | |
n_points = torch.sum(true_mask_oh, dim=[1, 2]).to(torch.float32) | |
nij = torch.einsum("bji,bjk->bki", pred_mask_oh, true_mask_oh) | |
a = torch.sum(nij, dim=1) | |
b = torch.sum(nij, dim=2) | |
rindex = torch.sum(nij * (nij - 1), dim=[1, 2]) | |
aindex = torch.sum(a * (a - 1), dim=1) | |
bindex = torch.sum(b * (b - 1), dim=1) | |
expected_rindex = aindex * bindex / (n_points * (n_points - 1)) | |
max_rindex = (aindex + bindex) / 2 | |
ari = (rindex - expected_rindex) / ( | |
max_rindex - expected_rindex + 0.000000000001 | |
) | |
_all_equal = lambda values: torch.all( | |
torch.eq(values, values[..., :1]), dim=-1 | |
) | |
both_single_cluster = torch.logical_and( | |
_all_equal(true_group_ids), _all_equal(pred_group_ids) | |
) | |
self._sum_fg_aris += torch.where( | |
both_single_cluster, torch.ones_like(ari), ari | |
).squeeze() | |
self._num_examples += 1 | |
def compute(self): | |
if self._num_examples == 0: | |
raise NotComputableError( | |
"CustomAccuracy must have at least one example before it can be computed." | |
) | |
return self._sum_fg_aris.item() / self._num_examples | |
def iteration_completed(self, engine: Engine) -> None: | |
output = self._output_transform(engine.state.output) | |
self.update(output) | |