Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,538 Bytes
90a9dd3 a7169e0 fe482e1 90a9dd3 fe482e1 90a9dd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
# functions to load models from config
import numpy as np
import torch
import re
import os
import math
from diffusers import (
BitsAndBytesConfig
)
from diffusers import AutoencoderTiny
from src.flair.pipelines import sd3
import os
from huggingface_hub import login
# If you really want an explicit login call:
login(token=os.environ["hf_token"])
@torch.no_grad()
def load_sd3(config, device):
if isinstance(device, list):
device = device[0]
if config["quantize"]:
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
else:
nf4_config = None
if config["model"] == "SD3.5-large":
pipe = sd3.SD3Wrapper.from_pretrained(
"stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.bfloat16, quantization_config=nf4_config
)
elif config["model"] == "SD3.5-large-turbo":
pipe = sd3.SD3Wrapper.from_pretrained(
"stabilityai/stable-diffusion-3.5-large-turbo", torch_dtype=torch.bfloat16, quantization_config=nf4_config,
)
else:
pipe = sd3.SD3Wrapper.from_pretrained("stabilityai/stable-diffusion-3.5-medium", torch_dtype=torch.float16, quantization_config=nf4_config)
# maybe use tiny autoencoder
if config["use_tiny_ae"]:
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd3", torch_dtype=torch.float16)
# encode prompts
inp_kwargs_list = []
prompts = config["prompt"]
pipe._guidance_scale = config["guidance"]
pipe._joint_attention_kwargs = {"ip_adapter_image_embeds": None}
for prompt in prompts:
print(f"Generating prompt embeddings for: {prompt}")
pipe.text_encoder.to(device).to(torch.bfloat16)
pipe.text_encoder_2.to(device).to(torch.bfloat16)
pipe.text_encoder_3.to(device).to(torch.bfloat16)
# encode
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(
prompt=prompt,
prompt_2=prompt,
prompt_3=prompt,
negative_prompt=config["negative_prompt"],
negative_prompt_2=config["negative_prompt"],
negative_prompt_3=config["negative_prompt"],
do_classifier_free_guidance=pipe.do_classifier_free_guidance,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
device=device,
clip_skip=None,
num_images_per_prompt=1,
max_sequence_length=256,
lora_scale=None,
)
if pipe.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
inp_kwargs = {
"prompt_embeds": prompt_embeds,
"pooled_prompt_embeds": pooled_prompt_embeds,
"guidance": config["guidance"],
}
inp_kwargs_list.append(inp_kwargs)
pipe.vae.to(device).to(torch.bfloat16)
pipe.transformer.to(device).to(torch.bfloat16)
return pipe, inp_kwargs_list
def load_model(config, device=["cuda"]):
if re.match(r"SD3*", config["model"]):
return load_sd3(config, device)
else:
raise ValueError(f"Unknown model type {config['model']}")
|