oKen38461's picture
初回コミットに基づくファイルの追加
ac7cda5
raw
history blame
14.2 kB
import copy
import random
import numpy as np
from scipy.special import softmax
from ..models.stitch_network import StitchNetwork
"""
# __init__
stitch_network_cfg = {
"model_path": "",
"device": "cuda",
}
# __call__
kwargs:
fade_alpha
fade_out_keys
delta_pitch
delta_yaw
delta_roll
"""
def ctrl_motion(x_d_info, **kwargs):
# pose + offset
for kk in ["delta_pitch", "delta_yaw", "delta_roll"]:
if kk in kwargs:
k = kk[6:]
x_d_info[k] = bin66_to_degree(x_d_info[k]) + kwargs[kk]
# pose * alpha
for kk in ["alpha_pitch", "alpha_yaw", "alpha_roll"]:
if kk in kwargs:
k = kk[6:]
x_d_info[k] = x_d_info[k] * kwargs[kk]
# exp + offset
if "delta_exp" in kwargs:
k = "exp"
x_d_info[k] = x_d_info[k] + kwargs["delta_exp"]
return x_d_info
def fade(x_d_info, dst, alpha, keys=None):
if keys is None:
keys = x_d_info.keys()
for k in keys:
if k == 'kp':
continue
x_d_info[k] = x_d_info[k] * alpha + dst[k] * (1 - alpha)
return x_d_info
def ctrl_vad(x_d_info, dst, alpha):
exp = x_d_info["exp"]
exp_dst = dst["exp"]
_lip = [6, 12, 14, 17, 19, 20]
_a1 = np.zeros((21, 3), dtype=np.float32)
_a1[_lip] = alpha
_a1 = _a1.reshape(1, -1)
x_d_info["exp"] = exp * alpha + exp_dst * (1 - alpha)
return x_d_info
def _mix_s_d_info(
x_s_info,
x_d_info,
use_d_keys=("exp", "pitch", "yaw", "roll", "t"),
d0=None,
):
if d0 is not None:
if isinstance(use_d_keys, dict):
x_d_info = {
k: x_s_info[k] + (v - d0[k]) * use_d_keys.get(k, 1)
for k, v in x_d_info.items()
}
else:
x_d_info = {k: x_s_info[k] + (v - d0[k]) for k, v in x_d_info.items()}
for k, v in x_s_info.items():
if k not in x_d_info or k not in use_d_keys:
x_d_info[k] = v
if isinstance(use_d_keys, dict) and d0 is None:
for k, alpha in use_d_keys.items():
x_d_info[k] *= alpha
return x_d_info
def _set_eye_blink_idx(N, blink_n=15, open_n=-1):
"""
open_n:
-1: no blink
0: random open_n
>0: fix open_n
list: loop open_n
"""
OPEN_MIN = 60
OPEN_MAX = 100
idx = [0] * N
if isinstance(open_n, int):
if open_n < 0: # no blink
return idx
elif open_n > 0: # fix open_n
open_ns = [open_n]
else: # open_n == 0: # random open_n, 60-100
open_ns = []
elif isinstance(open_n, list):
open_ns = open_n # loop open_n
else:
raise ValueError()
blink_idx = list(range(blink_n))
start_n = open_ns[0] if open_ns else random.randint(OPEN_MIN, OPEN_MAX)
end_n = open_ns[-1] if open_ns else random.randint(OPEN_MIN, OPEN_MAX)
max_i = N - max(end_n, blink_n)
cur_i = start_n
cur_n_i = 1
while cur_i < max_i:
idx[cur_i : cur_i + blink_n] = blink_idx
if open_ns:
cur_n = open_ns[cur_n_i % len(open_ns)]
cur_n_i += 1
else:
cur_n = random.randint(OPEN_MIN, OPEN_MAX)
cur_i = cur_i + blink_n + cur_n
return idx
def _fix_exp_for_x_d_info(x_d_info, x_s_info, delta_eye=None, drive_eye=True):
_eye = [11, 13, 15, 16, 18]
_lip = [6, 12, 14, 17, 19, 20]
alpha = np.zeros((21, 3), dtype=x_d_info["exp"].dtype)
alpha[_lip] = 1
if delta_eye is None and drive_eye: # use d eye
alpha[_eye] = 1
alpha = alpha.reshape(1, -1)
x_d_info["exp"] = x_d_info["exp"] * alpha + x_s_info["exp"] * (1 - alpha)
if delta_eye is not None and drive_eye:
alpha = np.zeros((21, 3), dtype=x_d_info["exp"].dtype)
alpha[_eye] = 1
alpha = alpha.reshape(1, -1)
x_d_info["exp"] = (delta_eye + x_s_info["exp"]) * alpha + x_d_info["exp"] * (
1 - alpha
)
return x_d_info
def _fix_exp_for_x_d_info_v2(x_d_info, x_s_info, delta_eye, a1, a2, a3):
x_d_info["exp"] = x_d_info["exp"] * a1 + x_s_info["exp"] * a2 + delta_eye * a3
return x_d_info
def bin66_to_degree(pred):
if pred.ndim > 1 and pred.shape[1] == 66:
idx = np.arange(66).astype(np.float32)
pred = softmax(pred, axis=1)
degree = np.sum(pred * idx, axis=1) * 3 - 97.5
return degree
return pred
def _eye_delta(exp, dx=0, dy=0):
if dx > 0:
exp[0, 33] += dx * 0.0007
exp[0, 45] += dx * 0.001
else:
exp[0, 33] += dx * 0.001
exp[0, 45] += dx * 0.0007
exp[0, 34] += dy * -0.001
exp[0, 46] += dy * -0.001
return exp
def _fix_gaze(pose_s, x_d_info):
x_ratio = 0.26
y_ratio = 0.28
yaw_s, pitch_s = pose_s
yaw_d = bin66_to_degree(x_d_info['yaw']).item()
pitch_d = bin66_to_degree(x_d_info['pitch']).item()
delta_yaw = yaw_d - yaw_s
delta_pitch = pitch_d - pitch_s
dx = delta_yaw * x_ratio
dy = delta_pitch * y_ratio
x_d_info['exp'] = _eye_delta(x_d_info['exp'], dx, dy)
return x_d_info
def get_rotation_matrix(pitch_, yaw_, roll_):
""" the input is in degree
"""
# transform to radian
pitch = pitch_ / 180 * np.pi
yaw = yaw_ / 180 * np.pi
roll = roll_ / 180 * np.pi
if pitch.ndim == 1:
pitch = pitch[:, None]
if yaw.ndim == 1:
yaw = yaw[:, None]
if roll.ndim == 1:
roll = roll[:, None]
# calculate the euler matrix
bs = pitch.shape[0]
ones = np.ones((bs, 1), dtype=np.float32)
zeros = np.zeros((bs, 1), dtype=np.float32)
x, y, z = pitch, yaw, roll
rot_x = np.concatenate([
ones, zeros, zeros,
zeros, np.cos(x), -np.sin(x),
zeros, np.sin(x), np.cos(x)
], axis=1).reshape(bs, 3, 3)
rot_y = np.concatenate([
np.cos(y), zeros, np.sin(y),
zeros, ones, zeros,
-np.sin(y), zeros, np.cos(y)
], axis=1).reshape(bs, 3, 3)
rot_z = np.concatenate([
np.cos(z), -np.sin(z), zeros,
np.sin(z), np.cos(z), zeros,
zeros, zeros, ones
], axis=1).reshape(bs, 3, 3)
rot = np.matmul(np.matmul(rot_z, rot_y), rot_x)
return np.transpose(rot, (0, 2, 1))
def transform_keypoint(kp_info: dict):
"""
transform the implicit keypoints with the pose, shift, and expression deformation
kp: BxNx3
"""
kp = kp_info['kp'] # (bs, k, 3)
pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
t, exp = kp_info['t'], kp_info['exp']
scale = kp_info['scale']
pitch = bin66_to_degree(pitch)
yaw = bin66_to_degree(yaw)
roll = bin66_to_degree(roll)
bs = kp.shape[0]
if kp.ndim == 2:
num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
else:
num_kp = kp.shape[1] # Bxnum_kpx3
rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
# Eqn.2: s * (R * x_c,s + exp) + t
kp_transformed = np.matmul(kp.reshape(bs, num_kp, 3), rot_mat) + exp.reshape(bs, num_kp, 3)
kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
return kp_transformed
class MotionStitch:
def __init__(
self,
stitch_network_cfg,
):
self.stitch_net = StitchNetwork(**stitch_network_cfg)
def set_Nd(self, N_d=-1):
# only for offline (make start|end eye open)
if N_d == self.N_d:
return
self.N_d = N_d
if self.drive_eye and self.delta_eye_arr is not None:
N = 3000 if self.N_d == -1 else self.N_d
self.delta_eye_idx_list = _set_eye_blink_idx(
N, len(self.delta_eye_arr), self.delta_eye_open_n
)
def setup(
self,
N_d=-1,
use_d_keys=None,
relative_d=True,
drive_eye=None, # use d eye or s eye
delta_eye_arr=None, # fix eye
delta_eye_open_n=-1, # int|list
fade_out_keys=("exp",),
fade_type="", # "" | "d0" | "s"
flag_stitching=True,
is_image_flag=True,
x_s_info=None,
d0=None,
ch_info=None,
overall_ctrl_info=None,
):
self.is_image_flag = is_image_flag
if use_d_keys is None:
if self.is_image_flag:
self.use_d_keys = ("exp", "pitch", "yaw", "roll", "t")
else:
self.use_d_keys = ("exp", )
else:
self.use_d_keys = use_d_keys
if drive_eye is None:
if self.is_image_flag:
self.drive_eye = True
else:
self.drive_eye = False
else:
self.drive_eye = drive_eye
self.N_d = N_d
self.relative_d = relative_d
self.delta_eye_arr = delta_eye_arr
self.delta_eye_open_n = delta_eye_open_n
self.fade_out_keys = fade_out_keys
self.fade_type = fade_type
self.flag_stitching = flag_stitching
_eye = [11, 13, 15, 16, 18]
_lip = [6, 12, 14, 17, 19, 20]
_a1 = np.zeros((21, 3), dtype=np.float32)
_a1[_lip] = 1
_a2 = 0
if self.drive_eye:
if self.delta_eye_arr is None:
_a1[_eye] = 1
else:
_a2 = np.zeros((21, 3), dtype=np.float32)
_a2[_eye] = 1
_a2 = _a2.reshape(1, -1)
_a1 = _a1.reshape(1, -1)
self.fix_exp_a1 = _a1 * (1 - _a2)
self.fix_exp_a2 = (1 - _a1) + _a1 * _a2
self.fix_exp_a3 = _a2
if self.drive_eye and self.delta_eye_arr is not None:
N = 3000 if self.N_d == -1 else self.N_d
self.delta_eye_idx_list = _set_eye_blink_idx(
N, len(self.delta_eye_arr), self.delta_eye_open_n
)
self.pose_s = None
self.x_s = None
self.fade_dst = None
if self.is_image_flag and x_s_info is not None:
yaw_s = bin66_to_degree(x_s_info['yaw']).item()
pitch_s = bin66_to_degree(x_s_info['pitch']).item()
self.pose_s = [yaw_s, pitch_s]
self.x_s = transform_keypoint(x_s_info)
if self.fade_type == "s":
self.fade_dst = copy.deepcopy(x_s_info)
if ch_info is not None:
self.scale_a = ch_info['x_s_info_lst'][0]['scale'].item()
if x_s_info is not None:
self.scale_b = x_s_info['scale'].item()
self.scale_ratio = self.scale_a / self.scale_b
self._set_scale_ratio(self.scale_ratio)
else:
self.scale_ratio = None
else:
self.scale_ratio = 1
self.overall_ctrl_info = overall_ctrl_info
self.d0 = d0
self.idx = 0
def _set_scale_ratio(self, scale_ratio=1):
if scale_ratio == 1:
return
if isinstance(self.use_d_keys, dict):
self.use_d_keys = {k: v * (scale_ratio if k in {"exp", "pitch", "yaw", "roll"} else 1) for k, v in self.use_d_keys.items()}
else:
self.use_d_keys = {k: scale_ratio if k in {"exp", "pitch", "yaw", "roll"} else 1 for k in self.use_d_keys}
@staticmethod
def _merge_kwargs(default_kwargs, run_kwargs):
if default_kwargs is None:
return run_kwargs
for k, v in default_kwargs.items():
if k not in run_kwargs:
run_kwargs[k] = v
return run_kwargs
def __call__(self, x_s_info, x_d_info, **kwargs):
# return x_s, x_d
kwargs = self._merge_kwargs(self.overall_ctrl_info, kwargs)
if self.scale_ratio is None:
self.scale_b = x_s_info['scale'].item()
self.scale_ratio = self.scale_a / self.scale_b
self._set_scale_ratio(self.scale_ratio)
if self.relative_d and self.d0 is None:
self.d0 = copy.deepcopy(x_d_info)
x_d_info = _mix_s_d_info(
x_s_info,
x_d_info,
self.use_d_keys,
self.d0,
)
delta_eye = 0
if self.drive_eye and self.delta_eye_arr is not None:
delta_eye = self.delta_eye_arr[
self.delta_eye_idx_list[self.idx % len(self.delta_eye_idx_list)]
][None]
x_d_info = _fix_exp_for_x_d_info_v2(
x_d_info,
x_s_info,
delta_eye,
self.fix_exp_a1,
self.fix_exp_a2,
self.fix_exp_a3,
)
if kwargs.get("vad_alpha", 1) < 1:
x_d_info = ctrl_vad(x_d_info, x_s_info, kwargs.get("vad_alpha", 1))
x_d_info = ctrl_motion(x_d_info, **kwargs)
if self.fade_type == "d0" and self.fade_dst is None:
self.fade_dst = copy.deepcopy(x_d_info)
# fade
if "fade_alpha" in kwargs and self.fade_type in ["d0", "s"]:
fade_alpha = kwargs["fade_alpha"]
fade_keys = kwargs.get("fade_out_keys", self.fade_out_keys)
if self.fade_type == "d0":
fade_dst = self.fade_dst
elif self.fade_type == "s":
if self.fade_dst is not None:
fade_dst = self.fade_dst
else:
fade_dst = copy.deepcopy(x_s_info)
if self.is_image_flag:
self.fade_dst = fade_dst
x_d_info = fade(x_d_info, fade_dst, fade_alpha, fade_keys)
if self.drive_eye:
if self.pose_s is None:
yaw_s = bin66_to_degree(x_s_info['yaw']).item()
pitch_s = bin66_to_degree(x_s_info['pitch']).item()
self.pose_s = [yaw_s, pitch_s]
x_d_info = _fix_gaze(self.pose_s, x_d_info)
if self.x_s is not None:
x_s = self.x_s
else:
x_s = transform_keypoint(x_s_info)
if self.is_image_flag:
self.x_s = x_s
x_d = transform_keypoint(x_d_info)
if self.flag_stitching:
x_d = self.stitch_net(x_s, x_d)
self.idx += 1
return x_s, x_d