oKen38461's picture
初回コミットに基づくファイルの追加
ac7cda5
raw
history blame
5.23 kB
# Latent Motion Diffusion Model
import torch
import torch.nn as nn
from .lmdm_modules.model import MotionDecoder
from .lmdm_modules.utils import extract, make_beta_schedule
class LMDM(nn.Module):
def __init__(
self,
motion_feat_dim=265,
audio_feat_dim=1024+35,
seq_frames=80,
checkpoint='',
device='cuda',
clip_denoised=False, # clip denoised (-1,1)
multi_cond_frame=False,
):
super().__init__()
self.motion_feat_dim = motion_feat_dim
self.audio_feat_dim = audio_feat_dim
self.seq_frames = seq_frames
self.device = device
self.n_timestep = 1000
self.clip_denoised = clip_denoised
self.guidance_weight = 2
self.model = MotionDecoder(
nfeats=motion_feat_dim,
seq_len=seq_frames,
latent_dim=512,
ff_size=1024,
num_layers=8,
num_heads=8,
dropout=0.1,
cond_feature_dim=audio_feat_dim,
multi_cond_frame=multi_cond_frame,
)
self.init_diff()
self.sampling_timesteps = None
def init_diff(self):
n_timestep = self.n_timestep
betas = torch.Tensor(
make_beta_schedule(schedule="cosine", n_timestep=n_timestep)
)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
self.register_buffer("alphas_cumprod", alphas_cumprod)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)
)
self.register_buffer("sqrt_recip1m_alphas_cumprod", torch.sqrt(1.0 / (1.0 - alphas_cumprod)))
def predict_noise_from_start(self, x_t, t, x0):
a = extract(self.sqrt_recip1m_alphas_cumprod, t, x_t.shape)
b = extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
return (a * x_t - x0 / b)
def maybe_clip(self, x):
if self.clip_denoised:
return torch.clamp(x, min=-1., max=1.)
else:
return x
def model_predictions(self, x, cond_frame, cond, t):
weight = self.guidance_weight
x_start = self.model.guided_forward(x, cond_frame, cond, t, weight)
x_start = self.maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
return pred_noise, x_start
@torch.no_grad()
def forward(self, x, cond_frame, cond, time_cond):
pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond)
return pred_noise, x_start
def load_model(self, ckpt_path):
checkpoint = torch.load(ckpt_path, map_location='cpu')
self.model.load_state_dict(checkpoint["model_state_dict"])
self.eval()
return self
def setup(self, sampling_timesteps=50):
if self.sampling_timesteps == sampling_timesteps:
return
self.sampling_timesteps = sampling_timesteps
total_timesteps = self.n_timestep
device = self.device
eta = 1
shape = (1, self.seq_frames, self.motion_feat_dim)
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = list(reversed(times.int().tolist()))
self.time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
self.time_cond_list = []
self.alpha_next_sqrt_list = []
self.sigma_list = []
self.c_list = []
self.noise_list = []
for time, time_next in self.time_pairs:
time_cond = torch.full((1,), time, device=device, dtype=torch.long)
self.time_cond_list.append(time_cond)
if time_next < 0:
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = (1 - alpha_next - sigma ** 2).sqrt()
noise = torch.randn(shape, device=device)
self.alpha_next_sqrt_list.append(alpha_next.sqrt())
self.sigma_list.append(sigma)
self.c_list.append(c)
self.noise_list.append(noise)
@torch.no_grad()
def ddim_sample(self, kp_cond, aud_cond, sampling_timesteps):
self.setup(sampling_timesteps)
cond_frame = kp_cond
cond = aud_cond
shape = (1, self.seq_frames, self.motion_feat_dim)
x = torch.randn(shape, device=self.device)
x_start = None
i = 0
for _, time_next in self.time_pairs:
time_cond = self.time_cond_list[i]
pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond)
if time_next < 0:
x = x_start
continue
alpha_next_sqrt = self.alpha_next_sqrt_list[i]
c = self.c_list[i]
sigma = self.sigma_list[i]
noise = self.noise_list[i]
x = x_start * alpha_next_sqrt + c * pred_noise + sigma * noise
i += 1
return x # pred_kp_seq