File size: 1,894 Bytes
357c94c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from pathlib import Path
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
from ..constants import VAE_PATH, PRECISION_TO_TYPE

def load_vae(vae_type,
             vae_precision=None,
             sample_size=None,
             vae_path=None,
             logger=None,
             device=None
             ):
    if vae_path is None:
        vae_path = VAE_PATH[vae_type]
    vae_compress_spec, _, _ = vae_type.split("-")
    length = len(vae_compress_spec)
    if length == 3:
        if logger is not None:
            logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
        config = AutoencoderKLCausal3D.load_config(vae_path)
        if sample_size:
            vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
        else:
            vae = AutoencoderKLCausal3D.from_config(config)
        ckpt = torch.load(Path(vae_path) / "pytorch_model.pt", map_location=vae.device)
        if "state_dict" in ckpt:
            ckpt = ckpt["state_dict"]
        # vae_ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
        vae_ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items()}  
        vae.load_state_dict(vae_ckpt)

        spatial_compression_ratio = vae.config.spatial_compression_ratio
        time_compression_ratio = vae.config.time_compression_ratio
    else:
        raise ValueError(f"Invalid VAE model: {vae_type}. Must be 3D VAE in the format of '???-*'.")

    if vae_precision is not None:
        vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])

    vae.requires_grad_(False)

    if logger is not None:
        logger.info(f"VAE to dtype: {vae.dtype}")

    if device is not None:
        vae = vae.to(device)

    # Set vae to eval mode, even though it's dropout rate is 0.
    vae.eval()

    return vae, vae_path, spatial_compression_ratio, time_compression_ratio