English
Antoine1091's picture
Upload folder using huggingface_hub
ede298f verified
from tqdm import tqdm
from pathlib import Path
from PIL import Image
import torch
from .sliding_window import batch_sliding_window_inference
def inference(model, loader, device, crop_size, stride, output_dir: Path):
"""
Performs batch inference and saves predicted masks.
"""
model.eval()
output_dir.mkdir(parents=True, exist_ok=True)
pbar = tqdm(loader, desc="Inference", leave=False)
with torch.no_grad():
for images, rel_paths in pbar:
preds = batch_sliding_window_inference(images, model, device, crop_size, stride)
preds = preds.cpu().numpy()
for b, rel_path in enumerate(rel_paths):
out_path = output_dir / rel_path
out_path.parent.mkdir(parents=True, exist_ok=True)
Image.fromarray(preds[b].astype('uint8')).save(out_path)