File size: 2,312 Bytes
fc0a183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gc
import os

import torch
from safetensors.torch import load_file

from .clip import CLIPModel
from .t5 import T5EncoderModel
from .transformer import WanModel
from .vae import WanVAE


def download_model(model_id):
    if not os.path.exists(model_id):
        from huggingface_hub import snapshot_download

        model_id = snapshot_download(repo_id=model_id)
    return model_id


def get_vae(model_path, device="cuda", weight_dtype=torch.float32) -> WanVAE:
    vae = WanVAE(model_path).to(device).to(weight_dtype)
    vae.vae.requires_grad_(False)
    vae.vae.eval()
    gc.collect()
    torch.cuda.empty_cache()
    return vae


def get_transformer(model_path, device="cuda", weight_dtype=torch.bfloat16) -> WanModel:
    config_path = os.path.join(model_path, "config.json")
    transformer = WanModel.from_config(config_path).to(weight_dtype).to(device)

    for file in os.listdir(model_path):
        if file.endswith(".safetensors"):
            file_path = os.path.join(model_path, file)
            state_dict = load_file(file_path)
            transformer.load_state_dict(state_dict, strict=False)
            del state_dict
            gc.collect()
            torch.cuda.empty_cache()

    transformer.requires_grad_(False)
    transformer.eval()
    gc.collect()
    torch.cuda.empty_cache()
    return transformer


def get_text_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> T5EncoderModel:
    t5_model = os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth")
    tokenizer_path = os.path.join(model_path, "google", "umt5-xxl")
    text_encoder = T5EncoderModel(checkpoint_path=t5_model, tokenizer_path=tokenizer_path).to(device).to(weight_dtype)
    text_encoder.requires_grad_(False)
    text_encoder.eval()
    gc.collect()
    torch.cuda.empty_cache()
    return text_encoder


def get_image_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> CLIPModel:
    checkpoint_path = os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
    tokenizer_path = os.path.join(model_path, "xlm-roberta-large")
    image_enc = CLIPModel(checkpoint_path, tokenizer_path).to(weight_dtype).to(device)
    image_enc.requires_grad_(False)
    image_enc.eval()
    gc.collect()
    torch.cuda.empty_cache()
    return image_enc