oKen38461's picture
初回コミットに基づくファイルの追加
ac7cda5
raw
history blame
5.98 kB
import numpy as np
from scipy.special import softmax
import copy
def _get_emo_avg(idx=6):
emo_avg = np.zeros(8, dtype=np.float32)
if isinstance(idx, (list, tuple)):
for i in idx:
emo_avg[i] = 8
else:
emo_avg[idx] = 8
emo_avg = softmax(emo_avg)
#emo_avg = None
# 'Angry', 'Disgust', 'Fear', 'Happy', 'Neutral', 'Sad', 'Surprise', 'Contempt'
return emo_avg
def _mirror_index(index, size):
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
class ConditionHandler:
"""
aud_feat, emo_seq, eye_seq, sc_seq -> cond_seq
"""
def __init__(
self,
use_emo=True,
use_sc=True,
use_eye_open=True,
use_eye_ball=True,
seq_frames=80,
):
self.use_emo = use_emo
self.use_sc = use_sc
self.use_eye_open = use_eye_open
self.use_eye_ball = use_eye_ball
self.seq_frames = seq_frames
def setup(self, setup_info, emo, eye_f0_mode=False, ch_info=None):
"""
emo: int | [int] | [[int]] | numpy
"""
if ch_info is None:
source_info = copy.deepcopy(setup_info)
else:
source_info = ch_info
self.eye_f0_mode = eye_f0_mode
self.x_s_info_0 = source_info['x_s_info_lst'][0]
if self.use_sc:
self.sc = source_info["sc"] # 63
self.sc_seq = np.stack([self.sc] * self.seq_frames, 0)
if self.use_eye_open:
self.eye_open_lst = np.concatenate(source_info["eye_open_lst"], 0) # [n, 2]
self.num_eye_open = len(self.eye_open_lst)
if self.num_eye_open == 1 or self.eye_f0_mode:
self.eye_open_seq = np.stack([self.eye_open_lst[0]] * self.seq_frames, 0)
else:
self.eye_open_seq = None
if self.use_eye_ball:
self.eye_ball_lst = np.concatenate(source_info["eye_ball_lst"], 0) # [n, 6]
self.num_eye_ball = len(self.eye_ball_lst)
if self.num_eye_ball == 1 or self.eye_f0_mode:
self.eye_ball_seq = np.stack([self.eye_ball_lst[0]] * self.seq_frames, 0)
else:
self.eye_ball_seq = None
if self.use_emo:
self.emo_lst = self._parse_emo_seq(emo)
self.num_emo = len(self.emo_lst)
if self.num_emo == 1:
self.emo_seq = np.concatenate([self.emo_lst] * self.seq_frames, 0)
else:
self.emo_seq = None
@staticmethod
def _parse_emo_seq(emo, seq_len=-1):
if isinstance(emo, np.ndarray) and emo.ndim == 2 and emo.shape[1] == 8:
# emo arr, e.g. real
emo_seq = emo # [m, 8]
elif isinstance(emo, int) and 0 <= emo < 8:
# emo label, e.g. 4
emo_seq = _get_emo_avg(emo).reshape(1, 8) # [1, 8]
elif isinstance(emo, (list, tuple)) and 0 < len(emo) < 8 and isinstance(emo[0], int):
# emo labels, e.g. [3,4]
emo_seq = _get_emo_avg(emo).reshape(1, 8) # [1, 8]
elif isinstance(emo, list) and emo and isinstance(emo[0], (list, tuple)):
# emo label list, e.g. [[4], [3,4], [3],[3,4,5], ...]
emo_seq = np.stack([_get_emo_avg(i) for i in emo], 0) # [m, 8]
else:
raise ValueError(f"Unsupported emo type: {emo}")
if seq_len > 0:
if len(emo_seq) == seq_len:
return emo_seq
elif len(emo_seq) == 1:
return np.concatenate([emo_seq] * seq_len, 0)
elif len(emo_seq) > seq_len:
return emo_seq[:seq_len]
else:
raise ValueError(f"emo len {len(emo_seq)} can not match seq len ({seq_len})")
else:
return emo_seq
def __call__(self, aud_feat, idx, emo=None):
"""
aud_feat: [n, 1024]
idx: int, <0 means pad (first clip buffer)
"""
frame_num = len(aud_feat)
more_cond = [aud_feat]
if self.use_emo:
if emo is not None:
emo_seq = self._parse_emo_seq(emo, frame_num)
elif self.emo_seq is not None and len(self.emo_seq) == frame_num:
emo_seq = self.emo_seq
else:
emo_idx_list = [max(i, 0) % self.num_emo for i in range(idx, idx + frame_num)]
emo_seq = self.emo_lst[emo_idx_list]
more_cond.append(emo_seq)
if self.use_eye_open:
if self.eye_open_seq is not None and len(self.eye_open_seq) == frame_num:
eye_open_seq = self.eye_open_seq
else:
if self.eye_f0_mode:
eye_idx_list = [0] * frame_num
else:
eye_idx_list = [_mirror_index(max(i, 0), self.num_eye_open) for i in range(idx, idx + frame_num)]
eye_open_seq = self.eye_open_lst[eye_idx_list]
more_cond.append(eye_open_seq)
if self.use_eye_ball:
if self.eye_ball_seq is not None and len(self.eye_ball_seq) == frame_num:
eye_ball_seq = self.eye_ball_seq
else:
if self.eye_f0_mode:
eye_idx_list = [0] * frame_num
else:
eye_idx_list = [_mirror_index(max(i, 0), self.num_eye_ball) for i in range(idx, idx + frame_num)]
eye_ball_seq = self.eye_ball_lst[eye_idx_list]
more_cond.append(eye_ball_seq)
if self.use_sc:
if len(self.sc_seq) == frame_num:
sc_seq = self.sc_seq
else:
sc_seq = np.stack([self.sc] * frame_num, 0)
more_cond.append(sc_seq)
if len(more_cond) > 1:
cond_seq = np.concatenate(more_cond, -1) # [n, dim_cond]
else:
cond_seq = aud_feat
return cond_seq