import numpy as np from ..models.lmdm import LMDM """ lmdm_cfg = { "model_path": "", "device": "cuda", "motion_feat_dim": 265, "audio_feat_dim": 1024+35, "seq_frames": 80, } """ def _cvt_LP_motion_info(inp, mode, ignore_keys=()): ks_shape_map = [ ['scale', (1, 1), 1], ['pitch', (1, 66), 66], ['yaw', (1, 66), 66], ['roll', (1, 66), 66], ['t', (1, 3), 3], ['exp', (1, 63), 63], ['kp', (1, 63), 63], ] def _dic2arr(_dic): arr = [] for k, _, ds in ks_shape_map: if k not in _dic or k in ignore_keys: continue v = _dic[k].reshape(ds) if k == 'scale': v = v - 1 arr.append(v) arr = np.concatenate(arr, -1) # (133) return arr def _arr2dic(_arr): dic = {} s = 0 for k, ds, ss in ks_shape_map: if k in ignore_keys: continue v = _arr[s:s + ss].reshape(ds) if k == 'scale': v = v + 1 dic[k] = v s += ss if s >= len(_arr): break return dic if mode == 'dic2arr': assert isinstance(inp, dict) return _dic2arr(inp) # (dim) elif mode == 'arr2dic': assert inp.shape[0] >= 265, f"{inp.shape}" return _arr2dic(inp) # {k: (1, dim)} else: raise ValueError() class Audio2Motion: def __init__( self, lmdm_cfg, ): self.lmdm = LMDM(**lmdm_cfg) def setup( self, x_s_info, overlap_v2=10, fix_kp_cond=0, fix_kp_cond_dim=None, sampling_timesteps=50, online_mode=False, v_min_max_for_clip=None, smo_k_d=3, ): self.smo_k_d = smo_k_d self.overlap_v2 = overlap_v2 self.seq_frames = self.lmdm.seq_frames self.valid_clip_len = self.seq_frames - self.overlap_v2 # for fuse self.online_mode = online_mode if self.online_mode: self.fuse_length = min(self.overlap_v2, self.valid_clip_len) else: self.fuse_length = self.overlap_v2 self.fuse_alpha = np.arange(self.fuse_length, dtype=np.float32).reshape(1, -1, 1) / self.fuse_length self.fix_kp_cond = fix_kp_cond self.fix_kp_cond_dim = fix_kp_cond_dim self.sampling_timesteps = sampling_timesteps self.v_min_max_for_clip = v_min_max_for_clip if self.v_min_max_for_clip is not None: self.v_min = self.v_min_max_for_clip[0][None] # [dim, 1] self.v_max = self.v_min_max_for_clip[1][None] kp_source = _cvt_LP_motion_info(x_s_info, mode='dic2arr', ignore_keys={'kp'})[None] self.s_kp_cond = kp_source.copy().reshape(1, -1) self.kp_cond = self.s_kp_cond.copy() self.lmdm.setup(sampling_timesteps) self.clip_idx = 0 def _fuse(self, res_kp_seq, pred_kp_seq): ## ======================== ## offline fuse mode ## last clip: ------- ## fuse part: ***** ## curr clip: ------- ## output: ^^ # ## online fuse mode ## last clip: ------- ## fuse part: ** ## curr clip: ------- ## output: ^^ ## ======================== fuse_r1_s = res_kp_seq.shape[1] - self.fuse_length fuse_r1_e = res_kp_seq.shape[1] fuse_r2_s = self.seq_frames - self.valid_clip_len - self.fuse_length fuse_r2_e = self.seq_frames - self.valid_clip_len r1 = res_kp_seq[:, fuse_r1_s:fuse_r1_e] # [1, fuse_len, dim] r2 = pred_kp_seq[:, fuse_r2_s: fuse_r2_e] # [1, fuse_len, dim] r_fuse = r1 * (1 - self.fuse_alpha) + r2 * self.fuse_alpha res_kp_seq[:, fuse_r1_s:fuse_r1_e] = r_fuse # fuse last res_kp_seq = np.concatenate([res_kp_seq, pred_kp_seq[:, fuse_r2_e:]], 1) # len(res_kp_seq) + valid_clip_len return res_kp_seq def _update_kp_cond(self, res_kp_seq, idx): if self.fix_kp_cond == 0: # 不重置 self.kp_cond = res_kp_seq[:, idx-1] elif self.fix_kp_cond > 0: if self.clip_idx % self.fix_kp_cond == 0: # 重置 self.kp_cond = self.s_kp_cond.copy() # 重置所有 if self.fix_kp_cond_dim is not None: ds, de = self.fix_kp_cond_dim self.kp_cond[:, ds:de] = res_kp_seq[:, idx-1, ds:de] else: self.kp_cond = res_kp_seq[:, idx-1] def _smo(self, res_kp_seq, s, e): if self.smo_k_d <= 1: return res_kp_seq new_res_kp_seq = res_kp_seq.copy() n = res_kp_seq.shape[1] half_k = self.smo_k_d // 2 for i in range(s, e): ss = max(0, i - half_k) ee = min(n, i + half_k + 1) res_kp_seq[:, i, :202] = np.mean(new_res_kp_seq[:, ss:ee, :202], axis=1) return res_kp_seq def __call__(self, aud_cond, res_kp_seq=None): """ aud_cond: (1, seq_frames, dim) """ pred_kp_seq = self.lmdm(self.kp_cond, aud_cond, self.sampling_timesteps) if res_kp_seq is None: res_kp_seq = pred_kp_seq # [1, seq_frames, dim] res_kp_seq = self._smo(res_kp_seq, 0, res_kp_seq.shape[1]) else: res_kp_seq = self._fuse(res_kp_seq, pred_kp_seq) # len(res_kp_seq) + valid_clip_len res_kp_seq = self._smo(res_kp_seq, res_kp_seq.shape[1] - self.valid_clip_len - self.fuse_length, res_kp_seq.shape[1] - self.valid_clip_len + 1) self.clip_idx += 1 idx = res_kp_seq.shape[1] - self.overlap_v2 self._update_kp_cond(res_kp_seq, idx) return res_kp_seq def cvt_fmt(self, res_kp_seq): # res_kp_seq: [1, n, dim] if self.v_min_max_for_clip is not None: tmp_res_kp_seq = np.clip(res_kp_seq[0], self.v_min, self.v_max) else: tmp_res_kp_seq = res_kp_seq[0] x_d_info_list = [] for i in range(tmp_res_kp_seq.shape[0]): x_d_info = _cvt_LP_motion_info(tmp_res_kp_seq[i], 'arr2dic') # {k: (1, dim)} x_d_info_list.append(x_d_info) return x_d_info_list