English
Antoine1091's picture
Upload folder using huggingface_hub
ede298f verified
import torch
from typing import Tuple
def batch_sliding_window_inference(images: torch.Tensor, model: torch.nn.Module,
device: torch.device,
crop_size: Tuple[int, int], stride: Tuple[int, int]) -> torch.Tensor:
"""
Applies sliding window inference with final argmax (prediction).
Returns a tensor (B, H, W) with predicted classes.
"""
B, C, H, W = images.shape
ph, pw = crop_size
sh, sw = stride
images = images.to(device)
num_classes = model.config.num_labels
full_logits = torch.zeros((B, num_classes, H, W), device=device)
count_map = torch.zeros((H, W), device=device)
for top in range(0, H, sh):
for left in range(0, W, sw):
bottom = min(top + ph, H)
right = min(left + pw, W)
top0 = max(bottom - ph, 0)
left0 = max(right - pw, 0)
patch = images[:, :, top0:bottom, left0:right].contiguous()
with torch.no_grad():
logits = model(pixel_values=patch).logits
full_logits[:, :, top0:bottom, left0:right] += logits
count_map[top0:bottom, left0:right] += 1
avg_logits = full_logits / count_map.unsqueeze(0).unsqueeze(0).clamp(min=1)
return avg_logits.argmax(dim=1)