File size: 1,220 Bytes
baa8e90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import torch
from diffusers import AutoencoderKL
def get_vae(version, file_path=None, fp16=False):
"""Load VAE from file or default hf repo. fp16 only works from hf"""
vae = None
dtype = torch.float16 if fp16 else torch.float32
if version == "v1" and file_path:
vae = AutoencoderKL.from_single_file(
file_path,
image_size=512,
)
elif version == "v1":
vae = AutoencoderKL.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="vae",
torch_dtype=dtype,
)
elif version == "v2" and file_path:
vae = AutoencoderKL.from_single_file(
file_path,
image_size=768,
)
elif version == "v2":
vae = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-2-1",
subfolder="vae",
torch_dtype=dtype,
)
elif version == "xl" and file_path:
vae = AutoencoderKL.from_single_file(
file_path,
image_size=1024
)
elif version == "xl" and fp16:
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
elif version == "xl":
vae = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="vae"
)
else:
input("Invalid VAE version. Press any key to exit")
exit(1)
return vae
|