File size: 4,688 Bytes
ac7cda5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import numpy as np
import cv2

from ..aux_models.insightface_det import InsightFaceDet
from ..aux_models.insightface_landmark106 import Landmark106
from ..aux_models.landmark203 import Landmark203
from ..aux_models.mediapipe_landmark478 import Landmark478
from ..models.appearance_extractor import AppearanceExtractor
from ..models.motion_extractor import MotionExtractor

from ..utils.crop import crop_image
from ..utils.eye_info import EyeAttrUtilsByMP


"""
insightface_det_cfg = {
    "model_path": "",
    "device": "cuda",
    "force_ori_type": False,
}
landmark106_cfg = {
    "model_path": "",
    "device": "cuda",
    "force_ori_type": False,
}
landmark203_cfg = {
    "model_path": "",
    "device": "cuda",
    "force_ori_type": False,
}
landmark478_cfg = {
    "blaze_face_model_path": "", 
    "face_mesh_model_path": "", 
    "device": "cuda",
    "force_ori_type": False,
    "task_path": "",
}
appearance_extractor_cfg = {
    "model_path": "",
    "device": "cuda",
}
motion_extractor_cfg = {
    "model_path": "",
    "device": "cuda",
}
"""


class Source2Info:
    def __init__(
        self,
        insightface_det_cfg,
        landmark106_cfg,
        landmark203_cfg,
        landmark478_cfg,
        appearance_extractor_cfg,
        motion_extractor_cfg,
    ):
        self.insightface_det = InsightFaceDet(**insightface_det_cfg)
        self.landmark106 = Landmark106(**landmark106_cfg)
        self.landmark203 = Landmark203(**landmark203_cfg)
        self.landmark478 = Landmark478(**landmark478_cfg)

        self.appearance_extractor = AppearanceExtractor(**appearance_extractor_cfg)
        self.motion_extractor = MotionExtractor(**motion_extractor_cfg)

    def _crop(self, img, last_lmk=None, **kwargs):
        # img_rgb -> det->landmark106->landmark203->crop

        if last_lmk is None:  # det for first frame or image
            det, _ = self.insightface_det(img)
            boxes = det[np.argsort(-(det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1]))]
            if len(boxes) == 0:
                return None
            lmk_for_track = self.landmark106(img, boxes[0])  # 106
        else:  # track for video frames
            lmk_for_track = last_lmk  # 203

        crop_dct = crop_image(
            img,
            lmk_for_track,
            dsize=self.landmark203.dsize,
            scale=1.5,
            vy_ratio=-0.1,
            pt_crop_flag=False,
        )
        lmk203 = self.landmark203(crop_dct["img_crop"], crop_dct["M_c2o"])

        ret_dct = crop_image(
            img,
            lmk203,
            dsize=512,
            scale=kwargs.get("crop_scale", 2.3),
            vx_ratio=kwargs.get("crop_vx_ratio", 0),
            vy_ratio=kwargs.get("crop_vy_ratio", -0.125),
            flag_do_rot=kwargs.get("crop_flag_do_rot", True),
            pt_crop_flag=False,
        )

        img_crop = ret_dct["img_crop"]
        M_c2o = ret_dct["M_c2o"]

        return img_crop, M_c2o, lmk203
    
    @staticmethod
    def _img_crop_to_bchw256(img_crop):
        rgb_256 = cv2.resize(img_crop, (256, 256), interpolation=cv2.INTER_AREA)
        rgb_256_bchw = (rgb_256.astype(np.float32) / 255.0)[None].transpose(0, 3, 1, 2)
        return rgb_256_bchw

    def _get_kp_info(self, img):
        # rgb_256_bchw_norm01
        kp_info = self.motion_extractor(img)
        return kp_info

    def _get_f3d(self, img):
        # rgb_256_bchw_norm01
        fs = self.appearance_extractor(img)
        return fs

    def _get_eye_info(self, img):
        # rgb uint8
        lmk478 = self.landmark478(img)  # [1, 478, 3]
        attr = EyeAttrUtilsByMP(lmk478)
        lr_open = attr.LR_open().reshape(-1, 2)   # [1, 2]
        lr_ball = attr.LR_ball_move().reshape(-1, 6)   # [1, 3, 2] -> [1, 6]
        return [lr_open, lr_ball]

    def __call__(self, img, last_lmk=None, **kwargs):
        """
        img: rgb, uint8
        last_lmk: last frame lmk203, for video tracking
        kwargs: optional crop cfg
            crop_scale: 2.3
            crop_vx_ratio: 0
            crop_vy_ratio: -0.125
            crop_flag_do_rot: True
        """
        img_crop, M_c2o, lmk203 = self._crop(img, last_lmk=last_lmk, **kwargs)

        eye_open, eye_ball = self._get_eye_info(img_crop)

        rgb_256_bchw = self._img_crop_to_bchw256(img_crop)
        kp_info = self._get_kp_info(rgb_256_bchw)
        fs = self._get_f3d(rgb_256_bchw)
        
        source_info = {
            "x_s_info": kp_info,
            "f_s": fs,
            "M_c2o": M_c2o,
            "eye_open": eye_open,   # [1, 2]
            "eye_ball": eye_ball,    # [1, 6]
            "lmk203": lmk203,  # for track
        }
        return source_info