import torch from typing import Tuple def batch_sliding_window_inference(images: torch.Tensor, model: torch.nn.Module, device: torch.device, crop_size: Tuple[int, int], stride: Tuple[int, int]) -> torch.Tensor: """ Applies sliding window inference with final argmax (prediction). Returns a tensor (B, H, W) with predicted classes. """ B, C, H, W = images.shape ph, pw = crop_size sh, sw = stride images = images.to(device) num_classes = model.config.num_labels full_logits = torch.zeros((B, num_classes, H, W), device=device) count_map = torch.zeros((H, W), device=device) for top in range(0, H, sh): for left in range(0, W, sw): bottom = min(top + ph, H) right = min(left + pw, W) top0 = max(bottom - ph, 0) left0 = max(right - pw, 0) patch = images[:, :, top0:bottom, left0:right].contiguous() with torch.no_grad(): logits = model(pixel_values=patch).logits full_logits[:, :, top0:bottom, left0:right] += logits count_map[top0:bottom, left0:right] += 1 avg_logits = full_logits / count_map.unsqueeze(0).unsqueeze(0).clamp(min=1) return avg_logits.argmax(dim=1)