File size: 4,614 Bytes
d0fef57
3d30604
 
d0fef57
d561a66
d0fef57
3d30604
e6ac7d7
 
 
d0fef57
3d30604
112d8be
 
3d30604
 
 
 
 
 
 
 
 
 
d0fef57
3d30604
 
 
 
 
d0fef57
 
 
3d30604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76cb1c1
d0fef57
3d30604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b261577
3d30604
 
 
 
 
76cb1c1
3d30604
d0fef57
 
3d30604
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import uuid
import zipfile
import cv2
import gradio as gr
import torch
import numpy as np
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from gfpgan.utils import GFPGANer
from realesrgan.utils import RealESRGANer

# β€”β€”β€” download weights if missing β€”β€”β€”
if not os.path.exists('realesr-general-x4v3.pth'):
    os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
for fname, url in [
    ('GFPGANv1.2.pth',   'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth'),
    ('GFPGANv1.3.pth',   'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'),
    ('GFPGANv1.4.pth',   'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'),
    ('RestoreFormer.pth','https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'),
]:
    if not os.path.exists(fname):
        os.system(f"wget {url} -P .")

# β€”β€”β€” background upsampler β€”β€”β€”
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
upsampler = RealESRGANer(
    scale=4, model_path='realesr-general-x4v3.pth', model=model,
    tile=256, tile_pad=10, pre_pad=0,
    half=torch.cuda.is_available()
)

os.makedirs('output', exist_ok=True)

def process_single_image(img_path, version, scale):
    img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
    if img is None:
        return None

    # handle alpha & grayscale
    img_mode = 'RGBA' if (img.ndim==3 and img.shape[2]==4) else None
    if img.ndim == 2:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

    h, w = img.shape[:2]

    # robust resize
    min_size, max_size = 512, 2048
    if max(h, w) < min_size:
        scale_factor = min_size / max(h, w)
        img = cv2.resize(img, (int(w * scale_factor), int(h * scale_factor)))
    elif max(h, w) > max_size:
        scale_factor = max_size / max(h, w)
        img = cv2.resize(img, (int(w * scale_factor), int(h * scale_factor)))

    # map version β†’ filename & arch
    if version.startswith('v'):
        model_fname = f'GFPGAN{version}.pth'
        arch = 'clean'
    else:
        model_fname = f'{version}.pth'
        arch = version

    face_enhancer = GFPGANer(
        model_path=model_fname,
        upscale=2,
        arch=arch,
        channel_multiplier=2,
        bg_upsampler=upsampler
    )

    try:
        _, _, restored = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
    except RuntimeError as e:
        print("GFPGAN error:", e)
        return None

    if restored is None:
        print(f"Restoration failed for {img_path}")
        return None

    # sanitize output to avoid black rectangles
    restored = np.nan_to_num(restored, nan=0.0, posinf=255.0, neginf=0.0)
    restored = np.clip(restored, 0, 255).astype(np.uint8)

    # rescale if needed
    if scale != 2:
        interp = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
        h, w = img.shape[:2]
        restored = cv2.resize(restored, (int(w * scale / 2), int(h * scale / 2)), interpolation=interp)

    ext = 'png' if img_mode == 'RGBA' else 'jpg'
    base = os.path.basename(img_path).rsplit('.', 1)[0]
    out_path = os.path.join('output', f"{base}_restored.{ext}")
    cv2.imwrite(out_path, restored)
    return out_path

def inference(img_files, version, scale):
    """
    img_files: list of file objects or paths
    returns: path to a ZIP archive containing all restored images
    """
    saved = []
    for p in img_files or []:
        img_path = p.name if hasattr(p, 'name') else p
        out = process_single_image(img_path, version, scale)
        if out:
            saved.append(out)

    zip_name = f"output/restored_{uuid.uuid4().hex}.zip"
    with zipfile.ZipFile(zip_name, 'w') as z:
        for fpath in saved:
            z.write(fpath, arcname=os.path.basename(fpath))

    return zip_name

# β€”β€”β€” Gradio interface β€”β€”β€”
title = "GFPGAN Multi‑Image Face Restoration + ZIP Download"
description = "Upload multiple images; restored outputs will be bundled into a ZIP for download."

demo = gr.Interface(
    fn=inference,
    inputs=[
        gr.File(file_count="multiple", file_types=["image"], label="Input Images"),
        gr.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'], value='v1.4', label='GFPGAN Version'),
        gr.Number(label="Rescaling Factor", value=2)
    ],
    outputs=gr.File(label="Download Restored ZIP"),
    title=title,
    description=description,
)

if _name_ == "_main_":
    demo.queue().launch(share=True)