James Zhou
[init]
9867d34
from typing import List, Tuple, Optional, Union, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from diffusers.models import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from .nn.activation_layers import SwiGLU, get_activation_layer
from .nn.attn_layers import apply_rotary_emb, attention
from .nn.embed_layers import TimestepEmbedder, ConditionProjection, PatchEmbed1D
from .nn.mlp_layers import MLP, ConvMLP, FinalLayer1D, ChannelLastConv1d
from .nn.modulate_layers import ModulateDiT, ckpt_wrapper, apply_gate, modulate
from .nn.norm_layers import get_norm_layer
from .nn.posemb_layers import get_nd_rotary_pos_embed
def interleave_two_sequences(x1: torch.Tensor, x2: torch.Tensor):
# [B, N1, H, C] & [B, N2, H, C]
B, N1, H, C = x1.shape
B, N2, H, C = x2.shape
assert x1.ndim == x2.ndim == 4
if N1 != N2:
x2 = x2.view(B, N2, -1).transpose(1, 2)
x2 = F.interpolate(x2, size=(N1), mode="nearest-exact")
x2 = x2.transpose(1, 2).view(B, N1, H, C)
x = torch.stack((x1, x2), dim=2)
x = x.reshape(B, N1 * 2, H, C)
return x
def decouple_interleaved_two_sequences(x: torch.Tensor, len1: int, len2: int):
B, N, H, C = x.shape
assert N % 2 == 0 and N // 2 == len1
x = x.reshape(B, -1, 2, H, C)
x1 = x[:, :, 0]
x2 = x[:, :, 1]
if x2.shape[1] != len2:
x2 = x2.view(B, len1, H * C).transpose(1, 2)
x2 = F.interpolate(x2, size=(len2), mode="nearest-exact")
x2 = x2.transpose(1, 2).view(B, len2, H, C)
return x1, x2
class TwoStreamCABlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qkv_bias: bool = False,
attn_mode: str = "torch",
reverse: bool = False,
interleaved_audio_visual_rope: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.reverse = reverse
self.attn_mode = attn_mode
self.num_heads = num_heads
self.hidden_size = hidden_size
head_dim = hidden_size // num_heads
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.interleaved_audio_visual_rope = interleaved_audio_visual_rope
# Self attention for audio + visual
self.audio_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs)
self.audio_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.audio_self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.audio_self_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.audio_self_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.audio_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
# visual cond
self.v_cond_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs)
self.v_cond_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.v_cond_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
self.v_cond_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.v_cond_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.v_cond_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.max_text_len = 100
self.rope_dim_list = None
# audio and video norm for cross attention with text
self.audio_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.v_cond_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
# Cross attention: (video_audio) as query, text as key/value
self.audio_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.v_cond_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.text_cross_kv = nn.Linear(hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs)
self.audio_cross_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.v_cond_cross_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.text_cross_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.audio_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.v_cond_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
# MLPs
self.audio_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.audio_mlp = MLP(
hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs
)
self.v_cond_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.v_cond_mlp = MLP(
hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs
)
def build_rope_for_text(self, text_len, head_dim, rope_dim_list=None):
target_ndim = 1 # n-d RoPE
rope_sizes = [text_len]
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
text_freqs_cos, text_freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list=rope_dim_list,
start=rope_sizes,
theta=10000,
use_real=True,
theta_rescale_factor=1.0,
)
return text_freqs_cos, text_freqs_sin
def set_attn_mode(self, new_mode):
if new_mode != "torch":
raise NotImplementedError(f"Only support 'torch' mode, got {new_mode}.")
self.attn_mode = new_mode
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
audio: torch.Tensor,
cond: torch.Tensor,
v_cond: torch.Tensor,
attn_mask: torch.Tensor,
vec: torch.Tensor,
freqs_cis: tuple = None,
v_freqs_cis: tuple = None,
sync_vec: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Get modulation parameters
if sync_vec is not None:
assert sync_vec.ndim == 3
(audio_mod1_shift, audio_mod1_scale, audio_mod1_gate,
audio_mod2_shift, audio_mod2_scale, audio_mod2_gate,
audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
) = self.audio_mod(sync_vec).chunk(9, dim=-1)
else:
(audio_mod1_shift, audio_mod1_scale, audio_mod1_gate,
audio_mod2_shift, audio_mod2_scale, audio_mod2_gate,
audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
) = self.audio_mod(vec).chunk(9, dim=-1)
(
v_cond_mod1_shift,
v_cond_mod1_scale,
v_cond_mod1_gate,
v_cond_mod2_shift,
v_cond_mod2_scale,
v_cond_mod2_gate,
v_cond_mod3_shift,
v_cond_mod3_scale,
v_cond_mod3_gate,
) = self.v_cond_mod(vec).chunk(9, dim=-1)
# 1. Self Attention for audio + visual
audio_modulated = self.audio_norm1(audio)
audio_modulated = modulate(audio_modulated, shift=audio_mod1_shift, scale=audio_mod1_scale)
audio_qkv = self.audio_self_attn_qkv(audio_modulated)
audio_q, audio_k, audio_v = rearrange(audio_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
audio_q = self.audio_self_q_norm(audio_q).to(audio_v)
audio_k = self.audio_self_k_norm(audio_k).to(audio_v)
# Prepare visual cond for attention
v_cond_modulated = self.v_cond_norm1(v_cond)
v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod1_shift, scale=v_cond_mod1_scale)
v_cond_qkv = self.v_cond_attn_qkv(v_cond_modulated)
v_cond_q, v_cond_k, v_cond_v = rearrange(v_cond_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
v_cond_q = self.v_cond_attn_q_norm(v_cond_q).to(v_cond_v)
v_cond_k = self.v_cond_attn_k_norm(v_cond_k).to(v_cond_v)
# Apply RoPE if needed for audio and visual
if freqs_cis is not None:
if not self.interleaved_audio_visual_rope:
audio_qq, audio_kk = apply_rotary_emb(audio_q, audio_k, freqs_cis, head_first=False)
audio_q, audio_k = audio_qq, audio_kk
else:
ori_audio_len = audio_q.shape[1]
ori_v_con_len = v_cond_q.shape[1]
interleaved_audio_visual_q = interleave_two_sequences(audio_q, v_cond_q)
interleaved_audio_visual_k = interleave_two_sequences(audio_k, v_cond_k)
interleaved_audio_visual_qq, interleaved_audio_visual_kk = apply_rotary_emb(
interleaved_audio_visual_q, interleaved_audio_visual_k, freqs_cis, head_first=False
)
audio_qq, v_cond_qq = decouple_interleaved_two_sequences(
interleaved_audio_visual_qq, ori_audio_len, ori_v_con_len
)
audio_kk, v_cond_kk = decouple_interleaved_two_sequences(
interleaved_audio_visual_kk, ori_audio_len, ori_v_con_len
)
audio_q, audio_k = audio_qq, audio_kk
v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
# Apply RoPE to visual if needed and not interleaved
if v_freqs_cis is not None and not self.interleaved_audio_visual_rope:
v_cond_qq, v_cond_kk = apply_rotary_emb(v_cond_q, v_cond_k, v_freqs_cis, head_first=False)
v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
# Concatenate for self-attention
q = torch.cat((v_cond_q, audio_q), dim=1)
k = torch.cat((v_cond_k, audio_k), dim=1)
v = torch.cat((v_cond_v, audio_v), dim=1)
# Run self-attention
attn = attention(q, k, v, mode=self.attn_mode, attn_mask=attn_mask, deterministic=self.deterministic)
v_cond_attn, audio_attn = torch.split(attn, [v_cond.shape[1], audio.shape[1]], dim=1)
# Apply self-attention output to audio and v_cond
audio = audio + apply_gate(self.audio_self_proj(audio_attn), gate=audio_mod1_gate)
v_cond = v_cond + apply_gate(self.v_cond_self_proj(v_cond_attn), gate=v_cond_mod1_gate)
# 2. Cross Attention: (v_cond, audio) as query, text as key/value
# audio, v_cond modulation
audio_modulated = self.audio_norm2(audio)
audio_modulated = modulate(audio_modulated, shift=audio_mod2_shift, scale=audio_mod2_scale)
v_cond_modulated = self.v_cond_norm2(v_cond)
v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod2_shift, scale=v_cond_mod2_scale)
# Prepare audio query
audio_q = self.audio_cross_q(audio_modulated)
audio_q = rearrange(audio_q, "B L (H D) -> B L H D", H=self.num_heads)
audio_q = self.audio_cross_q_norm(audio_q)
# Prepare v_cond query
v_cond_q = self.v_cond_cross_q(v_cond_modulated)
v_cond_q = rearrange(v_cond_q, "B L (H D) -> B L H D", H=self.num_heads)
v_cond_q = self.v_cond_cross_q_norm(v_cond_q)
# Prepare text key/value
text_kv = self.text_cross_kv(cond)
text_k, text_v = rearrange(text_kv, "B L (K H D) -> K B L H D", K=2, H=self.num_heads)
text_k = self.text_cross_k_norm(text_k).to(text_v)
# Apply RoPE to (v_cond, audio) query and text key if needed
head_dim = self.hidden_size // self.num_heads
audio_cross_freqs_cos, audio_cross_freqs_sin = self.build_rope_for_text(audio_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list)
audio_cross_freqs_cis = (audio_cross_freqs_cos.to(audio_q.device), audio_cross_freqs_sin.to(audio_q.device))
audio_q = apply_rotary_emb(audio_q, audio_q, audio_cross_freqs_cis, head_first=False)[0]
v_cond_cross_freqs_cos, v_cond_cross_freqs_sin = self.build_rope_for_text(v_cond_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list)
v_cond_cross_freqs_cis = (v_cond_cross_freqs_cos.to(v_cond_q.device), v_cond_cross_freqs_sin.to(v_cond_q.device))
v_cond_q = apply_rotary_emb(v_cond_q, v_cond_q, v_cond_cross_freqs_cis, head_first=False)[0]
text_len = text_k.shape[1]
text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim,
rope_dim_list=self.rope_dim_list)
text_freqs_cis = (text_freqs_cos.to(text_k.device), text_freqs_sin.to(text_k.device))
text_k = apply_rotary_emb(text_k, text_k, text_freqs_cis, head_first=False)[1]
# Concat v_cond and audio for cross-attention
v_cond_audio_q = torch.cat([v_cond_q, audio_q], dim=1)
# Run cross-attention
cross_attn = attention(v_cond_audio_q, text_k, text_v, mode=self.attn_mode, deterministic=self.deterministic)
v_cond_cross_attn, audio_cross_attn = torch.split(cross_attn, [v_cond.shape[1], audio.shape[1]], dim=1)
# Apply cross-attention output
audio = audio + apply_gate(self.audio_cross_proj(audio_cross_attn), gate=audio_mod2_gate)
v_cond = v_cond + apply_gate(self.v_cond_cross_proj(v_cond_cross_attn), gate=v_cond_mod2_gate)
# 3. Apply MLPs
audio = audio + apply_gate(
self.audio_mlp(modulate(self.audio_norm3(audio), shift=audio_mod3_shift, scale=audio_mod3_scale)),
gate=audio_mod3_gate,
)
# Apply visual MLP
v_cond = v_cond + apply_gate(
self.v_cond_mlp(modulate(self.v_cond_norm3(v_cond), shift=v_cond_mod3_shift, scale=v_cond_mod3_scale)),
gate=v_cond_mod3_gate,
)
return audio, cond, v_cond
class SingleStreamBlock(nn.Module):
def __init__(self, hidden_size: int,
num_heads: int,
mlp_ratio: float,
qk_norm_type: str = "rms",
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.modulation = ModulateDiT(
hidden_size=hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.linear_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True)
self.linear1 = ChannelLastConv1d(hidden_size, hidden_size, kernel_size=3, padding=1, **factory_kwargs)
self.linear2 = ConvMLP(hidden_size, hidden_size * mlp_ratio, kernel_size=3, padding=1, **factory_kwargs)
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
self.q_norm = nn.RMSNorm(hidden_size // num_heads)
self.k_norm = nn.RMSNorm(hidden_size // num_heads)
self.rearrange = Rearrange("B L (H D K) -> B H L D K", K=3, H=num_heads)
def forward(self, x: torch.Tensor, cond: torch.Tensor,freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None):
assert cond.ndim == 3, "Condition should be in shape of [B, T, D]"
modulation = self.modulation(cond)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1)
x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa
qkv = self.linear_qkv(x_norm1)
q, k, v = self.rearrange(qkv).chunk(3, dim=-1)
q = q.squeeze(-1)
k = k.squeeze(-1)
v = v.squeeze(-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=True)
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = F.scaled_dot_product_attention(q, k, v)
out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
x = x + apply_gate(self.linear1(out),gate=gate_msa)
x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp
x = x + apply_gate(self.linear2(x_norm), gate=gate_mlp)
return x
class HunyuanVideoFoley(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
model_config,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
model_args = model_config.model_config.model_kwargs
self.depth_triple_blocks = model_args.get("depth_triple_blocks", 19)
self.depth_single_blocks = model_args.get("depth_single_blocks", 38)
# Gradient checkpoint.
self.gradient_checkpoint = False
self.gradient_checkpoint_layers = None
if self.gradient_checkpoint:
assert self.gradient_checkpoint_layers <= self.depth_triple_blocks + self.depth_single_blocks, (
f"Gradient checkpoint layers must be less or equal than the depth of the model. "
f"Got gradient_checkpoint_layers={self.gradient_checkpoint_layers} and depth={self.depth_triple_blocks + self.depth_single_blocks}."
)
self.interleaved_audio_visual_rope = model_args.get("interleaved_audio_visual_rope", False)
# Condition projection. Default to linear projection.
self.condition_projection = model_args.get("condition_projection", "linear")
self.condition_dim = model_args.get("condition_dim", None)
self.use_attention_mask = model_args.get("use_attention_mask", False)
self.patch_size = model_args.get("patch_size", 1)
self.visual_in_channels = model_args.get("clip_dim", 768)
self.audio_vae_latent_dim = model_args.get("audio_vae_latent_dim", 128)
self.out_channels = self.audio_vae_latent_dim
self.unpatchify_channels = self.out_channels
self.reverse = model_args.get("reverse", False)
self.num_heads = model_args.get("num_heads", 24)
self.hidden_size = model_args.get("hidden_size", 3072)
self.rope_dim_list = model_args.get("rope_dim_list", None)
self.mlp_ratio = model_args.get("mlp_ratio", 4.0)
self.mlp_act_type = model_args.get("mlp_act_type", "gelu_tanh")
self.qkv_bias = model_args.get("qkv_bias", True)
self.qk_norm = model_args.get("qk_norm", True)
self.qk_norm_type = model_args.get("qk_norm_type", "rms")
self.attn_mode = model_args.get("attn_mode", "torch")
self.embedder_type = model_args.get("embedder_type", "default")
# sync condition things
self.sync_modulation = model_args.get("sync_modulation", False)
self.add_sync_feat_to_audio = model_args.get("add_sync_feat_to_audio", False)
self.sync_feat_dim = model_args.get("sync_feat_dim", 768)
self.sync_in_ksz = model_args.get("sync_in_ksz", 1)
# condition tokens length
self.clip_len = model_args.get("clip_length", 64)
self.sync_len = model_args.get("sync_length", 192)
if self.hidden_size % self.num_heads != 0:
raise ValueError(f"Hidden size {self.hidden_size} must be divisible by num_heads {self.num_heads}")
# Build audio patchify layer and visual gated linear projection
self.patch_size = 1
self.audio_embedder = PatchEmbed1D(self.patch_size, self.audio_vae_latent_dim, self.hidden_size, **factory_kwargs)
self.visual_proj = SwiGLU(self.visual_in_channels, hidden_dim=self.hidden_size, out_dim=self.hidden_size)
# condition
if self.condition_projection == "linear":
self.cond_in = ConditionProjection(
self.condition_dim, self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
else:
raise NotImplementedError(f"Unsupported condition_projection: {self.condition_projection}")
# time modulation
self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
# visual sync embedder if needed
if self.sync_in_ksz == 1:
sync_in_padding = 0
elif self.sync_in_ksz == 3:
sync_in_padding = 1
else:
raise ValueError
if self.sync_modulation or self.add_sync_feat_to_audio:
self.sync_in = nn.Sequential(
nn.Linear(self.sync_feat_dim, self.hidden_size),
nn.SiLU(),
ConvMLP(self.hidden_size, self.hidden_size * 4, kernel_size=self.sync_in_ksz, padding=sync_in_padding),
)
self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, self.sync_feat_dim)))
self.triple_blocks = nn.ModuleList(
[
TwoStreamCABlock(
hidden_size=self.hidden_size,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
mlp_act_type=self.mlp_act_type,
qk_norm=self.qk_norm,
qk_norm_type=self.qk_norm_type,
qkv_bias=self.qkv_bias,
attn_mode=self.attn_mode,
reverse=self.reverse,
interleaved_audio_visual_rope=self.interleaved_audio_visual_rope,
**factory_kwargs,
)
for _ in range(self.depth_triple_blocks)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
hidden_size=self.hidden_size,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
qk_norm_type=self.qk_norm_type,
**factory_kwargs,
)
for _ in range(self.depth_single_blocks)
]
)
self.final_layer = FinalLayer1D(
self.hidden_size, self.patch_size, self.out_channels, get_activation_layer("silu"), **factory_kwargs
)
self.unpatchify_channels = self.out_channels
self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels), requires_grad=True)
self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim), requires_grad=True)
nn.init.constant_(self.empty_clip_feat, 0)
nn.init.constant_(self.empty_sync_feat, 0)
def get_empty_string_sequence(self, bs=None) -> torch.Tensor:
if bs is None:
return self.empty_string_feat
else:
return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1)
def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor:
len = len if len is not None else self.clip_len
if bs is None:
return self.empty_clip_feat.expand(len, -1) # 15s
else:
return self.empty_clip_feat.unsqueeze(0).expand(bs, len, -1) # 15s
def get_empty_sync_sequence(self, bs=None, len=None) -> torch.Tensor:
len = len if len is not None else self.sync_len
if bs is None:
return self.empty_sync_feat.expand(len, -1)
else:
return self.empty_sync_feat.unsqueeze(0).expand(bs, len, -1)
def build_rope_for_audio_visual(self, audio_emb_len, visual_cond_len):
assert self.patch_size == 1
# ======================================== Build RoPE for audio tokens ======================================
target_ndim = 1 # n-d RoPE
rope_sizes = [audio_emb_len]
head_dim = self.hidden_size // self.num_heads
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list=rope_dim_list,
start=rope_sizes,
theta=10000,
use_real=True,
theta_rescale_factor=1.0,
)
# ========================== Build RoPE for clip tokens =========================
target_ndim = 1 # n-d RoPE
rope_sizes = [visual_cond_len]
head_dim = self.hidden_size // self.num_heads
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
v_freqs_cos, v_freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list=rope_dim_list,
start=rope_sizes,
theta=10000,
use_real=True,
theta_rescale_factor=1.0,
freq_scaling=1.0 * audio_emb_len / visual_cond_len,
)
return freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin
def build_rope_for_interleaved_audio_visual(self, total_len):
assert self.patch_size == 1
# ========================== Build RoPE for audio tokens ========================
target_ndim = 1 # n-d RoPE
rope_sizes = [total_len]
head_dim = self.hidden_size // self.num_heads
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list=rope_dim_list,
start=rope_sizes,
theta=10000,
use_real=True,
theta_rescale_factor=1.0,
)
return freqs_cos, freqs_sin
def set_attn_mode(self, new_mode):
for block in self.triple_blocks:
block.set_attn_mode(new_mode)
for block in self.single_blocks:
block.set_attn_mode(new_mode)
def enable_deterministic(self):
for block in self.triple_blocks:
block.enable_deterministic()
for block in self.single_blocks:
block.enable_deterministic()
def disable_deterministic(self):
for block in self.triple_blocks:
block.disable_deterministic()
for block in self.single_blocks:
block.disable_deterministic()
def forward(
self,
x: torch.Tensor,
t: torch.Tensor, # Should be in range(0, 1000).
clip_feat: Optional[torch.Tensor] = None,
cond: torch.Tensor = None,
audio_mask: Optional[torch.Tensor] = None,
cond_mask: torch.Tensor = None,
sync_feat: Optional[torch.Tensor] = None,
drop_visual: Optional[List[bool]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
out = {}
audio = x
bs, _, ol = x.shape
tl = ol // self.patch_size
# Prepare learnable empty conditions for visual condition
if drop_visual is not None:
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)
sync_feat[drop_visual] = self.get_empty_sync_sequence().to(dtype=sync_feat.dtype)
# ========================= Prepare time & visual modulation =========================
vec = self.time_in(t)
sync_vec = None
if self.sync_modulation:
assert sync_feat is not None and sync_feat.shape[1] % 8 == 0
sync_feat = sync_feat.view(bs, int(sync_feat.shape[1] / 8), 8, self.sync_feat_dim) + self.sync_pos_emb
sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels
sync_vec = self.sync_in(sync_feat) # bs, num_segments * 8, c
sync_vec = (
F.interpolate(sync_vec.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
) # bs, tl, c
sync_vec = sync_vec + vec.unsqueeze(1)
elif self.add_sync_feat_to_audio:
assert sync_feat is not None and sync_feat.shape[1] % 8 == 0
sync_feat = sync_feat.view(bs, sync_feat.shape[1] // 8, 8, self.sync_feat_dim) + self.sync_pos_emb
sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels
sync_feat = self.sync_in(sync_feat) # bs, num_segments * 8, c
add_sync_feat_to_audio = (
F.interpolate(sync_feat.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
) # bs, tl, c
# ========================= Get text, audio and video clip embedding =========================
cond = self.cond_in(cond)
cond_seq_len = cond.shape[1]
audio = self.audio_embedder(x)
audio_seq_len = audio.shape[1]
v_cond = self.visual_proj(clip_feat)
v_cond_seq_len = v_cond.shape[1]
# ========================= Compute attention mask =========================
attn_mask = None
if self.use_attention_mask:
assert cond_mask is not None
batch_size = audio.shape[0]
seq_len = cond_seq_len + v_cond_seq_len + audio_seq_len
# get default audio_mask and v_cond_mask
audio_mask = torch.ones((batch_size, audio_seq_len), dtype=torch.bool, device=audio.device)
v_cond_mask = torch.ones((batch_size, v_cond_seq_len), dtype=torch.bool, device=audio.device)
# batch_size x seq_len
concat_mask = torch.cat([cond_mask, v_cond_mask, audio_mask], dim=1)
# batch_size x 1 x seq_len x seq_len
attn_mask_1 = concat_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
# batch_size x 1 x seq_len x seq_len
attn_mask_2 = attn_mask_1.transpose(2, 3)
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads
attn_mask = (attn_mask_1 & attn_mask_2).bool()
# avoids self-attention weight being NaN for text padding tokens
attn_mask[:, :, :, 0] = True
# ========================= Build rope for audio and clip tokens =========================
if self.interleaved_audio_visual_rope:
freqs_cos, freqs_sin = self.build_rope_for_interleaved_audio_visual(audio_seq_len * 2)
v_freqs_cos = v_freqs_sin = None
else:
freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin = self.build_rope_for_audio_visual(
audio_seq_len, v_cond_seq_len
)
# ========================= Pass through DiT blocks =========================
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
v_freqs_cis = (v_freqs_cos, v_freqs_sin) if v_freqs_cos is not None else None
if self.add_sync_feat_to_audio:
add_sync_layer = 0
assert (
add_sync_layer < self.depth_triple_blocks
), f"The layer to add mel_spectrogram feature and sync feature should in the triple_stream_blocks (n: {self.depth_triple_blocks})."
# Triple-stream blocks
for layer_num, block in enumerate(self.triple_blocks):
if self.add_sync_feat_to_audio and layer_num == add_sync_layer:
audio = audio + add_sync_feat_to_audio
triple_block_args = [audio, cond, v_cond, attn_mask, vec, freqs_cis, v_freqs_cis, sync_vec]
if (
self.training
and self.gradient_checkpoint
and (self.gradient_checkpoint_layers == -1 or layer_num < self.gradient_checkpoint_layers)
):
audio, cond, v_cond = torch.utils.checkpoint.checkpoint(
ckpt_wrapper(block), *triple_block_args, use_reentrant=False
)
else:
audio, cond, v_cond = block(*triple_block_args)
x = audio
if sync_vec is not None:
vec = vec.unsqueeze(1).repeat(1, cond_seq_len + v_cond_seq_len, 1)
vec = torch.cat((vec, sync_vec), dim=1)
freqs_cos, freqs_sin, _, _ = self.build_rope_for_audio_visual(audio_seq_len, v_cond_seq_len)
if self.add_sync_feat_to_audio:
vec = add_sync_feat_to_audio + vec.unsqueeze(dim=1)
if len(self.single_blocks) > 0:
for layer_num, block in enumerate(self.single_blocks):
single_block_args = [
x,
vec,
(freqs_cos, freqs_sin),
]
if (
self.training
and self.gradient_checkpoint
and (
self.gradient_checkpoint_layers == -1
or layer_num + len(self.triple_blocks) < self.gradient_checkpoint_layers
)
):
x = torch.utils.checkpoint.checkpoint(ckpt_wrapper(block), *single_block_args, use_reentrant=False)
else:
x = block(*single_block_args)
audio = x
# ========================= Final layer =========================
if sync_vec is not None:
vec = sync_vec
audio = self.final_layer(audio, vec) # (N, T, patch_size * out_channels)
audio = self.unpatchify1d(audio, tl)
if return_dict:
out["x"] = audio
return out
return audio
def unpatchify1d(self, x, l):
# x: (N, L, patch_size * C)
# audio: (N, C, T), T == L * patch_size
c = self.unpatchify_channels
p = self.patch_size
assert l == x.shape[1]
x = x.reshape(shape=(x.shape[0], l, p, c))
x = torch.einsum("ntpc->nctp", x)
audio = x.reshape(shape=(x.shape[0], c, l * p))
return audio
def params_count(self):
counts = {
"triple": sum(
[
sum(p.numel() for p in block.audio_cross_q.parameters())
+ sum(p.numel() for p in block.v_cond_cross_q.parameters())
+ sum(p.numel() for p in block.text_cross_kv.parameters())
+ sum(p.numel() for p in block.audio_self_attn_qkv.parameters())
+ sum(p.numel() for p in block.v_cond_attn_qkv.parameters())
+ sum(p.numel() for p in block.audio_mlp.parameters())
+ sum(p.numel() for p in block.audio_self_proj.parameters())
+ sum(p.numel() for p in block.v_cond_self_proj.parameters())
+ sum(p.numel() for p in block.v_cond_mlp.parameters())
for block in self.triple_blocks
]
),
"single": sum(
[
sum(p.numel() for p in block.linear1.parameters())
+ sum(p.numel() for p in block.linear2.parameters())
for block in self.single_blocks
]
),
"total": sum(p.numel() for p in self.parameters()),
}
counts["attn+mlp"] = counts["triple"] + counts["single"]
return counts