MoDA-PLUS / src /models /dit /blocks.py
multimodalart's picture
Upload 247 files
7758cff verified
import torch
import torch.nn as nn
import numbers
from .modules import RMSNorm, SelfAttention, CrossAttention, Mlp,MMdual_attention,MMsingle_attention,MMfour_attention
from einops import rearrange, repeat
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
#################################################################################
# Core DiT Model #
#################################################################################
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning, contains CrossAttention.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
norm_type = block_kwargs.get("norm_type", "rms_norm")
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm1 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn1 = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out adaLN modulation layers in DiT blocks:
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
def forward(self, x, c,mask=None,freqs_cis=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn1(modulate(self.norm1(x), shift_msa, scale_msa),mask,freqs_cis)
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class MMSingleStreamBlock(nn.Module):
''' A multimodal dit block with seperate modulation '''
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
norm_type = block_kwargs.get("norm_type", "rms_norm")
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm1 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn1 = MMsingle_attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
# self.attn2 = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm3 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm4 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.qkv_xs = nn.Linear(hidden_size, hidden_size * 3+mlp_hidden_dim, bias=True)
# self.xs_mlp = Mlp(in_features=hidden_size+mlp_hidden_dim, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.linear2 = nn.Linear(
hidden_size + mlp_hidden_dim, hidden_size,
)
self.mlp_act = approx_gelu()
self.adaLN_modulation_xs = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 3* hidden_size, bias=True)
)
self.hidden_size=hidden_size
self.mlp_hidden_dim=mlp_hidden_dim
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out adaLN modulation layers in DiT blocks:
nn.init.constant_(self.adaLN_modulation_xs[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation_xs[-1].bias, 0)
def forward(self,seq_len, x, c,mask=None,freqs_cis=None,freqs_cis2=None,causal=False):
shift_msa_xs, scale_msa_xs, gate_msa_xs = self.adaLN_modulation_xs(c).chunk(3, dim=1)
# Prepare for attention
x_mod=modulate(self.norm1(x), shift_msa_xs, scale_msa_xs)
qkv, mlp = torch.split(
self.qkv_xs(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
)
att1= self.attn1(seq_len,qkv,mask,causal=causal,freqs_cis=freqs_cis,freqs_cis2=freqs_cis2)
output=self.linear2(torch.cat((att1, self.mlp_act(mlp)), 2))
x=x+gate_msa_xs.unsqueeze(1)*output
return x
class MMfourStreamBlock(nn.Module):
''' A multimodal dit block with seperate modulation '''
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
norm_type = block_kwargs.get("norm_type", "rms_norm")
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm1 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn1 = MMfour_attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
# self.attn2 = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm3 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm4 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm5 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm6 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm7 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm8 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.xs_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.audio_mlp1 = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.audio_mlp2 = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.audio_mlp3 = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation_xs = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.adaLN_modulation_audio1 = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.adaLN_modulation_audio2 = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True))
self.adaLN_modulation_audio3 = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True))
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out adaLN modulation layers in DiT blocks:
nn.init.constant_(self.adaLN_modulation_xs[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation_xs[-1].bias, 0)
nn.init.constant_(self.adaLN_modulation_audio1[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation_audio1[-1].bias, 0)
nn.init.constant_(self.adaLN_modulation_audio2[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation_audio2[-1].bias, 0)
nn.init.constant_(self.adaLN_modulation_audio3[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation_audio3[-1].bias, 0)
def forward(self, x, c, y1,y2,y3,mask=None,freqs_cis=None,freqs_cis2=None,causal=False):
shift_msa_xs, scale_msa_xs, gate_msa_xs, shift_mlp_xs, scale_mlp_xs, gate_mlp_xs = self.adaLN_modulation_xs(c).chunk(6, dim=1)
shift_mca_audio1, scale_mca_audio1, gate_mca_audio1, shift_mlp_audio1, scale_mlp_audio1, gate_mlp_audio1 = self.adaLN_modulation_audio1(c).chunk(6, dim=1)
shift_mca_audio2, scale_mca_audio2, gate_mca_audio2, shift_mlp_audio2, scale_mlp_audio2, gate_mlp_audio2 = self.adaLN_modulation_audio2(c).chunk(6, dim=1)
shift_mca_audio3, scale_mca_audio3, gate_mca_audio3, shift_mlp_audio3, scale_mlp_audio3, gate_mlp_audio3= self.adaLN_modulation_audio3(c).chunk(6, dim=1)
# Prepare for attention
att1,att2,att3,att4= self.attn1( modulate(self.norm1(x), shift_msa_xs, scale_msa_xs),
modulate(self.norm2(y1), shift_mca_audio1, scale_mca_audio1),
modulate(self.norm3(y2), shift_mca_audio2, scale_mca_audio2),
modulate(self.norm4(y3), shift_mca_audio3, scale_mca_audio3),
mask,causal=causal,freqs_cis=freqs_cis,freqs_cis2=freqs_cis2)
x=x+gate_msa_xs.unsqueeze(1)*att1
y1=y1+gate_mca_audio1.unsqueeze(1)*att2
y2=y2+gate_mca_audio2.unsqueeze(1)*att3
y3=y3+gate_mca_audio3.unsqueeze(1)*att4
x = x + gate_mlp_xs.unsqueeze(1) * self.xs_mlp(modulate(self.norm5(x), shift_mlp_xs, scale_mlp_xs))
y1 = y1 + gate_mlp_audio1.unsqueeze(1) * self.audio_mlp1(modulate(self.norm6(y1), shift_mlp_audio1, scale_mlp_audio1))
y2 = y2 + gate_mlp_audio2.unsqueeze(1) * self.audio_mlp2(modulate(self.norm7(y2), shift_mlp_audio2, scale_mlp_audio2))
y3 = y3 + gate_mlp_audio3.unsqueeze(1) * self.audio_mlp3(modulate(self.norm8(y3), shift_mlp_audio3, scale_mlp_audio3))
return x,y1,y2,y3
class MMDoubleStreamBlock(nn.Module):
''' A multimodal dit block with seperate modulation '''
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
norm_type = block_kwargs.get("norm_type", "rms_norm")
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm1 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn1 = MMdual_attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
# self.attn2 = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm3 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm4 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.xs_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.audio_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation_xs = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.adaLN_modulation_audio = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out adaLN modulation layers in DiT blocks:
nn.init.constant_(self.adaLN_modulation_xs[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation_xs[-1].bias, 0)
nn.init.constant_(self.adaLN_modulation_audio[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation_audio[-1].bias, 0)
def forward(self, seq_len,x, c, y,mask=None,freqs_cis=None,freqs_cis2=None,causal=False):
shift_msa_xs, scale_msa_xs, gate_msa_xs, shift_mlp_xs, scale_mlp_xs, gate_mlp_xs = self.adaLN_modulation_xs(c).chunk(6, dim=1)
shift_mca_audio, scale_mca_audio, gate_mca_audio, shift_mlp_audio, scale_mlp_audio, gate_mlp_audio = self.adaLN_modulation_audio(c).chunk(6, dim=1)
# Prepare for attention
att1,att2 = self.attn1(seq_len,modulate(self.norm1(x), shift_msa_xs, scale_msa_xs),modulate(self.norm2(y), shift_mca_audio, scale_mca_audio),mask,causal=causal,freqs_cis=freqs_cis,freqs_cis2=freqs_cis2)
x=x+gate_msa_xs.unsqueeze(1)*att1
y=y+gate_mca_audio.unsqueeze(1)*att2
x = x + gate_mlp_xs.unsqueeze(1) * self.xs_mlp(modulate(self.norm3(x), shift_mlp_xs, scale_mlp_xs))
y = y + gate_mlp_audio.unsqueeze(1) * self.audio_mlp(modulate(self.norm4(y), shift_mlp_audio, scale_mlp_audio))
return x,y
class CrossDiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning, contains CrossAttention.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
norm_type = block_kwargs.get("norm_type", "rms_norm")
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm1 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn1 = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn2 = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm3 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 9 * hidden_size, bias=True)
)
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out adaLN modulation layers in DiT blocks:
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
def forward(self, x, c, y,mask=None):
shift_msa, scale_msa, gate_msa, shift_mca, scale_mca, gate_mca, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(9, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn1(modulate(self.norm1(x), shift_msa, scale_msa),mask)
x = x + gate_mca.unsqueeze(1) * self.attn2(modulate(self.norm2(x), shift_mca, scale_mca), y,mask)
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm3(x), shift_mlp, scale_mlp))
return x
class SelfBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning, contains CrossAttention.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
norm_type = block_kwargs.get("norm_type", "rms_norm")
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn2 = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out adaLN modulation layers in DiT blocks:
def forward(self, x, y,mask=None):
x = x + self.attn2(self.norm2(x),mask)
return x
class CrossBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning, contains CrossAttention.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
norm_type = block_kwargs.get("norm_type", "rms_norm")
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn2 = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out adaLN modulation layers in DiT blocks:
def forward(self, x, y,mask=None):
x = x + self.attn2(self.norm2(x), y,mask)
return x
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, out_channels, norm_type="rms_norm"):
super().__init__()
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm_final = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out output layers:
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.linear.weight, 0)
nn.init.constant_(self.linear.bias, 0)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x