|
|
|
import cv2 |
|
import numpy as np |
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from tqdm import tqdm |
|
|
|
|
|
from basicsr.archs.mair_arch import MaIR |
|
from basicsr.utils.img_util import tensor2img |
|
|
|
class MaIR_Upsampler: |
|
""" |
|
A self-contained class for the MaIR model for inference. |
|
Handles model loading, pre-processing, and tiling for large images. |
|
""" |
|
def __init__(self, model_name, device=None): |
|
self.model_name = model_name |
|
|
|
if device is None: |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
else: |
|
self.device = device |
|
|
|
print(f"Using device: {self.device} for model {self.model_name}") |
|
|
|
self.MODEL_CONFIGS = self._get_model_configs() |
|
|
|
if model_name not in self.MODEL_CONFIGS: |
|
raise ValueError(f"Model '{model_name}' not recognized. Available: {list(self.MODEL_CONFIGS.keys())}") |
|
|
|
self.model, self.scale = self._load_model() |
|
self.model.eval() |
|
self.model.to(self.device) |
|
|
|
def _get_model_configs(self): |
|
"""Returns a dictionary of all supported model configurations.""" |
|
mair_sr_base_params = { |
|
'img_size': 64, 'patch_size': 1, 'in_chans': 3, 'embed_dim': 180, |
|
'depths': (6, 6, 6, 6, 6, 6), 'drop_rate': 0., 'd_state': 16, |
|
'ssm_ratio': 2.0, 'mlp_ratio': 2.5, 'drop_path_rate': 0.1, |
|
'norm_layer': nn.LayerNorm, 'patch_norm': True, 'use_checkpoint': False, |
|
'img_range': 1., 'upsampler': 'pixelshuffle', 'resi_connection': '1conv', |
|
'dynamic_ids': True, 'scan_len': 4, |
|
} |
|
mair_cdn_base_params = mair_sr_base_params.copy() |
|
mair_cdn_base_params.update({'upscale': 1, 'upsampler': ''}) |
|
|
|
return { |
|
'MaIR-SRx4': {'task': 'SR', 'scale': 4, 'filename': 'MaIR_SR_x4.pth', 'params': {**mair_sr_base_params, 'upscale': 4}}, |
|
'MaIR-SRx2': {'task': 'SR', 'scale': 2, 'filename': 'MaIR_SR_x2.pth', 'params': {**mair_sr_base_params, 'upscale': 2}}, |
|
'MaIR-CDN-s50': {'task': 'DN', 'scale': 1, 'filename': 'MaIR_CDN_s50.pth', 'params': mair_cdn_base_params}, |
|
} |
|
|
|
def _load_model(self): |
|
"""Loads the pretrained model weights from the local 'checkpoints' folder.""" |
|
config = self.MODEL_CONFIGS[self.model_name] |
|
params = config['params'] |
|
filename = config['filename'] |
|
scale = config['scale'] |
|
|
|
model_path = os.path.join('checkpoints', filename) |
|
if not os.path.exists(model_path): |
|
raise FileNotFoundError(f"Checkpoint not found: {model_path}. Ensure it's in a 'checkpoints' folder.") |
|
|
|
model = MaIR(**params) |
|
load_net = torch.load(model_path, map_location=self.device) |
|
param_key = 'params_ema' if 'params_ema' in load_net else 'params' |
|
load_net = load_net[param_key] |
|
|
|
for k, v in list(load_net.items()): |
|
if k.startswith('module.'): |
|
load_net[k[7:]] = v |
|
del load_net[k] |
|
|
|
model.load_state_dict(load_net, strict=True) |
|
print(f"Model {self.model_name} loaded successfully from {model_path}.") |
|
return model, scale |
|
|
|
def _tile_inference(self, img_tensor): |
|
"""Performs inference using a tiling strategy to handle large images.""" |
|
b, c, h, w = img_tensor.size() |
|
tile_size, tile_pad = 200, 20 |
|
num_tiles_h = int(np.ceil(h / tile_size)) |
|
num_tiles_w = int(np.ceil(w / tile_size)) |
|
pad_h, pad_w = num_tiles_h * tile_size - h, num_tiles_w * tile_size - w |
|
img_padded = F.pad(img_tensor, (0, pad_w, 0, pad_h), 'reflect') |
|
output_padded = F.interpolate(torch.zeros_like(img_padded), scale_factor=self.scale, mode='nearest') |
|
|
|
with torch.no_grad(): |
|
for i in range(num_tiles_h): |
|
for j in range(num_tiles_w): |
|
h_start, h_end = i * tile_size, (i + 1) * tile_size |
|
w_start, w_end = j * tile_size, (j + 1) * tile_size |
|
h_start_pad, h_end_pad = max(0, h_start - tile_pad), min(img_padded.shape[2], h_end + tile_pad) |
|
w_start_pad, w_end_pad = max(0, w_start - tile_pad), min(img_padded.shape[3], w_end + tile_pad) |
|
|
|
tile_input = img_padded[:, :, h_start_pad:h_end_pad, w_start_pad:w_end_pad] |
|
tile_output = self.model(tile_input) |
|
|
|
out_h_start, out_h_end = h_start * self.scale, h_end * self.scale |
|
out_w_start, out_w_end = w_start * self.scale, w_end * self.scale |
|
cut_h_start = (h_start - h_start_pad) * self.scale |
|
cut_h_end = cut_h_start + tile_size * self.scale |
|
cut_w_start = (w_start - w_start_pad) * self.scale |
|
cut_w_end = cut_w_start + tile_size * self.scale |
|
|
|
output_padded[:, :, out_h_start:out_h_end, out_w_start:out_w_end] = tile_output[:, :, cut_h_start:cut_h_end, cut_w_start:cut_w_end] |
|
|
|
return output_padded[:, :, :h * self.scale, :w * self.scale] |
|
|
|
def process(self, img): |
|
"""Main inference function.""" |
|
|
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
img_tensor = torch.from_numpy(img_rgb.transpose(2, 0, 1)).float() / 255.0 |
|
img_tensor = img_tensor.unsqueeze(0).to(self.device) |
|
|
|
|
|
output_tensor = self._tile_inference(img_tensor) |
|
|
|
|
|
return tensor2img(output_tensor, rgb2bgr=True) |