|
import os |
|
import random |
|
import cv2 |
|
import numpy |
|
import gradio as gr |
|
import spaces |
|
|
|
from basicsr.archs.rrdbnet_arch import RRDBNet |
|
from basicsr.utils.download_util import load_file_from_url |
|
from realesrgan import RealESRGANer |
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact |
|
|
|
|
|
|
|
|
|
last_file = None |
|
|
|
DEVICE = "cpu" |
|
USE_HALF = False |
|
|
|
|
|
UPSAMPLER_CACHE = {} |
|
GFPGAN_FACE_ENHANCER = {} |
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
def gpu_startup(): |
|
""" |
|
Hàm này chạy ngay khi Space bật trên ZeroGPU. |
|
Chỉ ở đây mới 'đụng' tới torch/cuda. |
|
""" |
|
global DEVICE, USE_HALF |
|
import torch |
|
|
|
has_cuda = torch.cuda.is_available() |
|
DEVICE = "cuda" if has_cuda else "cpu" |
|
|
|
USE_HALF = bool(has_cuda) |
|
|
|
print(f"[startup] CUDA available: {has_cuda}, device={DEVICE}, half={USE_HALF}") |
|
|
|
|
|
|
|
|
|
def rnd_string(x): |
|
chars = "abcdefghijklmnopqrstuvwxyz_0123456789" |
|
return "".join(random.choice(chars) for _ in range(x)) |
|
|
|
def image_properties(img): |
|
if img: |
|
|
|
return f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img.mode}" |
|
|
|
def reset(): |
|
global last_file |
|
if last_file: |
|
try: |
|
print(f"Deleting {last_file} ...") |
|
os.remove(last_file) |
|
except Exception as e: |
|
print("Delete error:", e) |
|
finally: |
|
last_file = None |
|
return gr.update(value=None), gr.update(value=None) |
|
|
|
|
|
|
|
|
|
def get_model_and_paths(model_name, denoise_strength): |
|
"""Chuẩn bị kiến trúc model + đường dẫn trọng số + dni_weight (nếu cần).""" |
|
if model_name in ('RealESRGAN_x4plus', 'RealESRNet_x4plus'): |
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) |
|
netscale = 4 |
|
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] \ |
|
if model_name == 'RealESRGAN_x4plus' else \ |
|
['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth'] |
|
elif model_name == 'RealESRGAN_x4plus_anime_6B': |
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) |
|
netscale = 4 |
|
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'] |
|
elif model_name == 'RealESRGAN_x2plus': |
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) |
|
netscale = 2 |
|
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'] |
|
elif model_name == 'realesr-general-x4v3': |
|
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') |
|
netscale = 4 |
|
file_url = [ |
|
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth', |
|
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth' |
|
] |
|
else: |
|
raise ValueError(f"Unsupported model: {model_name}") |
|
|
|
|
|
model_path = os.path.join('weights', model_name + '.pth') |
|
if not os.path.isfile(model_path): |
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
for url in file_url: |
|
model_path = load_file_from_url(url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), |
|
progress=True, file_name=None) |
|
|
|
|
|
dni_weight = None |
|
if model_name == 'realesr-general-x4v3' and denoise_strength != 1: |
|
wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3') |
|
model_path = [model_path, wdn_model_path] |
|
dni_weight = [denoise_strength, 1 - denoise_strength] |
|
|
|
return model, netscale, model_path, dni_weight |
|
|
|
def get_upsampler(model_name, denoise_strength): |
|
"""Khởi tạo/cached RealESRGANer theo device & half hiện hành.""" |
|
key = (model_name, float(denoise_strength), DEVICE, USE_HALF) |
|
if key in UPSAMPLER_CACHE: |
|
return UPSAMPLER_CACHE[key] |
|
|
|
model, netscale, model_path, dni_weight = get_model_and_paths(model_name, denoise_strength) |
|
|
|
|
|
|
|
|
|
half_flag = bool(USE_HALF) |
|
gpu_id = 0 if DEVICE == "cuda" else None |
|
|
|
upsampler = RealESRGANer( |
|
scale=netscale, |
|
model_path=model_path, |
|
dni_weight=dni_weight, |
|
model=model, |
|
tile=0, |
|
tile_pad=10, |
|
pre_pad=10, |
|
half=half_flag, |
|
gpu_id=gpu_id |
|
) |
|
UPSAMPLER_CACHE[key] = upsampler |
|
return upsampler |
|
|
|
def get_face_enhancer(upsampler, outscale): |
|
key = (int(outscale), DEVICE, USE_HALF) |
|
if key in GFPGAN_FACE_ENHANCER: |
|
return GFPGAN_FACE_ENHANCER[key] |
|
from gfpgan import GFPGANer |
|
face_enhancer = GFPGANer( |
|
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', |
|
upscale=int(outscale), |
|
arch='clean', |
|
channel_multiplier=2, |
|
bg_upsampler=upsampler |
|
) |
|
GFPGAN_FACE_ENHANCER[key] = face_enhancer |
|
return face_enhancer |
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
def realesrgan(img, model_name, denoise_strength, face_enhance, outscale): |
|
"""Real-ESRGAN restore/upscale.""" |
|
if not img: |
|
return |
|
|
|
upsampler = get_upsampler(model_name, denoise_strength) |
|
|
|
|
|
cv_img = numpy.array(img) |
|
if cv_img.ndim == 3 and cv_img.shape[2] == 4: |
|
|
|
img_bgra = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA) |
|
elif cv_img.ndim == 3 and cv_img.shape[2] == 3: |
|
|
|
bgr = cv2.cvtColor(cv_img, cv2.COLOR_RGB2BGR) |
|
alpha = numpy.full((bgr.shape[0], bgr.shape[1], 1), 255, dtype=bgr.dtype) |
|
img_bgra = numpy.concatenate([bgr, alpha], axis=2) |
|
else: |
|
|
|
bgr = cv2.cvtColor(cv_img, cv2.COLOR_GRAY2BGR) |
|
alpha = numpy.full((bgr.shape[0], bgr.shape[1], 1), 255, dtype=bgr.dtype) |
|
img_bgra = numpy.concatenate([bgr, alpha], axis=2) |
|
|
|
try: |
|
if face_enhance: |
|
face_enhancer = get_face_enhancer(upsampler, outscale) |
|
_, _, output = face_enhancer.enhance( |
|
img_bgra, has_aligned=False, only_center_face=False, paste_back=True |
|
) |
|
else: |
|
output, _ = upsampler.enhance(img_bgra, outscale=int(outscale)) |
|
except RuntimeError as error: |
|
|
|
print('Error', error) |
|
return None |
|
else: |
|
out_filename = f"output_{rnd_string(8)}.jpg" |
|
|
|
if output.ndim == 3 and output.shape[2] == 4: |
|
output_to_save = cv2.cvtColor(output, cv2.COLOR_BGRA2BGR) |
|
elif output.ndim == 3 and output.shape[2] == 3: |
|
output_to_save = output |
|
else: |
|
output_to_save = cv2.cvtColor(output, cv2.COLOR_GRAY2BGR) |
|
cv2.imwrite(out_filename, output_to_save) |
|
global last_file |
|
last_file = out_filename |
|
return out_filename |
|
|
|
|
|
|
|
|
|
def main(): |
|
with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo: |
|
gr.Markdown("## Image Upscaler") |
|
|
|
with gr.Accordion("Upscaling option"): |
|
with gr.Row(): |
|
model_name = gr.Dropdown( |
|
label="Upscaler model", |
|
choices=[ |
|
"RealESRGAN_x4plus", |
|
"RealESRNet_x4plus", |
|
"RealESRGAN_x4plus_anime_6B", |
|
"RealESRGAN_x2plus", |
|
"realesr-general-x4v3", |
|
], |
|
value="RealESRGAN_x4plus_anime_6B", |
|
show_label=True |
|
) |
|
denoise_strength = gr.Slider(label="Denoise Strength", minimum=0, maximum=1, step=0.1, value=0.5) |
|
outscale = gr.Slider(label="Resolution upscale", minimum=1, maximum=6, step=1, value=4, show_label=True) |
|
face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)") |
|
|
|
with gr.Row(): |
|
with gr.Group(): |
|
input_image = gr.Image(label="Input Image", type="pil", image_mode="RGBA") |
|
input_image_properties = gr.Textbox(label="Image Properties", max_lines=1) |
|
output_image = gr.Image(label="Output Image", image_mode="RGB") |
|
with gr.Row(): |
|
reset_btn = gr.Button("Remove images") |
|
restore_btn = gr.Button("Upscale") |
|
|
|
input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties) |
|
restore_btn.click(fn=realesrgan, |
|
inputs=[input_image, model_name, denoise_strength, face_enhance, outscale], |
|
outputs=output_image) |
|
reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image]) |
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
if __name__ == "__main__": |
|
|
|
gpu_startup() |
|
main() |