from typing import List, Tuple, Optional, Union, Dict import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from einops.layers.torch import Rearrange from diffusers.models import ModelMixin from diffusers.configuration_utils import ConfigMixin, register_to_config from .nn.activation_layers import SwiGLU, get_activation_layer from .nn.attn_layers import apply_rotary_emb, attention from .nn.embed_layers import TimestepEmbedder, ConditionProjection, PatchEmbed1D from .nn.mlp_layers import MLP, ConvMLP, FinalLayer1D, ChannelLastConv1d from .nn.modulate_layers import ModulateDiT, ckpt_wrapper, apply_gate, modulate from .nn.norm_layers import get_norm_layer from .nn.posemb_layers import get_nd_rotary_pos_embed def interleave_two_sequences(x1: torch.Tensor, x2: torch.Tensor): # [B, N1, H, C] & [B, N2, H, C] B, N1, H, C = x1.shape B, N2, H, C = x2.shape assert x1.ndim == x2.ndim == 4 if N1 != N2: x2 = x2.view(B, N2, -1).transpose(1, 2) x2 = F.interpolate(x2, size=(N1), mode="nearest-exact") x2 = x2.transpose(1, 2).view(B, N1, H, C) x = torch.stack((x1, x2), dim=2) x = x.reshape(B, N1 * 2, H, C) return x def decouple_interleaved_two_sequences(x: torch.Tensor, len1: int, len2: int): B, N, H, C = x.shape assert N % 2 == 0 and N // 2 == len1 x = x.reshape(B, -1, 2, H, C) x1 = x[:, :, 0] x2 = x[:, :, 1] if x2.shape[1] != len2: x2 = x2.view(B, len1, H * C).transpose(1, 2) x2 = F.interpolate(x2, size=(len2), mode="nearest-exact") x2 = x2.transpose(1, 2).view(B, len2, H, C) return x1, x2 class TwoStreamCABlock(nn.Module): def __init__( self, hidden_size: int, num_heads: int, mlp_ratio: float, mlp_act_type: str = "gelu_tanh", qk_norm: bool = True, qk_norm_type: str = "rms", qkv_bias: bool = False, attn_mode: str = "torch", reverse: bool = False, interleaved_audio_visual_rope: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.deterministic = False self.reverse = reverse self.attn_mode = attn_mode self.num_heads = num_heads self.hidden_size = hidden_size head_dim = hidden_size // num_heads mlp_hidden_dim = int(hidden_size * mlp_ratio) self.interleaved_audio_visual_rope = interleaved_audio_visual_rope # Self attention for audio + visual self.audio_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs) self.audio_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.audio_self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) qk_norm_layer = get_norm_layer(qk_norm_type) self.audio_self_q_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.audio_self_k_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.audio_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) # visual cond self.v_cond_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs) self.v_cond_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.v_cond_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) self.v_cond_attn_q_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.v_cond_attn_k_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.v_cond_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) self.max_text_len = 100 self.rope_dim_list = None # audio and video norm for cross attention with text self.audio_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.v_cond_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) # Cross attention: (video_audio) as query, text as key/value self.audio_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) self.v_cond_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) self.text_cross_kv = nn.Linear(hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs) self.audio_cross_q_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.v_cond_cross_q_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.text_cross_k_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.audio_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) self.v_cond_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) # MLPs self.audio_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.audio_mlp = MLP( hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs ) self.v_cond_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.v_cond_mlp = MLP( hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs ) def build_rope_for_text(self, text_len, head_dim, rope_dim_list=None): target_ndim = 1 # n-d RoPE rope_sizes = [text_len] if rope_dim_list is None: rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" text_freqs_cos, text_freqs_sin = get_nd_rotary_pos_embed( rope_dim_list=rope_dim_list, start=rope_sizes, theta=10000, use_real=True, theta_rescale_factor=1.0, ) return text_freqs_cos, text_freqs_sin def set_attn_mode(self, new_mode): if new_mode != "torch": raise NotImplementedError(f"Only support 'torch' mode, got {new_mode}.") self.attn_mode = new_mode def enable_deterministic(self): self.deterministic = True def disable_deterministic(self): self.deterministic = False def forward( self, audio: torch.Tensor, cond: torch.Tensor, v_cond: torch.Tensor, attn_mask: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, v_freqs_cis: tuple = None, sync_vec: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Get modulation parameters if sync_vec is not None: assert sync_vec.ndim == 3 (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate, audio_mod2_shift, audio_mod2_scale, audio_mod2_gate, audio_mod3_shift, audio_mod3_scale, audio_mod3_gate, ) = self.audio_mod(sync_vec).chunk(9, dim=-1) else: (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate, audio_mod2_shift, audio_mod2_scale, audio_mod2_gate, audio_mod3_shift, audio_mod3_scale, audio_mod3_gate, ) = self.audio_mod(vec).chunk(9, dim=-1) ( v_cond_mod1_shift, v_cond_mod1_scale, v_cond_mod1_gate, v_cond_mod2_shift, v_cond_mod2_scale, v_cond_mod2_gate, v_cond_mod3_shift, v_cond_mod3_scale, v_cond_mod3_gate, ) = self.v_cond_mod(vec).chunk(9, dim=-1) # 1. Self Attention for audio + visual audio_modulated = self.audio_norm1(audio) audio_modulated = modulate(audio_modulated, shift=audio_mod1_shift, scale=audio_mod1_scale) audio_qkv = self.audio_self_attn_qkv(audio_modulated) audio_q, audio_k, audio_v = rearrange(audio_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) audio_q = self.audio_self_q_norm(audio_q).to(audio_v) audio_k = self.audio_self_k_norm(audio_k).to(audio_v) # Prepare visual cond for attention v_cond_modulated = self.v_cond_norm1(v_cond) v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod1_shift, scale=v_cond_mod1_scale) v_cond_qkv = self.v_cond_attn_qkv(v_cond_modulated) v_cond_q, v_cond_k, v_cond_v = rearrange(v_cond_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) v_cond_q = self.v_cond_attn_q_norm(v_cond_q).to(v_cond_v) v_cond_k = self.v_cond_attn_k_norm(v_cond_k).to(v_cond_v) # Apply RoPE if needed for audio and visual if freqs_cis is not None: if not self.interleaved_audio_visual_rope: audio_qq, audio_kk = apply_rotary_emb(audio_q, audio_k, freqs_cis, head_first=False) audio_q, audio_k = audio_qq, audio_kk else: ori_audio_len = audio_q.shape[1] ori_v_con_len = v_cond_q.shape[1] interleaved_audio_visual_q = interleave_two_sequences(audio_q, v_cond_q) interleaved_audio_visual_k = interleave_two_sequences(audio_k, v_cond_k) interleaved_audio_visual_qq, interleaved_audio_visual_kk = apply_rotary_emb( interleaved_audio_visual_q, interleaved_audio_visual_k, freqs_cis, head_first=False ) audio_qq, v_cond_qq = decouple_interleaved_two_sequences( interleaved_audio_visual_qq, ori_audio_len, ori_v_con_len ) audio_kk, v_cond_kk = decouple_interleaved_two_sequences( interleaved_audio_visual_kk, ori_audio_len, ori_v_con_len ) audio_q, audio_k = audio_qq, audio_kk v_cond_q, v_cond_k = v_cond_qq, v_cond_kk # Apply RoPE to visual if needed and not interleaved if v_freqs_cis is not None and not self.interleaved_audio_visual_rope: v_cond_qq, v_cond_kk = apply_rotary_emb(v_cond_q, v_cond_k, v_freqs_cis, head_first=False) v_cond_q, v_cond_k = v_cond_qq, v_cond_kk # Concatenate for self-attention q = torch.cat((v_cond_q, audio_q), dim=1) k = torch.cat((v_cond_k, audio_k), dim=1) v = torch.cat((v_cond_v, audio_v), dim=1) # Run self-attention attn = attention(q, k, v, mode=self.attn_mode, attn_mask=attn_mask, deterministic=self.deterministic) v_cond_attn, audio_attn = torch.split(attn, [v_cond.shape[1], audio.shape[1]], dim=1) # Apply self-attention output to audio and v_cond audio = audio + apply_gate(self.audio_self_proj(audio_attn), gate=audio_mod1_gate) v_cond = v_cond + apply_gate(self.v_cond_self_proj(v_cond_attn), gate=v_cond_mod1_gate) # 2. Cross Attention: (v_cond, audio) as query, text as key/value # audio, v_cond modulation audio_modulated = self.audio_norm2(audio) audio_modulated = modulate(audio_modulated, shift=audio_mod2_shift, scale=audio_mod2_scale) v_cond_modulated = self.v_cond_norm2(v_cond) v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod2_shift, scale=v_cond_mod2_scale) # Prepare audio query audio_q = self.audio_cross_q(audio_modulated) audio_q = rearrange(audio_q, "B L (H D) -> B L H D", H=self.num_heads) audio_q = self.audio_cross_q_norm(audio_q) # Prepare v_cond query v_cond_q = self.v_cond_cross_q(v_cond_modulated) v_cond_q = rearrange(v_cond_q, "B L (H D) -> B L H D", H=self.num_heads) v_cond_q = self.v_cond_cross_q_norm(v_cond_q) # Prepare text key/value text_kv = self.text_cross_kv(cond) text_k, text_v = rearrange(text_kv, "B L (K H D) -> K B L H D", K=2, H=self.num_heads) text_k = self.text_cross_k_norm(text_k).to(text_v) # Apply RoPE to (v_cond, audio) query and text key if needed head_dim = self.hidden_size // self.num_heads audio_cross_freqs_cos, audio_cross_freqs_sin = self.build_rope_for_text(audio_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list) audio_cross_freqs_cis = (audio_cross_freqs_cos.to(audio_q.device), audio_cross_freqs_sin.to(audio_q.device)) audio_q = apply_rotary_emb(audio_q, audio_q, audio_cross_freqs_cis, head_first=False)[0] v_cond_cross_freqs_cos, v_cond_cross_freqs_sin = self.build_rope_for_text(v_cond_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list) v_cond_cross_freqs_cis = (v_cond_cross_freqs_cos.to(v_cond_q.device), v_cond_cross_freqs_sin.to(v_cond_q.device)) v_cond_q = apply_rotary_emb(v_cond_q, v_cond_q, v_cond_cross_freqs_cis, head_first=False)[0] text_len = text_k.shape[1] text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim, rope_dim_list=self.rope_dim_list) text_freqs_cis = (text_freqs_cos.to(text_k.device), text_freqs_sin.to(text_k.device)) text_k = apply_rotary_emb(text_k, text_k, text_freqs_cis, head_first=False)[1] # Concat v_cond and audio for cross-attention v_cond_audio_q = torch.cat([v_cond_q, audio_q], dim=1) # Run cross-attention cross_attn = attention(v_cond_audio_q, text_k, text_v, mode=self.attn_mode, deterministic=self.deterministic) v_cond_cross_attn, audio_cross_attn = torch.split(cross_attn, [v_cond.shape[1], audio.shape[1]], dim=1) # Apply cross-attention output audio = audio + apply_gate(self.audio_cross_proj(audio_cross_attn), gate=audio_mod2_gate) v_cond = v_cond + apply_gate(self.v_cond_cross_proj(v_cond_cross_attn), gate=v_cond_mod2_gate) # 3. Apply MLPs audio = audio + apply_gate( self.audio_mlp(modulate(self.audio_norm3(audio), shift=audio_mod3_shift, scale=audio_mod3_scale)), gate=audio_mod3_gate, ) # Apply visual MLP v_cond = v_cond + apply_gate( self.v_cond_mlp(modulate(self.v_cond_norm3(v_cond), shift=v_cond_mod3_shift, scale=v_cond_mod3_scale)), gate=v_cond_mod3_gate, ) return audio, cond, v_cond class SingleStreamBlock(nn.Module): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qk_norm_type: str = "rms", dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None,): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.modulation = ModulateDiT( hidden_size=hidden_size, factor=6, act_layer=get_activation_layer("silu"), **factory_kwargs, ) self.linear_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True) self.linear1 = ChannelLastConv1d(hidden_size, hidden_size, kernel_size=3, padding=1, **factory_kwargs) self.linear2 = ConvMLP(hidden_size, hidden_size * mlp_ratio, kernel_size=3, padding=1, **factory_kwargs) self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False) self.q_norm = nn.RMSNorm(hidden_size // num_heads) self.k_norm = nn.RMSNorm(hidden_size // num_heads) self.rearrange = Rearrange("B L (H D K) -> B H L D K", K=3, H=num_heads) def forward(self, x: torch.Tensor, cond: torch.Tensor,freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None): assert cond.ndim == 3, "Condition should be in shape of [B, T, D]" modulation = self.modulation(cond) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1) x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa qkv = self.linear_qkv(x_norm1) q, k, v = self.rearrange(qkv).chunk(3, dim=-1) q = q.squeeze(-1) k = k.squeeze(-1) v = v.squeeze(-1) q = self.q_norm(q) k = self.k_norm(k) q, k = apply_rotary_emb(q, k, freqs_cis, head_first=True) q = q.contiguous() k = k.contiguous() v = v.contiguous() out = F.scaled_dot_product_attention(q, k, v) out = rearrange(out, 'b h n d -> b n (h d)').contiguous() x = x + apply_gate(self.linear1(out),gate=gate_msa) x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp x = x + apply_gate(self.linear2(x_norm), gate=gate_mlp) return x class HunyuanVideoFoley(ModelMixin, ConfigMixin): @register_to_config def __init__( self, model_config, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() model_args = model_config.model_config.model_kwargs self.depth_triple_blocks = model_args.get("depth_triple_blocks", 19) self.depth_single_blocks = model_args.get("depth_single_blocks", 38) # Gradient checkpoint. self.gradient_checkpoint = False self.gradient_checkpoint_layers = None if self.gradient_checkpoint: assert self.gradient_checkpoint_layers <= self.depth_triple_blocks + self.depth_single_blocks, ( f"Gradient checkpoint layers must be less or equal than the depth of the model. " f"Got gradient_checkpoint_layers={self.gradient_checkpoint_layers} and depth={self.depth_triple_blocks + self.depth_single_blocks}." ) self.interleaved_audio_visual_rope = model_args.get("interleaved_audio_visual_rope", False) # Condition projection. Default to linear projection. self.condition_projection = model_args.get("condition_projection", "linear") self.condition_dim = model_args.get("condition_dim", None) self.use_attention_mask = model_args.get("use_attention_mask", False) self.patch_size = model_args.get("patch_size", 1) self.visual_in_channels = model_args.get("clip_dim", 768) self.audio_vae_latent_dim = model_args.get("audio_vae_latent_dim", 128) self.out_channels = self.audio_vae_latent_dim self.unpatchify_channels = self.out_channels self.reverse = model_args.get("reverse", False) self.num_heads = model_args.get("num_heads", 24) self.hidden_size = model_args.get("hidden_size", 3072) self.rope_dim_list = model_args.get("rope_dim_list", None) self.mlp_ratio = model_args.get("mlp_ratio", 4.0) self.mlp_act_type = model_args.get("mlp_act_type", "gelu_tanh") self.qkv_bias = model_args.get("qkv_bias", True) self.qk_norm = model_args.get("qk_norm", True) self.qk_norm_type = model_args.get("qk_norm_type", "rms") self.attn_mode = model_args.get("attn_mode", "torch") self.embedder_type = model_args.get("embedder_type", "default") # sync condition things self.sync_modulation = model_args.get("sync_modulation", False) self.add_sync_feat_to_audio = model_args.get("add_sync_feat_to_audio", False) self.sync_feat_dim = model_args.get("sync_feat_dim", 768) self.sync_in_ksz = model_args.get("sync_in_ksz", 1) # condition tokens length self.clip_len = model_args.get("clip_length", 64) self.sync_len = model_args.get("sync_length", 192) if self.hidden_size % self.num_heads != 0: raise ValueError(f"Hidden size {self.hidden_size} must be divisible by num_heads {self.num_heads}") # Build audio patchify layer and visual gated linear projection self.patch_size = 1 self.audio_embedder = PatchEmbed1D(self.patch_size, self.audio_vae_latent_dim, self.hidden_size, **factory_kwargs) self.visual_proj = SwiGLU(self.visual_in_channels, hidden_dim=self.hidden_size, out_dim=self.hidden_size) # condition if self.condition_projection == "linear": self.cond_in = ConditionProjection( self.condition_dim, self.hidden_size, get_activation_layer("silu"), **factory_kwargs ) else: raise NotImplementedError(f"Unsupported condition_projection: {self.condition_projection}") # time modulation self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) # visual sync embedder if needed if self.sync_in_ksz == 1: sync_in_padding = 0 elif self.sync_in_ksz == 3: sync_in_padding = 1 else: raise ValueError if self.sync_modulation or self.add_sync_feat_to_audio: self.sync_in = nn.Sequential( nn.Linear(self.sync_feat_dim, self.hidden_size), nn.SiLU(), ConvMLP(self.hidden_size, self.hidden_size * 4, kernel_size=self.sync_in_ksz, padding=sync_in_padding), ) self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, self.sync_feat_dim))) self.triple_blocks = nn.ModuleList( [ TwoStreamCABlock( hidden_size=self.hidden_size, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, mlp_act_type=self.mlp_act_type, qk_norm=self.qk_norm, qk_norm_type=self.qk_norm_type, qkv_bias=self.qkv_bias, attn_mode=self.attn_mode, reverse=self.reverse, interleaved_audio_visual_rope=self.interleaved_audio_visual_rope, **factory_kwargs, ) for _ in range(self.depth_triple_blocks) ] ) self.single_blocks = nn.ModuleList( [ SingleStreamBlock( hidden_size=self.hidden_size, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qk_norm_type=self.qk_norm_type, **factory_kwargs, ) for _ in range(self.depth_single_blocks) ] ) self.final_layer = FinalLayer1D( self.hidden_size, self.patch_size, self.out_channels, get_activation_layer("silu"), **factory_kwargs ) self.unpatchify_channels = self.out_channels self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels), requires_grad=True) self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim), requires_grad=True) nn.init.constant_(self.empty_clip_feat, 0) nn.init.constant_(self.empty_sync_feat, 0) def get_empty_string_sequence(self, bs=None) -> torch.Tensor: if bs is None: return self.empty_string_feat else: return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1) def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor: len = len if len is not None else self.clip_len if bs is None: return self.empty_clip_feat.expand(len, -1) # 15s else: return self.empty_clip_feat.unsqueeze(0).expand(bs, len, -1) # 15s def get_empty_sync_sequence(self, bs=None, len=None) -> torch.Tensor: len = len if len is not None else self.sync_len if bs is None: return self.empty_sync_feat.expand(len, -1) else: return self.empty_sync_feat.unsqueeze(0).expand(bs, len, -1) def build_rope_for_audio_visual(self, audio_emb_len, visual_cond_len): assert self.patch_size == 1 # ======================================== Build RoPE for audio tokens ====================================== target_ndim = 1 # n-d RoPE rope_sizes = [audio_emb_len] head_dim = self.hidden_size // self.num_heads rope_dim_list = self.rope_dim_list if rope_dim_list is None: rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" freqs_cos, freqs_sin = get_nd_rotary_pos_embed( rope_dim_list=rope_dim_list, start=rope_sizes, theta=10000, use_real=True, theta_rescale_factor=1.0, ) # ========================== Build RoPE for clip tokens ========================= target_ndim = 1 # n-d RoPE rope_sizes = [visual_cond_len] head_dim = self.hidden_size // self.num_heads rope_dim_list = self.rope_dim_list if rope_dim_list is None: rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" v_freqs_cos, v_freqs_sin = get_nd_rotary_pos_embed( rope_dim_list=rope_dim_list, start=rope_sizes, theta=10000, use_real=True, theta_rescale_factor=1.0, freq_scaling=1.0 * audio_emb_len / visual_cond_len, ) return freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin def build_rope_for_interleaved_audio_visual(self, total_len): assert self.patch_size == 1 # ========================== Build RoPE for audio tokens ======================== target_ndim = 1 # n-d RoPE rope_sizes = [total_len] head_dim = self.hidden_size // self.num_heads rope_dim_list = self.rope_dim_list if rope_dim_list is None: rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" freqs_cos, freqs_sin = get_nd_rotary_pos_embed( rope_dim_list=rope_dim_list, start=rope_sizes, theta=10000, use_real=True, theta_rescale_factor=1.0, ) return freqs_cos, freqs_sin def set_attn_mode(self, new_mode): for block in self.triple_blocks: block.set_attn_mode(new_mode) for block in self.single_blocks: block.set_attn_mode(new_mode) def enable_deterministic(self): for block in self.triple_blocks: block.enable_deterministic() for block in self.single_blocks: block.enable_deterministic() def disable_deterministic(self): for block in self.triple_blocks: block.disable_deterministic() for block in self.single_blocks: block.disable_deterministic() def forward( self, x: torch.Tensor, t: torch.Tensor, # Should be in range(0, 1000). clip_feat: Optional[torch.Tensor] = None, cond: torch.Tensor = None, audio_mask: Optional[torch.Tensor] = None, cond_mask: torch.Tensor = None, sync_feat: Optional[torch.Tensor] = None, drop_visual: Optional[List[bool]] = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: out = {} audio = x bs, _, ol = x.shape tl = ol // self.patch_size # Prepare learnable empty conditions for visual condition if drop_visual is not None: clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype) sync_feat[drop_visual] = self.get_empty_sync_sequence().to(dtype=sync_feat.dtype) # ========================= Prepare time & visual modulation ========================= vec = self.time_in(t) sync_vec = None if self.sync_modulation: assert sync_feat is not None and sync_feat.shape[1] % 8 == 0 sync_feat = sync_feat.view(bs, int(sync_feat.shape[1] / 8), 8, self.sync_feat_dim) + self.sync_pos_emb sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels sync_vec = self.sync_in(sync_feat) # bs, num_segments * 8, c sync_vec = ( F.interpolate(sync_vec.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2) ) # bs, tl, c sync_vec = sync_vec + vec.unsqueeze(1) elif self.add_sync_feat_to_audio: assert sync_feat is not None and sync_feat.shape[1] % 8 == 0 sync_feat = sync_feat.view(bs, sync_feat.shape[1] // 8, 8, self.sync_feat_dim) + self.sync_pos_emb sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels sync_feat = self.sync_in(sync_feat) # bs, num_segments * 8, c add_sync_feat_to_audio = ( F.interpolate(sync_feat.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2) ) # bs, tl, c # ========================= Get text, audio and video clip embedding ========================= cond = self.cond_in(cond) cond_seq_len = cond.shape[1] audio = self.audio_embedder(x) audio_seq_len = audio.shape[1] v_cond = self.visual_proj(clip_feat) v_cond_seq_len = v_cond.shape[1] # ========================= Compute attention mask ========================= attn_mask = None if self.use_attention_mask: assert cond_mask is not None batch_size = audio.shape[0] seq_len = cond_seq_len + v_cond_seq_len + audio_seq_len # get default audio_mask and v_cond_mask audio_mask = torch.ones((batch_size, audio_seq_len), dtype=torch.bool, device=audio.device) v_cond_mask = torch.ones((batch_size, v_cond_seq_len), dtype=torch.bool, device=audio.device) # batch_size x seq_len concat_mask = torch.cat([cond_mask, v_cond_mask, audio_mask], dim=1) # batch_size x 1 x seq_len x seq_len attn_mask_1 = concat_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) # batch_size x 1 x seq_len x seq_len attn_mask_2 = attn_mask_1.transpose(2, 3) # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads attn_mask = (attn_mask_1 & attn_mask_2).bool() # avoids self-attention weight being NaN for text padding tokens attn_mask[:, :, :, 0] = True # ========================= Build rope for audio and clip tokens ========================= if self.interleaved_audio_visual_rope: freqs_cos, freqs_sin = self.build_rope_for_interleaved_audio_visual(audio_seq_len * 2) v_freqs_cos = v_freqs_sin = None else: freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin = self.build_rope_for_audio_visual( audio_seq_len, v_cond_seq_len ) # ========================= Pass through DiT blocks ========================= freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None v_freqs_cis = (v_freqs_cos, v_freqs_sin) if v_freqs_cos is not None else None if self.add_sync_feat_to_audio: add_sync_layer = 0 assert ( add_sync_layer < self.depth_triple_blocks ), f"The layer to add mel_spectrogram feature and sync feature should in the triple_stream_blocks (n: {self.depth_triple_blocks})." # Triple-stream blocks for layer_num, block in enumerate(self.triple_blocks): if self.add_sync_feat_to_audio and layer_num == add_sync_layer: audio = audio + add_sync_feat_to_audio triple_block_args = [audio, cond, v_cond, attn_mask, vec, freqs_cis, v_freqs_cis, sync_vec] if ( self.training and self.gradient_checkpoint and (self.gradient_checkpoint_layers == -1 or layer_num < self.gradient_checkpoint_layers) ): audio, cond, v_cond = torch.utils.checkpoint.checkpoint( ckpt_wrapper(block), *triple_block_args, use_reentrant=False ) else: audio, cond, v_cond = block(*triple_block_args) x = audio if sync_vec is not None: vec = vec.unsqueeze(1).repeat(1, cond_seq_len + v_cond_seq_len, 1) vec = torch.cat((vec, sync_vec), dim=1) freqs_cos, freqs_sin, _, _ = self.build_rope_for_audio_visual(audio_seq_len, v_cond_seq_len) if self.add_sync_feat_to_audio: vec = add_sync_feat_to_audio + vec.unsqueeze(dim=1) if len(self.single_blocks) > 0: for layer_num, block in enumerate(self.single_blocks): single_block_args = [ x, vec, (freqs_cos, freqs_sin), ] if ( self.training and self.gradient_checkpoint and ( self.gradient_checkpoint_layers == -1 or layer_num + len(self.triple_blocks) < self.gradient_checkpoint_layers ) ): x = torch.utils.checkpoint.checkpoint(ckpt_wrapper(block), *single_block_args, use_reentrant=False) else: x = block(*single_block_args) audio = x # ========================= Final layer ========================= if sync_vec is not None: vec = sync_vec audio = self.final_layer(audio, vec) # (N, T, patch_size * out_channels) audio = self.unpatchify1d(audio, tl) if return_dict: out["x"] = audio return out return audio def unpatchify1d(self, x, l): # x: (N, L, patch_size * C) # audio: (N, C, T), T == L * patch_size c = self.unpatchify_channels p = self.patch_size assert l == x.shape[1] x = x.reshape(shape=(x.shape[0], l, p, c)) x = torch.einsum("ntpc->nctp", x) audio = x.reshape(shape=(x.shape[0], c, l * p)) return audio def params_count(self): counts = { "triple": sum( [ sum(p.numel() for p in block.audio_cross_q.parameters()) + sum(p.numel() for p in block.v_cond_cross_q.parameters()) + sum(p.numel() for p in block.text_cross_kv.parameters()) + sum(p.numel() for p in block.audio_self_attn_qkv.parameters()) + sum(p.numel() for p in block.v_cond_attn_qkv.parameters()) + sum(p.numel() for p in block.audio_mlp.parameters()) + sum(p.numel() for p in block.audio_self_proj.parameters()) + sum(p.numel() for p in block.v_cond_self_proj.parameters()) + sum(p.numel() for p in block.v_cond_mlp.parameters()) for block in self.triple_blocks ] ), "single": sum( [ sum(p.numel() for p in block.linear1.parameters()) + sum(p.numel() for p in block.linear2.parameters()) for block in self.single_blocks ] ), "total": sum(p.numel() for p in self.parameters()), } counts["attn+mlp"] = counts["triple"] + counts["single"] return counts