import os import random import cv2 import numpy import gradio as gr import spaces from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer from realesrgan.archs.srvgg_arch import SRVGGNetCompact # -------------------- # Global (CPU-only data; KHÔNG chạm CUDA ở đây) # -------------------- last_file = None DEVICE = "cpu" # set trong gpu_startup() USE_HALF = False # set trong gpu_startup() # cache cho các upsampler đã khởi tạo UPSAMPLER_CACHE = {} # key: (model_name, denoise_strength, DEVICE, USE_HALF) GFPGAN_FACE_ENHANCER = {} # key: (outscale, DEVICE, USE_HALF) # -------------------- # ZeroGPU: cấp GPU ngay khi khởi động # -------------------- @spaces.GPU def gpu_startup(): """ Hàm này chạy ngay khi Space bật trên ZeroGPU. Chỉ ở đây mới 'đụng' tới torch/cuda. """ global DEVICE, USE_HALF import torch has_cuda = torch.cuda.is_available() DEVICE = "cuda" if has_cuda else "cpu" # half precision chỉ an toàn khi có CUDA USE_HALF = bool(has_cuda) print(f"[startup] CUDA available: {has_cuda}, device={DEVICE}, half={USE_HALF}") # -------------------- # Utils # -------------------- def rnd_string(x): chars = "abcdefghijklmnopqrstuvwxyz_0123456789" return "".join(random.choice(chars) for _ in range(x)) def image_properties(img): if img: # Chỉ báo thông tin trực tiếp từ ảnh, không kiểm tra alpha return f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img.mode}" def reset(): global last_file if last_file: try: print(f"Deleting {last_file} ...") os.remove(last_file) except Exception as e: print("Delete error:", e) finally: last_file = None return gr.update(value=None), gr.update(value=None) # -------------------- # Model builder (không gọi CUDA ở ngoài startup; mọi thứ phụ thuộc DEVICE/USE_HALF) # -------------------- def get_model_and_paths(model_name, denoise_strength): """Chuẩn bị kiến trúc model + đường dẫn trọng số + dni_weight (nếu cần).""" if model_name in ('RealESRGAN_x4plus', 'RealESRNet_x4plus'): model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) netscale = 4 file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] \ if model_name == 'RealESRGAN_x4plus' else \ ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth'] elif model_name == 'RealESRGAN_x4plus_anime_6B': model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) netscale = 4 file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'] elif model_name == 'RealESRGAN_x2plus': model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) netscale = 2 file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'] elif model_name == 'realesr-general-x4v3': model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') netscale = 4 file_url = [ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth' ] else: raise ValueError(f"Unsupported model: {model_name}") # tải trọng số (nếu chưa có) model_path = os.path.join('weights', model_name + '.pth') if not os.path.isfile(model_path): ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) for url in file_url: model_path = load_file_from_url(url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None) # dni (chỉ riêng general-x4v3) dni_weight = None if model_name == 'realesr-general-x4v3' and denoise_strength != 1: wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3') model_path = [model_path, wdn_model_path] dni_weight = [denoise_strength, 1 - denoise_strength] return model, netscale, model_path, dni_weight def get_upsampler(model_name, denoise_strength): """Khởi tạo/cached RealESRGANer theo device & half hiện hành.""" key = (model_name, float(denoise_strength), DEVICE, USE_HALF) if key in UPSAMPLER_CACHE: return UPSAMPLER_CACHE[key] model, netscale, model_path, dni_weight = get_model_and_paths(model_name, denoise_strength) # Cấu hình theo thiết bị # - half=True khi GPU; False khi CPU # - gpu_id=0 khi GPU; None khi CPU half_flag = bool(USE_HALF) gpu_id = 0 if DEVICE == "cuda" else None upsampler = RealESRGANer( scale=netscale, model_path=model_path, dni_weight=dni_weight, model=model, tile=0, tile_pad=10, pre_pad=10, half=half_flag, gpu_id=gpu_id ) UPSAMPLER_CACHE[key] = upsampler return upsampler def get_face_enhancer(upsampler, outscale): key = (int(outscale), DEVICE, USE_HALF) if key in GFPGAN_FACE_ENHANCER: return GFPGAN_FACE_ENHANCER[key] from gfpgan import GFPGANer face_enhancer = GFPGANer( model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', upscale=int(outscale), arch='clean', channel_multiplier=2, bg_upsampler=upsampler ) GFPGAN_FACE_ENHANCER[key] = face_enhancer return face_enhancer # -------------------- # Inference (đánh dấu @spaces.GPU vì có thể chạy trên GPU) # -------------------- @spaces.GPU def realesrgan(img, model_name, denoise_strength, face_enhance, outscale): """Real-ESRGAN restore/upscale.""" if not img: return upsampler = get_upsampler(model_name, denoise_strength) # PIL -> cv2 (giữ nguyên nếu có alpha; ta sẽ bỏ alpha trước khi lưu JPG) cv_img = numpy.array(img) if cv_img.ndim == 3 and cv_img.shape[2] == 4: # RGBA -> BGRA img_bgra = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA) elif cv_img.ndim == 3 and cv_img.shape[2] == 3: # RGB -> BGR, rồi thêm alpha giả để pipeline cũ vẫn chạy nếu cần bgr = cv2.cvtColor(cv_img, cv2.COLOR_RGB2BGR) alpha = numpy.full((bgr.shape[0], bgr.shape[1], 1), 255, dtype=bgr.dtype) img_bgra = numpy.concatenate([bgr, alpha], axis=2) else: # 1-channel (L) -> BGR + alpha bgr = cv2.cvtColor(cv_img, cv2.COLOR_GRAY2BGR) alpha = numpy.full((bgr.shape[0], bgr.shape[1], 1), 255, dtype=bgr.dtype) img_bgra = numpy.concatenate([bgr, alpha], axis=2) try: if face_enhance: face_enhancer = get_face_enhancer(upsampler, outscale) _, _, output = face_enhancer.enhance( img_bgra, has_aligned=False, only_center_face=False, paste_back=True ) else: output, _ = upsampler.enhance(img_bgra, outscale=int(outscale)) except RuntimeError as error: # Gợi ý tự động giảm tile nếu OOM print('Error', error) return None else: out_filename = f"output_{rnd_string(8)}.jpg" # Đảm bảo ảnh 3 kênh trước khi lưu JPG if output.ndim == 3 and output.shape[2] == 4: output_to_save = cv2.cvtColor(output, cv2.COLOR_BGRA2BGR) elif output.ndim == 3 and output.shape[2] == 3: output_to_save = output else: output_to_save = cv2.cvtColor(output, cv2.COLOR_GRAY2BGR) cv2.imwrite(out_filename, output_to_save) global last_file last_file = out_filename return out_filename # -------------------- # UI # -------------------- def main(): with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo: gr.Markdown("## Image Upscaler") with gr.Accordion("Upscaling option"): with gr.Row(): model_name = gr.Dropdown( label="Upscaler model", choices=[ "RealESRGAN_x4plus", "RealESRNet_x4plus", "RealESRGAN_x4plus_anime_6B", "RealESRGAN_x2plus", "realesr-general-x4v3", ], value="RealESRGAN_x4plus_anime_6B", show_label=True ) denoise_strength = gr.Slider(label="Denoise Strength", minimum=0, maximum=1, step=0.1, value=0.5) outscale = gr.Slider(label="Resolution upscale", minimum=1, maximum=6, step=1, value=4, show_label=True) face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)") with gr.Row(): with gr.Group(): input_image = gr.Image(label="Input Image", type="pil", image_mode="RGBA") input_image_properties = gr.Textbox(label="Image Properties", max_lines=1) output_image = gr.Image(label="Output Image", image_mode="RGB") with gr.Row(): reset_btn = gr.Button("Remove images") restore_btn = gr.Button("Upscale") input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties) restore_btn.click(fn=realesrgan, inputs=[input_image, model_name, denoise_strength, face_enhance, outscale], outputs=output_image) reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image]) demo.launch(server_name="0.0.0.0", server_port=7860) if __name__ == "__main__": # Gọi hàm startup để ZeroGPU cấp GPU ngay khi Space boot gpu_startup() main()