File size: 12,322 Bytes
7758cff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
# encoding = 'utf-8'
import os.path as osp

import math
from rich.progress import track

from omegaconf import OmegaConf

import torch
import torch.nn as nn

from .talking_head_dit import TalkingHeadDiT_models
import sys
from ..schedulers.scheduling_ddim import DDIMScheduler
from ..schedulers.flow_matching import ModelSamplingDiscreteFlow
sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__))))))
scheduler_map = {
    "ddim": DDIMScheduler,
    # "ddpm": DiffusionSchedule,
    "flow_matching": ModelSamplingDiscreteFlow
}
lip_dims=[18, 19, 20, 36, 37, 38, 42, 43, 44, 51, 52, 53, 57, 58, 59, 60, 61, 62]

class MotionDiffusion(nn.Module):
    def __init__(self, config, device="cuda", dtype=torch.float32, smo_wsize=3, loss_type="l2"):
        super().__init__()

        self.config = config
        self.smo_wsize = smo_wsize
        print(f"================================== Init Motion GeneratorV2 ==================================")
        print(OmegaConf.to_yaml(self.config))
        
        motion_gen_config = config.motion_generator
        motion_gen_params = motion_gen_config.params
        
        audio_proj_config = config.audio_projector
        audio_proj_params = audio_proj_config.params
        
        scheduler_config = config.noise_scheduler
        scheduler_params = scheduler_config.params

        self.device = device

        # init motion generator
        self.talking_head_dit = TalkingHeadDiT_models[config.model_name](
            input_dim           = motion_gen_params.input_dim * 2,
            output_dim          = motion_gen_params.output_dim,
            seq_len             = motion_gen_params.n_pred_frames,
            audio_unit_len      = audio_proj_params.sequence_length,
            audio_blocks        = audio_proj_params.blocks,
            audio_dim           = audio_proj_params.audio_feat_dim,
            audio_tokens        = audio_proj_params.context_tokens,
            audio_embedder_type = audio_proj_params.audio_embedder_type,
            audio_cond_dim      = audio_proj_params.audio_cond_dim,
            norm_type           = motion_gen_params.norm_type,
            qk_norm             = motion_gen_params.qk_norm,
            exp_dim             = motion_gen_params.exp_dim
        )
        self.input_dim = motion_gen_params.input_dim
        self.exp_dim = motion_gen_params.exp_dim

        self.audio_feat_dim = audio_proj_params.audio_feat_dim 
        self.audio_seq_len = audio_proj_params.sequence_length
        self.audio_blocks = audio_proj_params.blocks
        self.audio_margin = (audio_proj_params.sequence_length - 1) // 2
        self.indices = (
            torch.arange(2 * self.audio_margin + 1) - self.audio_margin
        ).unsqueeze(0)  # Generates [-2, -1, 0, 1, 2], size 1 x (2*self.audio_margin+1)
        
        self.n_prev_frames = motion_gen_params.n_prev_frames
        self.n_pred_frames = motion_gen_params.n_pred_frames
        
        # init diffusion schedule
        self.scheduler = scheduler_map[scheduler_config.type](
            num_train_timesteps = scheduler_params.num_train_timesteps,
            beta_start          = scheduler_params.beta_start, 
            beta_end            = scheduler_params.beta_end, 
            beta_schedule       = scheduler_params.mode,
            prediction_type     = scheduler_config.sample_mode,
            time_shifting       = scheduler_params.time_shifting,
        )
        self.scheduler_type = scheduler_config.type
        self.eta = scheduler_params.eta
        self.scheduler.set_timesteps(scheduler_params.num_inference_steps, device=self.device)
        self.timesteps = self.scheduler.timesteps
        print(f"time steps: {self.timesteps}")
        
        self.sample_mode = scheduler_config.sample_mode
        assert (self.sample_mode in ["noise", "sample"], f"Unknown sample mode {self.sample_mode}, should be noise or sample")

        # init other params
        self.audio_drop_ratio = config.train.audio_drop_ratio
        self.pre_drop_ratio = config.train.pre_drop_ratio

        self.null_audio_feat = nn.Parameter(
            torch.randn(1, 1, 1, 1, self.audio_feat_dim), 
            requires_grad=True
        ).to(device=self.device, dtype=dtype)

        self.null_motion_feat = nn.Parameter(
            torch.randn(1, 1, self.input_dim),
            requires_grad=True
        ).to(device=self.device, dtype=dtype)
        
        # for segments fusion
        self.overlap_len = min(16, self.n_pred_frames - 16)
        self.fuse_alpha = torch.arange(self.overlap_len, device=self.device, dtype=dtype).reshape(1, -1, 1) / self.overlap_len

        self.dtype = dtype
        self.loss_type = loss_type

        total_params = sum(p.numel() for p in self.parameters())
        print('Number of parameter: % .4fM' % (total_params / 1e6))
        print(f"================================== init Motion GeneratorV2: Done ==================================")
        
    def _smooth(self, motion):
        # motion, B x L x D
        if self.smo_wsize <= 1:
            return motion
        new_motion = motion.clone()
        n = motion.shape[1]
        half_k = self.smo_wsize // 2
        for i in range(n):
            ss = max(0, i - half_k)
            ee = min(n, i + half_k + 1)
            # only smooth head pose motion
            motion[:, i, self.exp_dim:] = torch.mean(new_motion[:, ss:ee, self.exp_dim:], dim=1)
            
        return motion

    def _fuse(self, prev_motion, cur_motion):
        r1 = prev_motion[:, -self.overlap_len:]
        r2 = cur_motion[:, :self.overlap_len]
        r_fuse = r1 * (1 - self.fuse_alpha) + r2 * self.fuse_alpha

        prev_motion[:, -self.overlap_len:] = r_fuse    # fuse last
        return prev_motion
    
    @torch.no_grad()
    def sample_subclip(
        self, 
        audio, 
        ref_kp,
        prev_motion,
        emo=None,
        cfg_scale=1.15, 
        init_latents=None,
        dynamic_threshold = None
    ):
        # prepare audio feat
        batch_size = audio.shape[0]
        audio = audio.to(self.device)
        if audio.ndim == 4:
            audio = audio.unsqueeze(2)
        
        # reference keypoints
        ref_kp = ref_kp.view(batch_size, 1, -1)
        
        # cfg
        if cfg_scale > 1:
            uncond_audio = self.null_audio_feat.expand(
               batch_size, self.n_pred_frames, self.audio_seq_len, self.audio_blocks, -1
            )
            audio = torch.cat([uncond_audio,audio], dim=0)
            ref_kp = torch.cat([ref_kp] * 2, dim=0)
            if emo is not None:
                uncond_emo = torch.Tensor([self.talking_head_dit.num_emo_class]).long().to(self.device)
                emo = torch.cat([uncond_emo,emo], dim=0)
        ref_kp = ref_kp.repeat(1, audio.shape[1], 1)  # B, L, kD

        # prepare noisy motion
        if init_latents is None:
            latents = torch.randn((batch_size, self.n_pred_frames, self.input_dim)).to(self.device)
        else:
            latents = init_latents
        
        prev_motion = prev_motion.expand_as(latents).to(dtype=self.dtype)
        latents = latents.to(dtype=self.dtype)
        audio = audio.to(dtype=self.dtype)
        ref_kp = ref_kp.to(dtype=self.dtype)
        for t in track(self.timesteps, description='🚀Denosing', total=len(self.timesteps)):
            motion_in = torch.cat([prev_motion, latents], dim=-1)
            step_in = torch.tensor([t] * batch_size, device=self.device, dtype=self.dtype)
            if cfg_scale > 1:
                motion_in = torch.cat([motion_in] * 2, dim=0)
                step_in = torch.cat([step_in] * 2, dim=0)
            # predict
            pred = self.talking_head_dit(
                motion     = motion_in, 
                times       = step_in,
                audio      = audio,
                emo        = emo,
                audio_cond = ref_kp
            )

            if dynamic_threshold:
                dt_ratio, dt_min, dt_max = dynamic_threshold
                abs_results = pred.reshape(batch_size * 2, -1).abs()
                s = torch.quantile(abs_results, dt_ratio, dim=1)
                s = torch.clamp(s, min=dt_min, max=dt_max)
                s = s[..., None, None]
                pred = torch.clamp(pred, min=-s, max=s)

            # CFG
            if cfg_scale > 1:
                # uncond_pred, emo_cond_pred, all_cond_pred = pred.chunk(3, dim=0)
                # pred = uncond_pred + 8 * (emo_cond_pred - uncond_pred) + 1.2 * (all_cond_pred - emo_cond_pred)
                uncond_pred, cond_pred = pred.chunk(2, dim=0)
                pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred)
            # Step
            latents = self.scheduler.step(pred, t, latents, eta=self.eta, return_dict=False)[0]
        self.talking_head_dit.bank=[]
        return latents
            
    @torch.no_grad()
    def sample(self, audio, ref_kp, prev_motion, cfg_scale=1.15, audio_pad_mode="zero", emo=None,dynamic_threshold=None):
        # prev_motion, B, 1, D
        # for inference with any length audio
        # crop audio into n_subdivision according to n_pred_frames 
        clip_len = audio.shape[0]
        stride = self.n_pred_frames - self.overlap_len
        if clip_len <= self.n_pred_frames:
            n_subdivision = 1
        else:
            n_subdivision = math.ceil((clip_len - self.n_pred_frames) / stride) + 1
        
        # padding
        n_padding_frames = self.n_pred_frames + stride * (n_subdivision - 1) - clip_len
        if n_padding_frames > 0:
            padding_value = 0
            if audio_pad_mode == 'zero':
                padding_value = torch.zeros_like(audio[-1:])
            elif audio_pad_mode == 'replicate':
                padding_value = audio[-1:]
            else:
                raise ValueError(f'Unknown pad mode: {audio_pad_mode}')
            audio = torch.cat(
                [audio[:1]] * self.audio_margin \
                + [audio] + [padding_value] * n_padding_frames \
                + [audio[-1:]] * self.audio_margin, 
                dim=0
            )
        
        center_indices = torch.arange(
            self.audio_margin,
            audio.shape[0] - self.audio_margin
        ).unsqueeze(1) + self.indices
        audio_tensor = audio[center_indices]   # T, L, b, aD

        # add reference keypoints
        motion_lst = []
        #init_latents = torch.randn((1, self.n_pred_frames, self.motion_dim)).to(device=self.device)
        init_latents = None
        # emotion label
        if emo is not None:
            emo = torch.Tensor([emo]).long().to(self.device)
        start_idx = 0
        for i in range(0, n_subdivision):
            print(f"Sample subclip {i+1}/{n_subdivision}")
            end_idx = start_idx + self.n_pred_frames
            audio_segment = audio_tensor[start_idx: end_idx].unsqueeze(0)
            start_idx += stride

            # debug
            #print(f"scale:")
            
            motion_segment = self.sample_subclip(
                audio             = audio_segment, 
                ref_kp            = ref_kp,
                prev_motion       = prev_motion,
                emo               = emo,
                cfg_scale         = cfg_scale,
                init_latents      = init_latents,
                dynamic_threshold = dynamic_threshold
            )
            # smooth

            motion_segment = self._smooth(motion_segment)
            # update prev motion
            prev_motion = motion_segment[:, stride-1:stride].clone()

            # save results
            motion_coef = motion_segment
            if i == n_subdivision - 1 and n_padding_frames > 0:
                motion_coef = motion_coef[:, :-n_padding_frames]  # delete padded frames
            
            if len(motion_lst) > 0:
                # fuse segments
                motion_lst[-1] = self._fuse(motion_lst[-1], motion_coef)
                motion_lst.append(motion_coef[:, self.overlap_len:])
            else:
                motion_lst.append(motion_coef)
                
        motion = torch.cat(motion_lst, dim=1)
        # smooth for full clip
        motion = self._smooth(motion)
        motion = motion.squeeze()
        return motion.float()