|
import torch |
|
import numpy as np |
|
from tqdm import tqdm |
|
from pathlib import Path |
|
from typing import Dict, Tuple, List |
|
|
|
def batch_sliding_window_logits(images: torch.Tensor, model: torch.nn.Module, |
|
device: torch.device, |
|
crop_size: Tuple[int, int], stride: Tuple[int, int]) -> torch.Tensor: |
|
""" |
|
Version of sliding inference that returns average logits for export. |
|
""" |
|
B, C, H, W = images.shape |
|
ph, pw = crop_size |
|
sh, sw = stride |
|
images = images.to(device) |
|
|
|
num_classes = model.config.num_labels |
|
full_logits = torch.zeros((B, num_classes, H, W), device=device) |
|
count_map = torch.zeros((H, W), device=device) |
|
|
|
with torch.no_grad(): |
|
for top in range(0, H, sh): |
|
for left in range(0, W, sw): |
|
bottom = min(top + ph, H) |
|
right = min(left + pw, W) |
|
top0 = max(bottom - ph, 0) |
|
left0 = max(right - pw, 0) |
|
patch = images[:, :, top0:bottom, left0:right] |
|
logits = model(pixel_values=patch).logits |
|
full_logits[:, :, top0:bottom, left0:right] += logits |
|
count_map[top0:bottom, left0:right] += 1 |
|
|
|
avg_logits = full_logits / count_map.unsqueeze(0).unsqueeze(0).clamp(min=1) |
|
return avg_logits |
|
|
|
|
|
def export_logits_images(model, loader, device, crop_size, stride, output_dir: Path): |
|
""" |
|
Applies batch sliding window and exports probs as .npy files. |
|
""" |
|
model.eval() |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
for images, rel_paths in tqdm(loader, desc=f"Export logits to {output_dir}", leave=False): |
|
avg_logits = batch_sliding_window_logits(images, model, device, crop_size, stride) |
|
probs = torch.softmax(avg_logits, dim=1) |
|
probs = (probs * 255.0).clamp(0, 255).byte().cpu() |
|
|
|
B, C, H, W = probs.shape |
|
for b in range(B): |
|
arr = probs[b].permute(1, 2, 0).numpy() |
|
out_path = output_dir / rel_paths[b] |
|
out_path.parent.mkdir(parents=True, exist_ok=True) |
|
np.save(out_path.with_suffix('.npy'), arr) |
|
|