multimodalart's picture
Upload 247 files
7758cff verified
raw
history blame
67.7 kB
"""
Motion feature extractor
"""
import os
import os.path as osp
import sys
import pickle
from omegaconf import OmegaConf
import torch
from PIL import Image
import numpy as np
import cv2
import imageio
import pickle
import time
from decord import VideoReader # must after import torch
from rich.progress import track
sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))))))
from src.datasets.preprocess.extract_features.face_segmentation import build_face_parser, get_face_mask, vis_parsing_maps
from src.thirdparty.liveportrait.src.utils.helper import load_model, concat_feat
from src.thirdparty.liveportrait.src.utils.io import load_image_rgb, resize_to_limit, load_video
from src.thirdparty.liveportrait.src.utils.video import get_fps, images2video, add_audio_to_video
from src.thirdparty.liveportrait.src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from src.thirdparty.liveportrait.src.utils.cropper import Cropper
from src.thirdparty.liveportrait.src.utils.crop import prepare_paste_back, paste_back, paste_back_with_face_mask
from src.thirdparty.liveportrait.src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.thirdparty.liveportrait.src.utils.helper import mkdir, basename, dct2device, is_image, calc_motion_multiplier
from src.utils.filter import smooth as ksmooth
from src.utils.filter import smooth_
from skimage.metrics import peak_signal_noise_ratio
import warnings
def psnr(imgs1, imgs2):
psnrs = []
for img1, img2 in zip(imgs1, imgs2):
psnr = peak_signal_noise_ratio(img1, img2, data_range=255)
psnrs.append(psnr)
return psnrs
def suffix(filename):
"""a.jpg -> jpg"""
pos = filename.rfind(".")
if pos == -1:
return ""
return filename[pos + 1:]
def dump(wfp, obj):
wd = osp.split(wfp)[0]
if wd != "" and not osp.exists(wd):
mkdir(wd)
_suffix = suffix(wfp)
if _suffix == "npy":
np.save(wfp, obj)
elif _suffix == "pkl":
pickle.dump(obj, open(wfp, "wb"))
else:
raise Exception("Unknown type: {}".format(_suffix))
def load(fp):
suffix_ = suffix(fp)
if suffix_ == "npy":
return np.load(fp)
elif suffix_ == "pkl":
return pickle.load(open(fp, "rb"))
else:
raise Exception(f"Unknown type: {suffix}")
def remove_suffix(filepath):
"""a/b/c.jpg -> a/b/c"""
return osp.join(osp.dirname(filepath), basename(filepath))
class MotionProcesser(object):
def __init__(self, cfg_path, device_id=0) -> None:
device = f"cuda:{device_id}"
cfg = OmegaConf.load(cfg_path)
print(f"Load cfg from {osp.realpath(cfg_path)} done.")
print(f"=============================== Driven CFG ===============================")
print(OmegaConf.to_yaml(cfg))
print(f"=============================== ========== ===============================")
models_config = OmegaConf.load(cfg.models_config)
# 1. init appearance feature extractor
self.appearance_feature_extractor = load_model(
cfg.appearance_feature_extractor_path,
models_config,
device,
'appearance_feature_extractor'
)
print(f'1. Load appearance_feature_extractor from {osp.realpath(cfg.appearance_feature_extractor_path)} done.')
# 2. # init motion extractor
self.motion_extractor = load_model(
cfg.motion_extractor_path,
models_config,
device,
'motion_extractor'
)
print(f'2. Load motion_extractor from {osp.realpath(cfg.motion_extractor_path)} done.')
# 3. init S and R
if cfg.stitching_retargeting_module_path is not None and osp.exists(cfg.stitching_retargeting_module_path):
self.stitching_retargeting_module = load_model(
cfg.stitching_retargeting_module_path,
models_config,
device,
'stitching_retargeting_module'
)
print(f'3. Load stitching_retargeting_module from {osp.realpath(cfg.stitching_retargeting_module_path)} done.')
else:
self.stitching_retargeting_module = None
# 4. init motion warper
self.warping_module = load_model(
cfg.warping_module_path,
models_config,
device,
'warping_module'
)
print(f"4. Load warping_module from {osp.realpath(cfg.warping_module_path)} done.")
# 5. init decoder
self.spade_generator = load_model(
cfg.spade_generator_path,
models_config,
device,
'spade_generator'
)
print(f"Load generator from {osp.realpath(cfg.spade_generator_path)} done.")
# # Optimize for inference
self.compile = cfg.flag_do_torch_compile
if self.compile:
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
self.warping_module = torch.compile(self.warping_module, mode='max-autotune')
self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')
# 6. init cropper
crop_cfg = OmegaConf.load(cfg.crop_cfg)
self.cropper = Cropper(crop_cfg=crop_cfg, image_type="human_face", device_id=device_id)
self.cfg = cfg
self.models_config = models_config
self.device = device
# 7. load crop mask
self.mask_crop = cv2.imread(cfg.mask_crop, cv2.IMREAD_COLOR)
# 8. load lib array
with open(cfg.lip_array, 'rb') as f:
self.lip_array = pickle.load(f)
# 9. load face parser
self.face_parser, self.to_tensor = build_face_parser(weight_path=cfg.face_parser_weight_path, resnet_weight_path=cfg.resnet_weight_path, device_id=device_id)
def inference_ctx(self):
ctx = torch.autocast(device_type=self.device[:4], dtype=torch.float16,
enabled=self.cfg.flag_use_half_precision)
return ctx
@torch.no_grad()
def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
""" get the appearance feature of the image by F
x: Bx3xHxW, normalized to 0~1
"""
with self.inference_ctx():
feature_3d = self.appearance_feature_extractor(x)
return feature_3d.float()
@torch.no_grad()
def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
""" get the implicit keypoint information
x: Bx3xHxW, normalized to 0~1
flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
"""
with self.inference_ctx():
kp_info = self.motion_extractor(x)
if self.cfg.flag_use_half_precision:
# float the dict
for k, v in kp_info.items():
if isinstance(v, torch.Tensor):
kp_info[k] = v.float()
return kp_info
@torch.no_grad()
def refine_kp(self, kp_info):
bs = kp_info['exp'].shape[0]
kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1
kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1
kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1
kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3
if 'kp' in kp_info.keys():
kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3
return kp_info
@torch.no_grad()
def transform_keypoint(self, kp_info: dict):
"""
transform the implicit keypoints with the pose, shift, and expression deformation
kp: BxNx3
"""
kp = kp_info['kp'] # (bs, k, 3)
pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
t, exp = kp_info['t'], kp_info['exp']
scale = kp_info['scale']
pitch = headpose_pred_to_degree(pitch)
yaw = headpose_pred_to_degree(yaw)
roll = headpose_pred_to_degree(roll)
bs = kp.shape[0]
if kp.ndim == 2:
num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
else:
num_kp = kp.shape[1] # Bxnum_kpx3
rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
# Eqn.2: s * (R * x_c,s + exp) + t
kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
return kp_transformed
@torch.no_grad()
def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
""" conduct the stitching
kp_source: Bxnum_kpx3
kp_driving: Bxnum_kpx3
"""
if self.stitching_retargeting_module is not None:
bs, num_kp = kp_source.shape[:2]
kp_driving_new = kp_driving.clone()
# stich
feat_stiching = concat_feat(kp_source, kp_driving_new)
delta = self.stitching_retargeting_module['stitching'](feat_stiching) # Bxnum_kpx3
delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
kp_driving_new += delta_exp
kp_driving_new[..., :2] += delta_tx_ty
return kp_driving_new
return kp_driving
@torch.no_grad()
def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> dict[str, torch.Tensor]:
""" get the image after the warping of the implicit keypoints
feature_3d: Bx32x16x64x64, feature volume
kp_source: BxNx3
kp_driving: BxNx3
"""
# The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
with self.inference_ctx():
if self.compile:
# Mark the beginning of a new CUDA Graph step
torch.compiler.cudagraph_mark_step_begin()
# get decoder input
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
# print(f"=============================================================================")
# for out_key, out_value in ret_dct.items():
# if isinstance(out_value, str) or isinstance(out_value, int) or isinstance(out_value, float):
# print(f"{out_key}: {out_value}")
# elif isinstance(out_value, torch.Tensor):
# print(f"{out_key}: tensor shape {out_value.shape}, min: {torch.min(out_value)}, max: {torch.max(out_value)}, mean: {torch.mean(out_value)}, std: {torch.std(out_value)}")
# else:
# print(f"{out_key}: data type {type(out_value)}")
# decode
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
# float the dict
if self.cfg.flag_use_half_precision:
for k, v in ret_dct.items():
if isinstance(v, torch.Tensor):
ret_dct[k] = v.float()
return ret_dct
def parse_output(self, out: torch.Tensor) -> np.ndarray:
""" construct the output as standard
return: 1xHxWx3, uint8
"""
out = np.transpose(out.cpu().numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3
out = np.clip(out, 0, 1) # clip to 0~1
out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255
return out
@torch.no_grad()
def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk):
c_s_eyes = calc_eye_close_ratio(source_lmk[None])
c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device)
c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(self.device)
# [c_s,eyes, c_d,eyes,i]
combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1)
return combined_eye_ratio_tensor
@torch.no_grad()
def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk):
c_s_lip = calc_lip_close_ratio(source_lmk[None])
c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device)
c_d_lip_i_tensor = torch.Tensor([c_d_lip_i[0]]).to(self.device).reshape(1, 1) # 1x1
# [c_s,lip, c_d,lip,i]
combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2
return combined_lip_ratio_tensor
def calc_ratio(self, lmk_lst):
input_eye_ratio_lst = []
input_lip_ratio_lst = []
for lmk in lmk_lst:
# for eyes retargeting
input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
# for lip retargeting
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
return input_eye_ratio_lst, input_lip_ratio_lst
@torch.no_grad()
def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
"""
kp_source: BxNx3
lip_close_ratio: Bx2
Return: Bx(3*num_kp)
"""
feat_lip = concat_feat(kp_source, lip_close_ratio)
delta = self.stitching_retargeting_module['lip'](feat_lip)
return delta.reshape(-1, kp_source.shape[1], 3)
@torch.no_grad()
def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor:
"""
kp_source: BxNx3
eye_close_ratio: Bx3
Return: Bx(3*num_kp)
"""
feat_eye = concat_feat(kp_source, eye_close_ratio)
delta = self.stitching_retargeting_module['eye'](feat_eye)
return delta.reshape(-1, kp_source.shape[1], 3)
def crop_image(self, img, do_crop=False):
######## process source info ########
if do_crop:
crop_info = self.cropper.crop_source_image(img, self.cropper.crop_cfg)
if crop_info is None:
raise Exception("No face detected in the source image!")
lmk = crop_info['lmk_crop']
img_crop_256x256 = crop_info['img_crop_256x256']
else:
crop_info = None
lmk = self.cropper.calc_lmk_from_cropped_image(img)
img_crop_256x256 = cv2.resize(img, (256, 256)) # force to resize to 256x256
return img_crop_256x256, lmk, crop_info
def crop_source_video(self, img_lst, do_crop=False):
if do_crop:
ret_s = self.cropper.crop_source_video(img_lst, self.cropper.crop_cfg)
print(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
img_crop_256x256_lst, lmk_crop_lst, M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
else:
M_c2o_lst = None
lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(img_lst)
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in img_lst] # force to resize to 256x256
return img_crop_256x256_lst, lmk_crop_lst, M_c2o_lst
def crop_driving_videos(self, img_lst, do_crop=False):
if do_crop:
ret_d = self.cropper.crop_driving_video(img_lst)
print(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
img_crop_lst, lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst']
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in img_lst]
else:
lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(img_lst)
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in img_lst] # force to resize to 256x256
return img_crop_256x256_lst, lmk_crop_lst
def prepare_source(self, src_img):
""" construct the input as standard
img: HxWx3, uint8, 256x256
"""
# processing source image to tensor
h, w = src_img.shape[:2]
if h != self.cfg.input_height or w != self.cfg.input_width:
x = cv2.resize(src_img, (self.cfg.input_width, self.cfg.input_height))
else:
x = src_img.copy()
if x.ndim == 3:
x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
elif x.ndim == 4:
x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
else:
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
x = np.clip(x, 0, 1) # clip to 0~1
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
x = x.to(self.device)
# extract features
I_s = x
f_s = self.extract_feature_3d(I_s)
x_s_info = self.get_kp_info(I_s)
return f_s, x_s_info
def process_clips(self, clips):
""" construct the input as standard
clips: NxBxHxWx3, uint8
"""
# resize to 256 x 256
imgs = []
for img in clips:
h, w = img.shape[:2]
if h != self.cfg.input_height or w != self.cfg.input_width:
img = cv2.resize(img, (self.cfg.input_width, self.cfg.input_height))
else:
img = img.copy()
imgs.append(img)
# processing video frames to tensor
if isinstance(imgs, list):
_imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
elif isinstance(imgs, np.ndarray):
_imgs = imgs
else:
raise ValueError(f'imgs type error: {type(imgs)}')
y = _imgs.astype(np.float32) / 255.
y = np.clip(y, 0, 1) # clip to 0~1
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
y = y.to(self.device)
return y
def prepare_driving_videos(self, vid_frames, feat_type="tensor"):
""" get driving kp infos
vid_frames: image list of HxWx3, uint8
"""
# extract features
total_len = len(vid_frames)
kp_infos = {"pitch": [], "yaw": [], "roll": [], "t": [], "exp": [], "scale": [], "kp": []}
for start_idx in range(0, total_len, self.cfg.batch_size):
frames = vid_frames[start_idx: min(start_idx + self.cfg.batch_size, total_len)]
frames = self.process_clips(frames).squeeze(1)
kp_info = self.get_kp_info(frames)
for k, v in kp_info.items():
kp_infos[k].append(v)
# combine the kp_infos
for k, v in kp_infos.items():
kp_infos[k] = torch.cat(v, dim=0)
if feat_type == "np":
for k, v in kp_infos.items():
kp_infos[k] = v.cpu().numpy()
return kp_infos
def get_driving_template(self, kp_infos, smooth=False, dtype="pt_tensor"):
kp_infos = self.refine_kp(kp_infos)
motion_list = []
n_frames = len(kp_infos["exp"])
for idx in range(n_frames):
exp = kp_infos["exp"][idx]
scale = kp_infos["scale"][idx]
t = kp_infos["t"][idx]
pitch = kp_infos["pitch"][idx]
yaw = kp_infos["yaw"][idx]
roll = kp_infos["roll"][idx]
R = get_rotation_matrix(pitch, yaw, roll)
R = R.reshape(1, 3, 3)
exp = exp.reshape(1, 21, 3)
scale = scale.reshape(1, 1)
t = t.reshape(1, 3)
pitch = pitch.reshape(1, 1)
yaw = yaw.reshape(1, 1)
roll = roll.reshape(1, 1)
if dtype == "np":
R = R.cpu().numpy().astype(np.float32)
exp = exp.cpu().numpy().astype(np.float32)
scale = scale.cpu().numpy().astype(np.float32)
t = t.cpu().numpy().astype(np.float32)
pitch = pitch.cpu().numpy().astype(np.float32)
yaw = yaw.cpu().numpy().astype(np.float32)
roll = roll.cpu().numpy().astype(np.float32)
motion_list.append(
{"exp": exp, "scale": scale, "R": R, "t": t, "pitch": pitch, "yaw": yaw, "roll": roll}
)
tgt_motion = {'n_frames': n_frames, 'output_fps': 25, 'motion': motion_list}
if smooth:
print("Smoothing motion sequence...")
tgt_motion = smooth_(tgt_motion, method="ema")
return tgt_motion
@torch.no_grad()
def update_delta_new_eyeball_direction(self, eyeball_direction_x, eyeball_direction_y, delta_new, **kwargs):
if eyeball_direction_x > 0:
delta_new[0, 11, 0] += eyeball_direction_x * 0.0007
delta_new[0, 15, 0] += eyeball_direction_x * 0.001
else:
delta_new[0, 11, 0] += eyeball_direction_x * 0.001
delta_new[0, 15, 0] += eyeball_direction_x * 0.0007
delta_new[0, 11, 1] += eyeball_direction_y * -0.001
delta_new[0, 15, 1] += eyeball_direction_y * -0.001
blink = -eyeball_direction_y / 2.
delta_new[0, 11, 1] += blink * -0.001
delta_new[0, 13, 1] += blink * 0.0003
delta_new[0, 15, 1] += blink * -0.001
delta_new[0, 16, 1] += blink * 0.0003
return delta_new
def driven(self, f_s, x_s_info, s_lmk, c_s_eyes_lst, kp_infos, c_d_eyes_lst=None, c_d_lip_lst=None, smooth=False):
# source kp info
x_d_i_news=[]
x_ss=[]
f_ss=[]
x_s_info = self.refine_kp(x_s_info)
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
x_s = self.transform_keypoint(x_s_info)
x_c_s = x_s_info["kp"]
# driving kp infos
driving_template_dct = self.get_driving_template(kp_infos, smooth)
n_frames = driving_template_dct['n_frames']
# driving params
flag_normalize_lip = self.cfg.flag_normalize_lip
flag_relative_motion = self.cfg.flag_relative_motion
flag_source_video_eye_retargeting = self.cfg.flag_source_video_eye_retargeting
lip_normalize_threshold = self.cfg.lip_normalize_threshold
source_video_eye_retargeting_threshold = self.cfg.source_video_eye_retargeting_threshold
animation_region = self.cfg.animation_region
driving_option = self.cfg.driving_option
flag_stitching = self.cfg.flag_stitching
flag_eye_retargeting = self.cfg.flag_eye_retargeting
flag_lip_retargeting = self.cfg.flag_lip_retargeting
driving_multiplier = self.cfg.driving_multiplier
lib_multiplier = self.cfg.lib_multiplier
# let lip-open scalar to be 0 at first
lip_delta_before_animation, eye_delta_before_animation = None, None
if flag_normalize_lip and flag_relative_motion and s_lmk is not None:
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(c_d_lip_before_animation, s_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] >= lip_normalize_threshold:
lip_delta_before_animation = self.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
# let eye-open scalar to be the same as the first frame if the latter is eye-open state
if flag_source_video_eye_retargeting and s_lmk is not None:
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
if c_d_eye_before_animation_frame_zero[0][0] < source_video_eye_retargeting_threshold:
c_d_eye_before_animation_frame_zero = [[0.39]]
combined_eye_ratio_tensor_before_animation = self.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, s_lmk)
eye_delta_before_animation = self.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation)
# animate
I_p_lst = []
for i in range(n_frames):
x_d_i_info = driving_template_dct['motion'][i]
x_d_i_info = dct2device(x_d_i_info, self.device)
# R
R_d_i = x_d_i_info['R']
if i == 0: # cache the first frame
R_d_0 = R_d_i
x_d_0_info = x_d_i_info.copy()
# enhance lip
# if i > 0:
# for lip_idx in [6, 12, 14, 17, 19, 20]:
# x_d_i_info['exp'][:, lip_idx, :] = x_d_0_info['exp'][:, lip_idx, :] + (x_d_i_info['exp'][:, lip_idx, :] - x_d_0_info['exp'][:, lip_idx, :]) * lib_multiplier
# normalize eye_ball, TODO
x_d_i_info['exp'] = self.update_delta_new_eyeball_direction(0, -5, x_d_i_info['exp'])
# debug
#print(f"frame {i:03d}, src scale {x_s_info['scale']}, 0 scale {x_d_0_info['scale']}, i scale {x_d_i_info['scale']}")
# delta
delta_new = x_s_info['exp'].clone()
if flag_relative_motion:
# R
if animation_region == "all" or animation_region == "pose":
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
else:
R_new = R_s
# exp
if animation_region == "all" or animation_region == "exp":
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
elif animation_region == "lip":
for lip_idx in [6, 12, 14, 17, 19, 20]:
delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :]
elif animation_region == "eyes":
for eyes_idx in [11, 13, 15, 16, 18]:
delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :]
# scale
if animation_region == "all":
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
else:
scale_new = x_s_info['scale']
# translation
if animation_region == "all" or animation_region == "pose":
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
else:
t_new = x_s_info['t']
else:
# R
if animation_region == "all" or animation_region == "pose":
R_new = R_d_i
else:
R_new = R_s
# exp
if animation_region == "all" or animation_region == "exp":
EYE_IDX=[1,2,6,11,12,13,14,15,16,17,18,19,20]
delta_new[:, EYE_IDX, :] = x_d_i_info['exp'][:, EYE_IDX, :]
# for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
# delta_new[:, idx, :] = x_d_i_info['exp'][:, idx, :]
delta_new[:, 3:5, 1] = x_d_i_info['exp'][:, 3:5, 1]
delta_new[:, 5, 2] = x_d_i_info['exp'][:, 5, 2]
delta_new[:, 8, 2] = x_d_i_info['exp'][:, 8, 2]
delta_new[:, 9, 1:] = x_d_i_info['exp'][:, 9, 1:]
elif animation_region == "lip":
for lip_idx in [6, 12, 14, 17, 19, 20]:
delta_new[:, lip_idx, :] = x_d_i_info['exp'][:, lip_idx, :]
elif animation_region == "eyes":
for eyes_idx in [11, 13, 15, 16, 18]:
delta_new[:, eyes_idx, :] = x_d_i_info['exp'][:, eyes_idx, :]
# scale
scale_new = x_s_info['scale']
# translation
if animation_region == "all" or animation_region == "pose":
t_new = x_d_i_info['t']
else:
t_new = x_s_info['t']
t_new[..., 2].fill_(0) # zero tz
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
if flag_relative_motion and driving_option == "expression-friendly":
if i == 0:
x_d_0_new = x_d_i_new
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
x_d_diff = (x_d_i_new - x_d_0_new) * motion_multiplier
x_d_i_new = x_d_diff + x_s
# Algorithm 1 in Liveportrait:
if not flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
# without stitching or retargeting
if flag_normalize_lip and lip_delta_before_animation is not None:
x_d_i_new += lip_delta_before_animation
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
x_d_i_new += eye_delta_before_animation
else:
pass
elif flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
# with stitching and without retargeting
if flag_normalize_lip and lip_delta_before_animation is not None:
x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation
else:
x_d_i_new = self.stitching(x_s, x_d_i_new)
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
x_d_i_new += eye_delta_before_animation
else:
eyes_delta, lip_delta = None, None
if flag_eye_retargeting and s_lmk is not None and c_d_eyes_lst is not None:
c_d_eyes_i = c_d_eyes_lst[i]
combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, s_lmk)
eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor)
if flag_lip_retargeting and s_lmk is not None and c_d_lip_lst is not None:
c_d_lip_i = c_d_lip_lst[i]
combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, s_lmk)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor)
if flag_relative_motion: # use x_s
x_d_i_new = x_s + \
(eyes_delta if eyes_delta is not None else 0) + \
(lip_delta if lip_delta is not None else 0)
else: # use x_d,i
x_d_i_new = x_d_i_new + \
(eyes_delta if eyes_delta is not None else 0) + \
(lip_delta if lip_delta is not None else 0)
if flag_stitching:
x_d_i_new = self.stitching(x_s, x_d_i_new)
x_d_i_new = x_s + (x_d_i_new - x_s) * driving_multiplier
x_d_i_news.append(x_d_i_new)
f_s_s= f_s.expand(n_frames, *f_s.shape[1:])
x_s_s = x_s.expand(n_frames, *x_s.shape[1:])
x_d_i_new = torch.cat(x_d_i_news, dim=0)
for start in range(0, n_frames, 100):
end = min(start + 100,n_frames)
with torch.no_grad(), torch.autocast('cuda'):
out = self.warp_decode(f_s_s[start:end], x_s_s[start:end], x_d_i_new[start:end])
I_p_lst.append(out['out'])
I_p=torch.cat(I_p_lst, dim=0)
I_p_i = self.parse_output(I_p)
return I_p_i
def driven_debug(self, f_s, x_s_info, s_lmk, c_s_eyes_lst, driving_template_dct, c_d_eyes_lst=None, c_d_lip_lst=None):
# source kp info
x_s_info = self.refine_kp(x_s_info)
x_c_s = x_s_info["kp"]
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
x_s = self.transform_keypoint(x_s_info)
n_frames = driving_template_dct['n_frames']
# driving params
flag_normalize_lip = self.cfg.flag_normalize_lip
flag_relative_motion = self.cfg.flag_relative_motion
flag_source_video_eye_retargeting = self.cfg.flag_source_video_eye_retargeting
lip_normalize_threshold = self.cfg.lip_normalize_threshold
source_video_eye_retargeting_threshold = self.cfg.source_video_eye_retargeting_threshold
animation_region = self.cfg.animation_region
driving_option = self.cfg.driving_option
flag_stitching = self.cfg.flag_stitching
flag_eye_retargeting = self.cfg.flag_eye_retargeting
flag_lip_retargeting = self.cfg.flag_lip_retargeting
driving_multiplier = self.cfg.driving_multiplier
# let lip-open scalar to be 0 at first
lip_delta_before_animation, eye_delta_before_animation = None, None
if flag_normalize_lip and flag_relative_motion and s_lmk is not None:
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(c_d_lip_before_animation, s_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] >= lip_normalize_threshold:
lip_delta_before_animation = self.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
# let eye-open scalar to be the same as the first frame if the latter is eye-open state
if flag_source_video_eye_retargeting and s_lmk is not None:
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
if c_d_eye_before_animation_frame_zero[0][0] < source_video_eye_retargeting_threshold:
c_d_eye_before_animation_frame_zero = [[0.39]]
combined_eye_ratio_tensor_before_animation = self.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, s_lmk)
eye_delta_before_animation = self.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation)
# animate
I_p_lst = []
for i in range(n_frames):
x_d_i_info = driving_template_dct['motion'][i]
x_d_i_info = dct2device(x_d_i_info, self.device)
# R
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
if i == 0: # cache the first frame
R_d_0 = R_d_i
x_d_0_info = x_d_i_info.copy()
# debug
#print(f"frame {i:03d}, src scale {x_s_info['scale']}, 0 scale {x_d_0_info['scale']}, i scale {x_d_i_info['scale']}")
# delta
delta_new = x_s_info['exp'].clone()
if flag_relative_motion:
# R
if animation_region == "all" or animation_region == "pose":
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
else:
R_new = R_s
# exp
if animation_region == "all" or animation_region == "exp":
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
elif animation_region == "lip":
for lip_idx in [6, 12, 14, 17, 19, 20]:
delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :]
elif animation_region == "eyes":
for eyes_idx in [11, 13, 15, 16, 18]:
delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :]
# scale
if animation_region == "all":
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
else:
scale_new = x_s_info['scale']
# translation
if animation_region == "all" or animation_region == "pose":
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
else:
t_new = x_s_info['t']
else:
# R
if animation_region == "all" or animation_region == "pose":
R_new = R_d_i
else:
R_new = R_s
# exp
if animation_region == "all" or animation_region == "exp":
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
delta_new[:, idx, :] = x_d_i_info['exp'][:, idx, :]
delta_new[:, 3:5, 1] = x_d_i_info['exp'][:, 3:5, 1]
delta_new[:, 5, 2] = x_d_i_info['exp'][:, 5, 2]
delta_new[:, 8, 2] = x_d_i_info['exp'][:, 8, 2]
delta_new[:, 9, 1:] = x_d_i_info['exp'][:, 9, 1:]
elif animation_region == "lip":
for lip_idx in [6, 12, 14, 17, 19, 20]:
delta_new[:, lip_idx, :] = x_d_i_info['exp'][:, lip_idx, :]
elif animation_region == "eyes":
for eyes_idx in [11, 13, 15, 16, 18]:
delta_new[:, eyes_idx, :] = x_d_i_info['exp'][:, eyes_idx, :]
# scale
scale_new = x_s_info['scale']
# translation
if animation_region == "all" or animation_region == "pose":
t_new = x_d_i_info['t']
else:
t_new = x_s_info['t']
t_new[..., 2].fill_(0) # zero tz
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
if flag_relative_motion and driving_option == "expression-friendly":
if i == 0:
x_d_0_new = x_d_i_new
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
x_d_diff = (x_d_i_new - x_d_0_new) * motion_multiplier
x_d_i_new = x_d_diff + x_s
# Algorithm 1 in Liveportrait:
if not flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
# without stitching or retargeting
if flag_normalize_lip and lip_delta_before_animation is not None:
x_d_i_new += lip_delta_before_animation
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
x_d_i_new += eye_delta_before_animation
else:
pass
elif flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
# with stitching and without retargeting
if flag_normalize_lip and lip_delta_before_animation is not None:
x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation
else:
x_d_i_new = self.stitching(x_s, x_d_i_new)
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
x_d_i_new += eye_delta_before_animation
else:
eyes_delta, lip_delta = None, None
if flag_eye_retargeting and s_lmk is not None and c_d_eyes_lst is not None:
c_d_eyes_i = c_d_eyes_lst[i]
combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, s_lmk)
eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor)
if flag_lip_retargeting and s_lmk is not None and c_d_lip_lst is not None:
c_d_lip_i = c_d_lip_lst[i]
combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, s_lmk)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor)
if flag_relative_motion: # use x_s
x_d_i_new = x_s + \
(eyes_delta if eyes_delta is not None else 0) + \
(lip_delta if lip_delta is not None else 0)
else: # use x_d,i
x_d_i_new = x_d_i_new + \
(eyes_delta if eyes_delta is not None else 0) + \
(lip_delta if lip_delta is not None else 0)
if flag_stitching:
x_d_i_new = self.stitching(x_s, x_d_i_new)
x_d_i_new = x_s + (x_d_i_new - x_s) * driving_multiplier
out = self.warp_decode(f_s, x_s, x_d_i_new)
I_p_i = self.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
return I_p_lst
def read_image(self, image_path: str) -> list:
img_rgb = load_image_rgb(image_path)
img_rgb = resize_to_limit(img_rgb, self.cfg.source_max_dim, self.cfg.source_division)
source_rgb_list = [img_rgb]
print(f"Load image from {osp.realpath(image_path)} done.")
return source_rgb_list
def read_video(self, video_path: str, interval=None) -> list:
vr = VideoReader(video_path)
if interval is not None:
video_frames = vr.get_batch(np.arange(0, len(vr), interval)).numpy()
else:
video_frames = [vr[0].numpy(), vr[len(vr) // 2].numpy(), vr[-1].numpy()]
vr.seek(0)
driving_rgb_list = []
for video_frame in video_frames:
# h, w = video_frame.shape[:2]
# if h != self.cfg.output_height or w != self.cfg.output_width:
# video_frame = cv2.resize(video_frame, (self.cfg.output_height, self.cfg.output_width))
driving_rgb_list.append(video_frame)
return driving_rgb_list
def prepare_videos(self, imgs) -> torch.Tensor:
""" construct the input as standard
imgs: NxBxHxWx3, uint8
"""
if isinstance(imgs, list):
_imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
elif isinstance(imgs, np.ndarray):
_imgs = imgs
else:
raise ValueError(f'imgs type error: {type(imgs)}')
y = _imgs.astype(np.float32) / 255.
y = np.clip(y, 0, 1) # clip to 0~1
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
y = y.to(self.device)
return y
def make_motion_template(self, I_lst, c_eyes_lst, c_lip_lst, **kwargs):
n_frames = I_lst.shape[0]
template_dct = {
'n_frames': n_frames,
'output_fps': kwargs.get('output_fps', 25),
'motion': [],
'c_eyes_lst': [],
'c_lip_lst': [],
}
for i in track(range(n_frames), description='Making motion templates...', total=n_frames):
# collect s, R, δ and t for inference
I_i = I_lst[i]
x_i_info = self.refine_kp(self.get_kp_info(I_i))
x_s = self.transform_keypoint(x_i_info)
R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])
item_dct = {
'scale': x_i_info['scale'].cpu().numpy().astype(np.float32),
'R': R_i.cpu().numpy().astype(np.float32),
'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
't': x_i_info['t'].cpu().numpy().astype(np.float32),
'kp': x_i_info['kp'].cpu().numpy().astype(np.float32),
'x_s': x_s.cpu().numpy().astype(np.float32),
}
template_dct['motion'].append(item_dct)
c_eyes = c_eyes_lst[i].astype(np.float32)
template_dct['c_eyes_lst'].append(c_eyes)
c_lip = c_lip_lst[i].astype(np.float32)
template_dct['c_lip_lst'].append(c_lip)
return template_dct
def load_template(self, wfp_template):
print(f"Load from template: {wfp_template}, NOT the video, so the cropping video and audio are both NULL.")
driving_template_dct = load(wfp_template)
c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
driving_n_frames = driving_template_dct['n_frames']
flag_is_driving_video = True if driving_n_frames > 1 else False
n_frames = driving_n_frames
# set output_fps
output_fps = driving_template_dct.get('output_fps', 25)
print(f'The FPS of template: {output_fps}')
return driving_template_dct
def reconstruction(self, src_img, dst_imgs, video_path="template"):
# prepare source
src_img_256x256, s_lmk, _ = self.crop_image(src_img, do_crop=False)
#c_s_eyes_lst, c_s_lip_lst = self.calc_ratio([s_lmk])
c_s_eyes_lst = None
f_s, x_s_info = self.prepare_source(src_img_256x256)
# prepare driving videos
dst_imgs_256x256, d_lmk_lst = self.crop_driving_videos(dst_imgs, do_crop=False)
c_d_eyes_lst, c_d_lip_lst = self.calc_ratio(d_lmk_lst)
kp_infos = self.prepare_driving_videos(dst_imgs_256x256)
recs = self.driven(f_s, x_s_info, s_lmk, c_s_eyes_lst, kp_infos, c_d_eyes_lst, c_d_lip_lst)
return recs
def save_results(self, results, save_path, audio_path=None):
save_dir = osp.dirname(save_path)
save_name = osp.basename(save_path)
final_video = osp.join(save_dir, f'final_{save_name}')
images2video(results, wfp=save_path, fps=self.cfg.output_fps)
if audio_path is not None:
add_audio_to_video(save_path, audio_path, final_video)
os.remove(save_path)
def rec_score(self, video_path: str, interval=None, save_path=None):
video_frames = self.read_video(video_path, interval=interval)
#print(f"len frames: {len(video_frames)}, shape: {video_frames[0].shape}")
recs = self.reconstruction(video_frames[0], video_frames[1:], video_path)
if save_path is not None:
self.save_results(recs, save_path)
#print(f"len rec: {len(recs)}, shape: {recs[0].shape}")
psnrs = psnr(video_frames[1:], recs)
psnrs_np = np.array(psnrs)
psnr_mean, psnr_std = np.mean(psnrs_np), np.std(psnrs_np)
rec_score = {"mean": psnr_mean, "std": psnr_std}
return rec_score
@torch.no_grad()
def paste_back_by_face_mask(self, result, crop_info, src_img, crop_src_image, use_laplacian=False):
"""
paste back the result to the original image with face mask
"""
# detect src mask
crop_src_tensor = self.to_tensor(crop_src_image).unsqueeze(0).to(self.device)
src_msks = get_face_mask(self.face_parser, crop_src_tensor)
result_tensor = self.to_tensor(result).unsqueeze(0).to(self.device)
result_msks = get_face_mask(self.face_parser, result_tensor)
# combine masks
masks = []
for src_msk, result_msk in zip(src_msks, result_msks):
mask = np.clip(src_msk + result_msk, 0, 1)
masks.append(mask)
result = paste_back_with_face_mask(result, crop_info, src_img, masks[0], use_laplacian=use_laplacian)
return result
def driven_by_audio(self, src_img, kp_infos, save_path, audio_path=None, smooth=False):
# prepare source
# prepare source
src_img_256x256, s_lmk, crop_info = self.crop_image(src_img, do_crop=True)
#c_s_eyes_lst, c_s_lip_lst = self.calc_ratio([s_lmk])
c_s_eyes_lst = None
f_s, x_s_info = self.prepare_source(src_img_256x256)
mask_ori_float = prepare_paste_back(self.mask_crop, crop_info['M_c2o'], dsize=(src_img.shape[1], src_img.shape[0]))
# prepare driving videos
results = self.driven(f_s, x_s_info, s_lmk, c_s_eyes_lst, kp_infos, smooth=smooth)
frames=results.shape[0]
results = [paste_back(results[i], crop_info['M_c2o'], src_img, mask_ori_float) for i in range(frames)]
self.save_results(results, save_path, audio_path)
def mix_kp_infos(self, emo_kp_infos, lip_kp_infos, smooth=False, dtype="pt_tensor"):
driving_emo_template_dct = self.get_driving_template(emo_kp_infos, smooth=False, dtype=dtype)
if lip_kp_infos is not None:
driving_lip_template_dct = self.get_driving_template(lip_kp_infos, smooth=smooth, dtype=dtype)
driving_template_dct = {**driving_emo_template_dct}
n_frames = min(driving_emo_template_dct['n_frames'], driving_lip_template_dct['n_frames'])
driving_template_dct['n_frames'] = n_frames
for i in range(n_frames):
emo_motion = driving_emo_template_dct['motion'][i]['exp']
lib_motion = driving_lip_template_dct['motion'][i]['exp']
for lip_idx in [6, 12, 14, 17, 19, 20]:
emo_motion[:, lip_idx, :] = lib_motion[:, lip_idx, :]
driving_template_dct['motion'][i]['exp'] = emo_motion
else:
driving_template_dct = driving_emo_template_dct
return driving_template_dct
def driven_by_mix(self, src_img, driving_video_path, kp_infos, save_path, audio_path=None, smooth=False):
# prepare source
src_img_256x256, s_lmk, crop_info = self.crop_image(src_img, do_crop=True)
c_s_eyes_lst, c_s_lip_lst = self.calc_ratio([s_lmk])
f_s, x_s_info = self.prepare_source(src_img_256x256)
mask_ori_float = prepare_paste_back(self.mask_crop, crop_info['M_c2o'], dsize=(src_img.shape[1], src_img.shape[0]))
# prepare driving videos
driving_imgs = self.read_video(driving_video_path, interval=1)
dst_imgs_256x256, d_lmk_lst = self.crop_driving_videos(driving_imgs, do_crop=True)
c_d_eyes_lst, c_d_lip_lst = self.calc_ratio(d_lmk_lst)
emo_kp_infos = self.prepare_driving_videos(dst_imgs_256x256)
# mix kp_infos
driving_template_dct = self.mix_kp_infos(emo_kp_infos, kp_infos, smooth=smooth)
# driven
results = self.driven_debug(f_s, x_s_info, s_lmk, c_s_eyes_lst, driving_template_dct, c_d_eyes_lst=c_d_eyes_lst, c_d_lip_lst=c_d_lip_lst)
results = [paste_back(result, crop_info['M_c2o'], src_img, mask_ori_float) for result in results]
print(results.shape)
self.save_results(results, save_path, audio_path)
def drive_video_by_mix(self, video_path, driving_video_path, kp_infos, save_path, audio_path):
# prepare driving videos
driving_imgs = self.read_video(driving_video_path, interval=1)
dst_imgs_256x256, d_lmk_lst = self.crop_driving_videos(driving_imgs, do_crop=True)
emo_kp_infos = self.prepare_driving_videos(dst_imgs_256x256)
# mix kp_infos
#driving_template_dct = self.get_driving_template(emo_kp_infos, smooth=True, dtype="np")
driving_template_dct = self.mix_kp_infos(emo_kp_infos, kp_infos, smooth=True, dtype="np")
# driven
self.video_lip_retargeting(
video_path, None,
save_path, audio_path,
driving_template_dct=driving_template_dct, retargeting_ragion="exp"
)
def load_source_video(self, video_info, n_frames=-1):
reader = imageio.get_reader(video_info, "ffmpeg")
ret = []
for idx, frame_rgb in enumerate(reader):
if n_frames > 0 and idx >= n_frames:
break
ret.append(frame_rgb)
reader.close()
return ret
def video_lip_retargeting(self, video_path, kp_infos, save_path, audio_path, c_d_eyes_lst=None, c_d_lip_lst=None, smooth=False, driving_template_dct=None, retargeting_ragion="exp"):
# 0. process source motion template
source_rgb_lst = load_video(video_path)
source_rgb_lst = [resize_to_limit(img, self.cfg.source_max_dim, self.cfg.source_division) for img in source_rgb_lst]
img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = self.crop_source_video(source_rgb_lst, do_crop=True)
c_s_eyes_lst, c_s_lip_lst = self.calc_ratio(source_lmk_crop_lst)
I_s_lst = self.prepare_videos(img_crop_256x256_lst)
source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=25)
# 1. prepare driving template
if driving_template_dct is None:
driving_template_dct = self.get_driving_template(kp_infos, smooth=smooth, dtype="np")
# 2. driving
n_frames = min(source_template_dct['n_frames'], driving_template_dct['n_frames'])
# driving params
I_p_lst = []
I_p_pstbk_lst = []
R_d_0, x_d_0_info = None, None
flag_normalize_lip = self.cfg.flag_normalize_lip
flag_relative_motion = True #self.cfg.flag_relative_motion
flag_source_video_eye_retargeting = self.cfg.flag_source_video_eye_retargeting
lip_normalize_threshold = self.cfg.lip_normalize_threshold
source_video_eye_retargeting_threshold = self.cfg.source_video_eye_retargeting_threshold
animation_region = 'lip' #self.cfg.animation_region
driving_option = self.cfg.driving_option
flag_stitching = self.cfg.flag_stitching
flag_eye_retargeting = self.cfg.flag_eye_retargeting
flag_lip_retargeting = self.cfg.flag_lip_retargeting
driving_multiplier = self.cfg.driving_multiplier
driving_smooth_observation_variance = self.cfg.driving_smooth_observation_variance
key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d'
if flag_relative_motion:
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
for i in range(n_frames):
for idx in [6, 12, 14, 17, 19, 20]:
# lip motion use abs motion
x_d_exp_lst[i][:, idx, :] = driving_template_dct['motion'][i]['exp'][:, idx, :]
x_d_exp_lst_smooth = ksmooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, self.device, driving_smooth_observation_variance)
if animation_region == "all" or animation_region == "pose" or "all" in animation_region:
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
x_d_r_lst_smooth = ksmooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, self.device, driving_smooth_observation_variance)
else:
x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)]
x_d_exp_lst_smooth = ksmooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, self.device, driving_smooth_observation_variance)
if animation_region == "all" or animation_region == "pose" or "all" in animation_region:
x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)]
x_d_r_lst_smooth = ksmooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, self.device, driving_smooth_observation_variance)
# driving all
for i in track(range(n_frames), description='🚀Retargeting...', total=n_frames):
x_s_info = source_template_dct['motion'][i]
x_s_info = dct2device(x_s_info, self.device)
source_lmk = source_lmk_crop_lst[i]
img_crop_256x256 = img_crop_256x256_lst[i]
I_s = I_s_lst[i]
f_s = self.extract_feature_3d(I_s)
x_c_s = x_s_info['kp']
R_s = x_s_info['R']
x_s =x_s_info['x_s']
# let lip-open scalar to be 0 at first if the input is a video
lip_delta_before_animation = None
if flag_normalize_lip and flag_relative_motion and source_lmk is not None:
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] >= lip_normalize_threshold:
lip_delta_before_animation = self.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
else:
lip_delta_before_animation = None
# let eye-open scalar to be the same as the first frame if the latter is eye-open state
eye_delta_before_animation = None
if flag_source_video_eye_retargeting and source_lmk is not None:
if i == 0:
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
if c_d_eye_before_animation_frame_zero[0][0] < source_video_eye_retargeting_threshold:
c_d_eye_before_animation_frame_zero = [[0.39]]
combined_eye_ratio_tensor_before_animation = self.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, source_lmk)
eye_delta_before_animation = self.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation)
if flag_stitching: # prepare for paste back
mask_ori_float = prepare_paste_back(self.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0]))
x_d_i_info = driving_template_dct['motion'][i]
x_d_i_info = dct2device(x_d_i_info, self.device)
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
if i == 0: # cache the first frame
R_d_0 = R_d_i
x_d_0_info = x_d_i_info.copy()
delta_new = x_s_info['exp'].clone()
if flag_relative_motion:
if animation_region == "all" or animation_region == "pose" or "all" in animation_region:
R_new = x_d_r_lst_smooth[i]
else:
R_new = R_s
if animation_region == "all" or animation_region == "exp":
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
elif animation_region == "all_wo_lip" or animation_region == "exp_wo_lip":
for idx in [1, 2, 11, 13, 15, 16, 18]:
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
elif animation_region == "lip":
for lip_idx in [6, 12, 14, 17, 19, 20]:
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :]
elif animation_region == "eyes":
for eyes_idx in [11, 13, 15, 16, 18]:
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :]
scale_new = x_s_info['scale']
t_new = x_s_info['t']
else:
if animation_region == "all" or animation_region == "pose" or "all" in animation_region:
R_new = x_d_r_lst_smooth[i]
else:
R_new = R_s
if animation_region == "all" or animation_region == "exp":
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
elif animation_region == "all_wo_lip" or animation_region == "exp_wo_lip":
for idx in [1, 2, 11, 13, 15, 16, 18]:
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
elif animation_region == "lip":
for lip_idx in [6, 12, 14, 17, 19, 20]:
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :]
elif animation_region == "eyes":
for eyes_idx in [11, 13, 15, 16, 18]:
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :]
scale_new = x_s_info['scale']
if animation_region == "all" or animation_region == "pose" or "all" in animation_region:
t_new = x_d_i_info['t']
else:
t_new = x_s_info['t']
t_new[..., 2].fill_(0) # zero tz
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
# Algorithm 1:
if not flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
# without stitching or retargeting
if flag_normalize_lip and lip_delta_before_animation is not None:
x_d_i_new += lip_delta_before_animation
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
x_d_i_new += eye_delta_before_animation
else:
pass
elif flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
# with stitching and without retargeting
if flag_normalize_lip and lip_delta_before_animation is not None:
x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation
else:
x_d_i_new = self.stitching(x_s, x_d_i_new)
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
x_d_i_new += eye_delta_before_animation
else:
eyes_delta, lip_delta = None, None
if flag_eye_retargeting and source_lmk is not None and c_d_eyes_lst is not None:
c_d_eyes_i = c_d_eyes_lst[i]
combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor)
if flag_lip_retargeting and source_lmk is not None and c_d_lip_lst is not None:
c_d_lip_i = c_d_lip_lst[i]
combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor)
if flag_relative_motion: # use x_s
x_d_i_new = x_s + \
(eyes_delta if eyes_delta is not None else 0) + \
(lip_delta if lip_delta is not None else 0)
else: # use x_d,i
x_d_i_new = x_d_i_new + \
(eyes_delta if eyes_delta is not None else 0) + \
(lip_delta if lip_delta is not None else 0)
if flag_stitching:
x_d_i_new = self.stitching(x_s, x_d_i_new)
x_d_i_new = x_s + (x_d_i_new - x_s) * driving_multiplier
out = self.warp_decode(f_s, x_s, x_d_i_new)
I_p_i = self.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
if flag_stitching:
# TODO: the paste back procedure is slow, considering optimize it using multi-threading or GPU
#I_p_pstbk = self.paste_back_by_face_mask(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], img_crop_256x256, use_laplacian=True)
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_float, use_laplacian=True)
I_p_pstbk_lst.append(I_p_pstbk)
if len(I_p_pstbk_lst) > 0:
self.save_results(I_p_pstbk_lst, save_path, audio_path)
else:
self.save_results(I_p_lst, save_path, audio_path)
@torch.no_grad()
def video_reconstruction_test(self, video_tensor, xs, save_path):
# video_tensor, (1, F, C, H, W), [-1, 1]
# xs, (1, F, 63)
result_lst = []
#ori_videos = []
video_tensor = video_tensor[0:1] * 0.5 + 0.5 # [-1, 1] -> [0, 1], 1xTx3xHxW
video_tensor = torch.clip(video_tensor, 0, 1)
video_tensor = video_tensor.permute(1, 0, 2, 3, 4) # 1xTx3xHxW -> Tx1x3xHxW
video = video_tensor.to(self.device)
xs = xs[0:1].permute(1, 0, 2) # 1xTx63 -> Tx1x63
xs = xs.reshape(-1, 1, 21, 3)
xs = xs.to(self.device)
x_s_0 = xs[0]
I_s_0 = torch.nn.functional.interpolate(video[0], size=(256, 256), mode='bilinear')
f_s_0 = self.extract_feature_3d(I_s_0)
for i in range(video_tensor.shape[0]):
#I_s = video[i] # 1x3xHxW
#ori_videos.append((I_s.squeeze(0).squeeze(0).permute(1, 2, 0).cpu().numpy()*255).astype(np.uint8))
x_s = self.stitching(x_s_0, xs[i])
out = self.warp_decode(f_s_0, x_s_0, x_s)
I_p_i = self.parse_output(out['out'])[0]
result_lst.append(I_p_i)
#save_dir = osp.dirname(save_path)
#ori_path = osp.join(save_dir, "ori.mp4")
#save_path = osp.join(save_dir, "rec.mp4")
self.save_results(result_lst, save_path, audio_path=None)
#self.save_results(ori_videos, ori_path, audio_path=None)
@torch.no_grad()
def self_driven(self, image_tensor, xs, save_path, length):
result_lst = []
image_tensor = image_tensor[0:1] * 0.5 + 0.5 # [-1, 1] -> [0, 1], 1x3xHxW
image_tensor = torch.clip(image_tensor, 0, 1)
image = image_tensor.to(self.device)
I_s_0 = torch.nn.functional.interpolate(image, size=(256, 256), mode='bilinear')
xs = xs[0:1].permute(1, 0, 2) # 1xTx63 -> Tx1x63
xs = xs.reshape(-1, 1, 21, 3)
xs = xs.to(self.device)
x_s_0 = xs[0]
f_s_0 = self.extract_feature_3d(I_s_0)
for i in range(xs.shape[0]):
x_d = self.stitching(x_s_0, xs[i])
out = self.warp_decode(f_s_0, x_s_0, x_d)
I_p_i = self.parse_output(out['out'])[0]
result_lst.append(I_p_i)
assert len(result_lst) == length, f"length of result_lst is {len(result_lst)}, but length is {length}"
self.save_results(result_lst, save_path, audio_path=None)