|
import os |
|
import gc |
|
import cv2 |
|
import requests |
|
import numpy as np |
|
import gradio as gr |
|
import torch |
|
import traceback |
|
from tqdm import tqdm |
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact |
|
from gfpgan.utils import GFPGANer |
|
from realesrgan.utils import RealESRGANer |
|
from basicsr.archs.rrdbnet_arch import RRDBNet |
|
|
|
|
|
face_model = { |
|
"GFPGANv1.2.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth", |
|
"GFPGANv1.3.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth", |
|
"GFPGANv1.4.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth", |
|
"RestoreFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth", |
|
"CodeFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth", |
|
} |
|
realesr_model = { |
|
"realesr-general-x4v3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", |
|
"realesr-animevideov3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", |
|
"RealESRGAN_x4plus_anime_6B.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", |
|
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", |
|
"RealESRNet_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth", |
|
"RealESRGAN_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", |
|
"4x-AnimeSharp.pth": "https://huggingface.co/utnah/esrgan/resolve/main/4x-AnimeSharp.pth?download=true", |
|
} |
|
files_to_download = [ |
|
( "a1.jpg", |
|
"https://thumbs.dreamstime.com/b/tower-bridge-traditional-red-bus-black-white-colors-view-to-tower-bridge-london-black-white-colors-108478942.jpg" ), |
|
( "a2.jpg", |
|
"https://media.istockphoto.com/id/523514029/photo/london-skyline-b-w.jpg?s=612x612&w=0&k=20&c=kJS1BAtfqYeUDaORupj0sBPc1hpzJhBUUqEFfRnHzZ0=" ), |
|
( "a3.jpg", |
|
"https://i.guim.co.uk/img/media/06f614065ed82ca0e917b149a32493c791619854/0_0_3648_2789/master/3648.jpg?width=700&quality=85&auto=format&fit=max&s=05764b507c18a38590090d987c8b6202" ), |
|
( "a4.jpg", |
|
"https://i.pinimg.com/736x/46/96/9e/46969eb94aec2437323464804d27706d--victorian-london-victorian-era.jpg" ), |
|
] |
|
|
|
|
|
os.makedirs("weights", exist_ok=True) |
|
os.makedirs('output', exist_ok=True) |
|
|
|
def download_from_url(output_path, url): |
|
try: |
|
|
|
if os.path.exists(output_path): |
|
print(f"File already exists, skipping download: {output_path}") |
|
return |
|
|
|
print(f"Downloading: {url}") |
|
with requests.get(url, stream=True) as response, open(output_path, "wb") as f: |
|
total_size = int(response.headers.get('content-length', 0)) |
|
with tqdm(total=total_size, unit='B', unit_scale=True) as pbar: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
pbar.update(len(chunk)) |
|
print(f"Download successful: {output_path}") |
|
except requests.RequestException as e: |
|
print(f"Download failed: {url}, Error: {e}") |
|
|
|
|
|
|
|
for output_path, url in files_to_download: |
|
|
|
if os.path.exists(output_path): |
|
print(f"File already exists, skipping download: {output_path}") |
|
continue |
|
|
|
|
|
download_from_url(output_path, url) |
|
|
|
|
|
def inference(img, version, realesr, scale: float): |
|
print(img, version, scale) |
|
try: |
|
img_name = os.path.basename(str(img)) |
|
basename, extension = os.path.splitext(img_name) |
|
img = cv2.imread(img, cv2.IMREAD_UNCHANGED) |
|
if len(img.shape) == 3 and img.shape[2] == 4: |
|
img_mode = 'RGBA' |
|
elif len(img.shape) == 2: |
|
img_mode = None |
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) |
|
else: |
|
img_mode = None |
|
|
|
h, w = img.shape[0:2] |
|
if h < 300: |
|
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4) |
|
|
|
if version: |
|
download_from_url(os.path.join("weights", version), face_model[version]) |
|
if realesr: |
|
download_from_url(os.path.join("weights", realesr), realesr_model[realesr]) |
|
|
|
|
|
if realesr == 'RealESRGAN_x4plus.pth': |
|
netscale = 4 |
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=netscale) |
|
elif realesr == 'RealESRNet_x4plus.pth': |
|
netscale = 4 |
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=netscale) |
|
elif realesr == 'RealESRGAN_x4plus_anime_6B.pth': |
|
netscale = 4 |
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=netscale) |
|
elif realesr == 'RealESRGAN_x2plus.pth': |
|
netscale = 2 |
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=netscale) |
|
elif realesr == 'realesr-animevideov3.pth': |
|
netscale = 4 |
|
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=netscale, act_type='prelu') |
|
elif realesr == 'realesr-general-x4v3.pth': |
|
netscale = 4 |
|
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=netscale, act_type='prelu') |
|
|
|
|
|
|
|
|
|
half = True if torch.cuda.is_available() else False |
|
upsampler = RealESRGANer(scale=netscale, model_path=os.path.join("weights", realesr), model=model, tile=0, tile_pad=10, pre_pad=0, half=half) |
|
|
|
face_enhancer = None |
|
if version == 'GFPGANv1.2.pth': |
|
face_enhancer = GFPGANer( |
|
model_path='weights/GFPGANv1.2.pth', upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) |
|
elif version == 'GFPGANv1.3.pth': |
|
face_enhancer = GFPGANer( |
|
model_path='weights/GFPGANv1.3.pth', upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) |
|
elif version == 'GFPGANv1.4.pth': |
|
face_enhancer = GFPGANer( |
|
model_path='weights/GFPGANv1.4.pth', upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) |
|
elif version == 'RestoreFormer.pth': |
|
face_enhancer = GFPGANer( |
|
model_path='weights/RestoreFormer.pth', upscale=scale, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler) |
|
elif version == 'CodeFormer.pth': |
|
face_enhancer = GFPGANer( |
|
model_path='weights/CodeFormer.pth', upscale=scale, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler) |
|
|
|
files = [] |
|
outputs = [] |
|
try: |
|
if face_enhancer: |
|
cropped_faces, restored_aligned, restored_img = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) |
|
|
|
if cropped_faces and restored_aligned: |
|
for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_aligned)): |
|
|
|
save_crop_path = f"output/{basename}{idx:02d}_cropped_faces.png" |
|
cv2.imwrite(save_crop_path, cropped_face) |
|
|
|
save_restore_path = f"output/{basename}{idx:02d}_restored_faces.png" |
|
cv2.imwrite(save_restore_path, restored_face) |
|
|
|
save_cmp_path = f"output/{basename}{idx:02d}_cmp.png" |
|
cmp_img = np.concatenate((cropped_face, restored_face), axis=1) |
|
cv2.imwrite(save_cmp_path, cmp_img) |
|
|
|
files.append(save_crop_path) |
|
files.append(save_restore_path) |
|
files.append(save_cmp_path) |
|
outputs.append(cv2.cvtColor(cropped_face, cv2.COLOR_BGR2RGB)) |
|
outputs.append(cv2.cvtColor(restored_face, cv2.COLOR_BGR2RGB)) |
|
outputs.append(cv2.cvtColor(cmp_img, cv2.COLOR_BGR2RGB)) |
|
else: |
|
restored_img, _ = upsampler.enhance(img, outscale=scale) |
|
except RuntimeError as error: |
|
print(traceback.format_exc()) |
|
print('Error', error) |
|
finally: |
|
if face_enhancer: |
|
face_enhancer._cleanup() |
|
else: |
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
try: |
|
if scale != 2: |
|
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 |
|
h, w = img.shape[0:2] |
|
restored_img = cv2.resize(restored_img, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation) |
|
except Exception as error: |
|
print(traceback.format_exc()) |
|
print("wrong scale input.", error) |
|
|
|
if not extension: |
|
extension = ".png" if img_mode == "RGBA" else ".jpg" |
|
save_path = f"output/{basename}{extension}" |
|
cv2.imwrite(save_path, restored_img) |
|
|
|
restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB) |
|
files.append(save_path) |
|
outputs.append(restored_img) |
|
return outputs, files |
|
except Exception as error: |
|
print(traceback.format_exc()) |
|
print("global exception", error) |
|
return None, None |
|
|
|
|
|
title = "Image Upscaling & Restoration(esp. Face) using GFPGAN Algorithm" |
|
description = r"""Gradio demo for <a href='https://github.com/TencentARC/GFPGAN' target='_blank'><b>GFPGAN: Towards Real-World Blind Face Restoration and Upscalling of the image with a Generative Facial Prior</b></a>.<br> |
|
Practically the algorithm is used to restore your **old photos** or improve **AI-generated faces**.<br> |
|
To use it, simply just upload the concerned image.<br> |
|
""" |
|
article = r""" |
|
[](https://github.com/TencentARC/GFPGAN/releases) |
|
[](https://github.com/TencentARC/GFPGAN) |
|
[](https://arxiv.org/abs/2101.04061) |
|
<center><img src='https://visitor-badge.glitch.me/badge?page_id=dj_face_restoration_GFPGAN' alt='visitor badge'></center> |
|
""" |
|
demo = gr.Interface( |
|
inference, [ |
|
gr.Image(type="filepath", label="Input", format="png"), |
|
gr.Dropdown(["GFPGANv1.2.pth", |
|
"GFPGANv1.3.pth", |
|
"GFPGANv1.4.pth", |
|
"RestoreFormer.pth", |
|
|
|
None], type="value", value='GFPGANv1.4.pth', label='Face Restoration version', info="Face Restoration and RealESR can be freely combined in different ways, or one can be set to \"None\" to use only the other model. Face Restoration is primarily used for face restoration in real-life images, while RealESR serves as a background restoration model."), |
|
gr.Dropdown(["realesr-general-x4v3.pth", |
|
"realesr-animevideov3.pth", |
|
"RealESRGAN_x4plus_anime_6B.pth", |
|
"RealESRGAN_x2plus.pth", |
|
"RealESRNet_x4plus.pth", |
|
"RealESRGAN_x4plus.pth", |
|
|
|
None], type="value", value='realesr-general-x4v3.pth', label='RealESR version'), |
|
gr.Number(label="Rescaling factor", value=2), |
|
|
|
], [ |
|
gr.Gallery(type="numpy", label="Output (The whole image)", format="png"), |
|
gr.File(label="Download the output image") |
|
], |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=[['a1.jpg', 'GFPGANv1.4.pth', "realesr-general-x4v3.pth", 2], |
|
['a2.jpg', 'GFPGANv1.4.pth', "realesr-general-x4v3.pth", 2], |
|
['a3.jpg', 'GFPGANv1.4.pth', "realesr-general-x4v3.pth", 2], |
|
['a4.jpg', 'GFPGANv1.4.pth', "realesr-general-x4v3.pth", 2]]) |
|
|
|
demo.queue(default_concurrency_limit=4) |
|
demo.launch(inbrowser=True) |