Upscaler / app.py
tuan2308's picture
Update app.py
68cea4c verified
import os
import random
import cv2
import numpy
import gradio as gr
import spaces
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
# --------------------
# Global (CPU-only data; KHÔNG chạm CUDA ở đây)
# --------------------
last_file = None
DEVICE = "cpu" # set trong gpu_startup()
USE_HALF = False # set trong gpu_startup()
# cache cho các upsampler đã khởi tạo
UPSAMPLER_CACHE = {} # key: (model_name, denoise_strength, DEVICE, USE_HALF)
GFPGAN_FACE_ENHANCER = {} # key: (outscale, DEVICE, USE_HALF)
# --------------------
# ZeroGPU: cấp GPU ngay khi khởi động
# --------------------
@spaces.GPU
def gpu_startup():
"""
Hàm này chạy ngay khi Space bật trên ZeroGPU.
Chỉ ở đây mới 'đụng' tới torch/cuda.
"""
global DEVICE, USE_HALF
import torch
has_cuda = torch.cuda.is_available()
DEVICE = "cuda" if has_cuda else "cpu"
# half precision chỉ an toàn khi có CUDA
USE_HALF = bool(has_cuda)
print(f"[startup] CUDA available: {has_cuda}, device={DEVICE}, half={USE_HALF}")
# --------------------
# Utils
# --------------------
def rnd_string(x):
chars = "abcdefghijklmnopqrstuvwxyz_0123456789"
return "".join(random.choice(chars) for _ in range(x))
def image_properties(img):
if img:
# Chỉ báo thông tin trực tiếp từ ảnh, không kiểm tra alpha
return f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img.mode}"
def reset():
global last_file
if last_file:
try:
print(f"Deleting {last_file} ...")
os.remove(last_file)
except Exception as e:
print("Delete error:", e)
finally:
last_file = None
return gr.update(value=None), gr.update(value=None)
# --------------------
# Model builder (không gọi CUDA ở ngoài startup; mọi thứ phụ thuộc DEVICE/USE_HALF)
# --------------------
def get_model_and_paths(model_name, denoise_strength):
"""Chuẩn bị kiến trúc model + đường dẫn trọng số + dni_weight (nếu cần)."""
if model_name in ('RealESRGAN_x4plus', 'RealESRNet_x4plus'):
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] \
if model_name == 'RealESRGAN_x4plus' else \
['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
elif model_name == 'RealESRGAN_x4plus_anime_6B':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
elif model_name == 'RealESRGAN_x2plus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
elif model_name == 'realesr-general-x4v3':
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
netscale = 4
file_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
]
else:
raise ValueError(f"Unsupported model: {model_name}")
# tải trọng số (nếu chưa có)
model_path = os.path.join('weights', model_name + '.pth')
if not os.path.isfile(model_path):
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
for url in file_url:
model_path = load_file_from_url(url=url, model_dir=os.path.join(ROOT_DIR, 'weights'),
progress=True, file_name=None)
# dni (chỉ riêng general-x4v3)
dni_weight = None
if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
dni_weight = [denoise_strength, 1 - denoise_strength]
return model, netscale, model_path, dni_weight
def get_upsampler(model_name, denoise_strength):
"""Khởi tạo/cached RealESRGANer theo device & half hiện hành."""
key = (model_name, float(denoise_strength), DEVICE, USE_HALF)
if key in UPSAMPLER_CACHE:
return UPSAMPLER_CACHE[key]
model, netscale, model_path, dni_weight = get_model_and_paths(model_name, denoise_strength)
# Cấu hình theo thiết bị
# - half=True khi GPU; False khi CPU
# - gpu_id=0 khi GPU; None khi CPU
half_flag = bool(USE_HALF)
gpu_id = 0 if DEVICE == "cuda" else None
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=0,
tile_pad=10,
pre_pad=10,
half=half_flag,
gpu_id=gpu_id
)
UPSAMPLER_CACHE[key] = upsampler
return upsampler
def get_face_enhancer(upsampler, outscale):
key = (int(outscale), DEVICE, USE_HALF)
if key in GFPGAN_FACE_ENHANCER:
return GFPGAN_FACE_ENHANCER[key]
from gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=int(outscale),
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler
)
GFPGAN_FACE_ENHANCER[key] = face_enhancer
return face_enhancer
# --------------------
# Inference (đánh dấu @spaces.GPU vì có thể chạy trên GPU)
# --------------------
@spaces.GPU
def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
"""Real-ESRGAN restore/upscale."""
if not img:
return
upsampler = get_upsampler(model_name, denoise_strength)
# PIL -> cv2 (giữ nguyên nếu có alpha; ta sẽ bỏ alpha trước khi lưu JPG)
cv_img = numpy.array(img)
if cv_img.ndim == 3 and cv_img.shape[2] == 4:
# RGBA -> BGRA
img_bgra = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
elif cv_img.ndim == 3 and cv_img.shape[2] == 3:
# RGB -> BGR, rồi thêm alpha giả để pipeline cũ vẫn chạy nếu cần
bgr = cv2.cvtColor(cv_img, cv2.COLOR_RGB2BGR)
alpha = numpy.full((bgr.shape[0], bgr.shape[1], 1), 255, dtype=bgr.dtype)
img_bgra = numpy.concatenate([bgr, alpha], axis=2)
else:
# 1-channel (L) -> BGR + alpha
bgr = cv2.cvtColor(cv_img, cv2.COLOR_GRAY2BGR)
alpha = numpy.full((bgr.shape[0], bgr.shape[1], 1), 255, dtype=bgr.dtype)
img_bgra = numpy.concatenate([bgr, alpha], axis=2)
try:
if face_enhance:
face_enhancer = get_face_enhancer(upsampler, outscale)
_, _, output = face_enhancer.enhance(
img_bgra, has_aligned=False, only_center_face=False, paste_back=True
)
else:
output, _ = upsampler.enhance(img_bgra, outscale=int(outscale))
except RuntimeError as error:
# Gợi ý tự động giảm tile nếu OOM
print('Error', error)
return None
else:
out_filename = f"output_{rnd_string(8)}.jpg"
# Đảm bảo ảnh 3 kênh trước khi lưu JPG
if output.ndim == 3 and output.shape[2] == 4:
output_to_save = cv2.cvtColor(output, cv2.COLOR_BGRA2BGR)
elif output.ndim == 3 and output.shape[2] == 3:
output_to_save = output
else:
output_to_save = cv2.cvtColor(output, cv2.COLOR_GRAY2BGR)
cv2.imwrite(out_filename, output_to_save)
global last_file
last_file = out_filename
return out_filename
# --------------------
# UI
# --------------------
def main():
with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo:
gr.Markdown("## Image Upscaler")
with gr.Accordion("Upscaling option"):
with gr.Row():
model_name = gr.Dropdown(
label="Upscaler model",
choices=[
"RealESRGAN_x4plus",
"RealESRNet_x4plus",
"RealESRGAN_x4plus_anime_6B",
"RealESRGAN_x2plus",
"realesr-general-x4v3",
],
value="RealESRGAN_x4plus_anime_6B",
show_label=True
)
denoise_strength = gr.Slider(label="Denoise Strength", minimum=0, maximum=1, step=0.1, value=0.5)
outscale = gr.Slider(label="Resolution upscale", minimum=1, maximum=6, step=1, value=4, show_label=True)
face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)")
with gr.Row():
with gr.Group():
input_image = gr.Image(label="Input Image", type="pil", image_mode="RGBA")
input_image_properties = gr.Textbox(label="Image Properties", max_lines=1)
output_image = gr.Image(label="Output Image", image_mode="RGB")
with gr.Row():
reset_btn = gr.Button("Remove images")
restore_btn = gr.Button("Upscale")
input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties)
restore_btn.click(fn=realesrgan,
inputs=[input_image, model_name, denoise_strength, face_enhance, outscale],
outputs=output_image)
reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image])
demo.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
# Gọi hàm startup để ZeroGPU cấp GPU ngay khi Space boot
gpu_startup()
main()