English
File size: 1,086 Bytes
ede298f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import torch

from utils.metrics import compute_metrics
from inference.sliding_window import batch_sliding_window_inference

def eval_epoch(model, loader, criterion, device, crop_size, stride, num_classes):
    """
    Evaluates the model in validation mode using sliding window inference.
    """
    model.eval()
    conf_mat = np.zeros((num_classes, num_classes), dtype=int)
    pbar = tqdm(loader, desc="Validation", leave=False)

    with torch.no_grad():
        for images, masks in pbar:
            masks = masks.cpu().numpy()
            preds = batch_sliding_window_inference(images, model, device, crop_size, stride).cpu().numpy()

            for b in range(images.size(0)):
                true_b = masks[b]
                pred_b = preds[b]
                valid = true_b != criterion.ignore_index
                conf_mat += confusion_matrix(true_b[valid], pred_b[valid], labels=list(range(num_classes)))

    ious, miou, _, mf1 = compute_metrics(conf_mat)
    return ious, miou, mf1