Spaces:
Running
on
Zero
Running
on
Zero
# functions to load models from config | |
import numpy as np | |
import torch | |
import re | |
import os | |
import math | |
from diffusers import ( | |
BitsAndBytesConfig | |
) | |
from diffusers import AutoencoderTiny | |
from src.flair.pipelines import sd3 | |
import os | |
from huggingface_hub import login | |
# If you really want an explicit login call: | |
login(token=os.environ["hf_token"]) | |
def load_sd3(config, device): | |
if isinstance(device, list): | |
device = device[0] | |
if config["quantize"]: | |
nf4_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
else: | |
nf4_config = None | |
if config["model"] == "SD3.5-large": | |
pipe = sd3.SD3Wrapper.from_pretrained( | |
"stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.bfloat16, quantization_config=nf4_config | |
) | |
elif config["model"] == "SD3.5-large-turbo": | |
pipe = sd3.SD3Wrapper.from_pretrained( | |
"stabilityai/stable-diffusion-3.5-large-turbo", torch_dtype=torch.bfloat16, quantization_config=nf4_config, | |
) | |
else: | |
pipe = sd3.SD3Wrapper.from_pretrained("stabilityai/stable-diffusion-3.5-medium", torch_dtype=torch.float16, quantization_config=nf4_config) | |
# maybe use tiny autoencoder | |
if config["use_tiny_ae"]: | |
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd3", torch_dtype=torch.float16) | |
# encode prompts | |
inp_kwargs_list = [] | |
prompts = config["prompt"] | |
pipe._guidance_scale = config["guidance"] | |
pipe._joint_attention_kwargs = {"ip_adapter_image_embeds": None} | |
for prompt in prompts: | |
print(f"Generating prompt embeddings for: {prompt}") | |
pipe.text_encoder.to(device).to(torch.bfloat16) | |
pipe.text_encoder_2.to(device).to(torch.bfloat16) | |
pipe.text_encoder_3.to(device).to(torch.bfloat16) | |
# encode | |
( | |
prompt_embeds, | |
negative_prompt_embeds, | |
pooled_prompt_embeds, | |
negative_pooled_prompt_embeds, | |
) = pipe.encode_prompt( | |
prompt=prompt, | |
prompt_2=prompt, | |
prompt_3=prompt, | |
negative_prompt=config["negative_prompt"], | |
negative_prompt_2=config["negative_prompt"], | |
negative_prompt_3=config["negative_prompt"], | |
do_classifier_free_guidance=pipe.do_classifier_free_guidance, | |
prompt_embeds=None, | |
negative_prompt_embeds=None, | |
pooled_prompt_embeds=None, | |
negative_pooled_prompt_embeds=None, | |
device=device, | |
clip_skip=None, | |
num_images_per_prompt=1, | |
max_sequence_length=256, | |
lora_scale=None, | |
) | |
if pipe.do_classifier_free_guidance: | |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) | |
inp_kwargs = { | |
"prompt_embeds": prompt_embeds, | |
"pooled_prompt_embeds": pooled_prompt_embeds, | |
"guidance": config["guidance"], | |
} | |
inp_kwargs_list.append(inp_kwargs) | |
pipe.vae.to(device).to(torch.bfloat16) | |
pipe.transformer.to(device).to(torch.bfloat16) | |
return pipe, inp_kwargs_list | |
def load_model(config, device=["cuda"]): | |
if re.match(r"SD3*", config["model"]): | |
return load_sd3(config, device) | |
else: | |
raise ValueError(f"Unknown model type {config['model']}") | |