Spaces:
Runtime error
Runtime error
import os | |
import uuid | |
import zipfile | |
import cv2 | |
import gradio as gr | |
import torch | |
import numpy as np | |
from basicsr.archs.srvgg_arch import SRVGGNetCompact | |
from gfpgan.utils import GFPGANer | |
from realesrgan.utils import RealESRGANer | |
# βββ download weights if missing βββ | |
if not os.path.exists('realesr-general-x4v3.pth'): | |
os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .") | |
for fname, url in [ | |
('GFPGANv1.2.pth', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth'), | |
('GFPGANv1.3.pth', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'), | |
('GFPGANv1.4.pth', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'), | |
('RestoreFormer.pth','https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'), | |
]: | |
if not os.path.exists(fname): | |
os.system(f"wget {url} -P .") | |
# βββ background upsampler βββ | |
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
upsampler = RealESRGANer( | |
scale=4, model_path='realesr-general-x4v3.pth', model=model, | |
tile=256, tile_pad=10, pre_pad=0, | |
half=torch.cuda.is_available() | |
) | |
os.makedirs('output', exist_ok=True) | |
def process_single_image(img_path, version, scale): | |
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) | |
if img is None: | |
return None | |
# handle alpha & grayscale | |
img_mode = 'RGBA' if (img.ndim==3 and img.shape[2]==4) else None | |
if img.ndim == 2: | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
h, w = img.shape[:2] | |
# robust resize | |
min_size, max_size = 512, 2048 | |
if max(h, w) < min_size: | |
scale_factor = min_size / max(h, w) | |
img = cv2.resize(img, (int(w * scale_factor), int(h * scale_factor))) | |
elif max(h, w) > max_size: | |
scale_factor = max_size / max(h, w) | |
img = cv2.resize(img, (int(w * scale_factor), int(h * scale_factor))) | |
# map version β filename & arch | |
if version.startswith('v'): | |
model_fname = f'GFPGAN{version}.pth' | |
arch = 'clean' | |
else: | |
model_fname = f'{version}.pth' | |
arch = version | |
face_enhancer = GFPGANer( | |
model_path=model_fname, | |
upscale=2, | |
arch=arch, | |
channel_multiplier=2, | |
bg_upsampler=upsampler | |
) | |
try: | |
_, _, restored = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) | |
except RuntimeError as e: | |
print("GFPGAN error:", e) | |
return None | |
if restored is None: | |
print(f"Restoration failed for {img_path}") | |
return None | |
# sanitize output to avoid black rectangles | |
restored = np.nan_to_num(restored, nan=0.0, posinf=255.0, neginf=0.0) | |
restored = np.clip(restored, 0, 255).astype(np.uint8) | |
# rescale if needed | |
if scale != 2: | |
interp = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 | |
h, w = img.shape[:2] | |
restored = cv2.resize(restored, (int(w * scale / 2), int(h * scale / 2)), interpolation=interp) | |
ext = 'png' if img_mode == 'RGBA' else 'jpg' | |
base = os.path.basename(img_path).rsplit('.', 1)[0] | |
out_path = os.path.join('output', f"{base}_restored.{ext}") | |
cv2.imwrite(out_path, restored) | |
return out_path | |
def inference(img_files, version, scale): | |
""" | |
img_files: list of file objects or paths | |
returns: path to a ZIP archive containing all restored images | |
""" | |
saved = [] | |
for p in img_files or []: | |
img_path = p.name if hasattr(p, 'name') else p | |
out = process_single_image(img_path, version, scale) | |
if out: | |
saved.append(out) | |
zip_name = f"output/restored_{uuid.uuid4().hex}.zip" | |
with zipfile.ZipFile(zip_name, 'w') as z: | |
for fpath in saved: | |
z.write(fpath, arcname=os.path.basename(fpath)) | |
return zip_name | |
# βββ Gradio interface βββ | |
title = "GFPGAN MultiβImage Face Restoration + ZIP Download" | |
description = "Upload multiple images; restored outputs will be bundled into a ZIP for download." | |
demo = gr.Interface( | |
fn=inference, | |
inputs=[ | |
gr.File(file_count="multiple", file_types=["image"], label="Input Images"), | |
gr.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'], value='v1.4', label='GFPGAN Version'), | |
gr.Number(label="Rescaling Factor", value=2) | |
], | |
outputs=gr.File(label="Download Restored ZIP"), | |
title=title, | |
description=description, | |
) | |
if _name_ == "_main_": | |
demo.queue().launch(share=True) |