English
Antoine1091's picture
Upload folder using huggingface_hub
ede298f verified
import numpy as np
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import torch
from utils.metrics import compute_metrics
from inference.sliding_window import batch_sliding_window_inference
def eval_epoch(model, loader, criterion, device, crop_size, stride, num_classes):
"""
Evaluates the model in validation mode using sliding window inference.
"""
model.eval()
conf_mat = np.zeros((num_classes, num_classes), dtype=int)
pbar = tqdm(loader, desc="Validation", leave=False)
with torch.no_grad():
for images, masks in pbar:
masks = masks.cpu().numpy()
preds = batch_sliding_window_inference(images, model, device, crop_size, stride).cpu().numpy()
for b in range(images.size(0)):
true_b = masks[b]
pred_b = preds[b]
valid = true_b != criterion.ignore_index
conf_mat += confusion_matrix(true_b[valid], pred_b[valid], labels=list(range(num_classes)))
ious, miou, _, mf1 = compute_metrics(conf_mat)
return ious, miou, mf1