Spaces:
Running
Running
File size: 1,479 Bytes
9fd1204 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from transformers import AutoTokenizer, GlmModel
from finetrainers.models.cogview4 import CogView4ModelSpecification
class DummyCogView4ModelSpecification(CogView4ModelSpecification):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def load_condition_models(self):
text_encoder = GlmModel.from_pretrained(
"hf-internal-testing/tiny-random-cogview4", subfolder="text_encoder", torch_dtype=self.text_encoder_dtype
)
tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-cogview4", subfolder="tokenizer", trust_remote_code=True
)
return {"text_encoder": text_encoder, "tokenizer": tokenizer}
def load_latent_models(self):
torch.manual_seed(0)
vae = AutoencoderKL.from_pretrained(
"hf-internal-testing/tiny-random-cogview4", subfolder="vae", torch_dtype=self.vae_dtype
)
self.vae_config = vae.config
return {"vae": vae}
def load_diffusion_models(self):
torch.manual_seed(0)
transformer = CogView4Transformer2DModel.from_pretrained(
"hf-internal-testing/tiny-random-cogview4", subfolder="transformer", torch_dtype=self.transformer_dtype
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {"transformer": transformer, "scheduler": scheduler}
|