File size: 5,328 Bytes
0334511
 
 
 
 
 
 
 
 
 
 
 
 
 
a4138cd
0334511
 
a9c6794
0334511
 
 
 
 
 
bfb2718
 
 
 
 
a9c6794
 
 
 
 
0fc7db4
a9c6794
0fc7db4
 
 
 
0334511
 
a9c6794
 
0334511
a9c6794
 
0334511
 
 
a9c6794
0334511
 
 
a9c6794
 
0334511
 
 
 
 
 
 
 
a9c6794
 
 
f26e1dd
 
a9c6794
 
0fc7db4
a9c6794
 
b266889
f26e1dd
a9c6794
 
0334511
a9c6794
f26e1dd
 
 
 
 
 
 
0334511
f26e1dd
a9c6794
f26e1dd
 
a9c6794
a4138cd
f26e1dd
 
 
a9c6794
 
 
 
f26e1dd
a9c6794
 
 
 
f26e1dd
a9c6794
 
 
f26e1dd
a9c6794
 
0334511
a9c6794
8a3d969
a9c6794
 
52da111
9928d25
52da111
 
 
a9c6794
 
 
 
 
8a3d969
a9c6794
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import sys
sys.path.append('CodeFormer')
import os
import cv2
import torch
import torch.nn.functional as F
import gradio as gr

from torchvision.transforms.functional import normalize
from basicsr.utils import imwrite, img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.realesrgan_utils import RealESRGANer
from facelib.utils.misc import is_gray
from basicsr.utils.registry import ARCH_REGISTRY

# Model weight URLs
pretrain_model_url = {
    'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
    'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
    'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth',
    'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth'
}
load_file_from_url(
    url='https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', 
    model_dir='CodeFormer/weights/CodeFormer', 
    progress=True
)
# Download weights if not already present
for key, url in pretrain_model_url.items():
    file_path = f"CodeFormer/weights/{key}/{url.split('/')[-1]}"
    if not os.path.exists(file_path):
        load_file_from_url(url=url, model_dir=os.path.dirname(file_path), progress=True)

# Helper functions
def imread(img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def set_realesrgan():
    half = torch.cuda.is_available()
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
    upsampler = RealESRGANer(
        scale=2, model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth",
        model=model, tile=400, tile_pad=40, pre_pad=0, half=half
    )
    return upsampler

# Model setup
upsampler = set_realesrgan()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
    dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
    connect_list=["32", "64", "128", "256"]
).to(device)
ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth"
checkpoint = torch.load(ckpt_path)["params_ema"]
codeformer_net.load_state_dict(checkpoint)
codeformer_net.eval()

os.makedirs('output', exist_ok=True)

# Inference function
def inference(image, face_align=True, background_enhance=True, face_upsample=True, upscale=2, codeformer_fidelity=0.5):
    try:
        only_center_face = False
        detection_model = "retinaface_resnet50"
        
        # Load image and set parameters
        img = cv2.imread(str(image), cv2.IMREAD_COLOR)
        has_aligned = not face_align
        upscale = min(max(1, int(upscale)), 4)

        face_helper = FaceRestoreHelper(
            upscale, face_size=512, crop_ratio=(1, 1), det_model=detection_model,
            save_ext="png", use_parse=True, device=device
        )
        
        bg_upsampler = upsampler if background_enhance else None
        face_upsampler = upsampler if face_upsample else None

        if has_aligned:
            img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
            face_helper.is_gray = is_gray(img, threshold=5)
            face_helper.cropped_faces = [img]
        else:
            face_helper.read_image(img)
            num_det_faces = face_helper.get_face_landmarks_5(only_center_face=only_center_face, resize=640, eye_dist_threshold=5)
            face_helper.align_warp_face()

        for cropped_face in face_helper.cropped_faces:
            cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
            normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
            cropped_face_t = cropped_face_t.unsqueeze(0).to(device)

            with torch.no_grad():
                output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0]
                restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
            face_helper.add_restored_face(restored_face.astype("uint8"), cropped_face)

        restored_img = face_helper.paste_faces_to_input_image(
            upsample_img=bg_upsampler.enhance(img, outscale=upscale)[0] if bg_upsampler else None,
            face_upsampler=face_upsampler
        )

        save_path = 'output/out.png'
        imwrite(restored_img, save_path)
        return cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
    except Exception as error:
        print('Error during inference:', error)
        return None

# Gradio Interface
demo = gr.Interface(
    fn=inference, 
    inputs=[
        gr.Image(type="filepath", label="Input"),
        gr.Checkbox(value=True, label="Pre_Face_Align"),
        gr.Checkbox(value=True, label="Background_Enhance"),
        gr.Checkbox(value=True, label="Face_Upsample"),
        gr.Number(value=2, label="Rescaling_Factor (up to 4)"),
        gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity')
    ], 
    outputs=gr.Image(type="numpy", label="Output"),
    title="CodeFormer: Robust Face Restoration and Enhancement Network"
)

demo.launch(debug=os.getenv('DEBUG') == '1', share=True)