English
Antoine1091's picture
Upload folder using huggingface_hub
ede298f verified
import torch
import numpy as np
from tqdm import tqdm
from pathlib import Path
from typing import Dict, Tuple, List
def batch_sliding_window_logits(images: torch.Tensor, model: torch.nn.Module,
device: torch.device,
crop_size: Tuple[int, int], stride: Tuple[int, int]) -> torch.Tensor:
"""
Version of sliding inference that returns average logits for export.
"""
B, C, H, W = images.shape
ph, pw = crop_size
sh, sw = stride
images = images.to(device)
num_classes = model.config.num_labels
full_logits = torch.zeros((B, num_classes, H, W), device=device)
count_map = torch.zeros((H, W), device=device)
with torch.no_grad():
for top in range(0, H, sh):
for left in range(0, W, sw):
bottom = min(top + ph, H)
right = min(left + pw, W)
top0 = max(bottom - ph, 0)
left0 = max(right - pw, 0)
patch = images[:, :, top0:bottom, left0:right]
logits = model(pixel_values=patch).logits
full_logits[:, :, top0:bottom, left0:right] += logits
count_map[top0:bottom, left0:right] += 1
avg_logits = full_logits / count_map.unsqueeze(0).unsqueeze(0).clamp(min=1)
return avg_logits
def export_logits_images(model, loader, device, crop_size, stride, output_dir: Path):
"""
Applies batch sliding window and exports probs as .npy files.
"""
model.eval()
output_dir.mkdir(parents=True, exist_ok=True)
for images, rel_paths in tqdm(loader, desc=f"Export logits to {output_dir}", leave=False):
avg_logits = batch_sliding_window_logits(images, model, device, crop_size, stride)
probs = torch.softmax(avg_logits, dim=1)
probs = (probs * 255.0).clamp(0, 255).byte().cpu()
B, C, H, W = probs.shape
for b in range(B):
arr = probs[b].permute(1, 2, 0).numpy() # H×W×C
out_path = output_dir / rel_paths[b]
out_path.parent.mkdir(parents=True, exist_ok=True)
np.save(out_path.with_suffix('.npy'), arr)