from tqdm import tqdm from pathlib import Path from PIL import Image import torch from .sliding_window import batch_sliding_window_inference def inference(model, loader, device, crop_size, stride, output_dir: Path): """ Performs batch inference and saves predicted masks. """ model.eval() output_dir.mkdir(parents=True, exist_ok=True) pbar = tqdm(loader, desc="Inference", leave=False) with torch.no_grad(): for images, rel_paths in pbar: preds = batch_sliding_window_inference(images, model, device, crop_size, stride) preds = preds.cpu().numpy() for b, rel_path in enumerate(rel_paths): out_path = output_dir / rel_path out_path.parent.mkdir(parents=True, exist_ok=True) Image.fromarray(preds[b].astype('uint8')).save(out_path)