File size: 2,616 Bytes
9fd1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from transformers import AutoTokenizer, GlmConfig, GlmModel

from finetrainers.models.cogview4 import CogView4ControlModelSpecification
from finetrainers.models.utils import _expand_linear_with_zeroed_weights


class DummyCogView4ControlModelSpecification(CogView4ControlModelSpecification):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # This needs to be updated for the test to work correctly.
        # TODO(aryan): it will not be needed if we hosted the dummy model so that the correct config could be loaded
        # with ModelSpecification::_load_configs
        self.transformer_config.in_channels = 4

    def load_condition_models(self):
        text_encoder_config = GlmConfig(
            hidden_size=32, intermediate_size=8, num_hidden_layers=2, num_attention_heads=4, head_dim=8
        )
        text_encoder = GlmModel(text_encoder_config).to(self.text_encoder_dtype)
        # TODO(aryan): try to not rely on trust_remote_code by creating dummy tokenizer
        tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True)
        return {"text_encoder": text_encoder, "tokenizer": tokenizer}

    def load_latent_models(self):
        torch.manual_seed(0)
        vae = AutoencoderKL(
            block_out_channels=[32, 64],
            in_channels=3,
            out_channels=3,
            down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
            up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
            latent_channels=4,
            sample_size=128,
        ).to(self.vae_dtype)
        return {"vae": vae}

    def load_diffusion_models(self, new_in_features: int):
        torch.manual_seed(0)
        transformer = CogView4Transformer2DModel(
            patch_size=2,
            in_channels=4,
            num_layers=2,
            attention_head_dim=4,
            num_attention_heads=4,
            out_channels=4,
            text_embed_dim=32,
            time_embed_dim=8,
            condition_dim=4,
        ).to(self.transformer_dtype)
        actual_new_in_features = new_in_features * transformer.config.patch_size**2
        transformer.patch_embed.proj = _expand_linear_with_zeroed_weights(
            transformer.patch_embed.proj, new_in_features=actual_new_in_features
        )
        transformer.register_to_config(in_channels=new_in_features)

        scheduler = FlowMatchEulerDiscreteScheduler()

        return {"transformer": transformer, "scheduler": scheduler}