jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
raw
history blame
2.62 kB
import torch
from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from transformers import AutoTokenizer, GlmConfig, GlmModel
from finetrainers.models.cogview4 import CogView4ControlModelSpecification
from finetrainers.models.utils import _expand_linear_with_zeroed_weights
class DummyCogView4ControlModelSpecification(CogView4ControlModelSpecification):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# This needs to be updated for the test to work correctly.
# TODO(aryan): it will not be needed if we hosted the dummy model so that the correct config could be loaded
# with ModelSpecification::_load_configs
self.transformer_config.in_channels = 4
def load_condition_models(self):
text_encoder_config = GlmConfig(
hidden_size=32, intermediate_size=8, num_hidden_layers=2, num_attention_heads=4, head_dim=8
)
text_encoder = GlmModel(text_encoder_config).to(self.text_encoder_dtype)
# TODO(aryan): try to not rely on trust_remote_code by creating dummy tokenizer
tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True)
return {"text_encoder": text_encoder, "tokenizer": tokenizer}
def load_latent_models(self):
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=128,
).to(self.vae_dtype)
return {"vae": vae}
def load_diffusion_models(self, new_in_features: int):
torch.manual_seed(0)
transformer = CogView4Transformer2DModel(
patch_size=2,
in_channels=4,
num_layers=2,
attention_head_dim=4,
num_attention_heads=4,
out_channels=4,
text_embed_dim=32,
time_embed_dim=8,
condition_dim=4,
).to(self.transformer_dtype)
actual_new_in_features = new_in_features * transformer.config.patch_size**2
transformer.patch_embed.proj = _expand_linear_with_zeroed_weights(
transformer.patch_embed.proj, new_in_features=actual_new_in_features
)
transformer.register_to_config(in_channels=new_in_features)
scheduler = FlowMatchEulerDiscreteScheduler()
return {"transformer": transformer, "scheduler": scheduler}