import os import pickle import numpy as np def load_pkl(pkl): with open(pkl, "rb") as f: return pickle.load(f) def parse_cfg(cfg_pkl, data_root, replace_cfg=None): def _check_path(p): if os.path.isfile(p): return p else: return os.path.join(data_root, p) cfg = load_pkl(cfg_pkl) # --- # replace cfg for debug if isinstance(replace_cfg, dict): for k, v in replace_cfg.items(): if not isinstance(v, dict): continue for kk, vv in v.items(): cfg[k][kk] = vv # --- base_cfg = cfg["base_cfg"] audio2motion_cfg = cfg["audio2motion_cfg"] default_kwargs = cfg["default_kwargs"] for k in base_cfg: if k == "landmark478_cfg": for kk in ["task_path", "blaze_face_model_path", "face_mesh_model_path"]: if kk in base_cfg[k] and base_cfg[k][kk]: base_cfg[k][kk] = _check_path(base_cfg[k][kk]) else: base_cfg[k]["model_path"] = _check_path(base_cfg[k]["model_path"]) audio2motion_cfg["model_path"] = _check_path(audio2motion_cfg["model_path"]) avatar_registrar_cfg = { k: base_cfg[k] for k in [ "insightface_det_cfg", "landmark106_cfg", "landmark203_cfg", "landmark478_cfg", "appearance_extractor_cfg", "motion_extractor_cfg", ] } stitch_network_cfg = base_cfg["stitch_network_cfg"] warp_network_cfg = base_cfg["warp_network_cfg"] decoder_cfg = base_cfg["decoder_cfg"] condition_handler_cfg = { k: audio2motion_cfg[k] for k in [ "use_emo", "use_sc", "use_eye_open", "use_eye_ball", "seq_frames", ] } lmdm_cfg = { k: audio2motion_cfg[k] for k in [ "model_path", "device", "motion_feat_dim", "audio_feat_dim", "seq_frames", ] } w2f_type = audio2motion_cfg["w2f_type"] wav2feat_cfg = { "w2f_cfg": base_cfg["hubert_cfg"] if w2f_type == "hubert" else base_cfg["wavlm_cfg"], "w2f_type": w2f_type, } return [ avatar_registrar_cfg, condition_handler_cfg, lmdm_cfg, stitch_network_cfg, warp_network_cfg, decoder_cfg, wav2feat_cfg, default_kwargs, ] def print_cfg(**kwargs): for k, v in kwargs.items(): if k == "ch_info": print(k, type(v)) elif k == "ctrl_info": print(k, type(v), len(v)) else: if isinstance(v, np.ndarray): print(k, type(v), v.shape) else: print(k, type(v), v)