File size: 855 Bytes
ede298f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from tqdm import tqdm
from pathlib import Path
from PIL import Image
import torch
from .sliding_window import batch_sliding_window_inference
def inference(model, loader, device, crop_size, stride, output_dir: Path):
"""
Performs batch inference and saves predicted masks.
"""
model.eval()
output_dir.mkdir(parents=True, exist_ok=True)
pbar = tqdm(loader, desc="Inference", leave=False)
with torch.no_grad():
for images, rel_paths in pbar:
preds = batch_sliding_window_inference(images, model, device, crop_size, stride)
preds = preds.cpu().numpy()
for b, rel_path in enumerate(rel_paths):
out_path = output_dir / rel_path
out_path.parent.mkdir(parents=True, exist_ok=True)
Image.fromarray(preds[b].astype('uint8')).save(out_path)
|