oKen38461's picture
初回コミットに基づくファイルの追加
ac7cda5
raw
history blame
6.37 kB
import numpy as np
from ..models.lmdm import LMDM
"""
lmdm_cfg = {
"model_path": "",
"device": "cuda",
"motion_feat_dim": 265,
"audio_feat_dim": 1024+35,
"seq_frames": 80,
}
"""
def _cvt_LP_motion_info(inp, mode, ignore_keys=()):
ks_shape_map = [
['scale', (1, 1), 1],
['pitch', (1, 66), 66],
['yaw', (1, 66), 66],
['roll', (1, 66), 66],
['t', (1, 3), 3],
['exp', (1, 63), 63],
['kp', (1, 63), 63],
]
def _dic2arr(_dic):
arr = []
for k, _, ds in ks_shape_map:
if k not in _dic or k in ignore_keys:
continue
v = _dic[k].reshape(ds)
if k == 'scale':
v = v - 1
arr.append(v)
arr = np.concatenate(arr, -1) # (133)
return arr
def _arr2dic(_arr):
dic = {}
s = 0
for k, ds, ss in ks_shape_map:
if k in ignore_keys:
continue
v = _arr[s:s + ss].reshape(ds)
if k == 'scale':
v = v + 1
dic[k] = v
s += ss
if s >= len(_arr):
break
return dic
if mode == 'dic2arr':
assert isinstance(inp, dict)
return _dic2arr(inp) # (dim)
elif mode == 'arr2dic':
assert inp.shape[0] >= 265, f"{inp.shape}"
return _arr2dic(inp) # {k: (1, dim)}
else:
raise ValueError()
class Audio2Motion:
def __init__(
self,
lmdm_cfg,
):
self.lmdm = LMDM(**lmdm_cfg)
def setup(
self,
x_s_info,
overlap_v2=10,
fix_kp_cond=0,
fix_kp_cond_dim=None,
sampling_timesteps=50,
online_mode=False,
v_min_max_for_clip=None,
smo_k_d=3,
):
self.smo_k_d = smo_k_d
self.overlap_v2 = overlap_v2
self.seq_frames = self.lmdm.seq_frames
self.valid_clip_len = self.seq_frames - self.overlap_v2
# for fuse
self.online_mode = online_mode
if self.online_mode:
self.fuse_length = min(self.overlap_v2, self.valid_clip_len)
else:
self.fuse_length = self.overlap_v2
self.fuse_alpha = np.arange(self.fuse_length, dtype=np.float32).reshape(1, -1, 1) / self.fuse_length
self.fix_kp_cond = fix_kp_cond
self.fix_kp_cond_dim = fix_kp_cond_dim
self.sampling_timesteps = sampling_timesteps
self.v_min_max_for_clip = v_min_max_for_clip
if self.v_min_max_for_clip is not None:
self.v_min = self.v_min_max_for_clip[0][None] # [dim, 1]
self.v_max = self.v_min_max_for_clip[1][None]
kp_source = _cvt_LP_motion_info(x_s_info, mode='dic2arr', ignore_keys={'kp'})[None]
self.s_kp_cond = kp_source.copy().reshape(1, -1)
self.kp_cond = self.s_kp_cond.copy()
self.lmdm.setup(sampling_timesteps)
self.clip_idx = 0
def _fuse(self, res_kp_seq, pred_kp_seq):
## ========================
## offline fuse mode
## last clip: -------
## fuse part: *****
## curr clip: -------
## output: ^^
#
## online fuse mode
## last clip: -------
## fuse part: **
## curr clip: -------
## output: ^^
## ========================
fuse_r1_s = res_kp_seq.shape[1] - self.fuse_length
fuse_r1_e = res_kp_seq.shape[1]
fuse_r2_s = self.seq_frames - self.valid_clip_len - self.fuse_length
fuse_r2_e = self.seq_frames - self.valid_clip_len
r1 = res_kp_seq[:, fuse_r1_s:fuse_r1_e] # [1, fuse_len, dim]
r2 = pred_kp_seq[:, fuse_r2_s: fuse_r2_e] # [1, fuse_len, dim]
r_fuse = r1 * (1 - self.fuse_alpha) + r2 * self.fuse_alpha
res_kp_seq[:, fuse_r1_s:fuse_r1_e] = r_fuse # fuse last
res_kp_seq = np.concatenate([res_kp_seq, pred_kp_seq[:, fuse_r2_e:]], 1) # len(res_kp_seq) + valid_clip_len
return res_kp_seq
def _update_kp_cond(self, res_kp_seq, idx):
if self.fix_kp_cond == 0: # 不重置
self.kp_cond = res_kp_seq[:, idx-1]
elif self.fix_kp_cond > 0:
if self.clip_idx % self.fix_kp_cond == 0: # 重置
self.kp_cond = self.s_kp_cond.copy() # 重置所有
if self.fix_kp_cond_dim is not None:
ds, de = self.fix_kp_cond_dim
self.kp_cond[:, ds:de] = res_kp_seq[:, idx-1, ds:de]
else:
self.kp_cond = res_kp_seq[:, idx-1]
def _smo(self, res_kp_seq, s, e):
if self.smo_k_d <= 1:
return res_kp_seq
new_res_kp_seq = res_kp_seq.copy()
n = res_kp_seq.shape[1]
half_k = self.smo_k_d // 2
for i in range(s, e):
ss = max(0, i - half_k)
ee = min(n, i + half_k + 1)
res_kp_seq[:, i, :202] = np.mean(new_res_kp_seq[:, ss:ee, :202], axis=1)
return res_kp_seq
def __call__(self, aud_cond, res_kp_seq=None):
"""
aud_cond: (1, seq_frames, dim)
"""
pred_kp_seq = self.lmdm(self.kp_cond, aud_cond, self.sampling_timesteps)
if res_kp_seq is None:
res_kp_seq = pred_kp_seq # [1, seq_frames, dim]
res_kp_seq = self._smo(res_kp_seq, 0, res_kp_seq.shape[1])
else:
res_kp_seq = self._fuse(res_kp_seq, pred_kp_seq) # len(res_kp_seq) + valid_clip_len
res_kp_seq = self._smo(res_kp_seq, res_kp_seq.shape[1] - self.valid_clip_len - self.fuse_length, res_kp_seq.shape[1] - self.valid_clip_len + 1)
self.clip_idx += 1
idx = res_kp_seq.shape[1] - self.overlap_v2
self._update_kp_cond(res_kp_seq, idx)
return res_kp_seq
def cvt_fmt(self, res_kp_seq):
# res_kp_seq: [1, n, dim]
if self.v_min_max_for_clip is not None:
tmp_res_kp_seq = np.clip(res_kp_seq[0], self.v_min, self.v_max)
else:
tmp_res_kp_seq = res_kp_seq[0]
x_d_info_list = []
for i in range(tmp_res_kp_seq.shape[0]):
x_d_info = _cvt_LP_motion_info(tmp_res_kp_seq[i], 'arr2dic') # {k: (1, dim)}
x_d_info_list.append(x_d_info)
return x_d_info_list