File size: 5,782 Bytes
5df9861 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# inference_runner.py
import cv2
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
# --- Import model architecture from the repository ---
from basicsr.archs.mair_arch import MaIR
from basicsr.utils.img_util import tensor2img
class MaIR_Upsampler:
"""
A self-contained class for the MaIR model for inference.
Handles model loading, pre-processing, and tiling for large images.
"""
def __init__(self, model_name, device=None):
self.model_name = model_name
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
print(f"Using device: {self.device} for model {self.model_name}")
self.MODEL_CONFIGS = self._get_model_configs()
if model_name not in self.MODEL_CONFIGS:
raise ValueError(f"Model '{model_name}' not recognized. Available: {list(self.MODEL_CONFIGS.keys())}")
self.model, self.scale = self._load_model()
self.model.eval()
self.model.to(self.device)
def _get_model_configs(self):
"""Returns a dictionary of all supported model configurations."""
mair_sr_base_params = {
'img_size': 64, 'patch_size': 1, 'in_chans': 3, 'embed_dim': 180,
'depths': (6, 6, 6, 6, 6, 6), 'drop_rate': 0., 'd_state': 16,
'ssm_ratio': 2.0, 'mlp_ratio': 2.5, 'drop_path_rate': 0.1,
'norm_layer': nn.LayerNorm, 'patch_norm': True, 'use_checkpoint': False,
'img_range': 1., 'upsampler': 'pixelshuffle', 'resi_connection': '1conv',
'dynamic_ids': True, 'scan_len': 4,
}
mair_cdn_base_params = mair_sr_base_params.copy()
mair_cdn_base_params.update({'upscale': 1, 'upsampler': ''})
return {
'MaIR-SRx4': {'task': 'SR', 'scale': 4, 'filename': 'MaIR_SR_x4.pth', 'params': {**mair_sr_base_params, 'upscale': 4}},
'MaIR-SRx2': {'task': 'SR', 'scale': 2, 'filename': 'MaIR_SR_x2.pth', 'params': {**mair_sr_base_params, 'upscale': 2}},
'MaIR-CDN-s50': {'task': 'DN', 'scale': 1, 'filename': 'MaIR_CDN_s50.pth', 'params': mair_cdn_base_params},
}
def _load_model(self):
"""Loads the pretrained model weights from the local 'checkpoints' folder."""
config = self.MODEL_CONFIGS[self.model_name]
params = config['params']
filename = config['filename']
scale = config['scale']
model_path = os.path.join('checkpoints', filename)
if not os.path.exists(model_path):
raise FileNotFoundError(f"Checkpoint not found: {model_path}. Ensure it's in a 'checkpoints' folder.")
model = MaIR(**params)
load_net = torch.load(model_path, map_location=self.device)
param_key = 'params_ema' if 'params_ema' in load_net else 'params'
load_net = load_net[param_key]
for k, v in list(load_net.items()):
if k.startswith('module.'):
load_net[k[7:]] = v
del load_net[k]
model.load_state_dict(load_net, strict=True)
print(f"Model {self.model_name} loaded successfully from {model_path}.")
return model, scale
def _tile_inference(self, img_tensor):
"""Performs inference using a tiling strategy to handle large images."""
b, c, h, w = img_tensor.size()
tile_size, tile_pad = 200, 20
num_tiles_h = int(np.ceil(h / tile_size))
num_tiles_w = int(np.ceil(w / tile_size))
pad_h, pad_w = num_tiles_h * tile_size - h, num_tiles_w * tile_size - w
img_padded = F.pad(img_tensor, (0, pad_w, 0, pad_h), 'reflect')
output_padded = F.interpolate(torch.zeros_like(img_padded), scale_factor=self.scale, mode='nearest')
with torch.no_grad():
for i in range(num_tiles_h):
for j in range(num_tiles_w):
h_start, h_end = i * tile_size, (i + 1) * tile_size
w_start, w_end = j * tile_size, (j + 1) * tile_size
h_start_pad, h_end_pad = max(0, h_start - tile_pad), min(img_padded.shape[2], h_end + tile_pad)
w_start_pad, w_end_pad = max(0, w_start - tile_pad), min(img_padded.shape[3], w_end + tile_pad)
tile_input = img_padded[:, :, h_start_pad:h_end_pad, w_start_pad:w_end_pad]
tile_output = self.model(tile_input)
out_h_start, out_h_end = h_start * self.scale, h_end * self.scale
out_w_start, out_w_end = w_start * self.scale, w_end * self.scale
cut_h_start = (h_start - h_start_pad) * self.scale
cut_h_end = cut_h_start + tile_size * self.scale
cut_w_start = (w_start - w_start_pad) * self.scale
cut_w_end = cut_w_start + tile_size * self.scale
output_padded[:, :, out_h_start:out_h_end, out_w_start:out_w_end] = tile_output[:, :, cut_h_start:cut_h_end, cut_w_start:cut_w_end]
return output_padded[:, :, :h * self.scale, :w * self.scale]
def process(self, img):
"""Main inference function."""
# Pre-processing
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_tensor = torch.from_numpy(img_rgb.transpose(2, 0, 1)).float() / 255.0
img_tensor = img_tensor.unsqueeze(0).to(self.device)
# Inference
output_tensor = self._tile_inference(img_tensor)
# Post-processing
return tensor2img(output_tensor, rgb2bgr=True) |