Spaces:
Running
Running
VideoModelStudio
/
docs
/finetrainers-src-codebase
/tests
/models
/cogview4
/control_specification.py
import torch | |
from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler | |
from transformers import AutoTokenizer, GlmConfig, GlmModel | |
from finetrainers.models.cogview4 import CogView4ControlModelSpecification | |
from finetrainers.models.utils import _expand_linear_with_zeroed_weights | |
class DummyCogView4ControlModelSpecification(CogView4ControlModelSpecification): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
# This needs to be updated for the test to work correctly. | |
# TODO(aryan): it will not be needed if we hosted the dummy model so that the correct config could be loaded | |
# with ModelSpecification::_load_configs | |
self.transformer_config.in_channels = 4 | |
def load_condition_models(self): | |
text_encoder_config = GlmConfig( | |
hidden_size=32, intermediate_size=8, num_hidden_layers=2, num_attention_heads=4, head_dim=8 | |
) | |
text_encoder = GlmModel(text_encoder_config).to(self.text_encoder_dtype) | |
# TODO(aryan): try to not rely on trust_remote_code by creating dummy tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True) | |
return {"text_encoder": text_encoder, "tokenizer": tokenizer} | |
def load_latent_models(self): | |
torch.manual_seed(0) | |
vae = AutoencoderKL( | |
block_out_channels=[32, 64], | |
in_channels=3, | |
out_channels=3, | |
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], | |
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], | |
latent_channels=4, | |
sample_size=128, | |
).to(self.vae_dtype) | |
return {"vae": vae} | |
def load_diffusion_models(self, new_in_features: int): | |
torch.manual_seed(0) | |
transformer = CogView4Transformer2DModel( | |
patch_size=2, | |
in_channels=4, | |
num_layers=2, | |
attention_head_dim=4, | |
num_attention_heads=4, | |
out_channels=4, | |
text_embed_dim=32, | |
time_embed_dim=8, | |
condition_dim=4, | |
).to(self.transformer_dtype) | |
actual_new_in_features = new_in_features * transformer.config.patch_size**2 | |
transformer.patch_embed.proj = _expand_linear_with_zeroed_weights( | |
transformer.patch_embed.proj, new_in_features=actual_new_in_features | |
) | |
transformer.register_to_config(in_channels=new_in_features) | |
scheduler = FlowMatchEulerDiscreteScheduler() | |
return {"transformer": transformer, "scheduler": scheduler} | |