Spaces:
Running
Running
File size: 5,328 Bytes
0334511 a4138cd 0334511 a9c6794 0334511 bfb2718 a9c6794 0fc7db4 a9c6794 0fc7db4 0334511 a9c6794 0334511 a9c6794 0334511 a9c6794 0334511 a9c6794 0334511 a9c6794 f26e1dd a9c6794 0fc7db4 a9c6794 b266889 f26e1dd a9c6794 0334511 a9c6794 f26e1dd 0334511 f26e1dd a9c6794 f26e1dd a9c6794 a4138cd f26e1dd a9c6794 f26e1dd a9c6794 f26e1dd a9c6794 f26e1dd a9c6794 0334511 a9c6794 8a3d969 a9c6794 52da111 9928d25 52da111 a9c6794 8a3d969 a9c6794 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import sys
sys.path.append('CodeFormer')
import os
import cv2
import torch
import torch.nn.functional as F
import gradio as gr
from torchvision.transforms.functional import normalize
from basicsr.utils import imwrite, img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.realesrgan_utils import RealESRGANer
from facelib.utils.misc import is_gray
from basicsr.utils.registry import ARCH_REGISTRY
# Model weight URLs
pretrain_model_url = {
'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth',
'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth'
}
load_file_from_url(
url='https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
model_dir='CodeFormer/weights/CodeFormer',
progress=True
)
# Download weights if not already present
for key, url in pretrain_model_url.items():
file_path = f"CodeFormer/weights/{key}/{url.split('/')[-1]}"
if not os.path.exists(file_path):
load_file_from_url(url=url, model_dir=os.path.dirname(file_path), progress=True)
# Helper functions
def imread(img_path):
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def set_realesrgan():
half = torch.cuda.is_available()
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
upsampler = RealESRGANer(
scale=2, model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth",
model=model, tile=400, tile_pad=40, pre_pad=0, half=half
)
return upsampler
# Model setup
upsampler = set_realesrgan()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
connect_list=["32", "64", "128", "256"]
).to(device)
ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth"
checkpoint = torch.load(ckpt_path)["params_ema"]
codeformer_net.load_state_dict(checkpoint)
codeformer_net.eval()
os.makedirs('output', exist_ok=True)
# Inference function
def inference(image, face_align=True, background_enhance=True, face_upsample=True, upscale=2, codeformer_fidelity=0.5):
try:
only_center_face = False
detection_model = "retinaface_resnet50"
# Load image and set parameters
img = cv2.imread(str(image), cv2.IMREAD_COLOR)
has_aligned = not face_align
upscale = min(max(1, int(upscale)), 4)
face_helper = FaceRestoreHelper(
upscale, face_size=512, crop_ratio=(1, 1), det_model=detection_model,
save_ext="png", use_parse=True, device=device
)
bg_upsampler = upsampler if background_enhance else None
face_upsampler = upsampler if face_upsample else None
if has_aligned:
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
face_helper.is_gray = is_gray(img, threshold=5)
face_helper.cropped_faces = [img]
else:
face_helper.read_image(img)
num_det_faces = face_helper.get_face_landmarks_5(only_center_face=only_center_face, resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
for cropped_face in face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
with torch.no_grad():
output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
face_helper.add_restored_face(restored_face.astype("uint8"), cropped_face)
restored_img = face_helper.paste_faces_to_input_image(
upsample_img=bg_upsampler.enhance(img, outscale=upscale)[0] if bg_upsampler else None,
face_upsampler=face_upsampler
)
save_path = 'output/out.png'
imwrite(restored_img, save_path)
return cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
except Exception as error:
print('Error during inference:', error)
return None
# Gradio Interface
demo = gr.Interface(
fn=inference,
inputs=[
gr.Image(type="filepath", label="Input"),
gr.Checkbox(value=True, label="Pre_Face_Align"),
gr.Checkbox(value=True, label="Background_Enhance"),
gr.Checkbox(value=True, label="Face_Upsample"),
gr.Number(value=2, label="Rescaling_Factor (up to 4)"),
gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity')
],
outputs=gr.Image(type="numpy", label="Output"),
title="CodeFormer: Robust Face Restoration and Enhancement Network"
)
demo.launch(debug=os.getenv('DEBUG') == '1', share=True)
|