jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
import torch
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXTransformer3DModel
from transformers import AutoTokenizer, T5EncoderModel
from finetrainers.models.cogvideox import CogVideoXModelSpecification
class DummyCogVideoXModelSpecification(CogVideoXModelSpecification):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def load_condition_models(self):
text_encoder = T5EncoderModel.from_pretrained(
"hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
return {"text_encoder": text_encoder, "tokenizer": tokenizer}
def load_latent_models(self):
torch.manual_seed(0)
vae = AutoencoderKLCogVideoX(
in_channels=3,
out_channels=3,
down_block_types=(
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
),
up_block_types=(
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
),
block_out_channels=(8, 8, 8, 8),
latent_channels=4,
layers_per_block=1,
norm_num_groups=2,
temporal_compression_ratio=4,
)
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
# Doing so overrides things like _keep_in_fp32_modules
vae.to(self.vae_dtype)
self.vae_config = vae.config
return {"vae": vae}
def load_diffusion_models(self):
torch.manual_seed(0)
transformer = CogVideoXTransformer3DModel(
num_attention_heads=4,
attention_head_dim=16,
in_channels=4,
out_channels=4,
time_embed_dim=2,
text_embed_dim=32,
num_layers=2,
sample_width=24,
sample_height=24,
sample_frames=9,
patch_size=2,
temporal_compression_ratio=4,
max_text_seq_length=16,
use_rotary_positional_embeddings=True,
)
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
# Doing so overrides things like _keep_in_fp32_modules
transformer.to(self.transformer_dtype)
self.transformer_config = transformer.config
scheduler = CogVideoXDDIMScheduler()
return {"transformer": transformer, "scheduler": scheduler}