|
""" |
|
Motion feature extractor |
|
""" |
|
import os |
|
import os.path as osp |
|
import sys |
|
import pickle |
|
from omegaconf import OmegaConf |
|
|
|
import torch |
|
|
|
from PIL import Image |
|
import numpy as np |
|
import cv2 |
|
import imageio |
|
import pickle |
|
import time |
|
from decord import VideoReader |
|
|
|
from rich.progress import track |
|
|
|
|
|
|
|
|
|
sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__))))))) |
|
from src.datasets.preprocess.extract_features.face_segmentation import build_face_parser, get_face_mask, vis_parsing_maps |
|
from src.thirdparty.liveportrait.src.utils.helper import load_model, concat_feat |
|
from src.thirdparty.liveportrait.src.utils.io import load_image_rgb, resize_to_limit, load_video |
|
from src.thirdparty.liveportrait.src.utils.video import get_fps, images2video, add_audio_to_video |
|
from src.thirdparty.liveportrait.src.utils.camera import headpose_pred_to_degree, get_rotation_matrix |
|
|
|
from src.thirdparty.liveportrait.src.utils.cropper import Cropper |
|
from src.thirdparty.liveportrait.src.utils.crop import prepare_paste_back, paste_back, paste_back_with_face_mask |
|
from src.thirdparty.liveportrait.src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio |
|
from src.thirdparty.liveportrait.src.utils.helper import mkdir, basename, dct2device, is_image, calc_motion_multiplier |
|
from src.utils.filter import smooth as ksmooth |
|
from src.utils.filter import smooth_ |
|
|
|
from skimage.metrics import peak_signal_noise_ratio |
|
import warnings |
|
|
|
|
|
def psnr(imgs1, imgs2): |
|
psnrs = [] |
|
for img1, img2 in zip(imgs1, imgs2): |
|
psnr = peak_signal_noise_ratio(img1, img2, data_range=255) |
|
psnrs.append(psnr) |
|
return psnrs |
|
|
|
|
|
def suffix(filename): |
|
"""a.jpg -> jpg""" |
|
pos = filename.rfind(".") |
|
if pos == -1: |
|
return "" |
|
return filename[pos + 1:] |
|
|
|
def dump(wfp, obj): |
|
wd = osp.split(wfp)[0] |
|
if wd != "" and not osp.exists(wd): |
|
mkdir(wd) |
|
|
|
_suffix = suffix(wfp) |
|
if _suffix == "npy": |
|
np.save(wfp, obj) |
|
elif _suffix == "pkl": |
|
pickle.dump(obj, open(wfp, "wb")) |
|
else: |
|
raise Exception("Unknown type: {}".format(_suffix)) |
|
|
|
def load(fp): |
|
suffix_ = suffix(fp) |
|
|
|
if suffix_ == "npy": |
|
return np.load(fp) |
|
elif suffix_ == "pkl": |
|
return pickle.load(open(fp, "rb")) |
|
else: |
|
raise Exception(f"Unknown type: {suffix}") |
|
|
|
|
|
def remove_suffix(filepath): |
|
"""a/b/c.jpg -> a/b/c""" |
|
return osp.join(osp.dirname(filepath), basename(filepath)) |
|
|
|
|
|
class MotionProcesser(object): |
|
def __init__(self, cfg_path, device_id=0) -> None: |
|
device = f"cuda:{device_id}" |
|
cfg = OmegaConf.load(cfg_path) |
|
print(f"Load cfg from {osp.realpath(cfg_path)} done.") |
|
print(f"=============================== Driven CFG ===============================") |
|
print(OmegaConf.to_yaml(cfg)) |
|
print(f"=============================== ========== ===============================") |
|
models_config = OmegaConf.load(cfg.models_config) |
|
|
|
|
|
self.appearance_feature_extractor = load_model( |
|
cfg.appearance_feature_extractor_path, |
|
models_config, |
|
device, |
|
'appearance_feature_extractor' |
|
) |
|
print(f'1. Load appearance_feature_extractor from {osp.realpath(cfg.appearance_feature_extractor_path)} done.') |
|
|
|
|
|
self.motion_extractor = load_model( |
|
cfg.motion_extractor_path, |
|
models_config, |
|
device, |
|
'motion_extractor' |
|
) |
|
print(f'2. Load motion_extractor from {osp.realpath(cfg.motion_extractor_path)} done.') |
|
|
|
|
|
if cfg.stitching_retargeting_module_path is not None and osp.exists(cfg.stitching_retargeting_module_path): |
|
self.stitching_retargeting_module = load_model( |
|
cfg.stitching_retargeting_module_path, |
|
models_config, |
|
device, |
|
'stitching_retargeting_module' |
|
) |
|
print(f'3. Load stitching_retargeting_module from {osp.realpath(cfg.stitching_retargeting_module_path)} done.') |
|
else: |
|
self.stitching_retargeting_module = None |
|
|
|
|
|
self.warping_module = load_model( |
|
cfg.warping_module_path, |
|
models_config, |
|
device, |
|
'warping_module' |
|
) |
|
print(f"4. Load warping_module from {osp.realpath(cfg.warping_module_path)} done.") |
|
|
|
|
|
self.spade_generator = load_model( |
|
cfg.spade_generator_path, |
|
models_config, |
|
device, |
|
'spade_generator' |
|
) |
|
print(f"Load generator from {osp.realpath(cfg.spade_generator_path)} done.") |
|
|
|
|
|
self.compile = cfg.flag_do_torch_compile |
|
if self.compile: |
|
torch._dynamo.config.suppress_errors = True |
|
self.warping_module = torch.compile(self.warping_module, mode='max-autotune') |
|
self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune') |
|
|
|
|
|
crop_cfg = OmegaConf.load(cfg.crop_cfg) |
|
self.cropper = Cropper(crop_cfg=crop_cfg, image_type="human_face", device_id=device_id) |
|
|
|
self.cfg = cfg |
|
self.models_config = models_config |
|
self.device = device |
|
|
|
|
|
|
|
self.mask_crop = cv2.imread(cfg.mask_crop, cv2.IMREAD_COLOR) |
|
|
|
with open(cfg.lip_array, 'rb') as f: |
|
self.lip_array = pickle.load(f) |
|
|
|
|
|
self.face_parser, self.to_tensor = build_face_parser(weight_path=cfg.face_parser_weight_path, resnet_weight_path=cfg.resnet_weight_path, device_id=device_id) |
|
|
|
def inference_ctx(self): |
|
ctx = torch.autocast(device_type=self.device[:4], dtype=torch.float16, |
|
enabled=self.cfg.flag_use_half_precision) |
|
return ctx |
|
|
|
@torch.no_grad() |
|
def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor: |
|
""" get the appearance feature of the image by F |
|
x: Bx3xHxW, normalized to 0~1 |
|
""" |
|
with self.inference_ctx(): |
|
feature_3d = self.appearance_feature_extractor(x) |
|
|
|
return feature_3d.float() |
|
|
|
@torch.no_grad() |
|
def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict: |
|
""" get the implicit keypoint information |
|
x: Bx3xHxW, normalized to 0~1 |
|
flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape |
|
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp' |
|
""" |
|
with self.inference_ctx(): |
|
kp_info = self.motion_extractor(x) |
|
|
|
if self.cfg.flag_use_half_precision: |
|
|
|
for k, v in kp_info.items(): |
|
if isinstance(v, torch.Tensor): |
|
kp_info[k] = v.float() |
|
|
|
return kp_info |
|
|
|
@torch.no_grad() |
|
def refine_kp(self, kp_info): |
|
bs = kp_info['exp'].shape[0] |
|
kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] |
|
kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] |
|
kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] |
|
kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) |
|
if 'kp' in kp_info.keys(): |
|
kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) |
|
|
|
return kp_info |
|
|
|
@torch.no_grad() |
|
def transform_keypoint(self, kp_info: dict): |
|
""" |
|
transform the implicit keypoints with the pose, shift, and expression deformation |
|
kp: BxNx3 |
|
""" |
|
kp = kp_info['kp'] |
|
pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll'] |
|
|
|
t, exp = kp_info['t'], kp_info['exp'] |
|
scale = kp_info['scale'] |
|
|
|
pitch = headpose_pred_to_degree(pitch) |
|
yaw = headpose_pred_to_degree(yaw) |
|
roll = headpose_pred_to_degree(roll) |
|
|
|
bs = kp.shape[0] |
|
if kp.ndim == 2: |
|
num_kp = kp.shape[1] // 3 |
|
else: |
|
num_kp = kp.shape[1] |
|
|
|
rot_mat = get_rotation_matrix(pitch, yaw, roll) |
|
|
|
|
|
kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3) |
|
kp_transformed *= scale[..., None] |
|
kp_transformed[:, :, 0:2] += t[:, None, 0:2] |
|
|
|
return kp_transformed |
|
|
|
@torch.no_grad() |
|
def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: |
|
""" conduct the stitching |
|
kp_source: Bxnum_kpx3 |
|
kp_driving: Bxnum_kpx3 |
|
""" |
|
|
|
if self.stitching_retargeting_module is not None: |
|
bs, num_kp = kp_source.shape[:2] |
|
kp_driving_new = kp_driving.clone() |
|
|
|
feat_stiching = concat_feat(kp_source, kp_driving_new) |
|
delta = self.stitching_retargeting_module['stitching'](feat_stiching) |
|
|
|
delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) |
|
delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) |
|
|
|
kp_driving_new += delta_exp |
|
kp_driving_new[..., :2] += delta_tx_ty |
|
|
|
return kp_driving_new |
|
|
|
return kp_driving |
|
|
|
@torch.no_grad() |
|
def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> dict[str, torch.Tensor]: |
|
""" get the image after the warping of the implicit keypoints |
|
feature_3d: Bx32x16x64x64, feature volume |
|
kp_source: BxNx3 |
|
kp_driving: BxNx3 |
|
""" |
|
|
|
with self.inference_ctx(): |
|
if self.compile: |
|
|
|
torch.compiler.cudagraph_mark_step_begin() |
|
|
|
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ret_dct['out'] = self.spade_generator(feature=ret_dct['out']) |
|
|
|
|
|
if self.cfg.flag_use_half_precision: |
|
for k, v in ret_dct.items(): |
|
if isinstance(v, torch.Tensor): |
|
ret_dct[k] = v.float() |
|
|
|
return ret_dct |
|
|
|
def parse_output(self, out: torch.Tensor) -> np.ndarray: |
|
""" construct the output as standard |
|
return: 1xHxWx3, uint8 |
|
""" |
|
out = np.transpose(out.cpu().numpy(), [0, 2, 3, 1]) |
|
out = np.clip(out, 0, 1) |
|
out = np.clip(out * 255, 0, 255).astype(np.uint8) |
|
|
|
return out |
|
|
|
@torch.no_grad() |
|
def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk): |
|
c_s_eyes = calc_eye_close_ratio(source_lmk[None]) |
|
c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device) |
|
c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(self.device) |
|
|
|
combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1) |
|
return combined_eye_ratio_tensor |
|
|
|
@torch.no_grad() |
|
def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk): |
|
c_s_lip = calc_lip_close_ratio(source_lmk[None]) |
|
c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device) |
|
c_d_lip_i_tensor = torch.Tensor([c_d_lip_i[0]]).to(self.device).reshape(1, 1) |
|
|
|
combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) |
|
return combined_lip_ratio_tensor |
|
|
|
def calc_ratio(self, lmk_lst): |
|
input_eye_ratio_lst = [] |
|
input_lip_ratio_lst = [] |
|
for lmk in lmk_lst: |
|
|
|
input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None])) |
|
|
|
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None])) |
|
return input_eye_ratio_lst, input_lip_ratio_lst |
|
|
|
@torch.no_grad() |
|
def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor: |
|
""" |
|
kp_source: BxNx3 |
|
lip_close_ratio: Bx2 |
|
Return: Bx(3*num_kp) |
|
""" |
|
feat_lip = concat_feat(kp_source, lip_close_ratio) |
|
|
|
delta = self.stitching_retargeting_module['lip'](feat_lip) |
|
|
|
return delta.reshape(-1, kp_source.shape[1], 3) |
|
|
|
@torch.no_grad() |
|
def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor: |
|
""" |
|
kp_source: BxNx3 |
|
eye_close_ratio: Bx3 |
|
Return: Bx(3*num_kp) |
|
""" |
|
feat_eye = concat_feat(kp_source, eye_close_ratio) |
|
|
|
delta = self.stitching_retargeting_module['eye'](feat_eye) |
|
|
|
return delta.reshape(-1, kp_source.shape[1], 3) |
|
|
|
def crop_image(self, img, do_crop=False): |
|
|
|
if do_crop: |
|
crop_info = self.cropper.crop_source_image(img, self.cropper.crop_cfg) |
|
if crop_info is None: |
|
raise Exception("No face detected in the source image!") |
|
lmk = crop_info['lmk_crop'] |
|
img_crop_256x256 = crop_info['img_crop_256x256'] |
|
else: |
|
crop_info = None |
|
lmk = self.cropper.calc_lmk_from_cropped_image(img) |
|
img_crop_256x256 = cv2.resize(img, (256, 256)) |
|
return img_crop_256x256, lmk, crop_info |
|
|
|
def crop_source_video(self, img_lst, do_crop=False): |
|
if do_crop: |
|
ret_s = self.cropper.crop_source_video(img_lst, self.cropper.crop_cfg) |
|
print(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.') |
|
img_crop_256x256_lst, lmk_crop_lst, M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst'] |
|
else: |
|
M_c2o_lst = None |
|
lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(img_lst) |
|
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in img_lst] |
|
return img_crop_256x256_lst, lmk_crop_lst, M_c2o_lst |
|
|
|
def crop_driving_videos(self, img_lst, do_crop=False): |
|
if do_crop: |
|
ret_d = self.cropper.crop_driving_video(img_lst) |
|
print(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.') |
|
img_crop_lst, lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst'] |
|
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in img_lst] |
|
else: |
|
lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(img_lst) |
|
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in img_lst] |
|
return img_crop_256x256_lst, lmk_crop_lst |
|
|
|
def prepare_source(self, src_img): |
|
""" construct the input as standard |
|
img: HxWx3, uint8, 256x256 |
|
""" |
|
|
|
h, w = src_img.shape[:2] |
|
if h != self.cfg.input_height or w != self.cfg.input_width: |
|
x = cv2.resize(src_img, (self.cfg.input_width, self.cfg.input_height)) |
|
else: |
|
x = src_img.copy() |
|
|
|
if x.ndim == 3: |
|
x = x[np.newaxis].astype(np.float32) / 255. |
|
elif x.ndim == 4: |
|
x = x.astype(np.float32) / 255. |
|
else: |
|
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}') |
|
|
|
x = np.clip(x, 0, 1) |
|
x = torch.from_numpy(x).permute(0, 3, 1, 2) |
|
x = x.to(self.device) |
|
|
|
|
|
I_s = x |
|
f_s = self.extract_feature_3d(I_s) |
|
x_s_info = self.get_kp_info(I_s) |
|
|
|
return f_s, x_s_info |
|
|
|
def process_clips(self, clips): |
|
""" construct the input as standard |
|
clips: NxBxHxWx3, uint8 |
|
""" |
|
|
|
imgs = [] |
|
for img in clips: |
|
h, w = img.shape[:2] |
|
if h != self.cfg.input_height or w != self.cfg.input_width: |
|
img = cv2.resize(img, (self.cfg.input_width, self.cfg.input_height)) |
|
else: |
|
img = img.copy() |
|
imgs.append(img) |
|
|
|
|
|
if isinstance(imgs, list): |
|
_imgs = np.array(imgs)[..., np.newaxis] |
|
elif isinstance(imgs, np.ndarray): |
|
_imgs = imgs |
|
else: |
|
raise ValueError(f'imgs type error: {type(imgs)}') |
|
|
|
y = _imgs.astype(np.float32) / 255. |
|
y = np.clip(y, 0, 1) |
|
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) |
|
y = y.to(self.device) |
|
|
|
return y |
|
|
|
def prepare_driving_videos(self, vid_frames, feat_type="tensor"): |
|
""" get driving kp infos |
|
vid_frames: image list of HxWx3, uint8 |
|
""" |
|
|
|
total_len = len(vid_frames) |
|
kp_infos = {"pitch": [], "yaw": [], "roll": [], "t": [], "exp": [], "scale": [], "kp": []} |
|
for start_idx in range(0, total_len, self.cfg.batch_size): |
|
frames = vid_frames[start_idx: min(start_idx + self.cfg.batch_size, total_len)] |
|
frames = self.process_clips(frames).squeeze(1) |
|
kp_info = self.get_kp_info(frames) |
|
|
|
for k, v in kp_info.items(): |
|
kp_infos[k].append(v) |
|
|
|
|
|
for k, v in kp_infos.items(): |
|
kp_infos[k] = torch.cat(v, dim=0) |
|
|
|
if feat_type == "np": |
|
for k, v in kp_infos.items(): |
|
kp_infos[k] = v.cpu().numpy() |
|
|
|
return kp_infos |
|
|
|
def get_driving_template(self, kp_infos, smooth=False, dtype="pt_tensor"): |
|
kp_infos = self.refine_kp(kp_infos) |
|
motion_list = [] |
|
n_frames = len(kp_infos["exp"]) |
|
for idx in range(n_frames): |
|
exp = kp_infos["exp"][idx] |
|
scale = kp_infos["scale"][idx] |
|
t = kp_infos["t"][idx] |
|
pitch = kp_infos["pitch"][idx] |
|
yaw = kp_infos["yaw"][idx] |
|
roll = kp_infos["roll"][idx] |
|
|
|
R = get_rotation_matrix(pitch, yaw, roll) |
|
R = R.reshape(1, 3, 3) |
|
|
|
exp = exp.reshape(1, 21, 3) |
|
scale = scale.reshape(1, 1) |
|
t = t.reshape(1, 3) |
|
pitch = pitch.reshape(1, 1) |
|
yaw = yaw.reshape(1, 1) |
|
roll = roll.reshape(1, 1) |
|
|
|
if dtype == "np": |
|
R = R.cpu().numpy().astype(np.float32) |
|
exp = exp.cpu().numpy().astype(np.float32) |
|
scale = scale.cpu().numpy().astype(np.float32) |
|
t = t.cpu().numpy().astype(np.float32) |
|
pitch = pitch.cpu().numpy().astype(np.float32) |
|
yaw = yaw.cpu().numpy().astype(np.float32) |
|
roll = roll.cpu().numpy().astype(np.float32) |
|
|
|
motion_list.append( |
|
{"exp": exp, "scale": scale, "R": R, "t": t, "pitch": pitch, "yaw": yaw, "roll": roll} |
|
) |
|
tgt_motion = {'n_frames': n_frames, 'output_fps': 25, 'motion': motion_list} |
|
|
|
if smooth: |
|
print("Smoothing motion sequence...") |
|
tgt_motion = smooth_(tgt_motion, method="ema") |
|
return tgt_motion |
|
|
|
@torch.no_grad() |
|
def update_delta_new_eyeball_direction(self, eyeball_direction_x, eyeball_direction_y, delta_new, **kwargs): |
|
if eyeball_direction_x > 0: |
|
delta_new[0, 11, 0] += eyeball_direction_x * 0.0007 |
|
delta_new[0, 15, 0] += eyeball_direction_x * 0.001 |
|
else: |
|
delta_new[0, 11, 0] += eyeball_direction_x * 0.001 |
|
delta_new[0, 15, 0] += eyeball_direction_x * 0.0007 |
|
|
|
delta_new[0, 11, 1] += eyeball_direction_y * -0.001 |
|
delta_new[0, 15, 1] += eyeball_direction_y * -0.001 |
|
blink = -eyeball_direction_y / 2. |
|
|
|
delta_new[0, 11, 1] += blink * -0.001 |
|
delta_new[0, 13, 1] += blink * 0.0003 |
|
delta_new[0, 15, 1] += blink * -0.001 |
|
delta_new[0, 16, 1] += blink * 0.0003 |
|
|
|
return delta_new |
|
|
|
def driven(self, f_s, x_s_info, s_lmk, c_s_eyes_lst, kp_infos, c_d_eyes_lst=None, c_d_lip_lst=None, smooth=False): |
|
|
|
x_d_i_news=[] |
|
x_ss=[] |
|
f_ss=[] |
|
x_s_info = self.refine_kp(x_s_info) |
|
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) |
|
x_s = self.transform_keypoint(x_s_info) |
|
x_c_s = x_s_info["kp"] |
|
|
|
|
|
driving_template_dct = self.get_driving_template(kp_infos, smooth) |
|
n_frames = driving_template_dct['n_frames'] |
|
|
|
|
|
flag_normalize_lip = self.cfg.flag_normalize_lip |
|
flag_relative_motion = self.cfg.flag_relative_motion |
|
flag_source_video_eye_retargeting = self.cfg.flag_source_video_eye_retargeting |
|
lip_normalize_threshold = self.cfg.lip_normalize_threshold |
|
source_video_eye_retargeting_threshold = self.cfg.source_video_eye_retargeting_threshold |
|
animation_region = self.cfg.animation_region |
|
driving_option = self.cfg.driving_option |
|
flag_stitching = self.cfg.flag_stitching |
|
flag_eye_retargeting = self.cfg.flag_eye_retargeting |
|
flag_lip_retargeting = self.cfg.flag_lip_retargeting |
|
driving_multiplier = self.cfg.driving_multiplier |
|
lib_multiplier = self.cfg.lib_multiplier |
|
|
|
|
|
lip_delta_before_animation, eye_delta_before_animation = None, None |
|
if flag_normalize_lip and flag_relative_motion and s_lmk is not None: |
|
c_d_lip_before_animation = [0.] |
|
combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(c_d_lip_before_animation, s_lmk) |
|
if combined_lip_ratio_tensor_before_animation[0][0] >= lip_normalize_threshold: |
|
lip_delta_before_animation = self.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) |
|
|
|
|
|
if flag_source_video_eye_retargeting and s_lmk is not None: |
|
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0] |
|
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]] |
|
if c_d_eye_before_animation_frame_zero[0][0] < source_video_eye_retargeting_threshold: |
|
c_d_eye_before_animation_frame_zero = [[0.39]] |
|
combined_eye_ratio_tensor_before_animation = self.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, s_lmk) |
|
eye_delta_before_animation = self.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation) |
|
|
|
|
|
I_p_lst = [] |
|
for i in range(n_frames): |
|
x_d_i_info = driving_template_dct['motion'][i] |
|
x_d_i_info = dct2device(x_d_i_info, self.device) |
|
|
|
R_d_i = x_d_i_info['R'] |
|
if i == 0: |
|
R_d_0 = R_d_i |
|
x_d_0_info = x_d_i_info.copy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_d_i_info['exp'] = self.update_delta_new_eyeball_direction(0, -5, x_d_i_info['exp']) |
|
|
|
|
|
|
|
|
|
delta_new = x_s_info['exp'].clone() |
|
if flag_relative_motion: |
|
|
|
if animation_region == "all" or animation_region == "pose": |
|
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s |
|
else: |
|
R_new = R_s |
|
|
|
|
|
if animation_region == "all" or animation_region == "exp": |
|
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) |
|
elif animation_region == "lip": |
|
for lip_idx in [6, 12, 14, 17, 19, 20]: |
|
delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :] |
|
elif animation_region == "eyes": |
|
for eyes_idx in [11, 13, 15, 16, 18]: |
|
delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :] |
|
|
|
|
|
if animation_region == "all": |
|
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) |
|
else: |
|
scale_new = x_s_info['scale'] |
|
|
|
|
|
if animation_region == "all" or animation_region == "pose": |
|
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) |
|
else: |
|
t_new = x_s_info['t'] |
|
else: |
|
|
|
if animation_region == "all" or animation_region == "pose": |
|
R_new = R_d_i |
|
else: |
|
R_new = R_s |
|
|
|
|
|
if animation_region == "all" or animation_region == "exp": |
|
EYE_IDX=[1,2,6,11,12,13,14,15,16,17,18,19,20] |
|
delta_new[:, EYE_IDX, :] = x_d_i_info['exp'][:, EYE_IDX, :] |
|
|
|
|
|
delta_new[:, 3:5, 1] = x_d_i_info['exp'][:, 3:5, 1] |
|
delta_new[:, 5, 2] = x_d_i_info['exp'][:, 5, 2] |
|
delta_new[:, 8, 2] = x_d_i_info['exp'][:, 8, 2] |
|
delta_new[:, 9, 1:] = x_d_i_info['exp'][:, 9, 1:] |
|
elif animation_region == "lip": |
|
for lip_idx in [6, 12, 14, 17, 19, 20]: |
|
delta_new[:, lip_idx, :] = x_d_i_info['exp'][:, lip_idx, :] |
|
elif animation_region == "eyes": |
|
for eyes_idx in [11, 13, 15, 16, 18]: |
|
delta_new[:, eyes_idx, :] = x_d_i_info['exp'][:, eyes_idx, :] |
|
|
|
|
|
scale_new = x_s_info['scale'] |
|
|
|
|
|
if animation_region == "all" or animation_region == "pose": |
|
t_new = x_d_i_info['t'] |
|
else: |
|
t_new = x_s_info['t'] |
|
|
|
t_new[..., 2].fill_(0) |
|
|
|
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new |
|
|
|
if flag_relative_motion and driving_option == "expression-friendly": |
|
if i == 0: |
|
x_d_0_new = x_d_i_new |
|
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new) |
|
x_d_diff = (x_d_i_new - x_d_0_new) * motion_multiplier |
|
x_d_i_new = x_d_diff + x_s |
|
|
|
|
|
if not flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting: |
|
|
|
if flag_normalize_lip and lip_delta_before_animation is not None: |
|
x_d_i_new += lip_delta_before_animation |
|
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None: |
|
x_d_i_new += eye_delta_before_animation |
|
else: |
|
pass |
|
elif flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting: |
|
|
|
if flag_normalize_lip and lip_delta_before_animation is not None: |
|
x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation |
|
else: |
|
x_d_i_new = self.stitching(x_s, x_d_i_new) |
|
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None: |
|
x_d_i_new += eye_delta_before_animation |
|
else: |
|
eyes_delta, lip_delta = None, None |
|
if flag_eye_retargeting and s_lmk is not None and c_d_eyes_lst is not None: |
|
c_d_eyes_i = c_d_eyes_lst[i] |
|
combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, s_lmk) |
|
eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor) |
|
|
|
if flag_lip_retargeting and s_lmk is not None and c_d_lip_lst is not None: |
|
c_d_lip_i = c_d_lip_lst[i] |
|
combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, s_lmk) |
|
|
|
lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor) |
|
|
|
if flag_relative_motion: |
|
x_d_i_new = x_s + \ |
|
(eyes_delta if eyes_delta is not None else 0) + \ |
|
(lip_delta if lip_delta is not None else 0) |
|
else: |
|
x_d_i_new = x_d_i_new + \ |
|
(eyes_delta if eyes_delta is not None else 0) + \ |
|
(lip_delta if lip_delta is not None else 0) |
|
|
|
if flag_stitching: |
|
x_d_i_new = self.stitching(x_s, x_d_i_new) |
|
|
|
x_d_i_new = x_s + (x_d_i_new - x_s) * driving_multiplier |
|
x_d_i_news.append(x_d_i_new) |
|
f_s_s= f_s.expand(n_frames, *f_s.shape[1:]) |
|
x_s_s = x_s.expand(n_frames, *x_s.shape[1:]) |
|
x_d_i_new = torch.cat(x_d_i_news, dim=0) |
|
for start in range(0, n_frames, 100): |
|
end = min(start + 100,n_frames) |
|
with torch.no_grad(), torch.autocast('cuda'): |
|
out = self.warp_decode(f_s_s[start:end], x_s_s[start:end], x_d_i_new[start:end]) |
|
I_p_lst.append(out['out']) |
|
I_p=torch.cat(I_p_lst, dim=0) |
|
I_p_i = self.parse_output(I_p) |
|
return I_p_i |
|
|
|
def driven_debug(self, f_s, x_s_info, s_lmk, c_s_eyes_lst, driving_template_dct, c_d_eyes_lst=None, c_d_lip_lst=None): |
|
|
|
x_s_info = self.refine_kp(x_s_info) |
|
x_c_s = x_s_info["kp"] |
|
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) |
|
x_s = self.transform_keypoint(x_s_info) |
|
|
|
n_frames = driving_template_dct['n_frames'] |
|
|
|
|
|
flag_normalize_lip = self.cfg.flag_normalize_lip |
|
flag_relative_motion = self.cfg.flag_relative_motion |
|
flag_source_video_eye_retargeting = self.cfg.flag_source_video_eye_retargeting |
|
lip_normalize_threshold = self.cfg.lip_normalize_threshold |
|
source_video_eye_retargeting_threshold = self.cfg.source_video_eye_retargeting_threshold |
|
animation_region = self.cfg.animation_region |
|
driving_option = self.cfg.driving_option |
|
flag_stitching = self.cfg.flag_stitching |
|
flag_eye_retargeting = self.cfg.flag_eye_retargeting |
|
flag_lip_retargeting = self.cfg.flag_lip_retargeting |
|
driving_multiplier = self.cfg.driving_multiplier |
|
|
|
|
|
lip_delta_before_animation, eye_delta_before_animation = None, None |
|
if flag_normalize_lip and flag_relative_motion and s_lmk is not None: |
|
c_d_lip_before_animation = [0.] |
|
combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(c_d_lip_before_animation, s_lmk) |
|
if combined_lip_ratio_tensor_before_animation[0][0] >= lip_normalize_threshold: |
|
lip_delta_before_animation = self.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) |
|
|
|
|
|
if flag_source_video_eye_retargeting and s_lmk is not None: |
|
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0] |
|
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]] |
|
if c_d_eye_before_animation_frame_zero[0][0] < source_video_eye_retargeting_threshold: |
|
c_d_eye_before_animation_frame_zero = [[0.39]] |
|
combined_eye_ratio_tensor_before_animation = self.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, s_lmk) |
|
eye_delta_before_animation = self.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation) |
|
|
|
|
|
I_p_lst = [] |
|
for i in range(n_frames): |
|
x_d_i_info = driving_template_dct['motion'][i] |
|
x_d_i_info = dct2device(x_d_i_info, self.device) |
|
|
|
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] |
|
if i == 0: |
|
R_d_0 = R_d_i |
|
x_d_0_info = x_d_i_info.copy() |
|
|
|
|
|
|
|
|
|
delta_new = x_s_info['exp'].clone() |
|
if flag_relative_motion: |
|
|
|
if animation_region == "all" or animation_region == "pose": |
|
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s |
|
else: |
|
R_new = R_s |
|
|
|
|
|
if animation_region == "all" or animation_region == "exp": |
|
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) |
|
elif animation_region == "lip": |
|
for lip_idx in [6, 12, 14, 17, 19, 20]: |
|
delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :] |
|
elif animation_region == "eyes": |
|
for eyes_idx in [11, 13, 15, 16, 18]: |
|
delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :] |
|
|
|
|
|
if animation_region == "all": |
|
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) |
|
else: |
|
scale_new = x_s_info['scale'] |
|
|
|
|
|
if animation_region == "all" or animation_region == "pose": |
|
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) |
|
else: |
|
t_new = x_s_info['t'] |
|
else: |
|
|
|
if animation_region == "all" or animation_region == "pose": |
|
R_new = R_d_i |
|
else: |
|
R_new = R_s |
|
|
|
|
|
if animation_region == "all" or animation_region == "exp": |
|
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]: |
|
delta_new[:, idx, :] = x_d_i_info['exp'][:, idx, :] |
|
delta_new[:, 3:5, 1] = x_d_i_info['exp'][:, 3:5, 1] |
|
delta_new[:, 5, 2] = x_d_i_info['exp'][:, 5, 2] |
|
delta_new[:, 8, 2] = x_d_i_info['exp'][:, 8, 2] |
|
delta_new[:, 9, 1:] = x_d_i_info['exp'][:, 9, 1:] |
|
elif animation_region == "lip": |
|
for lip_idx in [6, 12, 14, 17, 19, 20]: |
|
delta_new[:, lip_idx, :] = x_d_i_info['exp'][:, lip_idx, :] |
|
elif animation_region == "eyes": |
|
for eyes_idx in [11, 13, 15, 16, 18]: |
|
delta_new[:, eyes_idx, :] = x_d_i_info['exp'][:, eyes_idx, :] |
|
|
|
|
|
scale_new = x_s_info['scale'] |
|
|
|
|
|
if animation_region == "all" or animation_region == "pose": |
|
t_new = x_d_i_info['t'] |
|
else: |
|
t_new = x_s_info['t'] |
|
|
|
t_new[..., 2].fill_(0) |
|
|
|
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new |
|
|
|
if flag_relative_motion and driving_option == "expression-friendly": |
|
if i == 0: |
|
x_d_0_new = x_d_i_new |
|
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new) |
|
x_d_diff = (x_d_i_new - x_d_0_new) * motion_multiplier |
|
x_d_i_new = x_d_diff + x_s |
|
|
|
|
|
if not flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting: |
|
|
|
if flag_normalize_lip and lip_delta_before_animation is not None: |
|
x_d_i_new += lip_delta_before_animation |
|
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None: |
|
x_d_i_new += eye_delta_before_animation |
|
else: |
|
pass |
|
elif flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting: |
|
|
|
if flag_normalize_lip and lip_delta_before_animation is not None: |
|
x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation |
|
else: |
|
x_d_i_new = self.stitching(x_s, x_d_i_new) |
|
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None: |
|
x_d_i_new += eye_delta_before_animation |
|
else: |
|
eyes_delta, lip_delta = None, None |
|
if flag_eye_retargeting and s_lmk is not None and c_d_eyes_lst is not None: |
|
c_d_eyes_i = c_d_eyes_lst[i] |
|
combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, s_lmk) |
|
eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor) |
|
|
|
if flag_lip_retargeting and s_lmk is not None and c_d_lip_lst is not None: |
|
c_d_lip_i = c_d_lip_lst[i] |
|
combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, s_lmk) |
|
|
|
lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor) |
|
|
|
if flag_relative_motion: |
|
x_d_i_new = x_s + \ |
|
(eyes_delta if eyes_delta is not None else 0) + \ |
|
(lip_delta if lip_delta is not None else 0) |
|
else: |
|
x_d_i_new = x_d_i_new + \ |
|
(eyes_delta if eyes_delta is not None else 0) + \ |
|
(lip_delta if lip_delta is not None else 0) |
|
|
|
if flag_stitching: |
|
x_d_i_new = self.stitching(x_s, x_d_i_new) |
|
|
|
x_d_i_new = x_s + (x_d_i_new - x_s) * driving_multiplier |
|
out = self.warp_decode(f_s, x_s, x_d_i_new) |
|
I_p_i = self.parse_output(out['out'])[0] |
|
I_p_lst.append(I_p_i) |
|
|
|
return I_p_lst |
|
|
|
def read_image(self, image_path: str) -> list: |
|
img_rgb = load_image_rgb(image_path) |
|
img_rgb = resize_to_limit(img_rgb, self.cfg.source_max_dim, self.cfg.source_division) |
|
source_rgb_list = [img_rgb] |
|
print(f"Load image from {osp.realpath(image_path)} done.") |
|
return source_rgb_list |
|
|
|
def read_video(self, video_path: str, interval=None) -> list: |
|
vr = VideoReader(video_path) |
|
if interval is not None: |
|
video_frames = vr.get_batch(np.arange(0, len(vr), interval)).numpy() |
|
else: |
|
video_frames = [vr[0].numpy(), vr[len(vr) // 2].numpy(), vr[-1].numpy()] |
|
vr.seek(0) |
|
driving_rgb_list = [] |
|
for video_frame in video_frames: |
|
|
|
|
|
|
|
driving_rgb_list.append(video_frame) |
|
|
|
return driving_rgb_list |
|
|
|
def prepare_videos(self, imgs) -> torch.Tensor: |
|
""" construct the input as standard |
|
imgs: NxBxHxWx3, uint8 |
|
""" |
|
if isinstance(imgs, list): |
|
_imgs = np.array(imgs)[..., np.newaxis] |
|
elif isinstance(imgs, np.ndarray): |
|
_imgs = imgs |
|
else: |
|
raise ValueError(f'imgs type error: {type(imgs)}') |
|
|
|
y = _imgs.astype(np.float32) / 255. |
|
y = np.clip(y, 0, 1) |
|
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) |
|
y = y.to(self.device) |
|
|
|
return y |
|
|
|
def make_motion_template(self, I_lst, c_eyes_lst, c_lip_lst, **kwargs): |
|
n_frames = I_lst.shape[0] |
|
template_dct = { |
|
'n_frames': n_frames, |
|
'output_fps': kwargs.get('output_fps', 25), |
|
'motion': [], |
|
'c_eyes_lst': [], |
|
'c_lip_lst': [], |
|
} |
|
|
|
for i in track(range(n_frames), description='Making motion templates...', total=n_frames): |
|
|
|
I_i = I_lst[i] |
|
x_i_info = self.refine_kp(self.get_kp_info(I_i)) |
|
x_s = self.transform_keypoint(x_i_info) |
|
R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll']) |
|
|
|
item_dct = { |
|
'scale': x_i_info['scale'].cpu().numpy().astype(np.float32), |
|
'R': R_i.cpu().numpy().astype(np.float32), |
|
'exp': x_i_info['exp'].cpu().numpy().astype(np.float32), |
|
't': x_i_info['t'].cpu().numpy().astype(np.float32), |
|
'kp': x_i_info['kp'].cpu().numpy().astype(np.float32), |
|
'x_s': x_s.cpu().numpy().astype(np.float32), |
|
} |
|
|
|
template_dct['motion'].append(item_dct) |
|
|
|
c_eyes = c_eyes_lst[i].astype(np.float32) |
|
template_dct['c_eyes_lst'].append(c_eyes) |
|
|
|
c_lip = c_lip_lst[i].astype(np.float32) |
|
template_dct['c_lip_lst'].append(c_lip) |
|
|
|
return template_dct |
|
|
|
def load_template(self, wfp_template): |
|
print(f"Load from template: {wfp_template}, NOT the video, so the cropping video and audio are both NULL.") |
|
driving_template_dct = load(wfp_template) |
|
c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] |
|
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst'] |
|
driving_n_frames = driving_template_dct['n_frames'] |
|
flag_is_driving_video = True if driving_n_frames > 1 else False |
|
n_frames = driving_n_frames |
|
|
|
|
|
output_fps = driving_template_dct.get('output_fps', 25) |
|
print(f'The FPS of template: {output_fps}') |
|
return driving_template_dct |
|
|
|
def reconstruction(self, src_img, dst_imgs, video_path="template"): |
|
|
|
src_img_256x256, s_lmk, _ = self.crop_image(src_img, do_crop=False) |
|
|
|
c_s_eyes_lst = None |
|
f_s, x_s_info = self.prepare_source(src_img_256x256) |
|
|
|
|
|
dst_imgs_256x256, d_lmk_lst = self.crop_driving_videos(dst_imgs, do_crop=False) |
|
c_d_eyes_lst, c_d_lip_lst = self.calc_ratio(d_lmk_lst) |
|
kp_infos = self.prepare_driving_videos(dst_imgs_256x256) |
|
|
|
|
|
recs = self.driven(f_s, x_s_info, s_lmk, c_s_eyes_lst, kp_infos, c_d_eyes_lst, c_d_lip_lst) |
|
return recs |
|
|
|
def save_results(self, results, save_path, audio_path=None): |
|
save_dir = osp.dirname(save_path) |
|
save_name = osp.basename(save_path) |
|
final_video = osp.join(save_dir, f'final_{save_name}') |
|
|
|
images2video(results, wfp=save_path, fps=self.cfg.output_fps) |
|
|
|
if audio_path is not None: |
|
add_audio_to_video(save_path, audio_path, final_video) |
|
os.remove(save_path) |
|
|
|
def rec_score(self, video_path: str, interval=None, save_path=None): |
|
video_frames = self.read_video(video_path, interval=interval) |
|
|
|
recs = self.reconstruction(video_frames[0], video_frames[1:], video_path) |
|
if save_path is not None: |
|
self.save_results(recs, save_path) |
|
|
|
psnrs = psnr(video_frames[1:], recs) |
|
psnrs_np = np.array(psnrs) |
|
psnr_mean, psnr_std = np.mean(psnrs_np), np.std(psnrs_np) |
|
rec_score = {"mean": psnr_mean, "std": psnr_std} |
|
return rec_score |
|
|
|
@torch.no_grad() |
|
def paste_back_by_face_mask(self, result, crop_info, src_img, crop_src_image, use_laplacian=False): |
|
""" |
|
paste back the result to the original image with face mask |
|
""" |
|
|
|
crop_src_tensor = self.to_tensor(crop_src_image).unsqueeze(0).to(self.device) |
|
src_msks = get_face_mask(self.face_parser, crop_src_tensor) |
|
result_tensor = self.to_tensor(result).unsqueeze(0).to(self.device) |
|
result_msks = get_face_mask(self.face_parser, result_tensor) |
|
|
|
masks = [] |
|
for src_msk, result_msk in zip(src_msks, result_msks): |
|
mask = np.clip(src_msk + result_msk, 0, 1) |
|
masks.append(mask) |
|
result = paste_back_with_face_mask(result, crop_info, src_img, masks[0], use_laplacian=use_laplacian) |
|
return result |
|
|
|
def driven_by_audio(self, src_img, kp_infos, save_path, audio_path=None, smooth=False): |
|
|
|
|
|
src_img_256x256, s_lmk, crop_info = self.crop_image(src_img, do_crop=True) |
|
|
|
c_s_eyes_lst = None |
|
f_s, x_s_info = self.prepare_source(src_img_256x256) |
|
|
|
mask_ori_float = prepare_paste_back(self.mask_crop, crop_info['M_c2o'], dsize=(src_img.shape[1], src_img.shape[0])) |
|
|
|
|
|
results = self.driven(f_s, x_s_info, s_lmk, c_s_eyes_lst, kp_infos, smooth=smooth) |
|
frames=results.shape[0] |
|
results = [paste_back(results[i], crop_info['M_c2o'], src_img, mask_ori_float) for i in range(frames)] |
|
self.save_results(results, save_path, audio_path) |
|
def mix_kp_infos(self, emo_kp_infos, lip_kp_infos, smooth=False, dtype="pt_tensor"): |
|
driving_emo_template_dct = self.get_driving_template(emo_kp_infos, smooth=False, dtype=dtype) |
|
if lip_kp_infos is not None: |
|
driving_lip_template_dct = self.get_driving_template(lip_kp_infos, smooth=smooth, dtype=dtype) |
|
driving_template_dct = {**driving_emo_template_dct} |
|
n_frames = min(driving_emo_template_dct['n_frames'], driving_lip_template_dct['n_frames']) |
|
driving_template_dct['n_frames'] = n_frames |
|
for i in range(n_frames): |
|
emo_motion = driving_emo_template_dct['motion'][i]['exp'] |
|
lib_motion = driving_lip_template_dct['motion'][i]['exp'] |
|
for lip_idx in [6, 12, 14, 17, 19, 20]: |
|
emo_motion[:, lip_idx, :] = lib_motion[:, lip_idx, :] |
|
driving_template_dct['motion'][i]['exp'] = emo_motion |
|
else: |
|
driving_template_dct = driving_emo_template_dct |
|
|
|
return driving_template_dct |
|
|
|
def driven_by_mix(self, src_img, driving_video_path, kp_infos, save_path, audio_path=None, smooth=False): |
|
|
|
src_img_256x256, s_lmk, crop_info = self.crop_image(src_img, do_crop=True) |
|
c_s_eyes_lst, c_s_lip_lst = self.calc_ratio([s_lmk]) |
|
f_s, x_s_info = self.prepare_source(src_img_256x256) |
|
mask_ori_float = prepare_paste_back(self.mask_crop, crop_info['M_c2o'], dsize=(src_img.shape[1], src_img.shape[0])) |
|
|
|
driving_imgs = self.read_video(driving_video_path, interval=1) |
|
dst_imgs_256x256, d_lmk_lst = self.crop_driving_videos(driving_imgs, do_crop=True) |
|
c_d_eyes_lst, c_d_lip_lst = self.calc_ratio(d_lmk_lst) |
|
emo_kp_infos = self.prepare_driving_videos(dst_imgs_256x256) |
|
|
|
driving_template_dct = self.mix_kp_infos(emo_kp_infos, kp_infos, smooth=smooth) |
|
|
|
results = self.driven_debug(f_s, x_s_info, s_lmk, c_s_eyes_lst, driving_template_dct, c_d_eyes_lst=c_d_eyes_lst, c_d_lip_lst=c_d_lip_lst) |
|
results = [paste_back(result, crop_info['M_c2o'], src_img, mask_ori_float) for result in results] |
|
print(results.shape) |
|
self.save_results(results, save_path, audio_path) |
|
|
|
def drive_video_by_mix(self, video_path, driving_video_path, kp_infos, save_path, audio_path): |
|
|
|
driving_imgs = self.read_video(driving_video_path, interval=1) |
|
dst_imgs_256x256, d_lmk_lst = self.crop_driving_videos(driving_imgs, do_crop=True) |
|
emo_kp_infos = self.prepare_driving_videos(dst_imgs_256x256) |
|
|
|
|
|
driving_template_dct = self.mix_kp_infos(emo_kp_infos, kp_infos, smooth=True, dtype="np") |
|
|
|
self.video_lip_retargeting( |
|
video_path, None, |
|
save_path, audio_path, |
|
driving_template_dct=driving_template_dct, retargeting_ragion="exp" |
|
) |
|
|
|
def load_source_video(self, video_info, n_frames=-1): |
|
reader = imageio.get_reader(video_info, "ffmpeg") |
|
|
|
ret = [] |
|
for idx, frame_rgb in enumerate(reader): |
|
if n_frames > 0 and idx >= n_frames: |
|
break |
|
ret.append(frame_rgb) |
|
|
|
reader.close() |
|
|
|
return ret |
|
|
|
def video_lip_retargeting(self, video_path, kp_infos, save_path, audio_path, c_d_eyes_lst=None, c_d_lip_lst=None, smooth=False, driving_template_dct=None, retargeting_ragion="exp"): |
|
|
|
source_rgb_lst = load_video(video_path) |
|
source_rgb_lst = [resize_to_limit(img, self.cfg.source_max_dim, self.cfg.source_division) for img in source_rgb_lst] |
|
img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = self.crop_source_video(source_rgb_lst, do_crop=True) |
|
c_s_eyes_lst, c_s_lip_lst = self.calc_ratio(source_lmk_crop_lst) |
|
I_s_lst = self.prepare_videos(img_crop_256x256_lst) |
|
source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=25) |
|
|
|
if driving_template_dct is None: |
|
driving_template_dct = self.get_driving_template(kp_infos, smooth=smooth, dtype="np") |
|
|
|
n_frames = min(source_template_dct['n_frames'], driving_template_dct['n_frames']) |
|
|
|
I_p_lst = [] |
|
I_p_pstbk_lst = [] |
|
R_d_0, x_d_0_info = None, None |
|
flag_normalize_lip = self.cfg.flag_normalize_lip |
|
flag_relative_motion = True |
|
flag_source_video_eye_retargeting = self.cfg.flag_source_video_eye_retargeting |
|
lip_normalize_threshold = self.cfg.lip_normalize_threshold |
|
source_video_eye_retargeting_threshold = self.cfg.source_video_eye_retargeting_threshold |
|
animation_region = 'lip' |
|
driving_option = self.cfg.driving_option |
|
flag_stitching = self.cfg.flag_stitching |
|
flag_eye_retargeting = self.cfg.flag_eye_retargeting |
|
flag_lip_retargeting = self.cfg.flag_lip_retargeting |
|
driving_multiplier = self.cfg.driving_multiplier |
|
driving_smooth_observation_variance = self.cfg.driving_smooth_observation_variance |
|
|
|
key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' |
|
if flag_relative_motion: |
|
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)] |
|
for i in range(n_frames): |
|
for idx in [6, 12, 14, 17, 19, 20]: |
|
|
|
x_d_exp_lst[i][:, idx, :] = driving_template_dct['motion'][i]['exp'][:, idx, :] |
|
x_d_exp_lst_smooth = ksmooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, self.device, driving_smooth_observation_variance) |
|
|
|
if animation_region == "all" or animation_region == "pose" or "all" in animation_region: |
|
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)] |
|
x_d_r_lst_smooth = ksmooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, self.device, driving_smooth_observation_variance) |
|
else: |
|
x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)] |
|
x_d_exp_lst_smooth = ksmooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, self.device, driving_smooth_observation_variance) |
|
|
|
if animation_region == "all" or animation_region == "pose" or "all" in animation_region: |
|
x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)] |
|
x_d_r_lst_smooth = ksmooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, self.device, driving_smooth_observation_variance) |
|
|
|
|
|
for i in track(range(n_frames), description='🚀Retargeting...', total=n_frames): |
|
x_s_info = source_template_dct['motion'][i] |
|
x_s_info = dct2device(x_s_info, self.device) |
|
|
|
source_lmk = source_lmk_crop_lst[i] |
|
img_crop_256x256 = img_crop_256x256_lst[i] |
|
I_s = I_s_lst[i] |
|
f_s = self.extract_feature_3d(I_s) |
|
|
|
x_c_s = x_s_info['kp'] |
|
R_s = x_s_info['R'] |
|
x_s =x_s_info['x_s'] |
|
|
|
|
|
lip_delta_before_animation = None |
|
if flag_normalize_lip and flag_relative_motion and source_lmk is not None: |
|
c_d_lip_before_animation = [0.] |
|
combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk) |
|
if combined_lip_ratio_tensor_before_animation[0][0] >= lip_normalize_threshold: |
|
lip_delta_before_animation = self.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) |
|
else: |
|
lip_delta_before_animation = None |
|
|
|
|
|
eye_delta_before_animation = None |
|
if flag_source_video_eye_retargeting and source_lmk is not None: |
|
if i == 0: |
|
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0] |
|
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]] |
|
if c_d_eye_before_animation_frame_zero[0][0] < source_video_eye_retargeting_threshold: |
|
c_d_eye_before_animation_frame_zero = [[0.39]] |
|
combined_eye_ratio_tensor_before_animation = self.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, source_lmk) |
|
eye_delta_before_animation = self.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation) |
|
|
|
if flag_stitching: |
|
mask_ori_float = prepare_paste_back(self.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0])) |
|
|
|
x_d_i_info = driving_template_dct['motion'][i] |
|
x_d_i_info = dct2device(x_d_i_info, self.device) |
|
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] |
|
|
|
if i == 0: |
|
R_d_0 = R_d_i |
|
x_d_0_info = x_d_i_info.copy() |
|
|
|
delta_new = x_s_info['exp'].clone() |
|
if flag_relative_motion: |
|
if animation_region == "all" or animation_region == "pose" or "all" in animation_region: |
|
R_new = x_d_r_lst_smooth[i] |
|
else: |
|
R_new = R_s |
|
if animation_region == "all" or animation_region == "exp": |
|
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]: |
|
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :] |
|
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1] |
|
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2] |
|
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2] |
|
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:] |
|
elif animation_region == "all_wo_lip" or animation_region == "exp_wo_lip": |
|
for idx in [1, 2, 11, 13, 15, 16, 18]: |
|
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :] |
|
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1] |
|
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2] |
|
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2] |
|
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:] |
|
elif animation_region == "lip": |
|
for lip_idx in [6, 12, 14, 17, 19, 20]: |
|
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] |
|
elif animation_region == "eyes": |
|
for eyes_idx in [11, 13, 15, 16, 18]: |
|
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :] |
|
|
|
scale_new = x_s_info['scale'] |
|
t_new = x_s_info['t'] |
|
else: |
|
if animation_region == "all" or animation_region == "pose" or "all" in animation_region: |
|
R_new = x_d_r_lst_smooth[i] |
|
else: |
|
R_new = R_s |
|
if animation_region == "all" or animation_region == "exp": |
|
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]: |
|
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :] |
|
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1] |
|
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2] |
|
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2] |
|
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:] |
|
elif animation_region == "all_wo_lip" or animation_region == "exp_wo_lip": |
|
for idx in [1, 2, 11, 13, 15, 16, 18]: |
|
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :] |
|
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1] |
|
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2] |
|
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2] |
|
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:] |
|
elif animation_region == "lip": |
|
for lip_idx in [6, 12, 14, 17, 19, 20]: |
|
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] |
|
elif animation_region == "eyes": |
|
for eyes_idx in [11, 13, 15, 16, 18]: |
|
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :] |
|
scale_new = x_s_info['scale'] |
|
if animation_region == "all" or animation_region == "pose" or "all" in animation_region: |
|
t_new = x_d_i_info['t'] |
|
else: |
|
t_new = x_s_info['t'] |
|
|
|
t_new[..., 2].fill_(0) |
|
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new |
|
|
|
|
|
if not flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting: |
|
|
|
if flag_normalize_lip and lip_delta_before_animation is not None: |
|
x_d_i_new += lip_delta_before_animation |
|
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None: |
|
x_d_i_new += eye_delta_before_animation |
|
else: |
|
pass |
|
elif flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting: |
|
|
|
if flag_normalize_lip and lip_delta_before_animation is not None: |
|
x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation |
|
else: |
|
x_d_i_new = self.stitching(x_s, x_d_i_new) |
|
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None: |
|
x_d_i_new += eye_delta_before_animation |
|
else: |
|
eyes_delta, lip_delta = None, None |
|
if flag_eye_retargeting and source_lmk is not None and c_d_eyes_lst is not None: |
|
c_d_eyes_i = c_d_eyes_lst[i] |
|
combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, source_lmk) |
|
|
|
eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor) |
|
if flag_lip_retargeting and source_lmk is not None and c_d_lip_lst is not None: |
|
c_d_lip_i = c_d_lip_lst[i] |
|
combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, source_lmk) |
|
|
|
lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor) |
|
|
|
if flag_relative_motion: |
|
x_d_i_new = x_s + \ |
|
(eyes_delta if eyes_delta is not None else 0) + \ |
|
(lip_delta if lip_delta is not None else 0) |
|
else: |
|
x_d_i_new = x_d_i_new + \ |
|
(eyes_delta if eyes_delta is not None else 0) + \ |
|
(lip_delta if lip_delta is not None else 0) |
|
|
|
if flag_stitching: |
|
x_d_i_new = self.stitching(x_s, x_d_i_new) |
|
|
|
x_d_i_new = x_s + (x_d_i_new - x_s) * driving_multiplier |
|
out = self.warp_decode(f_s, x_s, x_d_i_new) |
|
I_p_i = self.parse_output(out['out'])[0] |
|
I_p_lst.append(I_p_i) |
|
|
|
if flag_stitching: |
|
|
|
|
|
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_float, use_laplacian=True) |
|
I_p_pstbk_lst.append(I_p_pstbk) |
|
|
|
if len(I_p_pstbk_lst) > 0: |
|
self.save_results(I_p_pstbk_lst, save_path, audio_path) |
|
else: |
|
self.save_results(I_p_lst, save_path, audio_path) |
|
|
|
@torch.no_grad() |
|
def video_reconstruction_test(self, video_tensor, xs, save_path): |
|
|
|
|
|
result_lst = [] |
|
|
|
video_tensor = video_tensor[0:1] * 0.5 + 0.5 |
|
video_tensor = torch.clip(video_tensor, 0, 1) |
|
video_tensor = video_tensor.permute(1, 0, 2, 3, 4) |
|
video = video_tensor.to(self.device) |
|
xs = xs[0:1].permute(1, 0, 2) |
|
xs = xs.reshape(-1, 1, 21, 3) |
|
xs = xs.to(self.device) |
|
|
|
x_s_0 = xs[0] |
|
I_s_0 = torch.nn.functional.interpolate(video[0], size=(256, 256), mode='bilinear') |
|
f_s_0 = self.extract_feature_3d(I_s_0) |
|
|
|
for i in range(video_tensor.shape[0]): |
|
|
|
|
|
x_s = self.stitching(x_s_0, xs[i]) |
|
out = self.warp_decode(f_s_0, x_s_0, x_s) |
|
I_p_i = self.parse_output(out['out'])[0] |
|
result_lst.append(I_p_i) |
|
|
|
|
|
|
|
|
|
self.save_results(result_lst, save_path, audio_path=None) |
|
|
|
|
|
@torch.no_grad() |
|
def self_driven(self, image_tensor, xs, save_path, length): |
|
result_lst = [] |
|
image_tensor = image_tensor[0:1] * 0.5 + 0.5 |
|
image_tensor = torch.clip(image_tensor, 0, 1) |
|
image = image_tensor.to(self.device) |
|
I_s_0 = torch.nn.functional.interpolate(image, size=(256, 256), mode='bilinear') |
|
|
|
xs = xs[0:1].permute(1, 0, 2) |
|
xs = xs.reshape(-1, 1, 21, 3) |
|
xs = xs.to(self.device) |
|
|
|
x_s_0 = xs[0] |
|
f_s_0 = self.extract_feature_3d(I_s_0) |
|
|
|
for i in range(xs.shape[0]): |
|
x_d = self.stitching(x_s_0, xs[i]) |
|
out = self.warp_decode(f_s_0, x_s_0, x_d) |
|
I_p_i = self.parse_output(out['out'])[0] |
|
result_lst.append(I_p_i) |
|
|
|
assert len(result_lst) == length, f"length of result_lst is {len(result_lst)}, but length is {length}" |
|
|
|
self.save_results(result_lst, save_path, audio_path=None) |
|
|
|
|
|
|