|
from typing import List, Tuple, Optional, Union, Dict |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from einops.layers.torch import Rearrange |
|
from diffusers.models import ModelMixin |
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
|
from .nn.activation_layers import SwiGLU, get_activation_layer |
|
from .nn.attn_layers import apply_rotary_emb, attention |
|
from .nn.embed_layers import TimestepEmbedder, ConditionProjection, PatchEmbed1D |
|
from .nn.mlp_layers import MLP, ConvMLP, FinalLayer1D, ChannelLastConv1d |
|
from .nn.modulate_layers import ModulateDiT, ckpt_wrapper, apply_gate, modulate |
|
from .nn.norm_layers import get_norm_layer |
|
from .nn.posemb_layers import get_nd_rotary_pos_embed |
|
|
|
def interleave_two_sequences(x1: torch.Tensor, x2: torch.Tensor): |
|
|
|
B, N1, H, C = x1.shape |
|
B, N2, H, C = x2.shape |
|
assert x1.ndim == x2.ndim == 4 |
|
|
|
if N1 != N2: |
|
x2 = x2.view(B, N2, -1).transpose(1, 2) |
|
x2 = F.interpolate(x2, size=(N1), mode="nearest-exact") |
|
x2 = x2.transpose(1, 2).view(B, N1, H, C) |
|
x = torch.stack((x1, x2), dim=2) |
|
x = x.reshape(B, N1 * 2, H, C) |
|
return x |
|
|
|
def decouple_interleaved_two_sequences(x: torch.Tensor, len1: int, len2: int): |
|
B, N, H, C = x.shape |
|
assert N % 2 == 0 and N // 2 == len1 |
|
|
|
x = x.reshape(B, -1, 2, H, C) |
|
x1 = x[:, :, 0] |
|
x2 = x[:, :, 1] |
|
if x2.shape[1] != len2: |
|
x2 = x2.view(B, len1, H * C).transpose(1, 2) |
|
x2 = F.interpolate(x2, size=(len2), mode="nearest-exact") |
|
x2 = x2.transpose(1, 2).view(B, len2, H, C) |
|
return x1, x2 |
|
|
|
class TwoStreamCABlock(nn.Module): |
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
num_heads: int, |
|
mlp_ratio: float, |
|
mlp_act_type: str = "gelu_tanh", |
|
qk_norm: bool = True, |
|
qk_norm_type: str = "rms", |
|
qkv_bias: bool = False, |
|
attn_mode: str = "torch", |
|
reverse: bool = False, |
|
interleaved_audio_visual_rope: bool = False, |
|
dtype: Optional[torch.dtype] = None, |
|
device: Optional[torch.device] = None, |
|
): |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
|
|
self.deterministic = False |
|
self.reverse = reverse |
|
self.attn_mode = attn_mode |
|
self.num_heads = num_heads |
|
self.hidden_size = hidden_size |
|
head_dim = hidden_size // num_heads |
|
mlp_hidden_dim = int(hidden_size * mlp_ratio) |
|
|
|
self.interleaved_audio_visual_rope = interleaved_audio_visual_rope |
|
|
|
|
|
self.audio_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs) |
|
self.audio_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) |
|
self.audio_self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) |
|
qk_norm_layer = get_norm_layer(qk_norm_type) |
|
self.audio_self_q_norm = ( |
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() |
|
) |
|
self.audio_self_k_norm = ( |
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() |
|
) |
|
self.audio_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) |
|
|
|
|
|
self.v_cond_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs) |
|
self.v_cond_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) |
|
self.v_cond_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) |
|
self.v_cond_attn_q_norm = ( |
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() |
|
) |
|
self.v_cond_attn_k_norm = ( |
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() |
|
) |
|
self.v_cond_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) |
|
|
|
self.max_text_len = 100 |
|
self.rope_dim_list = None |
|
|
|
|
|
self.audio_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) |
|
self.v_cond_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) |
|
|
|
|
|
self.audio_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) |
|
self.v_cond_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) |
|
self.text_cross_kv = nn.Linear(hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs) |
|
|
|
self.audio_cross_q_norm = ( |
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() |
|
) |
|
self.v_cond_cross_q_norm = ( |
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() |
|
) |
|
self.text_cross_k_norm = ( |
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() |
|
) |
|
self.audio_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) |
|
self.v_cond_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) |
|
|
|
|
|
self.audio_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) |
|
self.audio_mlp = MLP( |
|
hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs |
|
) |
|
|
|
self.v_cond_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) |
|
self.v_cond_mlp = MLP( |
|
hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs |
|
) |
|
|
|
def build_rope_for_text(self, text_len, head_dim, rope_dim_list=None): |
|
target_ndim = 1 |
|
rope_sizes = [text_len] |
|
|
|
if rope_dim_list is None: |
|
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] |
|
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" |
|
|
|
text_freqs_cos, text_freqs_sin = get_nd_rotary_pos_embed( |
|
rope_dim_list=rope_dim_list, |
|
start=rope_sizes, |
|
theta=10000, |
|
use_real=True, |
|
theta_rescale_factor=1.0, |
|
) |
|
return text_freqs_cos, text_freqs_sin |
|
|
|
def set_attn_mode(self, new_mode): |
|
if new_mode != "torch": |
|
raise NotImplementedError(f"Only support 'torch' mode, got {new_mode}.") |
|
self.attn_mode = new_mode |
|
|
|
def enable_deterministic(self): |
|
self.deterministic = True |
|
|
|
def disable_deterministic(self): |
|
self.deterministic = False |
|
|
|
def forward( |
|
self, |
|
audio: torch.Tensor, |
|
cond: torch.Tensor, |
|
v_cond: torch.Tensor, |
|
attn_mask: torch.Tensor, |
|
vec: torch.Tensor, |
|
freqs_cis: tuple = None, |
|
v_freqs_cis: tuple = None, |
|
sync_vec: torch.Tensor = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
if sync_vec is not None: |
|
assert sync_vec.ndim == 3 |
|
(audio_mod1_shift, audio_mod1_scale, audio_mod1_gate, |
|
audio_mod2_shift, audio_mod2_scale, audio_mod2_gate, |
|
audio_mod3_shift, audio_mod3_scale, audio_mod3_gate, |
|
) = self.audio_mod(sync_vec).chunk(9, dim=-1) |
|
else: |
|
(audio_mod1_shift, audio_mod1_scale, audio_mod1_gate, |
|
audio_mod2_shift, audio_mod2_scale, audio_mod2_gate, |
|
audio_mod3_shift, audio_mod3_scale, audio_mod3_gate, |
|
) = self.audio_mod(vec).chunk(9, dim=-1) |
|
|
|
( |
|
v_cond_mod1_shift, |
|
v_cond_mod1_scale, |
|
v_cond_mod1_gate, |
|
v_cond_mod2_shift, |
|
v_cond_mod2_scale, |
|
v_cond_mod2_gate, |
|
v_cond_mod3_shift, |
|
v_cond_mod3_scale, |
|
v_cond_mod3_gate, |
|
) = self.v_cond_mod(vec).chunk(9, dim=-1) |
|
|
|
|
|
audio_modulated = self.audio_norm1(audio) |
|
audio_modulated = modulate(audio_modulated, shift=audio_mod1_shift, scale=audio_mod1_scale) |
|
audio_qkv = self.audio_self_attn_qkv(audio_modulated) |
|
audio_q, audio_k, audio_v = rearrange(audio_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) |
|
audio_q = self.audio_self_q_norm(audio_q).to(audio_v) |
|
audio_k = self.audio_self_k_norm(audio_k).to(audio_v) |
|
|
|
|
|
v_cond_modulated = self.v_cond_norm1(v_cond) |
|
v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod1_shift, scale=v_cond_mod1_scale) |
|
v_cond_qkv = self.v_cond_attn_qkv(v_cond_modulated) |
|
v_cond_q, v_cond_k, v_cond_v = rearrange(v_cond_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) |
|
v_cond_q = self.v_cond_attn_q_norm(v_cond_q).to(v_cond_v) |
|
v_cond_k = self.v_cond_attn_k_norm(v_cond_k).to(v_cond_v) |
|
|
|
|
|
if freqs_cis is not None: |
|
if not self.interleaved_audio_visual_rope: |
|
audio_qq, audio_kk = apply_rotary_emb(audio_q, audio_k, freqs_cis, head_first=False) |
|
audio_q, audio_k = audio_qq, audio_kk |
|
else: |
|
ori_audio_len = audio_q.shape[1] |
|
ori_v_con_len = v_cond_q.shape[1] |
|
interleaved_audio_visual_q = interleave_two_sequences(audio_q, v_cond_q) |
|
interleaved_audio_visual_k = interleave_two_sequences(audio_k, v_cond_k) |
|
interleaved_audio_visual_qq, interleaved_audio_visual_kk = apply_rotary_emb( |
|
interleaved_audio_visual_q, interleaved_audio_visual_k, freqs_cis, head_first=False |
|
) |
|
audio_qq, v_cond_qq = decouple_interleaved_two_sequences( |
|
interleaved_audio_visual_qq, ori_audio_len, ori_v_con_len |
|
) |
|
audio_kk, v_cond_kk = decouple_interleaved_two_sequences( |
|
interleaved_audio_visual_kk, ori_audio_len, ori_v_con_len |
|
) |
|
audio_q, audio_k = audio_qq, audio_kk |
|
v_cond_q, v_cond_k = v_cond_qq, v_cond_kk |
|
|
|
|
|
if v_freqs_cis is not None and not self.interleaved_audio_visual_rope: |
|
v_cond_qq, v_cond_kk = apply_rotary_emb(v_cond_q, v_cond_k, v_freqs_cis, head_first=False) |
|
v_cond_q, v_cond_k = v_cond_qq, v_cond_kk |
|
|
|
|
|
q = torch.cat((v_cond_q, audio_q), dim=1) |
|
k = torch.cat((v_cond_k, audio_k), dim=1) |
|
v = torch.cat((v_cond_v, audio_v), dim=1) |
|
|
|
|
|
attn = attention(q, k, v, mode=self.attn_mode, attn_mask=attn_mask, deterministic=self.deterministic) |
|
v_cond_attn, audio_attn = torch.split(attn, [v_cond.shape[1], audio.shape[1]], dim=1) |
|
|
|
|
|
audio = audio + apply_gate(self.audio_self_proj(audio_attn), gate=audio_mod1_gate) |
|
v_cond = v_cond + apply_gate(self.v_cond_self_proj(v_cond_attn), gate=v_cond_mod1_gate) |
|
|
|
|
|
|
|
audio_modulated = self.audio_norm2(audio) |
|
audio_modulated = modulate(audio_modulated, shift=audio_mod2_shift, scale=audio_mod2_scale) |
|
v_cond_modulated = self.v_cond_norm2(v_cond) |
|
v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod2_shift, scale=v_cond_mod2_scale) |
|
|
|
|
|
audio_q = self.audio_cross_q(audio_modulated) |
|
audio_q = rearrange(audio_q, "B L (H D) -> B L H D", H=self.num_heads) |
|
audio_q = self.audio_cross_q_norm(audio_q) |
|
|
|
|
|
v_cond_q = self.v_cond_cross_q(v_cond_modulated) |
|
v_cond_q = rearrange(v_cond_q, "B L (H D) -> B L H D", H=self.num_heads) |
|
v_cond_q = self.v_cond_cross_q_norm(v_cond_q) |
|
|
|
|
|
text_kv = self.text_cross_kv(cond) |
|
text_k, text_v = rearrange(text_kv, "B L (K H D) -> K B L H D", K=2, H=self.num_heads) |
|
text_k = self.text_cross_k_norm(text_k).to(text_v) |
|
|
|
|
|
head_dim = self.hidden_size // self.num_heads |
|
audio_cross_freqs_cos, audio_cross_freqs_sin = self.build_rope_for_text(audio_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list) |
|
audio_cross_freqs_cis = (audio_cross_freqs_cos.to(audio_q.device), audio_cross_freqs_sin.to(audio_q.device)) |
|
audio_q = apply_rotary_emb(audio_q, audio_q, audio_cross_freqs_cis, head_first=False)[0] |
|
|
|
v_cond_cross_freqs_cos, v_cond_cross_freqs_sin = self.build_rope_for_text(v_cond_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list) |
|
v_cond_cross_freqs_cis = (v_cond_cross_freqs_cos.to(v_cond_q.device), v_cond_cross_freqs_sin.to(v_cond_q.device)) |
|
v_cond_q = apply_rotary_emb(v_cond_q, v_cond_q, v_cond_cross_freqs_cis, head_first=False)[0] |
|
|
|
text_len = text_k.shape[1] |
|
|
|
text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim, |
|
rope_dim_list=self.rope_dim_list) |
|
text_freqs_cis = (text_freqs_cos.to(text_k.device), text_freqs_sin.to(text_k.device)) |
|
text_k = apply_rotary_emb(text_k, text_k, text_freqs_cis, head_first=False)[1] |
|
|
|
|
|
v_cond_audio_q = torch.cat([v_cond_q, audio_q], dim=1) |
|
|
|
|
|
cross_attn = attention(v_cond_audio_q, text_k, text_v, mode=self.attn_mode, deterministic=self.deterministic) |
|
v_cond_cross_attn, audio_cross_attn = torch.split(cross_attn, [v_cond.shape[1], audio.shape[1]], dim=1) |
|
|
|
|
|
audio = audio + apply_gate(self.audio_cross_proj(audio_cross_attn), gate=audio_mod2_gate) |
|
v_cond = v_cond + apply_gate(self.v_cond_cross_proj(v_cond_cross_attn), gate=v_cond_mod2_gate) |
|
|
|
|
|
audio = audio + apply_gate( |
|
self.audio_mlp(modulate(self.audio_norm3(audio), shift=audio_mod3_shift, scale=audio_mod3_scale)), |
|
gate=audio_mod3_gate, |
|
) |
|
|
|
|
|
v_cond = v_cond + apply_gate( |
|
self.v_cond_mlp(modulate(self.v_cond_norm3(v_cond), shift=v_cond_mod3_shift, scale=v_cond_mod3_scale)), |
|
gate=v_cond_mod3_gate, |
|
) |
|
|
|
return audio, cond, v_cond |
|
|
|
class SingleStreamBlock(nn.Module): |
|
|
|
def __init__(self, hidden_size: int, |
|
num_heads: int, |
|
mlp_ratio: float, |
|
qk_norm_type: str = "rms", |
|
dtype: Optional[torch.dtype] = None, |
|
device: Optional[torch.device] = None,): |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
|
|
self.hidden_size = hidden_size |
|
self.num_heads = num_heads |
|
|
|
self.modulation = ModulateDiT( |
|
hidden_size=hidden_size, |
|
factor=6, |
|
act_layer=get_activation_layer("silu"), |
|
**factory_kwargs, |
|
) |
|
self.linear_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True) |
|
self.linear1 = ChannelLastConv1d(hidden_size, hidden_size, kernel_size=3, padding=1, **factory_kwargs) |
|
self.linear2 = ConvMLP(hidden_size, hidden_size * mlp_ratio, kernel_size=3, padding=1, **factory_kwargs) |
|
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False) |
|
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False) |
|
self.q_norm = nn.RMSNorm(hidden_size // num_heads) |
|
self.k_norm = nn.RMSNorm(hidden_size // num_heads) |
|
self.rearrange = Rearrange("B L (H D K) -> B H L D K", K=3, H=num_heads) |
|
|
|
def forward(self, x: torch.Tensor, cond: torch.Tensor,freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None): |
|
assert cond.ndim == 3, "Condition should be in shape of [B, T, D]" |
|
modulation = self.modulation(cond) |
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1) |
|
x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa |
|
|
|
qkv = self.linear_qkv(x_norm1) |
|
q, k, v = self.rearrange(qkv).chunk(3, dim=-1) |
|
q = q.squeeze(-1) |
|
k = k.squeeze(-1) |
|
v = v.squeeze(-1) |
|
|
|
q = self.q_norm(q) |
|
k = self.k_norm(k) |
|
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=True) |
|
|
|
q = q.contiguous() |
|
k = k.contiguous() |
|
v = v.contiguous() |
|
out = F.scaled_dot_product_attention(q, k, v) |
|
out = rearrange(out, 'b h n d -> b n (h d)').contiguous() |
|
|
|
x = x + apply_gate(self.linear1(out),gate=gate_msa) |
|
x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp |
|
x = x + apply_gate(self.linear2(x_norm), gate=gate_mlp) |
|
|
|
return x |
|
|
|
class HunyuanVideoFoley(ModelMixin, ConfigMixin): |
|
@register_to_config |
|
def __init__( |
|
self, |
|
model_config, |
|
dtype: Optional[torch.dtype] = None, |
|
device: Optional[torch.device] = None, |
|
): |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
|
|
model_args = model_config.model_config.model_kwargs |
|
self.depth_triple_blocks = model_args.get("depth_triple_blocks", 19) |
|
self.depth_single_blocks = model_args.get("depth_single_blocks", 38) |
|
|
|
self.gradient_checkpoint = False |
|
self.gradient_checkpoint_layers = None |
|
if self.gradient_checkpoint: |
|
assert self.gradient_checkpoint_layers <= self.depth_triple_blocks + self.depth_single_blocks, ( |
|
f"Gradient checkpoint layers must be less or equal than the depth of the model. " |
|
f"Got gradient_checkpoint_layers={self.gradient_checkpoint_layers} and depth={self.depth_triple_blocks + self.depth_single_blocks}." |
|
) |
|
|
|
self.interleaved_audio_visual_rope = model_args.get("interleaved_audio_visual_rope", False) |
|
|
|
|
|
self.condition_projection = model_args.get("condition_projection", "linear") |
|
self.condition_dim = model_args.get("condition_dim", None) |
|
self.use_attention_mask = model_args.get("use_attention_mask", False) |
|
|
|
self.patch_size = model_args.get("patch_size", 1) |
|
self.visual_in_channels = model_args.get("clip_dim", 768) |
|
self.audio_vae_latent_dim = model_args.get("audio_vae_latent_dim", 128) |
|
self.out_channels = self.audio_vae_latent_dim |
|
self.unpatchify_channels = self.out_channels |
|
self.reverse = model_args.get("reverse", False) |
|
|
|
self.num_heads = model_args.get("num_heads", 24) |
|
self.hidden_size = model_args.get("hidden_size", 3072) |
|
self.rope_dim_list = model_args.get("rope_dim_list", None) |
|
self.mlp_ratio = model_args.get("mlp_ratio", 4.0) |
|
self.mlp_act_type = model_args.get("mlp_act_type", "gelu_tanh") |
|
|
|
self.qkv_bias = model_args.get("qkv_bias", True) |
|
self.qk_norm = model_args.get("qk_norm", True) |
|
self.qk_norm_type = model_args.get("qk_norm_type", "rms") |
|
self.attn_mode = model_args.get("attn_mode", "torch") |
|
|
|
self.embedder_type = model_args.get("embedder_type", "default") |
|
|
|
|
|
self.sync_modulation = model_args.get("sync_modulation", False) |
|
self.add_sync_feat_to_audio = model_args.get("add_sync_feat_to_audio", False) |
|
self.sync_feat_dim = model_args.get("sync_feat_dim", 768) |
|
self.sync_in_ksz = model_args.get("sync_in_ksz", 1) |
|
|
|
|
|
self.clip_len = model_args.get("clip_length", 64) |
|
self.sync_len = model_args.get("sync_length", 192) |
|
|
|
if self.hidden_size % self.num_heads != 0: |
|
raise ValueError(f"Hidden size {self.hidden_size} must be divisible by num_heads {self.num_heads}") |
|
|
|
|
|
self.patch_size = 1 |
|
self.audio_embedder = PatchEmbed1D(self.patch_size, self.audio_vae_latent_dim, self.hidden_size, **factory_kwargs) |
|
self.visual_proj = SwiGLU(self.visual_in_channels, hidden_dim=self.hidden_size, out_dim=self.hidden_size) |
|
|
|
|
|
if self.condition_projection == "linear": |
|
self.cond_in = ConditionProjection( |
|
self.condition_dim, self.hidden_size, get_activation_layer("silu"), **factory_kwargs |
|
) |
|
else: |
|
raise NotImplementedError(f"Unsupported condition_projection: {self.condition_projection}") |
|
|
|
|
|
self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) |
|
|
|
|
|
if self.sync_in_ksz == 1: |
|
sync_in_padding = 0 |
|
elif self.sync_in_ksz == 3: |
|
sync_in_padding = 1 |
|
else: |
|
raise ValueError |
|
if self.sync_modulation or self.add_sync_feat_to_audio: |
|
self.sync_in = nn.Sequential( |
|
nn.Linear(self.sync_feat_dim, self.hidden_size), |
|
nn.SiLU(), |
|
ConvMLP(self.hidden_size, self.hidden_size * 4, kernel_size=self.sync_in_ksz, padding=sync_in_padding), |
|
) |
|
self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, self.sync_feat_dim))) |
|
|
|
self.triple_blocks = nn.ModuleList( |
|
[ |
|
TwoStreamCABlock( |
|
hidden_size=self.hidden_size, |
|
num_heads=self.num_heads, |
|
mlp_ratio=self.mlp_ratio, |
|
mlp_act_type=self.mlp_act_type, |
|
qk_norm=self.qk_norm, |
|
qk_norm_type=self.qk_norm_type, |
|
qkv_bias=self.qkv_bias, |
|
attn_mode=self.attn_mode, |
|
reverse=self.reverse, |
|
interleaved_audio_visual_rope=self.interleaved_audio_visual_rope, |
|
**factory_kwargs, |
|
) |
|
for _ in range(self.depth_triple_blocks) |
|
] |
|
) |
|
|
|
|
|
self.single_blocks = nn.ModuleList( |
|
[ |
|
SingleStreamBlock( |
|
hidden_size=self.hidden_size, |
|
num_heads=self.num_heads, |
|
mlp_ratio=self.mlp_ratio, |
|
qk_norm_type=self.qk_norm_type, |
|
**factory_kwargs, |
|
) |
|
for _ in range(self.depth_single_blocks) |
|
] |
|
) |
|
|
|
self.final_layer = FinalLayer1D( |
|
self.hidden_size, self.patch_size, self.out_channels, get_activation_layer("silu"), **factory_kwargs |
|
) |
|
self.unpatchify_channels = self.out_channels |
|
|
|
self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels), requires_grad=True) |
|
self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim), requires_grad=True) |
|
nn.init.constant_(self.empty_clip_feat, 0) |
|
nn.init.constant_(self.empty_sync_feat, 0) |
|
|
|
def get_empty_string_sequence(self, bs=None) -> torch.Tensor: |
|
if bs is None: |
|
return self.empty_string_feat |
|
else: |
|
return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1) |
|
|
|
def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor: |
|
len = len if len is not None else self.clip_len |
|
if bs is None: |
|
return self.empty_clip_feat.expand(len, -1) |
|
else: |
|
return self.empty_clip_feat.unsqueeze(0).expand(bs, len, -1) |
|
|
|
def get_empty_sync_sequence(self, bs=None, len=None) -> torch.Tensor: |
|
len = len if len is not None else self.sync_len |
|
if bs is None: |
|
return self.empty_sync_feat.expand(len, -1) |
|
else: |
|
return self.empty_sync_feat.unsqueeze(0).expand(bs, len, -1) |
|
|
|
def build_rope_for_audio_visual(self, audio_emb_len, visual_cond_len): |
|
assert self.patch_size == 1 |
|
|
|
target_ndim = 1 |
|
rope_sizes = [audio_emb_len] |
|
head_dim = self.hidden_size // self.num_heads |
|
rope_dim_list = self.rope_dim_list |
|
if rope_dim_list is None: |
|
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] |
|
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" |
|
freqs_cos, freqs_sin = get_nd_rotary_pos_embed( |
|
rope_dim_list=rope_dim_list, |
|
start=rope_sizes, |
|
theta=10000, |
|
use_real=True, |
|
theta_rescale_factor=1.0, |
|
) |
|
|
|
|
|
target_ndim = 1 |
|
rope_sizes = [visual_cond_len] |
|
head_dim = self.hidden_size // self.num_heads |
|
rope_dim_list = self.rope_dim_list |
|
if rope_dim_list is None: |
|
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] |
|
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" |
|
v_freqs_cos, v_freqs_sin = get_nd_rotary_pos_embed( |
|
rope_dim_list=rope_dim_list, |
|
start=rope_sizes, |
|
theta=10000, |
|
use_real=True, |
|
theta_rescale_factor=1.0, |
|
freq_scaling=1.0 * audio_emb_len / visual_cond_len, |
|
) |
|
return freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin |
|
|
|
def build_rope_for_interleaved_audio_visual(self, total_len): |
|
assert self.patch_size == 1 |
|
|
|
target_ndim = 1 |
|
rope_sizes = [total_len] |
|
head_dim = self.hidden_size // self.num_heads |
|
rope_dim_list = self.rope_dim_list |
|
if rope_dim_list is None: |
|
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] |
|
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" |
|
freqs_cos, freqs_sin = get_nd_rotary_pos_embed( |
|
rope_dim_list=rope_dim_list, |
|
start=rope_sizes, |
|
theta=10000, |
|
use_real=True, |
|
theta_rescale_factor=1.0, |
|
) |
|
return freqs_cos, freqs_sin |
|
|
|
def set_attn_mode(self, new_mode): |
|
for block in self.triple_blocks: |
|
block.set_attn_mode(new_mode) |
|
for block in self.single_blocks: |
|
block.set_attn_mode(new_mode) |
|
|
|
def enable_deterministic(self): |
|
for block in self.triple_blocks: |
|
block.enable_deterministic() |
|
for block in self.single_blocks: |
|
block.enable_deterministic() |
|
|
|
def disable_deterministic(self): |
|
for block in self.triple_blocks: |
|
block.disable_deterministic() |
|
for block in self.single_blocks: |
|
block.disable_deterministic() |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
t: torch.Tensor, |
|
clip_feat: Optional[torch.Tensor] = None, |
|
cond: torch.Tensor = None, |
|
audio_mask: Optional[torch.Tensor] = None, |
|
cond_mask: torch.Tensor = None, |
|
sync_feat: Optional[torch.Tensor] = None, |
|
drop_visual: Optional[List[bool]] = None, |
|
return_dict: bool = True, |
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: |
|
out = {} |
|
audio = x |
|
bs, _, ol = x.shape |
|
tl = ol // self.patch_size |
|
|
|
|
|
if drop_visual is not None: |
|
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype) |
|
sync_feat[drop_visual] = self.get_empty_sync_sequence().to(dtype=sync_feat.dtype) |
|
|
|
|
|
vec = self.time_in(t) |
|
sync_vec = None |
|
if self.sync_modulation: |
|
assert sync_feat is not None and sync_feat.shape[1] % 8 == 0 |
|
sync_feat = sync_feat.view(bs, int(sync_feat.shape[1] / 8), 8, self.sync_feat_dim) + self.sync_pos_emb |
|
sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) |
|
sync_vec = self.sync_in(sync_feat) |
|
sync_vec = ( |
|
F.interpolate(sync_vec.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2) |
|
) |
|
sync_vec = sync_vec + vec.unsqueeze(1) |
|
elif self.add_sync_feat_to_audio: |
|
assert sync_feat is not None and sync_feat.shape[1] % 8 == 0 |
|
sync_feat = sync_feat.view(bs, sync_feat.shape[1] // 8, 8, self.sync_feat_dim) + self.sync_pos_emb |
|
sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) |
|
sync_feat = self.sync_in(sync_feat) |
|
add_sync_feat_to_audio = ( |
|
F.interpolate(sync_feat.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2) |
|
) |
|
|
|
|
|
cond = self.cond_in(cond) |
|
cond_seq_len = cond.shape[1] |
|
|
|
audio = self.audio_embedder(x) |
|
audio_seq_len = audio.shape[1] |
|
v_cond = self.visual_proj(clip_feat) |
|
v_cond_seq_len = v_cond.shape[1] |
|
|
|
|
|
attn_mask = None |
|
if self.use_attention_mask: |
|
assert cond_mask is not None |
|
batch_size = audio.shape[0] |
|
seq_len = cond_seq_len + v_cond_seq_len + audio_seq_len |
|
|
|
|
|
audio_mask = torch.ones((batch_size, audio_seq_len), dtype=torch.bool, device=audio.device) |
|
v_cond_mask = torch.ones((batch_size, v_cond_seq_len), dtype=torch.bool, device=audio.device) |
|
|
|
|
|
concat_mask = torch.cat([cond_mask, v_cond_mask, audio_mask], dim=1) |
|
|
|
attn_mask_1 = concat_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) |
|
|
|
attn_mask_2 = attn_mask_1.transpose(2, 3) |
|
|
|
attn_mask = (attn_mask_1 & attn_mask_2).bool() |
|
|
|
attn_mask[:, :, :, 0] = True |
|
|
|
|
|
|
|
if self.interleaved_audio_visual_rope: |
|
freqs_cos, freqs_sin = self.build_rope_for_interleaved_audio_visual(audio_seq_len * 2) |
|
v_freqs_cos = v_freqs_sin = None |
|
else: |
|
freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin = self.build_rope_for_audio_visual( |
|
audio_seq_len, v_cond_seq_len |
|
) |
|
|
|
|
|
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None |
|
v_freqs_cis = (v_freqs_cos, v_freqs_sin) if v_freqs_cos is not None else None |
|
|
|
if self.add_sync_feat_to_audio: |
|
add_sync_layer = 0 |
|
assert ( |
|
add_sync_layer < self.depth_triple_blocks |
|
), f"The layer to add mel_spectrogram feature and sync feature should in the triple_stream_blocks (n: {self.depth_triple_blocks})." |
|
|
|
for layer_num, block in enumerate(self.triple_blocks): |
|
if self.add_sync_feat_to_audio and layer_num == add_sync_layer: |
|
audio = audio + add_sync_feat_to_audio |
|
triple_block_args = [audio, cond, v_cond, attn_mask, vec, freqs_cis, v_freqs_cis, sync_vec] |
|
if ( |
|
self.training |
|
and self.gradient_checkpoint |
|
and (self.gradient_checkpoint_layers == -1 or layer_num < self.gradient_checkpoint_layers) |
|
): |
|
audio, cond, v_cond = torch.utils.checkpoint.checkpoint( |
|
ckpt_wrapper(block), *triple_block_args, use_reentrant=False |
|
) |
|
else: |
|
audio, cond, v_cond = block(*triple_block_args) |
|
|
|
x = audio |
|
if sync_vec is not None: |
|
vec = vec.unsqueeze(1).repeat(1, cond_seq_len + v_cond_seq_len, 1) |
|
vec = torch.cat((vec, sync_vec), dim=1) |
|
|
|
freqs_cos, freqs_sin, _, _ = self.build_rope_for_audio_visual(audio_seq_len, v_cond_seq_len) |
|
if self.add_sync_feat_to_audio: |
|
vec = add_sync_feat_to_audio + vec.unsqueeze(dim=1) |
|
if len(self.single_blocks) > 0: |
|
for layer_num, block in enumerate(self.single_blocks): |
|
single_block_args = [ |
|
x, |
|
vec, |
|
(freqs_cos, freqs_sin), |
|
] |
|
if ( |
|
self.training |
|
and self.gradient_checkpoint |
|
and ( |
|
self.gradient_checkpoint_layers == -1 |
|
or layer_num + len(self.triple_blocks) < self.gradient_checkpoint_layers |
|
) |
|
): |
|
x = torch.utils.checkpoint.checkpoint(ckpt_wrapper(block), *single_block_args, use_reentrant=False) |
|
else: |
|
x = block(*single_block_args) |
|
|
|
audio = x |
|
|
|
|
|
if sync_vec is not None: |
|
vec = sync_vec |
|
audio = self.final_layer(audio, vec) |
|
audio = self.unpatchify1d(audio, tl) |
|
|
|
if return_dict: |
|
out["x"] = audio |
|
return out |
|
return audio |
|
|
|
def unpatchify1d(self, x, l): |
|
|
|
|
|
c = self.unpatchify_channels |
|
p = self.patch_size |
|
assert l == x.shape[1] |
|
|
|
x = x.reshape(shape=(x.shape[0], l, p, c)) |
|
x = torch.einsum("ntpc->nctp", x) |
|
audio = x.reshape(shape=(x.shape[0], c, l * p)) |
|
return audio |
|
|
|
def params_count(self): |
|
counts = { |
|
"triple": sum( |
|
[ |
|
sum(p.numel() for p in block.audio_cross_q.parameters()) |
|
+ sum(p.numel() for p in block.v_cond_cross_q.parameters()) |
|
+ sum(p.numel() for p in block.text_cross_kv.parameters()) |
|
+ sum(p.numel() for p in block.audio_self_attn_qkv.parameters()) |
|
+ sum(p.numel() for p in block.v_cond_attn_qkv.parameters()) |
|
+ sum(p.numel() for p in block.audio_mlp.parameters()) |
|
+ sum(p.numel() for p in block.audio_self_proj.parameters()) |
|
+ sum(p.numel() for p in block.v_cond_self_proj.parameters()) |
|
+ sum(p.numel() for p in block.v_cond_mlp.parameters()) |
|
for block in self.triple_blocks |
|
] |
|
), |
|
"single": sum( |
|
[ |
|
sum(p.numel() for p in block.linear1.parameters()) |
|
+ sum(p.numel() for p in block.linear2.parameters()) |
|
for block in self.single_blocks |
|
] |
|
), |
|
"total": sum(p.numel() for p in self.parameters()), |
|
} |
|
|
|
counts["attn+mlp"] = counts["triple"] + counts["single"] |
|
return counts |
|
|