File size: 3,538 Bytes
90a9dd3
 
 
 
 
 
 
 
 
 
 
 
a7169e0
fe482e1
 
90a9dd3
fe482e1
 
90a9dd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# functions to load models from config
import numpy as np
import torch
import re
import os
import math
from diffusers import (
    BitsAndBytesConfig
)
from diffusers import AutoencoderTiny


from src.flair.pipelines import sd3
import os
from huggingface_hub import login

# If you really want an explicit login call:
login(token=os.environ["hf_token"])


@torch.no_grad()
def load_sd3(config, device):
    if isinstance(device, list):
        device = device[0]
    if config["quantize"]:
        nf4_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
    else:
        nf4_config = None
    if config["model"] == "SD3.5-large":
        pipe = sd3.SD3Wrapper.from_pretrained(
            "stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.bfloat16, quantization_config=nf4_config
        )
    elif config["model"] == "SD3.5-large-turbo":
        pipe = sd3.SD3Wrapper.from_pretrained(
            "stabilityai/stable-diffusion-3.5-large-turbo", torch_dtype=torch.bfloat16, quantization_config=nf4_config,
        )
    else:
        pipe = sd3.SD3Wrapper.from_pretrained("stabilityai/stable-diffusion-3.5-medium", torch_dtype=torch.float16, quantization_config=nf4_config)
    # maybe use tiny autoencoder
    if config["use_tiny_ae"]:
        pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd3", torch_dtype=torch.float16)

    # encode prompts
    inp_kwargs_list = []
    prompts = config["prompt"]
    pipe._guidance_scale = config["guidance"]
    pipe._joint_attention_kwargs = {"ip_adapter_image_embeds": None}
    for prompt in prompts:
        print(f"Generating prompt embeddings for: {prompt}")
        pipe.text_encoder.to(device).to(torch.bfloat16)
        pipe.text_encoder_2.to(device).to(torch.bfloat16)
        pipe.text_encoder_3.to(device).to(torch.bfloat16)
        # encode
        (
        prompt_embeds,
        negative_prompt_embeds,
        pooled_prompt_embeds,
        negative_pooled_prompt_embeds,
        ) = pipe.encode_prompt(
            prompt=prompt,
            prompt_2=prompt,
            prompt_3=prompt,
            negative_prompt=config["negative_prompt"],
            negative_prompt_2=config["negative_prompt"],
            negative_prompt_3=config["negative_prompt"],
            do_classifier_free_guidance=pipe.do_classifier_free_guidance,
            prompt_embeds=None,
            negative_prompt_embeds=None,
            pooled_prompt_embeds=None,
            negative_pooled_prompt_embeds=None,
            device=device,
            clip_skip=None,
            num_images_per_prompt=1,
            max_sequence_length=256,
            lora_scale=None,
        )

        if pipe.do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
            pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
        inp_kwargs = {
            "prompt_embeds": prompt_embeds,
            "pooled_prompt_embeds": pooled_prompt_embeds,
            "guidance": config["guidance"],
        }
        inp_kwargs_list.append(inp_kwargs)
    pipe.vae.to(device).to(torch.bfloat16)
    pipe.transformer.to(device).to(torch.bfloat16)


    return pipe, inp_kwargs_list

def load_model(config, device=["cuda"]):
    if re.match(r"SD3*", config["model"]):
        return load_sd3(config, device)
    else:
        raise ValueError(f"Unknown model type {config['model']}")