|
import copy
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import Dict, List
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
from loguru import logger
|
|
from torchvision.ops import box_iou
|
|
|
|
|
|
class Validator:
|
|
def __init__(
|
|
self,
|
|
gt: List[Dict[str, torch.Tensor]],
|
|
preds: List[Dict[str, torch.Tensor]],
|
|
conf_thresh=0.5,
|
|
iou_thresh=0.5,
|
|
) -> None:
|
|
"""
|
|
Format example:
|
|
gt = [{'labels': tensor([0]), 'boxes': tensor([[561.0, 297.0, 661.0, 359.0]])}, ...]
|
|
len(gt) is the number of images
|
|
bboxes are in format [x1, y1, x2, y2], absolute values
|
|
"""
|
|
self.gt = gt
|
|
self.preds = preds
|
|
self.conf_thresh = conf_thresh
|
|
self.iou_thresh = iou_thresh
|
|
self.thresholds = np.arange(0.2, 1.0, 0.05)
|
|
self.conf_matrix = None
|
|
|
|
def compute_metrics(self, extended=False) -> Dict[str, float]:
|
|
filtered_preds = filter_preds(copy.deepcopy(self.preds), self.conf_thresh)
|
|
metrics = self._compute_main_metrics(filtered_preds)
|
|
if not extended:
|
|
metrics.pop("extended_metrics", None)
|
|
return metrics
|
|
|
|
def _compute_main_metrics(self, preds):
|
|
(
|
|
self.metrics_per_class,
|
|
self.conf_matrix,
|
|
self.class_to_idx,
|
|
) = self._compute_metrics_and_confusion_matrix(preds)
|
|
tps, fps, fns = 0, 0, 0
|
|
ious = []
|
|
extended_metrics = {}
|
|
for key, value in self.metrics_per_class.items():
|
|
tps += value["TPs"]
|
|
fps += value["FPs"]
|
|
fns += value["FNs"]
|
|
ious.extend(value["IoUs"])
|
|
|
|
extended_metrics[f"precision_{key}"] = (
|
|
value["TPs"] / (value["TPs"] + value["FPs"])
|
|
if value["TPs"] + value["FPs"] > 0
|
|
else 0
|
|
)
|
|
extended_metrics[f"recall_{key}"] = (
|
|
value["TPs"] / (value["TPs"] + value["FNs"])
|
|
if value["TPs"] + value["FNs"] > 0
|
|
else 0
|
|
)
|
|
|
|
extended_metrics[f"iou_{key}"] = np.mean(value["IoUs"])
|
|
|
|
precision = tps / (tps + fps) if (tps + fps) > 0 else 0
|
|
recall = tps / (tps + fns) if (tps + fns) > 0 else 0
|
|
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
|
iou = np.mean(ious).item() if ious else 0
|
|
return {
|
|
"f1": f1,
|
|
"precision": precision,
|
|
"recall": recall,
|
|
"iou": iou,
|
|
"TPs": tps,
|
|
"FPs": fps,
|
|
"FNs": fns,
|
|
"extended_metrics": extended_metrics,
|
|
}
|
|
|
|
def _compute_matrix_multi_class(self, preds):
|
|
metrics_per_class = defaultdict(lambda: {"TPs": 0, "FPs": 0, "FNs": 0, "IoUs": []})
|
|
for pred, gt in zip(preds, self.gt):
|
|
pred_boxes = pred["boxes"]
|
|
pred_labels = pred["labels"]
|
|
gt_boxes = gt["boxes"]
|
|
gt_labels = gt["labels"]
|
|
|
|
|
|
labels = torch.unique(torch.cat([pred_labels, gt_labels]))
|
|
for label in labels:
|
|
pred_cl_boxes = pred_boxes[pred_labels == label]
|
|
gt_cl_boxes = gt_boxes[gt_labels == label]
|
|
|
|
n_preds = len(pred_cl_boxes)
|
|
n_gts = len(gt_cl_boxes)
|
|
if not (n_preds or n_gts):
|
|
continue
|
|
if not n_preds:
|
|
metrics_per_class[label.item()]["FNs"] += n_gts
|
|
metrics_per_class[label.item()]["IoUs"].extend([0] * n_gts)
|
|
continue
|
|
if not n_gts:
|
|
metrics_per_class[label.item()]["FPs"] += n_preds
|
|
metrics_per_class[label.item()]["IoUs"].extend([0] * n_preds)
|
|
continue
|
|
|
|
ious = box_iou(pred_cl_boxes, gt_cl_boxes)
|
|
ious_mask = ious >= self.iou_thresh
|
|
|
|
|
|
pred_indices, gt_indices = torch.nonzero(ious_mask, as_tuple=True)
|
|
|
|
if not pred_indices.numel():
|
|
metrics_per_class[label.item()]["FNs"] += n_gts
|
|
metrics_per_class[label.item()]["IoUs"].extend([0] * n_gts)
|
|
metrics_per_class[label.item()]["FPs"] += n_preds
|
|
metrics_per_class[label.item()]["IoUs"].extend([0] * n_preds)
|
|
continue
|
|
|
|
iou_values = ious[pred_indices, gt_indices]
|
|
|
|
|
|
sorted_indices = torch.argsort(-iou_values)
|
|
pred_indices = pred_indices[sorted_indices]
|
|
gt_indices = gt_indices[sorted_indices]
|
|
iou_values = iou_values[sorted_indices]
|
|
|
|
matched_preds = set()
|
|
matched_gts = set()
|
|
for pred_idx, gt_idx, iou in zip(pred_indices, gt_indices, iou_values):
|
|
if gt_idx.item() not in matched_gts and pred_idx.item() not in matched_preds:
|
|
matched_preds.add(pred_idx.item())
|
|
matched_gts.add(gt_idx.item())
|
|
metrics_per_class[label.item()]["TPs"] += 1
|
|
metrics_per_class[label.item()]["IoUs"].append(iou.item())
|
|
|
|
unmatched_preds = set(range(n_preds)) - matched_preds
|
|
unmatched_gts = set(range(n_gts)) - matched_gts
|
|
metrics_per_class[label.item()]["FPs"] += len(unmatched_preds)
|
|
metrics_per_class[label.item()]["IoUs"].extend([0] * len(unmatched_preds))
|
|
metrics_per_class[label.item()]["FNs"] += len(unmatched_gts)
|
|
metrics_per_class[label.item()]["IoUs"].extend([0] * len(unmatched_gts))
|
|
return metrics_per_class
|
|
|
|
def _compute_metrics_and_confusion_matrix(self, preds):
|
|
|
|
metrics_per_class = defaultdict(lambda: {"TPs": 0, "FPs": 0, "FNs": 0, "IoUs": []})
|
|
|
|
|
|
all_classes = set()
|
|
for pred in preds:
|
|
all_classes.update(pred["labels"].tolist())
|
|
for gt in self.gt:
|
|
all_classes.update(gt["labels"].tolist())
|
|
all_classes = sorted(list(all_classes))
|
|
class_to_idx = {cls_id: idx for idx, cls_id in enumerate(all_classes)}
|
|
n_classes = len(all_classes)
|
|
conf_matrix = np.zeros((n_classes + 1, n_classes + 1), dtype=int)
|
|
|
|
for pred, gt in zip(preds, self.gt):
|
|
pred_boxes = pred["boxes"]
|
|
pred_labels = pred["labels"]
|
|
gt_boxes = gt["boxes"]
|
|
gt_labels = gt["labels"]
|
|
|
|
n_preds = len(pred_boxes)
|
|
n_gts = len(gt_boxes)
|
|
|
|
if n_preds == 0 and n_gts == 0:
|
|
continue
|
|
|
|
ious = box_iou(pred_boxes, gt_boxes) if n_preds > 0 and n_gts > 0 else torch.tensor([])
|
|
|
|
matched_pred_indices = set()
|
|
matched_gt_indices = set()
|
|
|
|
if ious.numel() > 0:
|
|
|
|
ious_mask = ious >= self.iou_thresh
|
|
pred_indices, gt_indices = torch.nonzero(ious_mask, as_tuple=True)
|
|
iou_values = ious[pred_indices, gt_indices]
|
|
|
|
|
|
sorted_indices = torch.argsort(-iou_values)
|
|
pred_indices = pred_indices[sorted_indices]
|
|
gt_indices = gt_indices[sorted_indices]
|
|
iou_values = iou_values[sorted_indices]
|
|
|
|
for pred_idx, gt_idx, iou in zip(pred_indices, gt_indices, iou_values):
|
|
if (
|
|
pred_idx.item() in matched_pred_indices
|
|
or gt_idx.item() in matched_gt_indices
|
|
):
|
|
continue
|
|
matched_pred_indices.add(pred_idx.item())
|
|
matched_gt_indices.add(gt_idx.item())
|
|
|
|
pred_label = pred_labels[pred_idx].item()
|
|
gt_label = gt_labels[gt_idx].item()
|
|
|
|
pred_cls_idx = class_to_idx[pred_label]
|
|
gt_cls_idx = class_to_idx[gt_label]
|
|
|
|
|
|
conf_matrix[gt_cls_idx, pred_cls_idx] += 1
|
|
|
|
|
|
if pred_label == gt_label:
|
|
metrics_per_class[gt_label]["TPs"] += 1
|
|
metrics_per_class[gt_label]["IoUs"].append(iou.item())
|
|
else:
|
|
|
|
metrics_per_class[gt_label]["FNs"] += 1
|
|
metrics_per_class[pred_label]["FPs"] += 1
|
|
metrics_per_class[gt_label]["IoUs"].append(0)
|
|
metrics_per_class[pred_label]["IoUs"].append(0)
|
|
|
|
|
|
unmatched_pred_indices = set(range(n_preds)) - matched_pred_indices
|
|
for pred_idx in unmatched_pred_indices:
|
|
pred_label = pred_labels[pred_idx].item()
|
|
pred_cls_idx = class_to_idx[pred_label]
|
|
|
|
conf_matrix[n_classes, pred_cls_idx] += 1
|
|
|
|
metrics_per_class[pred_label]["FPs"] += 1
|
|
metrics_per_class[pred_label]["IoUs"].append(0)
|
|
|
|
|
|
unmatched_gt_indices = set(range(n_gts)) - matched_gt_indices
|
|
for gt_idx in unmatched_gt_indices:
|
|
gt_label = gt_labels[gt_idx].item()
|
|
gt_cls_idx = class_to_idx[gt_label]
|
|
|
|
conf_matrix[gt_cls_idx, n_classes] += 1
|
|
|
|
metrics_per_class[gt_label]["FNs"] += 1
|
|
metrics_per_class[gt_label]["IoUs"].append(0)
|
|
|
|
return metrics_per_class, conf_matrix, class_to_idx
|
|
|
|
def save_plots(self, path_to_save) -> None:
|
|
path_to_save = Path(path_to_save)
|
|
path_to_save.mkdir(parents=True, exist_ok=True)
|
|
|
|
if self.conf_matrix is not None:
|
|
class_labels = [str(cls_id) for cls_id in self.class_to_idx.keys()] + ["background"]
|
|
|
|
plt.figure(figsize=(10, 8))
|
|
plt.imshow(self.conf_matrix, interpolation="nearest", cmap=plt.cm.Blues)
|
|
plt.title("Confusion Matrix")
|
|
plt.colorbar()
|
|
tick_marks = np.arange(len(class_labels))
|
|
plt.xticks(tick_marks, class_labels, rotation=45)
|
|
plt.yticks(tick_marks, class_labels)
|
|
|
|
|
|
thresh = self.conf_matrix.max() / 2.0
|
|
for i in range(self.conf_matrix.shape[0]):
|
|
for j in range(self.conf_matrix.shape[1]):
|
|
plt.text(
|
|
j,
|
|
i,
|
|
format(self.conf_matrix[i, j], "d"),
|
|
horizontalalignment="center",
|
|
color="white" if self.conf_matrix[i, j] > thresh else "black",
|
|
)
|
|
|
|
plt.ylabel("True label")
|
|
plt.xlabel("Predicted label")
|
|
plt.tight_layout()
|
|
plt.savefig(path_to_save / "confusion_matrix.png")
|
|
plt.close()
|
|
|
|
thresholds = self.thresholds
|
|
precisions, recalls, f1_scores = [], [], []
|
|
|
|
|
|
original_preds = copy.deepcopy(self.preds)
|
|
|
|
for threshold in thresholds:
|
|
|
|
filtered_preds = filter_preds(copy.deepcopy(original_preds), threshold)
|
|
|
|
metrics = self._compute_main_metrics(filtered_preds)
|
|
precisions.append(metrics["precision"])
|
|
recalls.append(metrics["recall"])
|
|
f1_scores.append(metrics["f1"])
|
|
|
|
|
|
plt.figure()
|
|
plt.plot(thresholds, precisions, label="Precision", marker="o")
|
|
plt.plot(thresholds, recalls, label="Recall", marker="o")
|
|
plt.xlabel("Threshold")
|
|
plt.ylabel("Value")
|
|
plt.title("Precision and Recall vs Threshold")
|
|
plt.legend()
|
|
plt.grid(True)
|
|
plt.savefig(path_to_save / "precision_recall_vs_threshold.png")
|
|
plt.close()
|
|
|
|
|
|
plt.figure()
|
|
plt.plot(thresholds, f1_scores, label="F1 Score", marker="o")
|
|
plt.xlabel("Threshold")
|
|
plt.ylabel("F1 Score")
|
|
plt.title("F1 Score vs Threshold")
|
|
plt.grid(True)
|
|
plt.savefig(path_to_save / "f1_score_vs_threshold.png")
|
|
plt.close()
|
|
|
|
|
|
best_idx = len(f1_scores) - np.argmax(f1_scores[::-1]) - 1
|
|
best_threshold = thresholds[best_idx]
|
|
best_f1 = f1_scores[best_idx]
|
|
|
|
logger.info(
|
|
f"Best Threshold: {round(best_threshold, 2)} with F1 Score: {round(best_f1, 3)}"
|
|
)
|
|
|
|
|
|
def filter_preds(preds, conf_thresh):
|
|
for pred in preds:
|
|
keep_idxs = pred["scores"] >= conf_thresh
|
|
pred["scores"] = pred["scores"][keep_idxs]
|
|
pred["boxes"] = pred["boxes"][keep_idxs]
|
|
pred["labels"] = pred["labels"][keep_idxs]
|
|
return preds
|
|
|
|
|
|
def scale_boxes(boxes, orig_shape, resized_shape):
|
|
"""
|
|
boxes in format: [x1, y1, x2, y2], absolute values
|
|
orig_shape: [height, width]
|
|
resized_shape: [height, width]
|
|
"""
|
|
scale_x = orig_shape[1] / resized_shape[1]
|
|
scale_y = orig_shape[0] / resized_shape[0]
|
|
boxes[:, 0] *= scale_x
|
|
boxes[:, 2] *= scale_x
|
|
boxes[:, 1] *= scale_y
|
|
boxes[:, 3] *= scale_y
|
|
return boxes
|
|
|