|
import torch |
|
from diffusers import AutoencoderKL |
|
|
|
def get_vae(version, file_path=None, fp16=False): |
|
"""Load VAE from file or default hf repo. fp16 only works from hf""" |
|
vae = None |
|
dtype = torch.float16 if fp16 else torch.float32 |
|
if version == "v1" and file_path: |
|
vae = AutoencoderKL.from_single_file( |
|
file_path, |
|
image_size=512, |
|
) |
|
elif version == "v1": |
|
vae = AutoencoderKL.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
subfolder="vae", |
|
torch_dtype=dtype, |
|
) |
|
elif version == "v2" and file_path: |
|
vae = AutoencoderKL.from_single_file( |
|
file_path, |
|
image_size=768, |
|
) |
|
elif version == "v2": |
|
vae = AutoencoderKL.from_pretrained( |
|
"stabilityai/stable-diffusion-2-1", |
|
subfolder="vae", |
|
torch_dtype=dtype, |
|
) |
|
elif version == "xl" and file_path: |
|
vae = AutoencoderKL.from_single_file( |
|
file_path, |
|
image_size=1024 |
|
) |
|
elif version == "xl" and fp16: |
|
vae = AutoencoderKL.from_pretrained( |
|
"madebyollin/sdxl-vae-fp16-fix", |
|
torch_dtype=torch.float16, |
|
) |
|
elif version == "xl": |
|
vae = AutoencoderKL.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
subfolder="vae" |
|
) |
|
else: |
|
input("Invalid VAE version. Press any key to exit") |
|
exit(1) |
|
return vae |
|
|