from tqdm import tqdm | |
from pathlib import Path | |
from PIL import Image | |
import torch | |
from .sliding_window import batch_sliding_window_inference | |
def inference(model, loader, device, crop_size, stride, output_dir: Path): | |
""" | |
Performs batch inference and saves predicted masks. | |
""" | |
model.eval() | |
output_dir.mkdir(parents=True, exist_ok=True) | |
pbar = tqdm(loader, desc="Inference", leave=False) | |
with torch.no_grad(): | |
for images, rel_paths in pbar: | |
preds = batch_sliding_window_inference(images, model, device, crop_size, stride) | |
preds = preds.cpu().numpy() | |
for b, rel_path in enumerate(rel_paths): | |
out_path = output_dir / rel_path | |
out_path.parent.mkdir(parents=True, exist_ok=True) | |
Image.fromarray(preds[b].astype('uint8')).save(out_path) | |