File size: 3,162 Bytes
d0cbcd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import Optional

import torch
from diffusers.models.embeddings import get_2d_rotary_pos_embed_lumina
from transformers import PretrainedConfig, PreTrainedModel

from blip3o.model.lumina_nextdit2d import LuminaNextDiT2DModel


class NextDiTCrossAttnConfig(PretrainedConfig):
    model_type = "nextdit-crossattn"

    def __init__(
        self,
        input_size: int = 8,
        patch_size: int = 1,
        in_channels: int = 1792,
        dim: int = 1792,
        n_layers: int = 24,
        n_heads: int = 28,
        n_kv_heads: int = 28,
        multiple_of: int = 256,
        ffn_dim_multiplier: Optional[float] = None,
        norm_eps: float = 1e-5,
        latent_embedding_size: int = 3584,
        learn_sigma: bool = False,
        qk_norm: bool = True,
        _gradient_checkpointing: bool = True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.input_size = input_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.dim = dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.multiple_of = multiple_of
        self.ffn_dim_multiplier = ffn_dim_multiplier
        self.norm_eps = norm_eps
        self.learn_sigma = learn_sigma
        self.qk_norm = qk_norm
        self.latent_embedding_size = latent_embedding_size
        self._gradient_checkpointing = _gradient_checkpointing


class NextDiTCrossAttn(PreTrainedModel):
    config_class = NextDiTCrossAttnConfig

    def __init__(
        self,
        config: NextDiTCrossAttnConfig,
    ) -> None:
        super().__init__(config)
        assert config.learn_sigma is False, "learn_sigma is not supported in nextdit-crossattn"
        self._gradient_checkpointing = config._gradient_checkpointing

        self.model = LuminaNextDiT2DModel(
            sample_size=config.input_size,
            patch_size=config.patch_size,
            in_channels=config.in_channels,
            hidden_size=config.dim,
            num_layers=config.n_layers,
            num_attention_heads=config.n_heads,
            num_kv_heads=config.n_kv_heads,
            multiple_of=config.multiple_of,
            ffn_dim_multiplier=config.ffn_dim_multiplier,
            norm_eps=config.norm_eps,
            learn_sigma=config.learn_sigma,
            qk_norm=config.qk_norm,
            cross_attention_dim=config.latent_embedding_size,
        )

        if self._gradient_checkpointing:
            self.model.enable_gradient_checkpointing()

        # self.model.requires_grad_(False)

        self.freqs_cis = get_2d_rotary_pos_embed_lumina(
            config.dim // config.n_heads,
            384,
            384,
        )

    def forward(self, x, timestep, z_latents, **kwargs):
        model_pred = self.model(
            hidden_states=x,
            timestep=timestep,
            encoder_hidden_states=z_latents,
            encoder_mask=torch.ones((z_latents.shape[0], z_latents.shape[1]), device=z_latents.device),
            image_rotary_emb=self.freqs_cis,
            cross_attention_kwargs=dict(),
        ).sample
        return model_pred