English
File size: 1,314 Bytes
ede298f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from typing import Tuple

def batch_sliding_window_inference(images: torch.Tensor, model: torch.nn.Module,
                                   device: torch.device,
                                   crop_size: Tuple[int, int], stride: Tuple[int, int]) -> torch.Tensor:
    """
    Applies sliding window inference with final argmax (prediction).
    Returns a tensor (B, H, W) with predicted classes.
    """
    B, C, H, W = images.shape
    ph, pw = crop_size
    sh, sw = stride
    images = images.to(device)

    num_classes = model.config.num_labels
    full_logits = torch.zeros((B, num_classes, H, W), device=device)
    count_map = torch.zeros((H, W), device=device)

    for top in range(0, H, sh):
        for left in range(0, W, sw):
            bottom = min(top + ph, H)
            right = min(left + pw, W)
            top0 = max(bottom - ph, 0)
            left0 = max(right - pw, 0)

            patch = images[:, :, top0:bottom, left0:right].contiguous()
            with torch.no_grad():
                logits = model(pixel_values=patch).logits
            full_logits[:, :, top0:bottom, left0:right] += logits
            count_map[top0:bottom, left0:right] += 1

    avg_logits = full_logits / count_map.unsqueeze(0).unsqueeze(0).clamp(min=1)
    return avg_logits.argmax(dim=1)