SemanticSegmentationModel
/
semantic-segmentation
/SemanticModel
/.ipynb_checkpoints
/metrics-checkpoint.py
from typing import Dict, Optional | |
import numpy as np | |
def compute_intersection_union(prediction, ground_truth, num_classes, ignore_index: bool, | |
label_mapping: Optional[Dict[int, int]] = None, | |
reduce_labels: bool = False): | |
"""Computes intersection and union for IoU calculation.""" | |
if label_mapping: | |
for old_id, new_id in label_mapping.items(): | |
ground_truth[ground_truth == old_id] = new_id | |
prediction = np.array(prediction) | |
ground_truth = np.array(ground_truth) | |
if reduce_labels: | |
ground_truth[ground_truth == 0] = 255 | |
ground_truth = ground_truth - 1 | |
ground_truth[ground_truth == 254] = 255 | |
valid_mask = np.not_equal(ground_truth, ignore_index) | |
prediction = prediction[valid_mask] | |
ground_truth = ground_truth[valid_mask] | |
intersection_mask = prediction == ground_truth | |
intersection = prediction[intersection_mask] | |
area_intersection = np.histogram(intersection, bins=num_classes, | |
range=(0, num_classes - 1))[0] | |
area_prediction = np.histogram(prediction, bins=num_classes, | |
range=(0, num_classes - 1))[0] | |
area_ground_truth = np.histogram(ground_truth, bins=num_classes, | |
range=(0, num_classes - 1))[0] | |
area_union = area_prediction + area_ground_truth - area_intersection | |
return area_intersection, area_union, area_prediction, area_ground_truth | |
def compute_total_intersection_union(predictions, ground_truths, num_classes, ignore_index: bool, | |
label_mapping: Optional[Dict[int, int]] = None, | |
reduce_labels: bool = False): | |
"""Computes total intersection and union across all samples.""" | |
totals = { | |
'intersection': np.zeros((num_classes,), dtype=np.float64), | |
'union': np.zeros((num_classes,), dtype=np.float64), | |
'prediction': np.zeros((num_classes,), dtype=np.float64), | |
'ground_truth': np.zeros((num_classes,), dtype=np.float64) | |
} | |
for pred, gt in zip(predictions, ground_truths): | |
intersection, union, pred_area, gt_area = compute_intersection_union( | |
pred, gt, num_classes, ignore_index, label_mapping, reduce_labels | |
) | |
totals['intersection'] += intersection | |
totals['union'] += union | |
totals['prediction'] += pred_area | |
totals['ground_truth'] += gt_area | |
return tuple(totals.values()) | |
def compute_mean_iou(predictions, ground_truths, num_classes, ignore_index: bool, | |
nan_to_num: Optional[int] = None, | |
label_mapping: Optional[Dict[int, int]] = None, | |
reduce_labels: bool = False): | |
"""Computes mean IoU and related metrics.""" | |
intersection, union, prediction_area, ground_truth_area = compute_total_intersection_union( | |
predictions, ground_truths, num_classes, ignore_index, label_mapping, reduce_labels | |
) | |
metrics = {} | |
# Compute overall accuracy | |
total_accuracy = intersection.sum() / ground_truth_area.sum() | |
# Compute IoU per class | |
iou_per_class = intersection / union | |
accuracy_per_class = intersection / ground_truth_area | |
metrics.update({ | |
"mean_iou": np.nanmean(iou_per_class), | |
"mean_accuracy": np.nanmean(accuracy_per_class), | |
"overall_accuracy": total_accuracy, | |
"per_category_iou": iou_per_class, | |
"per_category_accuracy": accuracy_per_class | |
}) | |
if nan_to_num is not None: | |
metrics = { | |
metric: np.nan_to_num(value, nan=nan_to_num) | |
for metric, value in metrics.items() | |
} | |
return metrics |