Spaces:
Runtime error
Runtime error
File size: 2,827 Bytes
ac7cda5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import os
import pickle
import numpy as np
def load_pkl(pkl):
with open(pkl, "rb") as f:
return pickle.load(f)
def parse_cfg(cfg_pkl, data_root, replace_cfg=None):
def _check_path(p):
if os.path.isfile(p):
return p
else:
return os.path.join(data_root, p)
cfg = load_pkl(cfg_pkl)
# ---
# replace cfg for debug
if isinstance(replace_cfg, dict):
for k, v in replace_cfg.items():
if not isinstance(v, dict):
continue
for kk, vv in v.items():
cfg[k][kk] = vv
# ---
base_cfg = cfg["base_cfg"]
audio2motion_cfg = cfg["audio2motion_cfg"]
default_kwargs = cfg["default_kwargs"]
for k in base_cfg:
if k == "landmark478_cfg":
for kk in ["task_path", "blaze_face_model_path", "face_mesh_model_path"]:
if kk in base_cfg[k] and base_cfg[k][kk]:
base_cfg[k][kk] = _check_path(base_cfg[k][kk])
else:
base_cfg[k]["model_path"] = _check_path(base_cfg[k]["model_path"])
audio2motion_cfg["model_path"] = _check_path(audio2motion_cfg["model_path"])
avatar_registrar_cfg = {
k: base_cfg[k]
for k in [
"insightface_det_cfg",
"landmark106_cfg",
"landmark203_cfg",
"landmark478_cfg",
"appearance_extractor_cfg",
"motion_extractor_cfg",
]
}
stitch_network_cfg = base_cfg["stitch_network_cfg"]
warp_network_cfg = base_cfg["warp_network_cfg"]
decoder_cfg = base_cfg["decoder_cfg"]
condition_handler_cfg = {
k: audio2motion_cfg[k]
for k in [
"use_emo",
"use_sc",
"use_eye_open",
"use_eye_ball",
"seq_frames",
]
}
lmdm_cfg = {
k: audio2motion_cfg[k]
for k in [
"model_path",
"device",
"motion_feat_dim",
"audio_feat_dim",
"seq_frames",
]
}
w2f_type = audio2motion_cfg["w2f_type"]
wav2feat_cfg = {
"w2f_cfg": base_cfg["hubert_cfg"] if w2f_type == "hubert" else base_cfg["wavlm_cfg"],
"w2f_type": w2f_type,
}
return [
avatar_registrar_cfg,
condition_handler_cfg,
lmdm_cfg,
stitch_network_cfg,
warp_network_cfg,
decoder_cfg,
wav2feat_cfg,
default_kwargs,
]
def print_cfg(**kwargs):
for k, v in kwargs.items():
if k == "ch_info":
print(k, type(v))
elif k == "ctrl_info":
print(k, type(v), len(v))
else:
if isinstance(v, np.ndarray):
print(k, type(v), v.shape)
else:
print(k, type(v), v)
|