Spaces:
Runtime error
Runtime error
import copy | |
import random | |
import numpy as np | |
from scipy.special import softmax | |
from ..models.stitch_network import StitchNetwork | |
""" | |
# __init__ | |
stitch_network_cfg = { | |
"model_path": "", | |
"device": "cuda", | |
} | |
# __call__ | |
kwargs: | |
fade_alpha | |
fade_out_keys | |
delta_pitch | |
delta_yaw | |
delta_roll | |
""" | |
def ctrl_motion(x_d_info, **kwargs): | |
# pose + offset | |
for kk in ["delta_pitch", "delta_yaw", "delta_roll"]: | |
if kk in kwargs: | |
k = kk[6:] | |
x_d_info[k] = bin66_to_degree(x_d_info[k]) + kwargs[kk] | |
# pose * alpha | |
for kk in ["alpha_pitch", "alpha_yaw", "alpha_roll"]: | |
if kk in kwargs: | |
k = kk[6:] | |
x_d_info[k] = x_d_info[k] * kwargs[kk] | |
# exp + offset | |
if "delta_exp" in kwargs: | |
k = "exp" | |
x_d_info[k] = x_d_info[k] + kwargs["delta_exp"] | |
return x_d_info | |
def fade(x_d_info, dst, alpha, keys=None): | |
if keys is None: | |
keys = x_d_info.keys() | |
for k in keys: | |
if k == 'kp': | |
continue | |
x_d_info[k] = x_d_info[k] * alpha + dst[k] * (1 - alpha) | |
return x_d_info | |
def ctrl_vad(x_d_info, dst, alpha): | |
exp = x_d_info["exp"] | |
exp_dst = dst["exp"] | |
_lip = [6, 12, 14, 17, 19, 20] | |
_a1 = np.zeros((21, 3), dtype=np.float32) | |
_a1[_lip] = alpha | |
_a1 = _a1.reshape(1, -1) | |
x_d_info["exp"] = exp * alpha + exp_dst * (1 - alpha) | |
return x_d_info | |
def _mix_s_d_info( | |
x_s_info, | |
x_d_info, | |
use_d_keys=("exp", "pitch", "yaw", "roll", "t"), | |
d0=None, | |
): | |
if d0 is not None: | |
if isinstance(use_d_keys, dict): | |
x_d_info = { | |
k: x_s_info[k] + (v - d0[k]) * use_d_keys.get(k, 1) | |
for k, v in x_d_info.items() | |
} | |
else: | |
x_d_info = {k: x_s_info[k] + (v - d0[k]) for k, v in x_d_info.items()} | |
for k, v in x_s_info.items(): | |
if k not in x_d_info or k not in use_d_keys: | |
x_d_info[k] = v | |
if isinstance(use_d_keys, dict) and d0 is None: | |
for k, alpha in use_d_keys.items(): | |
x_d_info[k] *= alpha | |
return x_d_info | |
def _set_eye_blink_idx(N, blink_n=15, open_n=-1): | |
""" | |
open_n: | |
-1: no blink | |
0: random open_n | |
>0: fix open_n | |
list: loop open_n | |
""" | |
OPEN_MIN = 60 | |
OPEN_MAX = 100 | |
idx = [0] * N | |
if isinstance(open_n, int): | |
if open_n < 0: # no blink | |
return idx | |
elif open_n > 0: # fix open_n | |
open_ns = [open_n] | |
else: # open_n == 0: # random open_n, 60-100 | |
open_ns = [] | |
elif isinstance(open_n, list): | |
open_ns = open_n # loop open_n | |
else: | |
raise ValueError() | |
blink_idx = list(range(blink_n)) | |
start_n = open_ns[0] if open_ns else random.randint(OPEN_MIN, OPEN_MAX) | |
end_n = open_ns[-1] if open_ns else random.randint(OPEN_MIN, OPEN_MAX) | |
max_i = N - max(end_n, blink_n) | |
cur_i = start_n | |
cur_n_i = 1 | |
while cur_i < max_i: | |
idx[cur_i : cur_i + blink_n] = blink_idx | |
if open_ns: | |
cur_n = open_ns[cur_n_i % len(open_ns)] | |
cur_n_i += 1 | |
else: | |
cur_n = random.randint(OPEN_MIN, OPEN_MAX) | |
cur_i = cur_i + blink_n + cur_n | |
return idx | |
def _fix_exp_for_x_d_info(x_d_info, x_s_info, delta_eye=None, drive_eye=True): | |
_eye = [11, 13, 15, 16, 18] | |
_lip = [6, 12, 14, 17, 19, 20] | |
alpha = np.zeros((21, 3), dtype=x_d_info["exp"].dtype) | |
alpha[_lip] = 1 | |
if delta_eye is None and drive_eye: # use d eye | |
alpha[_eye] = 1 | |
alpha = alpha.reshape(1, -1) | |
x_d_info["exp"] = x_d_info["exp"] * alpha + x_s_info["exp"] * (1 - alpha) | |
if delta_eye is not None and drive_eye: | |
alpha = np.zeros((21, 3), dtype=x_d_info["exp"].dtype) | |
alpha[_eye] = 1 | |
alpha = alpha.reshape(1, -1) | |
x_d_info["exp"] = (delta_eye + x_s_info["exp"]) * alpha + x_d_info["exp"] * ( | |
1 - alpha | |
) | |
return x_d_info | |
def _fix_exp_for_x_d_info_v2(x_d_info, x_s_info, delta_eye, a1, a2, a3): | |
x_d_info["exp"] = x_d_info["exp"] * a1 + x_s_info["exp"] * a2 + delta_eye * a3 | |
return x_d_info | |
def bin66_to_degree(pred): | |
if pred.ndim > 1 and pred.shape[1] == 66: | |
idx = np.arange(66).astype(np.float32) | |
pred = softmax(pred, axis=1) | |
degree = np.sum(pred * idx, axis=1) * 3 - 97.5 | |
return degree | |
return pred | |
def _eye_delta(exp, dx=0, dy=0): | |
if dx > 0: | |
exp[0, 33] += dx * 0.0007 | |
exp[0, 45] += dx * 0.001 | |
else: | |
exp[0, 33] += dx * 0.001 | |
exp[0, 45] += dx * 0.0007 | |
exp[0, 34] += dy * -0.001 | |
exp[0, 46] += dy * -0.001 | |
return exp | |
def _fix_gaze(pose_s, x_d_info): | |
x_ratio = 0.26 | |
y_ratio = 0.28 | |
yaw_s, pitch_s = pose_s | |
yaw_d = bin66_to_degree(x_d_info['yaw']).item() | |
pitch_d = bin66_to_degree(x_d_info['pitch']).item() | |
delta_yaw = yaw_d - yaw_s | |
delta_pitch = pitch_d - pitch_s | |
dx = delta_yaw * x_ratio | |
dy = delta_pitch * y_ratio | |
x_d_info['exp'] = _eye_delta(x_d_info['exp'], dx, dy) | |
return x_d_info | |
def get_rotation_matrix(pitch_, yaw_, roll_): | |
""" the input is in degree | |
""" | |
# transform to radian | |
pitch = pitch_ / 180 * np.pi | |
yaw = yaw_ / 180 * np.pi | |
roll = roll_ / 180 * np.pi | |
if pitch.ndim == 1: | |
pitch = pitch[:, None] | |
if yaw.ndim == 1: | |
yaw = yaw[:, None] | |
if roll.ndim == 1: | |
roll = roll[:, None] | |
# calculate the euler matrix | |
bs = pitch.shape[0] | |
ones = np.ones((bs, 1), dtype=np.float32) | |
zeros = np.zeros((bs, 1), dtype=np.float32) | |
x, y, z = pitch, yaw, roll | |
rot_x = np.concatenate([ | |
ones, zeros, zeros, | |
zeros, np.cos(x), -np.sin(x), | |
zeros, np.sin(x), np.cos(x) | |
], axis=1).reshape(bs, 3, 3) | |
rot_y = np.concatenate([ | |
np.cos(y), zeros, np.sin(y), | |
zeros, ones, zeros, | |
-np.sin(y), zeros, np.cos(y) | |
], axis=1).reshape(bs, 3, 3) | |
rot_z = np.concatenate([ | |
np.cos(z), -np.sin(z), zeros, | |
np.sin(z), np.cos(z), zeros, | |
zeros, zeros, ones | |
], axis=1).reshape(bs, 3, 3) | |
rot = np.matmul(np.matmul(rot_z, rot_y), rot_x) | |
return np.transpose(rot, (0, 2, 1)) | |
def transform_keypoint(kp_info: dict): | |
""" | |
transform the implicit keypoints with the pose, shift, and expression deformation | |
kp: BxNx3 | |
""" | |
kp = kp_info['kp'] # (bs, k, 3) | |
pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll'] | |
t, exp = kp_info['t'], kp_info['exp'] | |
scale = kp_info['scale'] | |
pitch = bin66_to_degree(pitch) | |
yaw = bin66_to_degree(yaw) | |
roll = bin66_to_degree(roll) | |
bs = kp.shape[0] | |
if kp.ndim == 2: | |
num_kp = kp.shape[1] // 3 # Bx(num_kpx3) | |
else: | |
num_kp = kp.shape[1] # Bxnum_kpx3 | |
rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3) | |
# Eqn.2: s * (R * x_c,s + exp) + t | |
kp_transformed = np.matmul(kp.reshape(bs, num_kp, 3), rot_mat) + exp.reshape(bs, num_kp, 3) | |
kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3) | |
kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty | |
return kp_transformed | |
class MotionStitch: | |
def __init__( | |
self, | |
stitch_network_cfg, | |
): | |
self.stitch_net = StitchNetwork(**stitch_network_cfg) | |
def set_Nd(self, N_d=-1): | |
# only for offline (make start|end eye open) | |
if N_d == self.N_d: | |
return | |
self.N_d = N_d | |
if self.drive_eye and self.delta_eye_arr is not None: | |
N = 3000 if self.N_d == -1 else self.N_d | |
self.delta_eye_idx_list = _set_eye_blink_idx( | |
N, len(self.delta_eye_arr), self.delta_eye_open_n | |
) | |
def setup( | |
self, | |
N_d=-1, | |
use_d_keys=None, | |
relative_d=True, | |
drive_eye=None, # use d eye or s eye | |
delta_eye_arr=None, # fix eye | |
delta_eye_open_n=-1, # int|list | |
fade_out_keys=("exp",), | |
fade_type="", # "" | "d0" | "s" | |
flag_stitching=True, | |
is_image_flag=True, | |
x_s_info=None, | |
d0=None, | |
ch_info=None, | |
overall_ctrl_info=None, | |
): | |
self.is_image_flag = is_image_flag | |
if use_d_keys is None: | |
if self.is_image_flag: | |
self.use_d_keys = ("exp", "pitch", "yaw", "roll", "t") | |
else: | |
self.use_d_keys = ("exp", ) | |
else: | |
self.use_d_keys = use_d_keys | |
if drive_eye is None: | |
if self.is_image_flag: | |
self.drive_eye = True | |
else: | |
self.drive_eye = False | |
else: | |
self.drive_eye = drive_eye | |
self.N_d = N_d | |
self.relative_d = relative_d | |
self.delta_eye_arr = delta_eye_arr | |
self.delta_eye_open_n = delta_eye_open_n | |
self.fade_out_keys = fade_out_keys | |
self.fade_type = fade_type | |
self.flag_stitching = flag_stitching | |
_eye = [11, 13, 15, 16, 18] | |
_lip = [6, 12, 14, 17, 19, 20] | |
_a1 = np.zeros((21, 3), dtype=np.float32) | |
_a1[_lip] = 1 | |
_a2 = 0 | |
if self.drive_eye: | |
if self.delta_eye_arr is None: | |
_a1[_eye] = 1 | |
else: | |
_a2 = np.zeros((21, 3), dtype=np.float32) | |
_a2[_eye] = 1 | |
_a2 = _a2.reshape(1, -1) | |
_a1 = _a1.reshape(1, -1) | |
self.fix_exp_a1 = _a1 * (1 - _a2) | |
self.fix_exp_a2 = (1 - _a1) + _a1 * _a2 | |
self.fix_exp_a3 = _a2 | |
if self.drive_eye and self.delta_eye_arr is not None: | |
N = 3000 if self.N_d == -1 else self.N_d | |
self.delta_eye_idx_list = _set_eye_blink_idx( | |
N, len(self.delta_eye_arr), self.delta_eye_open_n | |
) | |
self.pose_s = None | |
self.x_s = None | |
self.fade_dst = None | |
if self.is_image_flag and x_s_info is not None: | |
yaw_s = bin66_to_degree(x_s_info['yaw']).item() | |
pitch_s = bin66_to_degree(x_s_info['pitch']).item() | |
self.pose_s = [yaw_s, pitch_s] | |
self.x_s = transform_keypoint(x_s_info) | |
if self.fade_type == "s": | |
self.fade_dst = copy.deepcopy(x_s_info) | |
if ch_info is not None: | |
self.scale_a = ch_info['x_s_info_lst'][0]['scale'].item() | |
if x_s_info is not None: | |
self.scale_b = x_s_info['scale'].item() | |
self.scale_ratio = self.scale_a / self.scale_b | |
self._set_scale_ratio(self.scale_ratio) | |
else: | |
self.scale_ratio = None | |
else: | |
self.scale_ratio = 1 | |
self.overall_ctrl_info = overall_ctrl_info | |
self.d0 = d0 | |
self.idx = 0 | |
def _set_scale_ratio(self, scale_ratio=1): | |
if scale_ratio == 1: | |
return | |
if isinstance(self.use_d_keys, dict): | |
self.use_d_keys = {k: v * (scale_ratio if k in {"exp", "pitch", "yaw", "roll"} else 1) for k, v in self.use_d_keys.items()} | |
else: | |
self.use_d_keys = {k: scale_ratio if k in {"exp", "pitch", "yaw", "roll"} else 1 for k in self.use_d_keys} | |
def _merge_kwargs(default_kwargs, run_kwargs): | |
if default_kwargs is None: | |
return run_kwargs | |
for k, v in default_kwargs.items(): | |
if k not in run_kwargs: | |
run_kwargs[k] = v | |
return run_kwargs | |
def __call__(self, x_s_info, x_d_info, **kwargs): | |
# return x_s, x_d | |
kwargs = self._merge_kwargs(self.overall_ctrl_info, kwargs) | |
if self.scale_ratio is None: | |
self.scale_b = x_s_info['scale'].item() | |
self.scale_ratio = self.scale_a / self.scale_b | |
self._set_scale_ratio(self.scale_ratio) | |
if self.relative_d and self.d0 is None: | |
self.d0 = copy.deepcopy(x_d_info) | |
x_d_info = _mix_s_d_info( | |
x_s_info, | |
x_d_info, | |
self.use_d_keys, | |
self.d0, | |
) | |
delta_eye = 0 | |
if self.drive_eye and self.delta_eye_arr is not None: | |
delta_eye = self.delta_eye_arr[ | |
self.delta_eye_idx_list[self.idx % len(self.delta_eye_idx_list)] | |
][None] | |
x_d_info = _fix_exp_for_x_d_info_v2( | |
x_d_info, | |
x_s_info, | |
delta_eye, | |
self.fix_exp_a1, | |
self.fix_exp_a2, | |
self.fix_exp_a3, | |
) | |
if kwargs.get("vad_alpha", 1) < 1: | |
x_d_info = ctrl_vad(x_d_info, x_s_info, kwargs.get("vad_alpha", 1)) | |
x_d_info = ctrl_motion(x_d_info, **kwargs) | |
if self.fade_type == "d0" and self.fade_dst is None: | |
self.fade_dst = copy.deepcopy(x_d_info) | |
# fade | |
if "fade_alpha" in kwargs and self.fade_type in ["d0", "s"]: | |
fade_alpha = kwargs["fade_alpha"] | |
fade_keys = kwargs.get("fade_out_keys", self.fade_out_keys) | |
if self.fade_type == "d0": | |
fade_dst = self.fade_dst | |
elif self.fade_type == "s": | |
if self.fade_dst is not None: | |
fade_dst = self.fade_dst | |
else: | |
fade_dst = copy.deepcopy(x_s_info) | |
if self.is_image_flag: | |
self.fade_dst = fade_dst | |
x_d_info = fade(x_d_info, fade_dst, fade_alpha, fade_keys) | |
if self.drive_eye: | |
if self.pose_s is None: | |
yaw_s = bin66_to_degree(x_s_info['yaw']).item() | |
pitch_s = bin66_to_degree(x_s_info['pitch']).item() | |
self.pose_s = [yaw_s, pitch_s] | |
x_d_info = _fix_gaze(self.pose_s, x_d_info) | |
if self.x_s is not None: | |
x_s = self.x_s | |
else: | |
x_s = transform_keypoint(x_s_info) | |
if self.is_image_flag: | |
self.x_s = x_s | |
x_d = transform_keypoint(x_d_info) | |
if self.flag_stitching: | |
x_d = self.stitch_net(x_s, x_d) | |
self.idx += 1 | |
return x_s, x_d | |