English
File size: 855 Bytes
ede298f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from tqdm import tqdm
from pathlib import Path
from PIL import Image
import torch

from .sliding_window import batch_sliding_window_inference


def inference(model, loader, device, crop_size, stride, output_dir: Path):
    """
    Performs batch inference and saves predicted masks.
    """
    model.eval()
    output_dir.mkdir(parents=True, exist_ok=True)
    pbar = tqdm(loader, desc="Inference", leave=False)

    with torch.no_grad():
        for images, rel_paths in pbar:
            preds = batch_sliding_window_inference(images, model, device, crop_size, stride)
            preds = preds.cpu().numpy()

            for b, rel_path in enumerate(rel_paths):
                out_path = output_dir / rel_path
                out_path.parent.mkdir(parents=True, exist_ok=True)
                Image.fromarray(preds[b].astype('uint8')).save(out_path)