File size: 10,232 Bytes
d69ad9e 052a51e d69ad9e 052a51e d69ad9e 052a51e 79f94dc 052a51e d69ad9e 052a51e d69ad9e 052a51e 68cea4c 052a51e d69ad9e 052a51e d69ad9e 052a51e d69ad9e 052a51e d69ad9e 052a51e d69ad9e 052a51e d69ad9e 052a51e d69ad9e 052a51e d69ad9e 052a51e d69ad9e 06a3e99 d69ad9e 052a51e d69ad9e 052a51e d69ad9e 68cea4c d69ad9e 68cea4c d69ad9e 052a51e d69ad9e 052a51e d69ad9e 052a51e d69ad9e 052a51e d69ad9e 68cea4c d69ad9e 052a51e d69ad9e 06a3e99 052a51e d69ad9e 052a51e a9dea45 052a51e d69ad9e 68cea4c d69ad9e 052a51e d69ad9e 052a51e d69ad9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
import os
import random
import cv2
import numpy
import gradio as gr
import spaces
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
# --------------------
# Global (CPU-only data; KHÔNG chạm CUDA ở đây)
# --------------------
last_file = None
DEVICE = "cpu" # set trong gpu_startup()
USE_HALF = False # set trong gpu_startup()
# cache cho các upsampler đã khởi tạo
UPSAMPLER_CACHE = {} # key: (model_name, denoise_strength, DEVICE, USE_HALF)
GFPGAN_FACE_ENHANCER = {} # key: (outscale, DEVICE, USE_HALF)
# --------------------
# ZeroGPU: cấp GPU ngay khi khởi động
# --------------------
@spaces.GPU
def gpu_startup():
"""
Hàm này chạy ngay khi Space bật trên ZeroGPU.
Chỉ ở đây mới 'đụng' tới torch/cuda.
"""
global DEVICE, USE_HALF
import torch
has_cuda = torch.cuda.is_available()
DEVICE = "cuda" if has_cuda else "cpu"
# half precision chỉ an toàn khi có CUDA
USE_HALF = bool(has_cuda)
print(f"[startup] CUDA available: {has_cuda}, device={DEVICE}, half={USE_HALF}")
# --------------------
# Utils
# --------------------
def rnd_string(x):
chars = "abcdefghijklmnopqrstuvwxyz_0123456789"
return "".join(random.choice(chars) for _ in range(x))
def image_properties(img):
if img:
# Chỉ báo thông tin trực tiếp từ ảnh, không kiểm tra alpha
return f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img.mode}"
def reset():
global last_file
if last_file:
try:
print(f"Deleting {last_file} ...")
os.remove(last_file)
except Exception as e:
print("Delete error:", e)
finally:
last_file = None
return gr.update(value=None), gr.update(value=None)
# --------------------
# Model builder (không gọi CUDA ở ngoài startup; mọi thứ phụ thuộc DEVICE/USE_HALF)
# --------------------
def get_model_and_paths(model_name, denoise_strength):
"""Chuẩn bị kiến trúc model + đường dẫn trọng số + dni_weight (nếu cần)."""
if model_name in ('RealESRGAN_x4plus', 'RealESRNet_x4plus'):
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] \
if model_name == 'RealESRGAN_x4plus' else \
['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
elif model_name == 'RealESRGAN_x4plus_anime_6B':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
elif model_name == 'RealESRGAN_x2plus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
elif model_name == 'realesr-general-x4v3':
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
netscale = 4
file_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
]
else:
raise ValueError(f"Unsupported model: {model_name}")
# tải trọng số (nếu chưa có)
model_path = os.path.join('weights', model_name + '.pth')
if not os.path.isfile(model_path):
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
for url in file_url:
model_path = load_file_from_url(url=url, model_dir=os.path.join(ROOT_DIR, 'weights'),
progress=True, file_name=None)
# dni (chỉ riêng general-x4v3)
dni_weight = None
if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
dni_weight = [denoise_strength, 1 - denoise_strength]
return model, netscale, model_path, dni_weight
def get_upsampler(model_name, denoise_strength):
"""Khởi tạo/cached RealESRGANer theo device & half hiện hành."""
key = (model_name, float(denoise_strength), DEVICE, USE_HALF)
if key in UPSAMPLER_CACHE:
return UPSAMPLER_CACHE[key]
model, netscale, model_path, dni_weight = get_model_and_paths(model_name, denoise_strength)
# Cấu hình theo thiết bị
# - half=True khi GPU; False khi CPU
# - gpu_id=0 khi GPU; None khi CPU
half_flag = bool(USE_HALF)
gpu_id = 0 if DEVICE == "cuda" else None
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=0,
tile_pad=10,
pre_pad=10,
half=half_flag,
gpu_id=gpu_id
)
UPSAMPLER_CACHE[key] = upsampler
return upsampler
def get_face_enhancer(upsampler, outscale):
key = (int(outscale), DEVICE, USE_HALF)
if key in GFPGAN_FACE_ENHANCER:
return GFPGAN_FACE_ENHANCER[key]
from gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=int(outscale),
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler
)
GFPGAN_FACE_ENHANCER[key] = face_enhancer
return face_enhancer
# --------------------
# Inference (đánh dấu @spaces.GPU vì có thể chạy trên GPU)
# --------------------
@spaces.GPU
def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
"""Real-ESRGAN restore/upscale."""
if not img:
return
upsampler = get_upsampler(model_name, denoise_strength)
# PIL -> cv2 (giữ nguyên nếu có alpha; ta sẽ bỏ alpha trước khi lưu JPG)
cv_img = numpy.array(img)
if cv_img.ndim == 3 and cv_img.shape[2] == 4:
# RGBA -> BGRA
img_bgra = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
elif cv_img.ndim == 3 and cv_img.shape[2] == 3:
# RGB -> BGR, rồi thêm alpha giả để pipeline cũ vẫn chạy nếu cần
bgr = cv2.cvtColor(cv_img, cv2.COLOR_RGB2BGR)
alpha = numpy.full((bgr.shape[0], bgr.shape[1], 1), 255, dtype=bgr.dtype)
img_bgra = numpy.concatenate([bgr, alpha], axis=2)
else:
# 1-channel (L) -> BGR + alpha
bgr = cv2.cvtColor(cv_img, cv2.COLOR_GRAY2BGR)
alpha = numpy.full((bgr.shape[0], bgr.shape[1], 1), 255, dtype=bgr.dtype)
img_bgra = numpy.concatenate([bgr, alpha], axis=2)
try:
if face_enhance:
face_enhancer = get_face_enhancer(upsampler, outscale)
_, _, output = face_enhancer.enhance(
img_bgra, has_aligned=False, only_center_face=False, paste_back=True
)
else:
output, _ = upsampler.enhance(img_bgra, outscale=int(outscale))
except RuntimeError as error:
# Gợi ý tự động giảm tile nếu OOM
print('Error', error)
return None
else:
out_filename = f"output_{rnd_string(8)}.jpg"
# Đảm bảo ảnh 3 kênh trước khi lưu JPG
if output.ndim == 3 and output.shape[2] == 4:
output_to_save = cv2.cvtColor(output, cv2.COLOR_BGRA2BGR)
elif output.ndim == 3 and output.shape[2] == 3:
output_to_save = output
else:
output_to_save = cv2.cvtColor(output, cv2.COLOR_GRAY2BGR)
cv2.imwrite(out_filename, output_to_save)
global last_file
last_file = out_filename
return out_filename
# --------------------
# UI
# --------------------
def main():
with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo:
gr.Markdown("## Image Upscaler")
with gr.Accordion("Upscaling option"):
with gr.Row():
model_name = gr.Dropdown(
label="Upscaler model",
choices=[
"RealESRGAN_x4plus",
"RealESRNet_x4plus",
"RealESRGAN_x4plus_anime_6B",
"RealESRGAN_x2plus",
"realesr-general-x4v3",
],
value="RealESRGAN_x4plus_anime_6B",
show_label=True
)
denoise_strength = gr.Slider(label="Denoise Strength", minimum=0, maximum=1, step=0.1, value=0.5)
outscale = gr.Slider(label="Resolution upscale", minimum=1, maximum=6, step=1, value=4, show_label=True)
face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)")
with gr.Row():
with gr.Group():
input_image = gr.Image(label="Input Image", type="pil", image_mode="RGBA")
input_image_properties = gr.Textbox(label="Image Properties", max_lines=1)
output_image = gr.Image(label="Output Image", image_mode="RGB")
with gr.Row():
reset_btn = gr.Button("Remove images")
restore_btn = gr.Button("Upscale")
input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties)
restore_btn.click(fn=realesrgan,
inputs=[input_image, model_name, denoise_strength, face_enhance, outscale],
outputs=output_image)
reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image])
demo.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
# Gọi hàm startup để ZeroGPU cấp GPU ngay khi Space boot
gpu_startup()
main() |