# inference_runner.py import cv2 import numpy as np import os import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm # --- Import model architecture from the repository --- from basicsr.archs.mair_arch import MaIR from basicsr.utils.img_util import tensor2img class MaIR_Upsampler: """ A self-contained class for the MaIR model for inference. Handles model loading, pre-processing, and tiling for large images. """ def __init__(self, model_name, device=None): self.model_name = model_name if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = device print(f"Using device: {self.device} for model {self.model_name}") self.MODEL_CONFIGS = self._get_model_configs() if model_name not in self.MODEL_CONFIGS: raise ValueError(f"Model '{model_name}' not recognized. Available: {list(self.MODEL_CONFIGS.keys())}") self.model, self.scale = self._load_model() self.model.eval() self.model.to(self.device) def _get_model_configs(self): """Returns a dictionary of all supported model configurations.""" mair_sr_base_params = { 'img_size': 64, 'patch_size': 1, 'in_chans': 3, 'embed_dim': 180, 'depths': (6, 6, 6, 6, 6, 6), 'drop_rate': 0., 'd_state': 16, 'ssm_ratio': 2.0, 'mlp_ratio': 2.5, 'drop_path_rate': 0.1, 'norm_layer': nn.LayerNorm, 'patch_norm': True, 'use_checkpoint': False, 'img_range': 1., 'upsampler': 'pixelshuffle', 'resi_connection': '1conv', 'dynamic_ids': True, 'scan_len': 4, } mair_cdn_base_params = mair_sr_base_params.copy() mair_cdn_base_params.update({'upscale': 1, 'upsampler': ''}) return { 'MaIR-SRx4': {'task': 'SR', 'scale': 4, 'filename': 'MaIR_SR_x4.pth', 'params': {**mair_sr_base_params, 'upscale': 4}}, 'MaIR-SRx2': {'task': 'SR', 'scale': 2, 'filename': 'MaIR_SR_x2.pth', 'params': {**mair_sr_base_params, 'upscale': 2}}, 'MaIR-CDN-s50': {'task': 'DN', 'scale': 1, 'filename': 'MaIR_CDN_s50.pth', 'params': mair_cdn_base_params}, } def _load_model(self): """Loads the pretrained model weights from the local 'checkpoints' folder.""" config = self.MODEL_CONFIGS[self.model_name] params = config['params'] filename = config['filename'] scale = config['scale'] model_path = os.path.join('checkpoints', filename) if not os.path.exists(model_path): raise FileNotFoundError(f"Checkpoint not found: {model_path}. Ensure it's in a 'checkpoints' folder.") model = MaIR(**params) load_net = torch.load(model_path, map_location=self.device) param_key = 'params_ema' if 'params_ema' in load_net else 'params' load_net = load_net[param_key] for k, v in list(load_net.items()): if k.startswith('module.'): load_net[k[7:]] = v del load_net[k] model.load_state_dict(load_net, strict=True) print(f"Model {self.model_name} loaded successfully from {model_path}.") return model, scale def _tile_inference(self, img_tensor): """Performs inference using a tiling strategy to handle large images.""" b, c, h, w = img_tensor.size() tile_size, tile_pad = 200, 20 num_tiles_h = int(np.ceil(h / tile_size)) num_tiles_w = int(np.ceil(w / tile_size)) pad_h, pad_w = num_tiles_h * tile_size - h, num_tiles_w * tile_size - w img_padded = F.pad(img_tensor, (0, pad_w, 0, pad_h), 'reflect') output_padded = F.interpolate(torch.zeros_like(img_padded), scale_factor=self.scale, mode='nearest') with torch.no_grad(): for i in range(num_tiles_h): for j in range(num_tiles_w): h_start, h_end = i * tile_size, (i + 1) * tile_size w_start, w_end = j * tile_size, (j + 1) * tile_size h_start_pad, h_end_pad = max(0, h_start - tile_pad), min(img_padded.shape[2], h_end + tile_pad) w_start_pad, w_end_pad = max(0, w_start - tile_pad), min(img_padded.shape[3], w_end + tile_pad) tile_input = img_padded[:, :, h_start_pad:h_end_pad, w_start_pad:w_end_pad] tile_output = self.model(tile_input) out_h_start, out_h_end = h_start * self.scale, h_end * self.scale out_w_start, out_w_end = w_start * self.scale, w_end * self.scale cut_h_start = (h_start - h_start_pad) * self.scale cut_h_end = cut_h_start + tile_size * self.scale cut_w_start = (w_start - w_start_pad) * self.scale cut_w_end = cut_w_start + tile_size * self.scale output_padded[:, :, out_h_start:out_h_end, out_w_start:out_w_end] = tile_output[:, :, cut_h_start:cut_h_end, cut_w_start:cut_w_end] return output_padded[:, :, :h * self.scale, :w * self.scale] def process(self, img): """Main inference function.""" # Pre-processing img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_tensor = torch.from_numpy(img_rgb.transpose(2, 0, 1)).float() / 255.0 img_tensor = img_tensor.unsqueeze(0).to(self.device) # Inference output_tensor = self._tile_inference(img_tensor) # Post-processing return tensor2img(output_tensor, rgb2bgr=True)