blip-3o / blip3o /model /nextdit_crossattn.py
multimodalart's picture
Upload 69 files
d0cbcd5 verified
from typing import Optional
import torch
from diffusers.models.embeddings import get_2d_rotary_pos_embed_lumina
from transformers import PretrainedConfig, PreTrainedModel
from blip3o.model.lumina_nextdit2d import LuminaNextDiT2DModel
class NextDiTCrossAttnConfig(PretrainedConfig):
model_type = "nextdit-crossattn"
def __init__(
self,
input_size: int = 8,
patch_size: int = 1,
in_channels: int = 1792,
dim: int = 1792,
n_layers: int = 24,
n_heads: int = 28,
n_kv_heads: int = 28,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
latent_embedding_size: int = 3584,
learn_sigma: bool = False,
qk_norm: bool = True,
_gradient_checkpointing: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.input_size = input_size
self.patch_size = patch_size
self.in_channels = in_channels
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.multiple_of = multiple_of
self.ffn_dim_multiplier = ffn_dim_multiplier
self.norm_eps = norm_eps
self.learn_sigma = learn_sigma
self.qk_norm = qk_norm
self.latent_embedding_size = latent_embedding_size
self._gradient_checkpointing = _gradient_checkpointing
class NextDiTCrossAttn(PreTrainedModel):
config_class = NextDiTCrossAttnConfig
def __init__(
self,
config: NextDiTCrossAttnConfig,
) -> None:
super().__init__(config)
assert config.learn_sigma is False, "learn_sigma is not supported in nextdit-crossattn"
self._gradient_checkpointing = config._gradient_checkpointing
self.model = LuminaNextDiT2DModel(
sample_size=config.input_size,
patch_size=config.patch_size,
in_channels=config.in_channels,
hidden_size=config.dim,
num_layers=config.n_layers,
num_attention_heads=config.n_heads,
num_kv_heads=config.n_kv_heads,
multiple_of=config.multiple_of,
ffn_dim_multiplier=config.ffn_dim_multiplier,
norm_eps=config.norm_eps,
learn_sigma=config.learn_sigma,
qk_norm=config.qk_norm,
cross_attention_dim=config.latent_embedding_size,
)
if self._gradient_checkpointing:
self.model.enable_gradient_checkpointing()
# self.model.requires_grad_(False)
self.freqs_cis = get_2d_rotary_pos_embed_lumina(
config.dim // config.n_heads,
384,
384,
)
def forward(self, x, timestep, z_latents, **kwargs):
model_pred = self.model(
hidden_states=x,
timestep=timestep,
encoder_hidden_states=z_latents,
encoder_mask=torch.ones((z_latents.shape[0], z_latents.shape[1]), device=z_latents.device),
image_rotary_emb=self.freqs_cis,
cross_attention_kwargs=dict(),
).sample
return model_pred