File size: 5,782 Bytes
5df9861
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# inference_runner.py
import cv2
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

# --- Import model architecture from the repository ---
from basicsr.archs.mair_arch import MaIR
from basicsr.utils.img_util import tensor2img

class MaIR_Upsampler:
    """
    A self-contained class for the MaIR model for inference.
    Handles model loading, pre-processing, and tiling for large images.
    """
    def __init__(self, model_name, device=None):
        self.model_name = model_name
        
        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
        
        print(f"Using device: {self.device} for model {self.model_name}")

        self.MODEL_CONFIGS = self._get_model_configs()
        
        if model_name not in self.MODEL_CONFIGS:
            raise ValueError(f"Model '{model_name}' not recognized. Available: {list(self.MODEL_CONFIGS.keys())}")

        self.model, self.scale = self._load_model()
        self.model.eval()
        self.model.to(self.device)

    def _get_model_configs(self):
        """Returns a dictionary of all supported model configurations."""
        mair_sr_base_params = {
            'img_size': 64, 'patch_size': 1, 'in_chans': 3, 'embed_dim': 180, 
            'depths': (6, 6, 6, 6, 6, 6), 'drop_rate': 0., 'd_state': 16, 
            'ssm_ratio': 2.0, 'mlp_ratio': 2.5, 'drop_path_rate': 0.1, 
            'norm_layer': nn.LayerNorm, 'patch_norm': True, 'use_checkpoint': False, 
            'img_range': 1., 'upsampler': 'pixelshuffle', 'resi_connection': '1conv', 
            'dynamic_ids': True, 'scan_len': 4,
        }
        mair_cdn_base_params = mair_sr_base_params.copy()
        mair_cdn_base_params.update({'upscale': 1, 'upsampler': ''})

        return {
            'MaIR-SRx4': {'task': 'SR', 'scale': 4, 'filename': 'MaIR_SR_x4.pth', 'params': {**mair_sr_base_params, 'upscale': 4}},
            'MaIR-SRx2': {'task': 'SR', 'scale': 2, 'filename': 'MaIR_SR_x2.pth', 'params': {**mair_sr_base_params, 'upscale': 2}},
            'MaIR-CDN-s50': {'task': 'DN', 'scale': 1, 'filename': 'MaIR_CDN_s50.pth', 'params': mair_cdn_base_params},
        }

    def _load_model(self):
        """Loads the pretrained model weights from the local 'checkpoints' folder."""
        config = self.MODEL_CONFIGS[self.model_name]
        params = config['params']
        filename = config['filename']
        scale = config['scale']
        
        model_path = os.path.join('checkpoints', filename)
        if not os.path.exists(model_path):
             raise FileNotFoundError(f"Checkpoint not found: {model_path}. Ensure it's in a 'checkpoints' folder.")
        
        model = MaIR(**params)
        load_net = torch.load(model_path, map_location=self.device)
        param_key = 'params_ema' if 'params_ema' in load_net else 'params'
        load_net = load_net[param_key]

        for k, v in list(load_net.items()):
            if k.startswith('module.'):
                load_net[k[7:]] = v
                del load_net[k]
        
        model.load_state_dict(load_net, strict=True)
        print(f"Model {self.model_name} loaded successfully from {model_path}.")
        return model, scale
        
    def _tile_inference(self, img_tensor):
        """Performs inference using a tiling strategy to handle large images."""
        b, c, h, w = img_tensor.size()
        tile_size, tile_pad = 200, 20
        num_tiles_h = int(np.ceil(h / tile_size))
        num_tiles_w = int(np.ceil(w / tile_size))
        pad_h, pad_w = num_tiles_h * tile_size - h, num_tiles_w * tile_size - w
        img_padded = F.pad(img_tensor, (0, pad_w, 0, pad_h), 'reflect')
        output_padded = F.interpolate(torch.zeros_like(img_padded), scale_factor=self.scale, mode='nearest')
        
        with torch.no_grad():
            for i in range(num_tiles_h):
                for j in range(num_tiles_w):
                    h_start, h_end = i * tile_size, (i + 1) * tile_size
                    w_start, w_end = j * tile_size, (j + 1) * tile_size
                    h_start_pad, h_end_pad = max(0, h_start - tile_pad), min(img_padded.shape[2], h_end + tile_pad)
                    w_start_pad, w_end_pad = max(0, w_start - tile_pad), min(img_padded.shape[3], w_end + tile_pad)
                    
                    tile_input = img_padded[:, :, h_start_pad:h_end_pad, w_start_pad:w_end_pad]
                    tile_output = self.model(tile_input)
                    
                    out_h_start, out_h_end = h_start * self.scale, h_end * self.scale
                    out_w_start, out_w_end = w_start * self.scale, w_end * self.scale
                    cut_h_start = (h_start - h_start_pad) * self.scale
                    cut_h_end = cut_h_start + tile_size * self.scale
                    cut_w_start = (w_start - w_start_pad) * self.scale
                    cut_w_end = cut_w_start + tile_size * self.scale
                    
                    output_padded[:, :, out_h_start:out_h_end, out_w_start:out_w_end] = tile_output[:, :, cut_h_start:cut_h_end, cut_w_start:cut_w_end]
                    
        return output_padded[:, :, :h * self.scale, :w * self.scale]

    def process(self, img):
        """Main inference function."""
        # Pre-processing
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_tensor = torch.from_numpy(img_rgb.transpose(2, 0, 1)).float() / 255.0
        img_tensor = img_tensor.unsqueeze(0).to(self.device)
        
        # Inference
        output_tensor = self._tile_inference(img_tensor)
            
        # Post-processing
        return tensor2img(output_tensor, rgb2bgr=True)