MoDA-PLUS / src /models /dit /talking_head_diffusion.py
multimodalart's picture
Upload 247 files
7758cff verified
# encoding = 'utf-8'
import os.path as osp
import math
from rich.progress import track
from omegaconf import OmegaConf
import torch
import torch.nn as nn
from .talking_head_dit import TalkingHeadDiT_models
import sys
from ..schedulers.scheduling_ddim import DDIMScheduler
from ..schedulers.flow_matching import ModelSamplingDiscreteFlow
sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__))))))
scheduler_map = {
"ddim": DDIMScheduler,
# "ddpm": DiffusionSchedule,
"flow_matching": ModelSamplingDiscreteFlow
}
lip_dims=[18, 19, 20, 36, 37, 38, 42, 43, 44, 51, 52, 53, 57, 58, 59, 60, 61, 62]
class MotionDiffusion(nn.Module):
def __init__(self, config, device="cuda", dtype=torch.float32, smo_wsize=3, loss_type="l2"):
super().__init__()
self.config = config
self.smo_wsize = smo_wsize
print(f"================================== Init Motion GeneratorV2 ==================================")
print(OmegaConf.to_yaml(self.config))
motion_gen_config = config.motion_generator
motion_gen_params = motion_gen_config.params
audio_proj_config = config.audio_projector
audio_proj_params = audio_proj_config.params
scheduler_config = config.noise_scheduler
scheduler_params = scheduler_config.params
self.device = device
# init motion generator
self.talking_head_dit = TalkingHeadDiT_models[config.model_name](
input_dim = motion_gen_params.input_dim * 2,
output_dim = motion_gen_params.output_dim,
seq_len = motion_gen_params.n_pred_frames,
audio_unit_len = audio_proj_params.sequence_length,
audio_blocks = audio_proj_params.blocks,
audio_dim = audio_proj_params.audio_feat_dim,
audio_tokens = audio_proj_params.context_tokens,
audio_embedder_type = audio_proj_params.audio_embedder_type,
audio_cond_dim = audio_proj_params.audio_cond_dim,
norm_type = motion_gen_params.norm_type,
qk_norm = motion_gen_params.qk_norm,
exp_dim = motion_gen_params.exp_dim
)
self.input_dim = motion_gen_params.input_dim
self.exp_dim = motion_gen_params.exp_dim
self.audio_feat_dim = audio_proj_params.audio_feat_dim
self.audio_seq_len = audio_proj_params.sequence_length
self.audio_blocks = audio_proj_params.blocks
self.audio_margin = (audio_proj_params.sequence_length - 1) // 2
self.indices = (
torch.arange(2 * self.audio_margin + 1) - self.audio_margin
).unsqueeze(0) # Generates [-2, -1, 0, 1, 2], size 1 x (2*self.audio_margin+1)
self.n_prev_frames = motion_gen_params.n_prev_frames
self.n_pred_frames = motion_gen_params.n_pred_frames
# init diffusion schedule
self.scheduler = scheduler_map[scheduler_config.type](
num_train_timesteps = scheduler_params.num_train_timesteps,
beta_start = scheduler_params.beta_start,
beta_end = scheduler_params.beta_end,
beta_schedule = scheduler_params.mode,
prediction_type = scheduler_config.sample_mode,
time_shifting = scheduler_params.time_shifting,
)
self.scheduler_type = scheduler_config.type
self.eta = scheduler_params.eta
self.scheduler.set_timesteps(scheduler_params.num_inference_steps, device=self.device)
self.timesteps = self.scheduler.timesteps
print(f"time steps: {self.timesteps}")
self.sample_mode = scheduler_config.sample_mode
assert (self.sample_mode in ["noise", "sample"], f"Unknown sample mode {self.sample_mode}, should be noise or sample")
# init other params
self.audio_drop_ratio = config.train.audio_drop_ratio
self.pre_drop_ratio = config.train.pre_drop_ratio
self.null_audio_feat = nn.Parameter(
torch.randn(1, 1, 1, 1, self.audio_feat_dim),
requires_grad=True
).to(device=self.device, dtype=dtype)
self.null_motion_feat = nn.Parameter(
torch.randn(1, 1, self.input_dim),
requires_grad=True
).to(device=self.device, dtype=dtype)
# for segments fusion
self.overlap_len = min(16, self.n_pred_frames - 16)
self.fuse_alpha = torch.arange(self.overlap_len, device=self.device, dtype=dtype).reshape(1, -1, 1) / self.overlap_len
self.dtype = dtype
self.loss_type = loss_type
total_params = sum(p.numel() for p in self.parameters())
print('Number of parameter: % .4fM' % (total_params / 1e6))
print(f"================================== init Motion GeneratorV2: Done ==================================")
def _smooth(self, motion):
# motion, B x L x D
if self.smo_wsize <= 1:
return motion
new_motion = motion.clone()
n = motion.shape[1]
half_k = self.smo_wsize // 2
for i in range(n):
ss = max(0, i - half_k)
ee = min(n, i + half_k + 1)
# only smooth head pose motion
motion[:, i, self.exp_dim:] = torch.mean(new_motion[:, ss:ee, self.exp_dim:], dim=1)
return motion
def _fuse(self, prev_motion, cur_motion):
r1 = prev_motion[:, -self.overlap_len:]
r2 = cur_motion[:, :self.overlap_len]
r_fuse = r1 * (1 - self.fuse_alpha) + r2 * self.fuse_alpha
prev_motion[:, -self.overlap_len:] = r_fuse # fuse last
return prev_motion
@torch.no_grad()
def sample_subclip(
self,
audio,
ref_kp,
prev_motion,
emo=None,
cfg_scale=1.15,
init_latents=None,
dynamic_threshold = None
):
# prepare audio feat
batch_size = audio.shape[0]
audio = audio.to(self.device)
if audio.ndim == 4:
audio = audio.unsqueeze(2)
# reference keypoints
ref_kp = ref_kp.view(batch_size, 1, -1)
# cfg
if cfg_scale > 1:
uncond_audio = self.null_audio_feat.expand(
batch_size, self.n_pred_frames, self.audio_seq_len, self.audio_blocks, -1
)
audio = torch.cat([uncond_audio,audio], dim=0)
ref_kp = torch.cat([ref_kp] * 2, dim=0)
if emo is not None:
uncond_emo = torch.Tensor([self.talking_head_dit.num_emo_class]).long().to(self.device)
emo = torch.cat([uncond_emo,emo], dim=0)
ref_kp = ref_kp.repeat(1, audio.shape[1], 1) # B, L, kD
# prepare noisy motion
if init_latents is None:
latents = torch.randn((batch_size, self.n_pred_frames, self.input_dim)).to(self.device)
else:
latents = init_latents
prev_motion = prev_motion.expand_as(latents).to(dtype=self.dtype)
latents = latents.to(dtype=self.dtype)
audio = audio.to(dtype=self.dtype)
ref_kp = ref_kp.to(dtype=self.dtype)
for t in track(self.timesteps, description='🚀Denosing', total=len(self.timesteps)):
motion_in = torch.cat([prev_motion, latents], dim=-1)
step_in = torch.tensor([t] * batch_size, device=self.device, dtype=self.dtype)
if cfg_scale > 1:
motion_in = torch.cat([motion_in] * 2, dim=0)
step_in = torch.cat([step_in] * 2, dim=0)
# predict
pred = self.talking_head_dit(
motion = motion_in,
times = step_in,
audio = audio,
emo = emo,
audio_cond = ref_kp
)
if dynamic_threshold:
dt_ratio, dt_min, dt_max = dynamic_threshold
abs_results = pred.reshape(batch_size * 2, -1).abs()
s = torch.quantile(abs_results, dt_ratio, dim=1)
s = torch.clamp(s, min=dt_min, max=dt_max)
s = s[..., None, None]
pred = torch.clamp(pred, min=-s, max=s)
# CFG
if cfg_scale > 1:
# uncond_pred, emo_cond_pred, all_cond_pred = pred.chunk(3, dim=0)
# pred = uncond_pred + 8 * (emo_cond_pred - uncond_pred) + 1.2 * (all_cond_pred - emo_cond_pred)
uncond_pred, cond_pred = pred.chunk(2, dim=0)
pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred)
# Step
latents = self.scheduler.step(pred, t, latents, eta=self.eta, return_dict=False)[0]
self.talking_head_dit.bank=[]
return latents
@torch.no_grad()
def sample(self, audio, ref_kp, prev_motion, cfg_scale=1.15, audio_pad_mode="zero", emo=None,dynamic_threshold=None):
# prev_motion, B, 1, D
# for inference with any length audio
# crop audio into n_subdivision according to n_pred_frames
clip_len = audio.shape[0]
stride = self.n_pred_frames - self.overlap_len
if clip_len <= self.n_pred_frames:
n_subdivision = 1
else:
n_subdivision = math.ceil((clip_len - self.n_pred_frames) / stride) + 1
# padding
n_padding_frames = self.n_pred_frames + stride * (n_subdivision - 1) - clip_len
if n_padding_frames > 0:
padding_value = 0
if audio_pad_mode == 'zero':
padding_value = torch.zeros_like(audio[-1:])
elif audio_pad_mode == 'replicate':
padding_value = audio[-1:]
else:
raise ValueError(f'Unknown pad mode: {audio_pad_mode}')
audio = torch.cat(
[audio[:1]] * self.audio_margin \
+ [audio] + [padding_value] * n_padding_frames \
+ [audio[-1:]] * self.audio_margin,
dim=0
)
center_indices = torch.arange(
self.audio_margin,
audio.shape[0] - self.audio_margin
).unsqueeze(1) + self.indices
audio_tensor = audio[center_indices] # T, L, b, aD
# add reference keypoints
motion_lst = []
#init_latents = torch.randn((1, self.n_pred_frames, self.motion_dim)).to(device=self.device)
init_latents = None
# emotion label
if emo is not None:
emo = torch.Tensor([emo]).long().to(self.device)
start_idx = 0
for i in range(0, n_subdivision):
print(f"Sample subclip {i+1}/{n_subdivision}")
end_idx = start_idx + self.n_pred_frames
audio_segment = audio_tensor[start_idx: end_idx].unsqueeze(0)
start_idx += stride
# debug
#print(f"scale:")
motion_segment = self.sample_subclip(
audio = audio_segment,
ref_kp = ref_kp,
prev_motion = prev_motion,
emo = emo,
cfg_scale = cfg_scale,
init_latents = init_latents,
dynamic_threshold = dynamic_threshold
)
# smooth
motion_segment = self._smooth(motion_segment)
# update prev motion
prev_motion = motion_segment[:, stride-1:stride].clone()
# save results
motion_coef = motion_segment
if i == n_subdivision - 1 and n_padding_frames > 0:
motion_coef = motion_coef[:, :-n_padding_frames] # delete padded frames
if len(motion_lst) > 0:
# fuse segments
motion_lst[-1] = self._fuse(motion_lst[-1], motion_coef)
motion_lst.append(motion_coef[:, self.overlap_len:])
else:
motion_lst.append(motion_coef)
motion = torch.cat(motion_lst, dim=1)
# smooth for full clip
motion = self._smooth(motion)
motion = motion.squeeze()
return motion.float()