Spaces:
Runtime error
Runtime error
# Latent Motion Diffusion Model | |
import torch | |
import torch.nn as nn | |
from .lmdm_modules.model import MotionDecoder | |
from .lmdm_modules.utils import extract, make_beta_schedule | |
class LMDM(nn.Module): | |
def __init__( | |
self, | |
motion_feat_dim=265, | |
audio_feat_dim=1024+35, | |
seq_frames=80, | |
checkpoint='', | |
device='cuda', | |
clip_denoised=False, # clip denoised (-1,1) | |
multi_cond_frame=False, | |
): | |
super().__init__() | |
self.motion_feat_dim = motion_feat_dim | |
self.audio_feat_dim = audio_feat_dim | |
self.seq_frames = seq_frames | |
self.device = device | |
self.n_timestep = 1000 | |
self.clip_denoised = clip_denoised | |
self.guidance_weight = 2 | |
self.model = MotionDecoder( | |
nfeats=motion_feat_dim, | |
seq_len=seq_frames, | |
latent_dim=512, | |
ff_size=1024, | |
num_layers=8, | |
num_heads=8, | |
dropout=0.1, | |
cond_feature_dim=audio_feat_dim, | |
multi_cond_frame=multi_cond_frame, | |
) | |
self.init_diff() | |
self.sampling_timesteps = None | |
def init_diff(self): | |
n_timestep = self.n_timestep | |
betas = torch.Tensor( | |
make_beta_schedule(schedule="cosine", n_timestep=n_timestep) | |
) | |
alphas = 1.0 - betas | |
alphas_cumprod = torch.cumprod(alphas, axis=0) | |
self.register_buffer("alphas_cumprod", alphas_cumprod) | |
self.register_buffer( | |
"sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1) | |
) | |
self.register_buffer("sqrt_recip1m_alphas_cumprod", torch.sqrt(1.0 / (1.0 - alphas_cumprod))) | |
def predict_noise_from_start(self, x_t, t, x0): | |
a = extract(self.sqrt_recip1m_alphas_cumprod, t, x_t.shape) | |
b = extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) | |
return (a * x_t - x0 / b) | |
def maybe_clip(self, x): | |
if self.clip_denoised: | |
return torch.clamp(x, min=-1., max=1.) | |
else: | |
return x | |
def model_predictions(self, x, cond_frame, cond, t): | |
weight = self.guidance_weight | |
x_start = self.model.guided_forward(x, cond_frame, cond, t, weight) | |
x_start = self.maybe_clip(x_start) | |
pred_noise = self.predict_noise_from_start(x, t, x_start) | |
return pred_noise, x_start | |
def forward(self, x, cond_frame, cond, time_cond): | |
pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond) | |
return pred_noise, x_start | |
def load_model(self, ckpt_path): | |
checkpoint = torch.load(ckpt_path, map_location='cpu') | |
self.model.load_state_dict(checkpoint["model_state_dict"]) | |
self.eval() | |
return self | |
def setup(self, sampling_timesteps=50): | |
if self.sampling_timesteps == sampling_timesteps: | |
return | |
self.sampling_timesteps = sampling_timesteps | |
total_timesteps = self.n_timestep | |
device = self.device | |
eta = 1 | |
shape = (1, self.seq_frames, self.motion_feat_dim) | |
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps | |
times = list(reversed(times.int().tolist())) | |
self.time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] | |
self.time_cond_list = [] | |
self.alpha_next_sqrt_list = [] | |
self.sigma_list = [] | |
self.c_list = [] | |
self.noise_list = [] | |
for time, time_next in self.time_pairs: | |
time_cond = torch.full((1,), time, device=device, dtype=torch.long) | |
self.time_cond_list.append(time_cond) | |
if time_next < 0: | |
continue | |
alpha = self.alphas_cumprod[time] | |
alpha_next = self.alphas_cumprod[time_next] | |
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() | |
c = (1 - alpha_next - sigma ** 2).sqrt() | |
noise = torch.randn(shape, device=device) | |
self.alpha_next_sqrt_list.append(alpha_next.sqrt()) | |
self.sigma_list.append(sigma) | |
self.c_list.append(c) | |
self.noise_list.append(noise) | |
def ddim_sample(self, kp_cond, aud_cond, sampling_timesteps): | |
self.setup(sampling_timesteps) | |
cond_frame = kp_cond | |
cond = aud_cond | |
shape = (1, self.seq_frames, self.motion_feat_dim) | |
x = torch.randn(shape, device=self.device) | |
x_start = None | |
i = 0 | |
for _, time_next in self.time_pairs: | |
time_cond = self.time_cond_list[i] | |
pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond) | |
if time_next < 0: | |
x = x_start | |
continue | |
alpha_next_sqrt = self.alpha_next_sqrt_list[i] | |
c = self.c_list[i] | |
sigma = self.sigma_list[i] | |
noise = self.noise_list[i] | |
x = x_start * alpha_next_sqrt + c * pred_noise + sigma * noise | |
i += 1 | |
return x # pred_kp_seq | |