File size: 10,232 Bytes
d69ad9e
 
052a51e
 
 
 
 
d69ad9e
 
 
 
 
052a51e
 
 
d69ad9e
 
052a51e
 
 
 
 
 
 
 
 
 
79f94dc
052a51e
d69ad9e
052a51e
 
 
 
 
d69ad9e
052a51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68cea4c
 
052a51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d69ad9e
 
052a51e
 
 
 
d69ad9e
 
 
052a51e
d69ad9e
 
 
052a51e
d69ad9e
 
 
 
 
 
052a51e
 
d69ad9e
052a51e
d69ad9e
 
 
 
052a51e
 
d69ad9e
052a51e
d69ad9e
 
 
 
 
 
052a51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d69ad9e
 
 
 
 
06a3e99
d69ad9e
 
052a51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d69ad9e
052a51e
 
 
 
 
 
 
 
 
 
 
 
 
d69ad9e
68cea4c
d69ad9e
68cea4c
 
 
 
 
 
 
 
 
 
 
 
 
d69ad9e
 
 
052a51e
 
 
 
d69ad9e
052a51e
d69ad9e
052a51e
d69ad9e
052a51e
d69ad9e
68cea4c
 
 
 
 
 
 
 
 
d69ad9e
 
 
 
052a51e
 
 
d69ad9e
06a3e99
052a51e
d69ad9e
 
 
052a51e
 
 
 
 
 
 
 
 
 
 
a9dea45
052a51e
 
 
 
d69ad9e
 
 
 
68cea4c
d69ad9e
 
 
 
 
 
 
 
 
 
052a51e
d69ad9e
 
052a51e
 
d69ad9e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import os
import random
import cv2
import numpy
import gradio as gr
import spaces

from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact

# --------------------
# Global (CPU-only data; KHÔNG chạm CUDA ở đây)
# --------------------
last_file = None

DEVICE = "cpu"      # set trong gpu_startup()
USE_HALF = False    # set trong gpu_startup()

# cache cho các upsampler đã khởi tạo
UPSAMPLER_CACHE = {}       # key: (model_name, denoise_strength, DEVICE, USE_HALF)
GFPGAN_FACE_ENHANCER = {}  # key: (outscale, DEVICE, USE_HALF)

# --------------------
# ZeroGPU: cấp GPU ngay khi khởi động
# --------------------
@spaces.GPU
def gpu_startup():
    """
    Hàm này chạy ngay khi Space bật trên ZeroGPU.
    Chỉ ở đây mới 'đụng' tới torch/cuda.
    """
    global DEVICE, USE_HALF
    import torch

    has_cuda = torch.cuda.is_available()
    DEVICE = "cuda" if has_cuda else "cpu"
    # half precision chỉ an toàn khi có CUDA
    USE_HALF = bool(has_cuda)

    print(f"[startup] CUDA available: {has_cuda}, device={DEVICE}, half={USE_HALF}")

# --------------------
# Utils
# --------------------
def rnd_string(x):
    chars = "abcdefghijklmnopqrstuvwxyz_0123456789"
    return "".join(random.choice(chars) for _ in range(x))

def image_properties(img):
    if img:
        # Chỉ báo thông tin trực tiếp từ ảnh, không kiểm tra alpha
        return f"Resolution: Width: {img.size[0]}, Height: {img.size[1]}  |  Color Mode: {img.mode}"

def reset():
    global last_file
    if last_file:
        try:
            print(f"Deleting {last_file} ...")
            os.remove(last_file)
        except Exception as e:
            print("Delete error:", e)
        finally:
            last_file = None
    return gr.update(value=None), gr.update(value=None)

# --------------------
# Model builder (không gọi CUDA ở ngoài startup; mọi thứ phụ thuộc DEVICE/USE_HALF)
# --------------------
def get_model_and_paths(model_name, denoise_strength):
    """Chuẩn bị kiến trúc model + đường dẫn trọng số + dni_weight (nếu cần)."""
    if model_name in ('RealESRGAN_x4plus', 'RealESRNet_x4plus'):
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        netscale = 4
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] \
            if model_name == 'RealESRGAN_x4plus' else \
            ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
    elif model_name == 'RealESRGAN_x4plus_anime_6B':
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
        netscale = 4
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
    elif model_name == 'RealESRGAN_x2plus':
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
        netscale = 2
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
    elif model_name == 'realesr-general-x4v3':
        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
        netscale = 4
        file_url = [
            'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
            'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
        ]
    else:
        raise ValueError(f"Unsupported model: {model_name}")

    # tải trọng số (nếu chưa có)
    model_path = os.path.join('weights', model_name + '.pth')
    if not os.path.isfile(model_path):
        ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
        for url in file_url:
            model_path = load_file_from_url(url=url, model_dir=os.path.join(ROOT_DIR, 'weights'),
                                            progress=True, file_name=None)

    # dni (chỉ riêng general-x4v3)
    dni_weight = None
    if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
        wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
        model_path = [model_path, wdn_model_path]
        dni_weight = [denoise_strength, 1 - denoise_strength]

    return model, netscale, model_path, dni_weight

