GFPGAN / app.py
luthrabhuvan's picture
Updated app.py
3d30604 verified
raw
history blame
4.61 kB
import os
import uuid
import zipfile
import cv2
import gradio as gr
import torch
import numpy as np
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from gfpgan.utils import GFPGANer
from realesrgan.utils import RealESRGANer
# β€”β€”β€” download weights if missing β€”β€”β€”
if not os.path.exists('realesr-general-x4v3.pth'):
os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
for fname, url in [
('GFPGANv1.2.pth', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth'),
('GFPGANv1.3.pth', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'),
('GFPGANv1.4.pth', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'),
('RestoreFormer.pth','https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'),
]:
if not os.path.exists(fname):
os.system(f"wget {url} -P .")
# β€”β€”β€” background upsampler β€”β€”β€”
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
upsampler = RealESRGANer(
scale=4, model_path='realesr-general-x4v3.pth', model=model,
tile=256, tile_pad=10, pre_pad=0,
half=torch.cuda.is_available()
)
os.makedirs('output', exist_ok=True)
def process_single_image(img_path, version, scale):
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
if img is None:
return None
# handle alpha & grayscale
img_mode = 'RGBA' if (img.ndim==3 and img.shape[2]==4) else None
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
h, w = img.shape[:2]
# robust resize
min_size, max_size = 512, 2048
if max(h, w) < min_size:
scale_factor = min_size / max(h, w)
img = cv2.resize(img, (int(w * scale_factor), int(h * scale_factor)))
elif max(h, w) > max_size:
scale_factor = max_size / max(h, w)
img = cv2.resize(img, (int(w * scale_factor), int(h * scale_factor)))
# map version β†’ filename & arch
if version.startswith('v'):
model_fname = f'GFPGAN{version}.pth'
arch = 'clean'
else:
model_fname = f'{version}.pth'
arch = version
face_enhancer = GFPGANer(
model_path=model_fname,
upscale=2,
arch=arch,
channel_multiplier=2,
bg_upsampler=upsampler
)
try:
_, _, restored = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
except RuntimeError as e:
print("GFPGAN error:", e)
return None
if restored is None:
print(f"Restoration failed for {img_path}")
return None
# sanitize output to avoid black rectangles
restored = np.nan_to_num(restored, nan=0.0, posinf=255.0, neginf=0.0)
restored = np.clip(restored, 0, 255).astype(np.uint8)
# rescale if needed
if scale != 2:
interp = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
h, w = img.shape[:2]
restored = cv2.resize(restored, (int(w * scale / 2), int(h * scale / 2)), interpolation=interp)
ext = 'png' if img_mode == 'RGBA' else 'jpg'
base = os.path.basename(img_path).rsplit('.', 1)[0]
out_path = os.path.join('output', f"{base}_restored.{ext}")
cv2.imwrite(out_path, restored)
return out_path
def inference(img_files, version, scale):
"""
img_files: list of file objects or paths
returns: path to a ZIP archive containing all restored images
"""
saved = []
for p in img_files or []:
img_path = p.name if hasattr(p, 'name') else p
out = process_single_image(img_path, version, scale)
if out:
saved.append(out)
zip_name = f"output/restored_{uuid.uuid4().hex}.zip"
with zipfile.ZipFile(zip_name, 'w') as z:
for fpath in saved:
z.write(fpath, arcname=os.path.basename(fpath))
return zip_name
# β€”β€”β€” Gradio interface β€”β€”β€”
title = "GFPGAN Multi‑Image Face Restoration + ZIP Download"
description = "Upload multiple images; restored outputs will be bundled into a ZIP for download."
demo = gr.Interface(
fn=inference,
inputs=[
gr.File(file_count="multiple", file_types=["image"], label="Input Images"),
gr.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'], value='v1.4', label='GFPGAN Version'),
gr.Number(label="Rescaling Factor", value=2)
],
outputs=gr.File(label="Download Restored ZIP"),
title=title,
description=description,
)
if _name_ == "_main_":
demo.queue().launch(share=True)