Spaces:
Running
Running
File size: 2,312 Bytes
fc0a183 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import gc
import os
import torch
from safetensors.torch import load_file
from .clip import CLIPModel
from .t5 import T5EncoderModel
from .transformer import WanModel
from .vae import WanVAE
def download_model(model_id):
if not os.path.exists(model_id):
from huggingface_hub import snapshot_download
model_id = snapshot_download(repo_id=model_id)
return model_id
def get_vae(model_path, device="cuda", weight_dtype=torch.float32) -> WanVAE:
vae = WanVAE(model_path).to(device).to(weight_dtype)
vae.vae.requires_grad_(False)
vae.vae.eval()
gc.collect()
torch.cuda.empty_cache()
return vae
def get_transformer(model_path, device="cuda", weight_dtype=torch.bfloat16) -> WanModel:
config_path = os.path.join(model_path, "config.json")
transformer = WanModel.from_config(config_path).to(weight_dtype).to(device)
for file in os.listdir(model_path):
if file.endswith(".safetensors"):
file_path = os.path.join(model_path, file)
state_dict = load_file(file_path)
transformer.load_state_dict(state_dict, strict=False)
del state_dict
gc.collect()
torch.cuda.empty_cache()
transformer.requires_grad_(False)
transformer.eval()
gc.collect()
torch.cuda.empty_cache()
return transformer
def get_text_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> T5EncoderModel:
t5_model = os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth")
tokenizer_path = os.path.join(model_path, "google", "umt5-xxl")
text_encoder = T5EncoderModel(checkpoint_path=t5_model, tokenizer_path=tokenizer_path).to(device).to(weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
gc.collect()
torch.cuda.empty_cache()
return text_encoder
def get_image_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> CLIPModel:
checkpoint_path = os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
tokenizer_path = os.path.join(model_path, "xlm-roberta-large")
image_enc = CLIPModel(checkpoint_path, tokenizer_path).to(weight_dtype).to(device)
image_enc.requires_grad_(False)
image_enc.eval()
gc.collect()
torch.cuda.empty_cache()
return image_enc
|