import torch import torch.nn as nn import numbers from .modules import RMSNorm, SelfAttention, CrossAttention, Mlp,MMdual_attention,MMsingle_attention,MMfour_attention from einops import rearrange, repeat def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) ################################################################################# # Core DiT Model # ################################################################################# class DiTBlock(nn.Module): """ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning, contains CrossAttention. """ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): super().__init__() norm_type = block_kwargs.get("norm_type", "rms_norm") assert norm_type in ["layer_norm", "rms_norm"] make_norm_layer = ( nn.LayerNorm if norm_type == "layer_norm" else RMSNorm ) self.norm1 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.attn1 = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) def initialize_weights(self): self.apply(_basic_init) # Zero-out adaLN modulation layers in DiT blocks: nn.init.constant_(self.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.adaLN_modulation[-1].bias, 0) def forward(self, x, c,mask=None,freqs_cis=None): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) x = x + gate_msa.unsqueeze(1) * self.attn1(modulate(self.norm1(x), shift_msa, scale_msa),mask,freqs_cis) x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x class MMSingleStreamBlock(nn.Module): ''' A multimodal dit block with seperate modulation ''' def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): super().__init__() norm_type = block_kwargs.get("norm_type", "rms_norm") assert norm_type in ["layer_norm", "rms_norm"] make_norm_layer = ( nn.LayerNorm if norm_type == "layer_norm" else RMSNorm ) self.norm1 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.attn1 = MMsingle_attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) # self.attn2 = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm3 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.norm4 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.qkv_xs = nn.Linear(hidden_size, hidden_size * 3+mlp_hidden_dim, bias=True) # self.xs_mlp = Mlp(in_features=hidden_size+mlp_hidden_dim, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.linear2 = nn.Linear( hidden_size + mlp_hidden_dim, hidden_size, ) self.mlp_act = approx_gelu() self.adaLN_modulation_xs = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 3* hidden_size, bias=True) ) self.hidden_size=hidden_size self.mlp_hidden_dim=mlp_hidden_dim def initialize_weights(self): self.apply(_basic_init) # Zero-out adaLN modulation layers in DiT blocks: nn.init.constant_(self.adaLN_modulation_xs[-1].weight, 0) nn.init.constant_(self.adaLN_modulation_xs[-1].bias, 0) def forward(self,seq_len, x, c,mask=None,freqs_cis=None,freqs_cis2=None,causal=False): shift_msa_xs, scale_msa_xs, gate_msa_xs = self.adaLN_modulation_xs(c).chunk(3, dim=1) # Prepare for attention x_mod=modulate(self.norm1(x), shift_msa_xs, scale_msa_xs) qkv, mlp = torch.split( self.qkv_xs(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 ) att1= self.attn1(seq_len,qkv,mask,causal=causal,freqs_cis=freqs_cis,freqs_cis2=freqs_cis2) output=self.linear2(torch.cat((att1, self.mlp_act(mlp)), 2)) x=x+gate_msa_xs.unsqueeze(1)*output return x class MMfourStreamBlock(nn.Module): ''' A multimodal dit block with seperate modulation ''' def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): super().__init__() norm_type = block_kwargs.get("norm_type", "rms_norm") assert norm_type in ["layer_norm", "rms_norm"] make_norm_layer = ( nn.LayerNorm if norm_type == "layer_norm" else RMSNorm ) self.norm1 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.attn1 = MMfour_attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) # self.attn2 = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm3 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.norm4 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.norm5 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.norm6 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.norm7 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.norm8 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.xs_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.audio_mlp1 = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.audio_mlp2 = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.audio_mlp3 = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.adaLN_modulation_xs = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) self.adaLN_modulation_audio1 = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) self.adaLN_modulation_audio2 = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) self.adaLN_modulation_audio3 = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) def initialize_weights(self): self.apply(_basic_init) # Zero-out adaLN modulation layers in DiT blocks: nn.init.constant_(self.adaLN_modulation_xs[-1].weight, 0) nn.init.constant_(self.adaLN_modulation_xs[-1].bias, 0) nn.init.constant_(self.adaLN_modulation_audio1[-1].weight, 0) nn.init.constant_(self.adaLN_modulation_audio1[-1].bias, 0) nn.init.constant_(self.adaLN_modulation_audio2[-1].weight, 0) nn.init.constant_(self.adaLN_modulation_audio2[-1].bias, 0) nn.init.constant_(self.adaLN_modulation_audio3[-1].weight, 0) nn.init.constant_(self.adaLN_modulation_audio3[-1].bias, 0) def forward(self, x, c, y1,y2,y3,mask=None,freqs_cis=None,freqs_cis2=None,causal=False): shift_msa_xs, scale_msa_xs, gate_msa_xs, shift_mlp_xs, scale_mlp_xs, gate_mlp_xs = self.adaLN_modulation_xs(c).chunk(6, dim=1) shift_mca_audio1, scale_mca_audio1, gate_mca_audio1, shift_mlp_audio1, scale_mlp_audio1, gate_mlp_audio1 = self.adaLN_modulation_audio1(c).chunk(6, dim=1) shift_mca_audio2, scale_mca_audio2, gate_mca_audio2, shift_mlp_audio2, scale_mlp_audio2, gate_mlp_audio2 = self.adaLN_modulation_audio2(c).chunk(6, dim=1) shift_mca_audio3, scale_mca_audio3, gate_mca_audio3, shift_mlp_audio3, scale_mlp_audio3, gate_mlp_audio3= self.adaLN_modulation_audio3(c).chunk(6, dim=1) # Prepare for attention att1,att2,att3,att4= self.attn1( modulate(self.norm1(x), shift_msa_xs, scale_msa_xs), modulate(self.norm2(y1), shift_mca_audio1, scale_mca_audio1), modulate(self.norm3(y2), shift_mca_audio2, scale_mca_audio2), modulate(self.norm4(y3), shift_mca_audio3, scale_mca_audio3), mask,causal=causal,freqs_cis=freqs_cis,freqs_cis2=freqs_cis2) x=x+gate_msa_xs.unsqueeze(1)*att1 y1=y1+gate_mca_audio1.unsqueeze(1)*att2 y2=y2+gate_mca_audio2.unsqueeze(1)*att3 y3=y3+gate_mca_audio3.unsqueeze(1)*att4 x = x + gate_mlp_xs.unsqueeze(1) * self.xs_mlp(modulate(self.norm5(x), shift_mlp_xs, scale_mlp_xs)) y1 = y1 + gate_mlp_audio1.unsqueeze(1) * self.audio_mlp1(modulate(self.norm6(y1), shift_mlp_audio1, scale_mlp_audio1)) y2 = y2 + gate_mlp_audio2.unsqueeze(1) * self.audio_mlp2(modulate(self.norm7(y2), shift_mlp_audio2, scale_mlp_audio2)) y3 = y3 + gate_mlp_audio3.unsqueeze(1) * self.audio_mlp3(modulate(self.norm8(y3), shift_mlp_audio3, scale_mlp_audio3)) return x,y1,y2,y3 class MMDoubleStreamBlock(nn.Module): ''' A multimodal dit block with seperate modulation ''' def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): super().__init__() norm_type = block_kwargs.get("norm_type", "rms_norm") assert norm_type in ["layer_norm", "rms_norm"] make_norm_layer = ( nn.LayerNorm if norm_type == "layer_norm" else RMSNorm ) self.norm1 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.attn1 = MMdual_attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) # self.attn2 = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm3 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.norm4 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.xs_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.audio_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.adaLN_modulation_xs = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) self.adaLN_modulation_audio = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) def initialize_weights(self): self.apply(_basic_init) # Zero-out adaLN modulation layers in DiT blocks: nn.init.constant_(self.adaLN_modulation_xs[-1].weight, 0) nn.init.constant_(self.adaLN_modulation_xs[-1].bias, 0) nn.init.constant_(self.adaLN_modulation_audio[-1].weight, 0) nn.init.constant_(self.adaLN_modulation_audio[-1].bias, 0) def forward(self, seq_len,x, c, y,mask=None,freqs_cis=None,freqs_cis2=None,causal=False): shift_msa_xs, scale_msa_xs, gate_msa_xs, shift_mlp_xs, scale_mlp_xs, gate_mlp_xs = self.adaLN_modulation_xs(c).chunk(6, dim=1) shift_mca_audio, scale_mca_audio, gate_mca_audio, shift_mlp_audio, scale_mlp_audio, gate_mlp_audio = self.adaLN_modulation_audio(c).chunk(6, dim=1) # Prepare for attention att1,att2 = self.attn1(seq_len,modulate(self.norm1(x), shift_msa_xs, scale_msa_xs),modulate(self.norm2(y), shift_mca_audio, scale_mca_audio),mask,causal=causal,freqs_cis=freqs_cis,freqs_cis2=freqs_cis2) x=x+gate_msa_xs.unsqueeze(1)*att1 y=y+gate_mca_audio.unsqueeze(1)*att2 x = x + gate_mlp_xs.unsqueeze(1) * self.xs_mlp(modulate(self.norm3(x), shift_mlp_xs, scale_mlp_xs)) y = y + gate_mlp_audio.unsqueeze(1) * self.audio_mlp(modulate(self.norm4(y), shift_mlp_audio, scale_mlp_audio)) return x,y class CrossDiTBlock(nn.Module): """ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning, contains CrossAttention. """ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): super().__init__() norm_type = block_kwargs.get("norm_type", "rms_norm") assert norm_type in ["layer_norm", "rms_norm"] make_norm_layer = ( nn.LayerNorm if norm_type == "layer_norm" else RMSNorm ) self.norm1 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.attn1 = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.attn2 = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm3 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 9 * hidden_size, bias=True) ) def initialize_weights(self): self.apply(_basic_init) # Zero-out adaLN modulation layers in DiT blocks: nn.init.constant_(self.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.adaLN_modulation[-1].bias, 0) def forward(self, x, c, y,mask=None): shift_msa, scale_msa, gate_msa, shift_mca, scale_mca, gate_mca, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(9, dim=1) x = x + gate_msa.unsqueeze(1) * self.attn1(modulate(self.norm1(x), shift_msa, scale_msa),mask) x = x + gate_mca.unsqueeze(1) * self.attn2(modulate(self.norm2(x), shift_mca, scale_mca), y,mask) x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm3(x), shift_mlp, scale_mlp)) return x class SelfBlock(nn.Module): """ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning, contains CrossAttention. """ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): super().__init__() norm_type = block_kwargs.get("norm_type", "rms_norm") assert norm_type in ["layer_norm", "rms_norm"] make_norm_layer = ( nn.LayerNorm if norm_type == "layer_norm" else RMSNorm ) self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.attn2 = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) def initialize_weights(self): self.apply(_basic_init) # Zero-out adaLN modulation layers in DiT blocks: def forward(self, x, y,mask=None): x = x + self.attn2(self.norm2(x),mask) return x class CrossBlock(nn.Module): """ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning, contains CrossAttention. """ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): super().__init__() norm_type = block_kwargs.get("norm_type", "rms_norm") assert norm_type in ["layer_norm", "rms_norm"] make_norm_layer = ( nn.LayerNorm if norm_type == "layer_norm" else RMSNorm ) self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.attn2 = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) def initialize_weights(self): self.apply(_basic_init) # Zero-out adaLN modulation layers in DiT blocks: def forward(self, x, y,mask=None): x = x + self.attn2(self.norm2(x), y,mask) return x class FinalLayer(nn.Module): """ The final layer of DiT. """ def __init__(self, hidden_size, out_channels, norm_type="rms_norm"): super().__init__() assert norm_type in ["layer_norm", "rms_norm"] make_norm_layer = ( nn.LayerNorm if norm_type == "layer_norm" else RMSNorm ) self.norm_final = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, out_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) ) def initialize_weights(self): self.apply(_basic_init) # Zero-out output layers: nn.init.constant_(self.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.linear.weight, 0) nn.init.constant_(self.linear.bias, 0) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x