File size: 2,827 Bytes
ac7cda5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import pickle
import numpy as np


def load_pkl(pkl):
    with open(pkl, "rb") as f:
        return pickle.load(f)
    

def parse_cfg(cfg_pkl, data_root, replace_cfg=None):

    def _check_path(p):
        if os.path.isfile(p):
            return p
        else:
            return os.path.join(data_root, p)

    cfg = load_pkl(cfg_pkl)

    # ---
    # replace cfg for debug
    if isinstance(replace_cfg, dict):
        for k, v in replace_cfg.items():
            if not isinstance(v, dict):
                continue
            for kk, vv in v.items():
                cfg[k][kk] = vv
    # ---

    base_cfg = cfg["base_cfg"]
    audio2motion_cfg = cfg["audio2motion_cfg"]
    default_kwargs = cfg["default_kwargs"]

    for k in base_cfg:
        if k == "landmark478_cfg":
            for kk in ["task_path", "blaze_face_model_path", "face_mesh_model_path"]:
                if kk in base_cfg[k] and base_cfg[k][kk]:
                    base_cfg[k][kk] = _check_path(base_cfg[k][kk])
        else:
            base_cfg[k]["model_path"] = _check_path(base_cfg[k]["model_path"])

    audio2motion_cfg["model_path"] = _check_path(audio2motion_cfg["model_path"])

    avatar_registrar_cfg = {
        k: base_cfg[k]
        for k in [
            "insightface_det_cfg",
            "landmark106_cfg",
            "landmark203_cfg",
            "landmark478_cfg",
            "appearance_extractor_cfg",
            "motion_extractor_cfg",
        ]
    }

    stitch_network_cfg = base_cfg["stitch_network_cfg"]
    warp_network_cfg = base_cfg["warp_network_cfg"]
    decoder_cfg = base_cfg["decoder_cfg"]
    
    condition_handler_cfg = {
        k: audio2motion_cfg[k]
        for k in [
            "use_emo",
            "use_sc",
            "use_eye_open",
            "use_eye_ball",
            "seq_frames",
        ]
    }

    lmdm_cfg = {
        k: audio2motion_cfg[k]
        for k in [
            "model_path",
            "device",
            "motion_feat_dim",
            "audio_feat_dim",
            "seq_frames",
        ]
    }

    w2f_type = audio2motion_cfg["w2f_type"]
    wav2feat_cfg = {
        "w2f_cfg": base_cfg["hubert_cfg"] if w2f_type == "hubert" else base_cfg["wavlm_cfg"],
        "w2f_type": w2f_type,
    }
    
    return [
        avatar_registrar_cfg,
        condition_handler_cfg,
        lmdm_cfg,
        stitch_network_cfg,
        warp_network_cfg,
        decoder_cfg,
        wav2feat_cfg,
        default_kwargs,
    ]


def print_cfg(**kwargs):
    for k, v in kwargs.items():
        if k == "ch_info":
            print(k, type(v))
        elif k == "ctrl_info":
            print(k, type(v), len(v))
        else:
            if isinstance(v, np.ndarray):
                print(k, type(v), v.shape)
            else:
                print(k, type(v), v)