import torch import numpy as np from tqdm import tqdm from pathlib import Path from typing import Dict, Tuple, List def batch_sliding_window_logits(images: torch.Tensor, model: torch.nn.Module, device: torch.device, crop_size: Tuple[int, int], stride: Tuple[int, int]) -> torch.Tensor: """ Version of sliding inference that returns average logits for export. """ B, C, H, W = images.shape ph, pw = crop_size sh, sw = stride images = images.to(device) num_classes = model.config.num_labels full_logits = torch.zeros((B, num_classes, H, W), device=device) count_map = torch.zeros((H, W), device=device) with torch.no_grad(): for top in range(0, H, sh): for left in range(0, W, sw): bottom = min(top + ph, H) right = min(left + pw, W) top0 = max(bottom - ph, 0) left0 = max(right - pw, 0) patch = images[:, :, top0:bottom, left0:right] logits = model(pixel_values=patch).logits full_logits[:, :, top0:bottom, left0:right] += logits count_map[top0:bottom, left0:right] += 1 avg_logits = full_logits / count_map.unsqueeze(0).unsqueeze(0).clamp(min=1) return avg_logits def export_logits_images(model, loader, device, crop_size, stride, output_dir: Path): """ Applies batch sliding window and exports probs as .npy files. """ model.eval() output_dir.mkdir(parents=True, exist_ok=True) for images, rel_paths in tqdm(loader, desc=f"Export logits to {output_dir}", leave=False): avg_logits = batch_sliding_window_logits(images, model, device, crop_size, stride) probs = torch.softmax(avg_logits, dim=1) probs = (probs * 255.0).clamp(0, 255).byte().cpu() B, C, H, W = probs.shape for b in range(B): arr = probs[b].permute(1, 2, 0).numpy() # H×W×C out_path = output_dir / rel_paths[b] out_path.parent.mkdir(parents=True, exist_ok=True) np.save(out_path.with_suffix('.npy'), arr)