|
import numpy as np |
|
from tqdm import tqdm |
|
from sklearn.metrics import confusion_matrix |
|
import torch |
|
|
|
from utils.metrics import compute_metrics |
|
from inference.sliding_window import batch_sliding_window_inference |
|
|
|
def eval_epoch(model, loader, criterion, device, crop_size, stride, num_classes): |
|
""" |
|
Evaluates the model in validation mode using sliding window inference. |
|
""" |
|
model.eval() |
|
conf_mat = np.zeros((num_classes, num_classes), dtype=int) |
|
pbar = tqdm(loader, desc="Validation", leave=False) |
|
|
|
with torch.no_grad(): |
|
for images, masks in pbar: |
|
masks = masks.cpu().numpy() |
|
preds = batch_sliding_window_inference(images, model, device, crop_size, stride).cpu().numpy() |
|
|
|
for b in range(images.size(0)): |
|
true_b = masks[b] |
|
pred_b = preds[b] |
|
valid = true_b != criterion.ignore_index |
|
conf_mat += confusion_matrix(true_b[valid], pred_b[valid], labels=list(range(num_classes))) |
|
|
|
ious, miou, _, mf1 = compute_metrics(conf_mat) |
|
return ious, miou, mf1 |
|
|