|
import torch |
|
from typing import Tuple |
|
|
|
def batch_sliding_window_inference(images: torch.Tensor, model: torch.nn.Module, |
|
device: torch.device, |
|
crop_size: Tuple[int, int], stride: Tuple[int, int]) -> torch.Tensor: |
|
""" |
|
Applies sliding window inference with final argmax (prediction). |
|
Returns a tensor (B, H, W) with predicted classes. |
|
""" |
|
B, C, H, W = images.shape |
|
ph, pw = crop_size |
|
sh, sw = stride |
|
images = images.to(device) |
|
|
|
num_classes = model.config.num_labels |
|
full_logits = torch.zeros((B, num_classes, H, W), device=device) |
|
count_map = torch.zeros((H, W), device=device) |
|
|
|
for top in range(0, H, sh): |
|
for left in range(0, W, sw): |
|
bottom = min(top + ph, H) |
|
right = min(left + pw, W) |
|
top0 = max(bottom - ph, 0) |
|
left0 = max(right - pw, 0) |
|
|
|
patch = images[:, :, top0:bottom, left0:right].contiguous() |
|
with torch.no_grad(): |
|
logits = model(pixel_values=patch).logits |
|
full_logits[:, :, top0:bottom, left0:right] += logits |
|
count_map[top0:bottom, left0:right] += 1 |
|
|
|
avg_logits = full_logits / count_map.unsqueeze(0).unsqueeze(0).clamp(min=1) |
|
return avg_logits.argmax(dim=1) |
|
|