|
|
|
import os.path as osp |
|
|
|
import math |
|
from rich.progress import track |
|
|
|
from omegaconf import OmegaConf |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .talking_head_dit import TalkingHeadDiT_models |
|
import sys |
|
from ..schedulers.scheduling_ddim import DDIMScheduler |
|
from ..schedulers.flow_matching import ModelSamplingDiscreteFlow |
|
sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))))) |
|
scheduler_map = { |
|
"ddim": DDIMScheduler, |
|
|
|
"flow_matching": ModelSamplingDiscreteFlow |
|
} |
|
lip_dims=[18, 19, 20, 36, 37, 38, 42, 43, 44, 51, 52, 53, 57, 58, 59, 60, 61, 62] |
|
|
|
class MotionDiffusion(nn.Module): |
|
def __init__(self, config, device="cuda", dtype=torch.float32, smo_wsize=3, loss_type="l2"): |
|
super().__init__() |
|
|
|
self.config = config |
|
self.smo_wsize = smo_wsize |
|
print(f"================================== Init Motion GeneratorV2 ==================================") |
|
print(OmegaConf.to_yaml(self.config)) |
|
|
|
motion_gen_config = config.motion_generator |
|
motion_gen_params = motion_gen_config.params |
|
|
|
audio_proj_config = config.audio_projector |
|
audio_proj_params = audio_proj_config.params |
|
|
|
scheduler_config = config.noise_scheduler |
|
scheduler_params = scheduler_config.params |
|
|
|
self.device = device |
|
|
|
|
|
self.talking_head_dit = TalkingHeadDiT_models[config.model_name]( |
|
input_dim = motion_gen_params.input_dim * 2, |
|
output_dim = motion_gen_params.output_dim, |
|
seq_len = motion_gen_params.n_pred_frames, |
|
audio_unit_len = audio_proj_params.sequence_length, |
|
audio_blocks = audio_proj_params.blocks, |
|
audio_dim = audio_proj_params.audio_feat_dim, |
|
audio_tokens = audio_proj_params.context_tokens, |
|
audio_embedder_type = audio_proj_params.audio_embedder_type, |
|
audio_cond_dim = audio_proj_params.audio_cond_dim, |
|
norm_type = motion_gen_params.norm_type, |
|
qk_norm = motion_gen_params.qk_norm, |
|
exp_dim = motion_gen_params.exp_dim |
|
) |
|
self.input_dim = motion_gen_params.input_dim |
|
self.exp_dim = motion_gen_params.exp_dim |
|
|
|
self.audio_feat_dim = audio_proj_params.audio_feat_dim |
|
self.audio_seq_len = audio_proj_params.sequence_length |
|
self.audio_blocks = audio_proj_params.blocks |
|
self.audio_margin = (audio_proj_params.sequence_length - 1) // 2 |
|
self.indices = ( |
|
torch.arange(2 * self.audio_margin + 1) - self.audio_margin |
|
).unsqueeze(0) |
|
|
|
self.n_prev_frames = motion_gen_params.n_prev_frames |
|
self.n_pred_frames = motion_gen_params.n_pred_frames |
|
|
|
|
|
self.scheduler = scheduler_map[scheduler_config.type]( |
|
num_train_timesteps = scheduler_params.num_train_timesteps, |
|
beta_start = scheduler_params.beta_start, |
|
beta_end = scheduler_params.beta_end, |
|
beta_schedule = scheduler_params.mode, |
|
prediction_type = scheduler_config.sample_mode, |
|
time_shifting = scheduler_params.time_shifting, |
|
) |
|
self.scheduler_type = scheduler_config.type |
|
self.eta = scheduler_params.eta |
|
self.scheduler.set_timesteps(scheduler_params.num_inference_steps, device=self.device) |
|
self.timesteps = self.scheduler.timesteps |
|
print(f"time steps: {self.timesteps}") |
|
|
|
self.sample_mode = scheduler_config.sample_mode |
|
assert (self.sample_mode in ["noise", "sample"], f"Unknown sample mode {self.sample_mode}, should be noise or sample") |
|
|
|
|
|
self.audio_drop_ratio = config.train.audio_drop_ratio |
|
self.pre_drop_ratio = config.train.pre_drop_ratio |
|
|
|
self.null_audio_feat = nn.Parameter( |
|
torch.randn(1, 1, 1, 1, self.audio_feat_dim), |
|
requires_grad=True |
|
).to(device=self.device, dtype=dtype) |
|
|
|
self.null_motion_feat = nn.Parameter( |
|
torch.randn(1, 1, self.input_dim), |
|
requires_grad=True |
|
).to(device=self.device, dtype=dtype) |
|
|
|
|
|
self.overlap_len = min(16, self.n_pred_frames - 16) |
|
self.fuse_alpha = torch.arange(self.overlap_len, device=self.device, dtype=dtype).reshape(1, -1, 1) / self.overlap_len |
|
|
|
self.dtype = dtype |
|
self.loss_type = loss_type |
|
|
|
total_params = sum(p.numel() for p in self.parameters()) |
|
print('Number of parameter: % .4fM' % (total_params / 1e6)) |
|
print(f"================================== init Motion GeneratorV2: Done ==================================") |
|
|
|
def _smooth(self, motion): |
|
|
|
if self.smo_wsize <= 1: |
|
return motion |
|
new_motion = motion.clone() |
|
n = motion.shape[1] |
|
half_k = self.smo_wsize // 2 |
|
for i in range(n): |
|
ss = max(0, i - half_k) |
|
ee = min(n, i + half_k + 1) |
|
|
|
motion[:, i, self.exp_dim:] = torch.mean(new_motion[:, ss:ee, self.exp_dim:], dim=1) |
|
|
|
return motion |
|
|
|
def _fuse(self, prev_motion, cur_motion): |
|
r1 = prev_motion[:, -self.overlap_len:] |
|
r2 = cur_motion[:, :self.overlap_len] |
|
r_fuse = r1 * (1 - self.fuse_alpha) + r2 * self.fuse_alpha |
|
|
|
prev_motion[:, -self.overlap_len:] = r_fuse |
|
return prev_motion |
|
|
|
@torch.no_grad() |
|
def sample_subclip( |
|
self, |
|
audio, |
|
ref_kp, |
|
prev_motion, |
|
emo=None, |
|
cfg_scale=1.15, |
|
init_latents=None, |
|
dynamic_threshold = None |
|
): |
|
|
|
batch_size = audio.shape[0] |
|
audio = audio.to(self.device) |
|
if audio.ndim == 4: |
|
audio = audio.unsqueeze(2) |
|
|
|
|
|
ref_kp = ref_kp.view(batch_size, 1, -1) |
|
|
|
|
|
if cfg_scale > 1: |
|
uncond_audio = self.null_audio_feat.expand( |
|
batch_size, self.n_pred_frames, self.audio_seq_len, self.audio_blocks, -1 |
|
) |
|
audio = torch.cat([uncond_audio,audio], dim=0) |
|
ref_kp = torch.cat([ref_kp] * 2, dim=0) |
|
if emo is not None: |
|
uncond_emo = torch.Tensor([self.talking_head_dit.num_emo_class]).long().to(self.device) |
|
emo = torch.cat([uncond_emo,emo], dim=0) |
|
ref_kp = ref_kp.repeat(1, audio.shape[1], 1) |
|
|
|
|
|
if init_latents is None: |
|
latents = torch.randn((batch_size, self.n_pred_frames, self.input_dim)).to(self.device) |
|
else: |
|
latents = init_latents |
|
|
|
prev_motion = prev_motion.expand_as(latents).to(dtype=self.dtype) |
|
latents = latents.to(dtype=self.dtype) |
|
audio = audio.to(dtype=self.dtype) |
|
ref_kp = ref_kp.to(dtype=self.dtype) |
|
for t in track(self.timesteps, description='🚀Denosing', total=len(self.timesteps)): |
|
motion_in = torch.cat([prev_motion, latents], dim=-1) |
|
step_in = torch.tensor([t] * batch_size, device=self.device, dtype=self.dtype) |
|
if cfg_scale > 1: |
|
motion_in = torch.cat([motion_in] * 2, dim=0) |
|
step_in = torch.cat([step_in] * 2, dim=0) |
|
|
|
pred = self.talking_head_dit( |
|
motion = motion_in, |
|
times = step_in, |
|
audio = audio, |
|
emo = emo, |
|
audio_cond = ref_kp |
|
) |
|
|
|
if dynamic_threshold: |
|
dt_ratio, dt_min, dt_max = dynamic_threshold |
|
abs_results = pred.reshape(batch_size * 2, -1).abs() |
|
s = torch.quantile(abs_results, dt_ratio, dim=1) |
|
s = torch.clamp(s, min=dt_min, max=dt_max) |
|
s = s[..., None, None] |
|
pred = torch.clamp(pred, min=-s, max=s) |
|
|
|
|
|
if cfg_scale > 1: |
|
|
|
|
|
uncond_pred, cond_pred = pred.chunk(2, dim=0) |
|
pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred) |
|
|
|
latents = self.scheduler.step(pred, t, latents, eta=self.eta, return_dict=False)[0] |
|
self.talking_head_dit.bank=[] |
|
return latents |
|
|
|
@torch.no_grad() |
|
def sample(self, audio, ref_kp, prev_motion, cfg_scale=1.15, audio_pad_mode="zero", emo=None,dynamic_threshold=None): |
|
|
|
|
|
|
|
clip_len = audio.shape[0] |
|
stride = self.n_pred_frames - self.overlap_len |
|
if clip_len <= self.n_pred_frames: |
|
n_subdivision = 1 |
|
else: |
|
n_subdivision = math.ceil((clip_len - self.n_pred_frames) / stride) + 1 |
|
|
|
|
|
n_padding_frames = self.n_pred_frames + stride * (n_subdivision - 1) - clip_len |
|
if n_padding_frames > 0: |
|
padding_value = 0 |
|
if audio_pad_mode == 'zero': |
|
padding_value = torch.zeros_like(audio[-1:]) |
|
elif audio_pad_mode == 'replicate': |
|
padding_value = audio[-1:] |
|
else: |
|
raise ValueError(f'Unknown pad mode: {audio_pad_mode}') |
|
audio = torch.cat( |
|
[audio[:1]] * self.audio_margin \ |
|
+ [audio] + [padding_value] * n_padding_frames \ |
|
+ [audio[-1:]] * self.audio_margin, |
|
dim=0 |
|
) |
|
|
|
center_indices = torch.arange( |
|
self.audio_margin, |
|
audio.shape[0] - self.audio_margin |
|
).unsqueeze(1) + self.indices |
|
audio_tensor = audio[center_indices] |
|
|
|
|
|
motion_lst = [] |
|
|
|
init_latents = None |
|
|
|
if emo is not None: |
|
emo = torch.Tensor([emo]).long().to(self.device) |
|
start_idx = 0 |
|
for i in range(0, n_subdivision): |
|
print(f"Sample subclip {i+1}/{n_subdivision}") |
|
end_idx = start_idx + self.n_pred_frames |
|
audio_segment = audio_tensor[start_idx: end_idx].unsqueeze(0) |
|
start_idx += stride |
|
|
|
|
|
|
|
|
|
motion_segment = self.sample_subclip( |
|
audio = audio_segment, |
|
ref_kp = ref_kp, |
|
prev_motion = prev_motion, |
|
emo = emo, |
|
cfg_scale = cfg_scale, |
|
init_latents = init_latents, |
|
dynamic_threshold = dynamic_threshold |
|
) |
|
|
|
|
|
motion_segment = self._smooth(motion_segment) |
|
|
|
prev_motion = motion_segment[:, stride-1:stride].clone() |
|
|
|
|
|
motion_coef = motion_segment |
|
if i == n_subdivision - 1 and n_padding_frames > 0: |
|
motion_coef = motion_coef[:, :-n_padding_frames] |
|
|
|
if len(motion_lst) > 0: |
|
|
|
motion_lst[-1] = self._fuse(motion_lst[-1], motion_coef) |
|
motion_lst.append(motion_coef[:, self.overlap_len:]) |
|
else: |
|
motion_lst.append(motion_coef) |
|
|
|
motion = torch.cat(motion_lst, dim=1) |
|
|
|
motion = self._smooth(motion) |
|
motion = motion.squeeze() |
|
return motion.float() |
|
|
|
|
|
|