Spaces:
Running
on
Zero
Running
on
Zero
from typing import Optional | |
import torch | |
from diffusers.models.embeddings import get_2d_rotary_pos_embed_lumina | |
from transformers import PretrainedConfig, PreTrainedModel | |
from blip3o.model.lumina_nextdit2d import LuminaNextDiT2DModel | |
class NextDiTCrossAttnConfig(PretrainedConfig): | |
model_type = "nextdit-crossattn" | |
def __init__( | |
self, | |
input_size: int = 8, | |
patch_size: int = 1, | |
in_channels: int = 1792, | |
dim: int = 1792, | |
n_layers: int = 24, | |
n_heads: int = 28, | |
n_kv_heads: int = 28, | |
multiple_of: int = 256, | |
ffn_dim_multiplier: Optional[float] = None, | |
norm_eps: float = 1e-5, | |
latent_embedding_size: int = 3584, | |
learn_sigma: bool = False, | |
qk_norm: bool = True, | |
_gradient_checkpointing: bool = True, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.input_size = input_size | |
self.patch_size = patch_size | |
self.in_channels = in_channels | |
self.dim = dim | |
self.n_layers = n_layers | |
self.n_heads = n_heads | |
self.n_kv_heads = n_kv_heads | |
self.multiple_of = multiple_of | |
self.ffn_dim_multiplier = ffn_dim_multiplier | |
self.norm_eps = norm_eps | |
self.learn_sigma = learn_sigma | |
self.qk_norm = qk_norm | |
self.latent_embedding_size = latent_embedding_size | |
self._gradient_checkpointing = _gradient_checkpointing | |
class NextDiTCrossAttn(PreTrainedModel): | |
config_class = NextDiTCrossAttnConfig | |
def __init__( | |
self, | |
config: NextDiTCrossAttnConfig, | |
) -> None: | |
super().__init__(config) | |
assert config.learn_sigma is False, "learn_sigma is not supported in nextdit-crossattn" | |
self._gradient_checkpointing = config._gradient_checkpointing | |
self.model = LuminaNextDiT2DModel( | |
sample_size=config.input_size, | |
patch_size=config.patch_size, | |
in_channels=config.in_channels, | |
hidden_size=config.dim, | |
num_layers=config.n_layers, | |
num_attention_heads=config.n_heads, | |
num_kv_heads=config.n_kv_heads, | |
multiple_of=config.multiple_of, | |
ffn_dim_multiplier=config.ffn_dim_multiplier, | |
norm_eps=config.norm_eps, | |
learn_sigma=config.learn_sigma, | |
qk_norm=config.qk_norm, | |
cross_attention_dim=config.latent_embedding_size, | |
) | |
if self._gradient_checkpointing: | |
self.model.enable_gradient_checkpointing() | |
# self.model.requires_grad_(False) | |
self.freqs_cis = get_2d_rotary_pos_embed_lumina( | |
config.dim // config.n_heads, | |
384, | |
384, | |
) | |
def forward(self, x, timestep, z_latents, **kwargs): | |
model_pred = self.model( | |
hidden_states=x, | |
timestep=timestep, | |
encoder_hidden_states=z_latents, | |
encoder_mask=torch.ones((z_latents.shape[0], z_latents.shape[1]), device=z_latents.device), | |
image_rotary_emb=self.freqs_cis, | |
cross_attention_kwargs=dict(), | |
).sample | |
return model_pred | |