File size: 4,611 Bytes
3e165b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os

import cv2
import torch
from gfpgan import GFPGANer
from tqdm import tqdm

from visualizr import logger
from visualizr.face_sr.videoio import load_video_to_cv2


class GeneratorWithLen(object):
    """From https://stackoverflow.com/a/7460929"""

    def __init__(self, gen, length):
        self.gen = gen
        self.length = length

    def __len__(self):
        return self.length

    def __iter__(self):
        return self.gen


def enhancer_list(images, method="gfpgan", bg_upsampler="realesrgan"):
    gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
    return list(gen)


def enhancer_generator_with_len(images, method="gfpgan", bg_upsampler="realesrgan"):
    """Provide a generator with a __len__ method so that it can passed to functions that
    call len()"""

    if os.path.isfile(images):  # handle video to images
        images = load_video_to_cv2(images)

    gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
    gen_with_len = GeneratorWithLen(gen, len(images))
    return gen_with_len


def enhancer_generator_no_len(images, method="gfpgan", bg_upsampler="realesrgan"):
    """Provide a generator function so that all of the enhanced images don't need
    to be stored in memory at the same time. This can save tons of RAM compared to
    the enhancer function."""
    if method not in ["gfpgan", "RestoreFormer", "codeformer"]:
        raise ValueError(f"Wrong model version {method}.")
    logger.info("face enhancer....")
    if not isinstance(images, list) and os.path.isfile(
        images
    ):  # handle video to images
        images = load_video_to_cv2(images)

    # ------------------------ set up GFPGAN restorer ------------------------
    match method:
        case "gfpgan":
            arch = "clean"
            channel_multiplier = 2
            model_name = "GFPGANv1.4"
            url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
        case "RestoreFormer":
            arch = "RestoreFormer"
            channel_multiplier = 2
            model_name = "RestoreFormer"
            url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
        case "codeformer":
            arch = "CodeFormer"
            channel_multiplier = 2
            model_name = "CodeFormer"
            url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
    # ------------------------ set up background upsampler ------------------------
    if bg_upsampler == "realesrgan":
        if not torch.cuda.is_available():  # CPU
            import warnings

            warnings.warn(
                "The unoptimized RealESRGAN is slow on CPU. We do not use it. "
                "If you really want to use it, please modify the corresponding codes."
            )
            bg_upsampler = None
        else:
            from basicsr.archs.rrdbnet_arch import RRDBNet
            from realesrgan import RealESRGANer

            model = RRDBNet(
                num_in_ch=3,
                num_out_ch=3,
                num_feat=64,
                num_block=23,
                num_grow_ch=32,
                scale=2,
            )
            bg_upsampler = RealESRGANer(
                scale=2,
                model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
                model=model,
                tile=400,
                tile_pad=10,
                pre_pad=0,
                half=True,
            )  # need to set False in CPU mode
    else:
        bg_upsampler = None

    # determine model paths
    model_path = os.path.join("gfpgan/weights", model_name + ".pth")

    if not os.path.isfile(model_path):
        model_path = os.path.join("checkpoints", model_name + ".pth")

    if not os.path.isfile(model_path):
        # download pre-trained models from url
        model_path = url

    restorer = GFPGANer(
        model_path=model_path,
        upscale=2,
        arch=arch,
        channel_multiplier=channel_multiplier,
        bg_upsampler=bg_upsampler,
    )

    # ------------------------ restore ------------------------
    for idx in tqdm(range(len(images)), "Face Enhancer:"):
        img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR)

        # restore faces and background if necessary
        cropped_faces, restored_faces, r_img = restorer.enhance(
            img, has_aligned=False, only_center_face=False, paste_back=True
        )

        r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
        yield r_img