def get_upsampler(model_name, denoise_strength):
    """Khởi tạo/cached RealESRGANer theo device & half hiện hành."""
    key = (model_name, float(denoise_strength), DEVICE, USE_HALF)
    if key in UPSAMPLER_CACHE:
        return UPSAMPLER_CACHE[key]

    model, netscale, model_path, dni_weight = get_model_and_paths(model_name, denoise_strength)

    # Cấu hình theo thiết bị
    # - half=True khi GPU; False khi CPU
    # - gpu_id=0 khi GPU; None khi CPU
    half_flag = bool(USE_HALF)
    gpu_id = 0 if DEVICE == "cuda" else None

    upsampler = RealESRGANer(
        scale=netscale,
        model_path=model_path,
        dni_weight=dni_weight,
        model=model,
        tile=0,
        tile_pad=10,
        pre_pad=10,
        half=half_flag,
        gpu_id=gpu_id
    )
    UPSAMPLER_CACHE[key] = upsampler
    return upsampler

def get_face_enhancer(upsampler, outscale):
    key = (int(outscale), DEVICE, USE_HALF)
    if key in GFPGAN_FACE_ENHANCER:
        return GFPGAN_FACE_ENHANCER[key]
    from gfpgan import GFPGANer
    face_enhancer = GFPGANer(
        model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
        upscale=int(outscale),
        arch='clean',
        channel_multiplier=2,
        bg_upsampler=upsampler
    )
    GFPGAN_FACE_ENHANCER[key] = face_enhancer
    return face_enhancer

# --------------------
# Inference (đánh dấu @spaces.GPU vì có thể chạy trên GPU)
# --------------------
@spaces.GPU
def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
    """Real-ESRGAN restore/upscale."""
    if not img:
        return

    upsampler = get_upsampler(model_name, denoise_strength)

    # PIL -> cv2 (giữ nguyên nếu có alpha; ta sẽ bỏ alpha trước khi lưu JPG)
    cv_img = numpy.array(img)
    if cv_img.ndim == 3 and cv_img.shape[2] == 4:
        # RGBA -> BGRA
        img_bgra = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
    elif cv_img.ndim == 3 and cv_img.shape[2] == 3:
        # RGB -> BGR, rồi thêm alpha giả để pipeline cũ vẫn chạy nếu cần
        bgr = cv2.cvtColor(cv_img, cv2.COLOR_RGB2BGR)
        alpha = numpy.full((bgr.shape[0], bgr.shape[1], 1), 255, dtype=bgr.dtype)
        img_bgra = numpy.concatenate([bgr, alpha], axis=2)
    else:
        # 1-channel (L) -> BGR + alpha
        bgr = cv2.cvtColor(cv_img, cv2.COLOR_GRAY2BGR)
        alpha = numpy.full((bgr.shape[0], bgr.shape[1], 1), 255, dtype=bgr.dtype)
        img_bgra = numpy.concatenate([bgr, alpha], axis=2)

    try:
        if face_enhance:
            face_enhancer = get_face_enhancer(upsampler, outscale)
            _, _, output = face_enhancer.enhance(
                img_bgra, has_aligned=False, only_center_face=False, paste_back=True
            )
        else:
            output, _ = upsampler.enhance(img_bgra, outscale=int(outscale))
    except RuntimeError as error:
        # Gợi ý tự động giảm tile nếu OOM
        print('Error', error)
        return None
    else:
        out_filename = f"output_{rnd_string(8)}.jpg"
        # Đảm bảo ảnh 3 kênh trước khi lưu JPG
        if output.ndim == 3 and output.shape[2] == 4:
            output_to_save = cv2.cvtColor(output, cv2.COLOR_BGRA2BGR)
        elif output.ndim == 3 and output.shape[2] == 3:
            output_to_save = output
        else:
            output_to_save = cv2.cvtColor(output, cv2.COLOR_GRAY2BGR)
        cv2.imwrite(out_filename, output_to_save)
        global last_file
        last_file = out_filename
        return out_filename

# --------------------
# UI
# --------------------
def main():
    with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo:
        gr.Markdown("## Image Upscaler")

        with gr.Accordion("Upscaling option"):
            with gr.Row():
                model_name = gr.Dropdown(
                    label="Upscaler model",
                    choices=[
                        "RealESRGAN_x4plus",
                        "RealESRNet_x4plus",
                        "RealESRGAN_x4plus_anime_6B",
                        "RealESRGAN_x2plus",
                        "realesr-general-x4v3",
                    ],
                    value="RealESRGAN_x4plus_anime_6B",
                    show_label=True
                )
                denoise_strength = gr.Slider(label="Denoise Strength", minimum=0, maximum=1, step=0.1, value=0.5)
                outscale = gr.Slider(label="Resolution upscale", minimum=1, maximum=6, step=1, value=4, show_label=True)
                face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)")

        with gr.Row():
            with gr.Group():
                input_image = gr.Image(label="Input Image", type="pil", image_mode="RGBA")
                input_image_properties = gr.Textbox(label="Image Properties", max_lines=1)
            output_image = gr.Image(label="Output Image", image_mode="RGB")
        with gr.Row():
            reset_btn = gr.Button("Remove images")
            restore_btn = gr.Button("Upscale")

        input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties)
        restore_btn.click(fn=realesrgan,
                          inputs=[input_image, model_name, denoise_strength, face_enhance, outscale],
                          outputs=output_image)
        reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image])

    demo.launch(server_name="0.0.0.0", server_port=7860)

if __name__ == "__main__":
    # Gọi hàm startup để ZeroGPU cấp GPU ngay khi Space boot
    gpu_startup()
    main()