diff --git "a/stable_diffusion_model.py" "b/stable_diffusion_model.py" deleted file mode 100644--- "a/stable_diffusion_model.py" +++ /dev/null @@ -1,2745 +0,0 @@ -import copy -import gc -import json -import random -import shutil -import typing -from typing import Union, List, Literal, Iterator -import sys -import os -from collections import OrderedDict -import copy -import yaml -from PIL import Image -from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, \ - ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN -from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg -from safetensors.torch import save_file, load_file -from torch import autocast -from torch.nn import Parameter -from torch.utils.checkpoint import checkpoint -from tqdm import tqdm -from torchvision.transforms import Resize, transforms - -from toolkit.assistant_lora import load_assistant_lora_from_path -from toolkit.clip_vision_adapter import ClipVisionAdapter -from toolkit.custom_adapter import CustomAdapter -from toolkit.ip_adapter import IPAdapter -from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ - convert_vae_state_dict, load_vae -from toolkit import train_tools -from toolkit.config_modules import ModelConfig, GenerateImageConfig -from toolkit.metadata import get_meta_for_safetensors -from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT -from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds -from toolkit.reference_adapter import ReferenceAdapter -from toolkit.sampler import get_sampler -from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler -from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers -from toolkit.sd_device_states_presets import empty_preset -from toolkit.train_tools import get_torch_dtype, apply_noise_offset -from einops import rearrange, repeat -import torch -from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ - StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline -from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ - StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \ - StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ - StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \ - StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \ - FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel -import diffusers -from diffusers import \ - AutoencoderKL, \ - UNet2DConditionModel -from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline -from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T5TokenizerFast -from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection - -from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT -from huggingface_hub import hf_hub_download - -from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from toolkit.lora_special import LoRASpecialNetwork - -# tell it to shut up -diffusers.logging.set_verbosity(diffusers.logging.ERROR) - -SD_PREFIX_VAE = "vae" -SD_PREFIX_UNET = "unet" -SD_PREFIX_REFINER_UNET = "refiner_unet" -SD_PREFIX_TEXT_ENCODER = "te" - -SD_PREFIX_TEXT_ENCODER1 = "te0" -SD_PREFIX_TEXT_ENCODER2 = "te1" - -# prefixed diffusers keys -DO_NOT_TRAIN_WEIGHTS = [ - "unet_time_embedding.linear_1.bias", - "unet_time_embedding.linear_1.weight", - "unet_time_embedding.linear_2.bias", - "unet_time_embedding.linear_2.weight", - "refiner_unet_time_embedding.linear_1.bias", - "refiner_unet_time_embedding.linear_1.weight", - "refiner_unet_time_embedding.linear_2.bias", - "refiner_unet_time_embedding.linear_2.weight", -] - -DeviceStatePreset = Literal['cache_latents', 'generate'] - - -class BlankNetwork: - - def __init__(self): - self.multiplier = 1.0 - self.is_active = True - self.is_merged_in = False - self.can_merge_in = False - - def __enter__(self): - self.is_active = True - - def __exit__(self, exc_type, exc_val, exc_tb): - self.is_active = False - - -def flush(): - torch.cuda.empty_cache() - gc.collect() - - -UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 -# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 - - - -class StableDiffusion: - - def __init__( - self, - device, - model_config: ModelConfig, - dtype='fp16', - custom_pipeline=None, - noise_scheduler=None, - quantize_device=None, - ): - self.custom_pipeline = custom_pipeline - self.device = device - self.dtype = dtype - self.torch_dtype = get_torch_dtype(dtype) - self.device_torch = torch.device(self.device) - - self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device( - model_config.vae_device) - self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) - - self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device( - model_config.te_device) - self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) - - self.model_config = model_config - self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" - - self.device_state = None - - self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] - self.vae: Union[None, 'AutoencoderKL'] - self.unet: Union[None, 'UNet2DConditionModel'] - self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] - self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] - self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler - - self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None - self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None - - # sdxl stuff - self.logit_scale = None - self.ckppt_info = None - self.is_loaded = False - - # to hold network if there is one - self.network = None - self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None - self.is_xl = model_config.is_xl - self.is_v2 = model_config.is_v2 - self.is_ssd = model_config.is_ssd - self.is_v3 = model_config.is_v3 - self.is_vega = model_config.is_vega - self.is_pixart = model_config.is_pixart - self.is_auraflow = model_config.is_auraflow - self.is_flux = model_config.is_flux - - self.use_text_encoder_1 = model_config.use_text_encoder_1 - self.use_text_encoder_2 = model_config.use_text_encoder_2 - - self.config_file = None - - self.is_flow_matching = False - if self.is_flux or self.is_v3 or self.is_auraflow or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler): - self.is_flow_matching = True - - self.quantize_device = quantize_device if quantize_device is not None else self.device - self.low_vram = self.model_config.low_vram - - # merge in and preview active with -1 weight - self.invert_assistant_lora = False - - def load_model(self): - if self.is_loaded: - return - dtype = get_torch_dtype(self.dtype) - - # move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why - # self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch) - # self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch) - # self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch) - - model_path = self.model_config.name_or_path - if 'civitai.com' in self.model_config.name_or_path: - # load is a civit ai model, use the loader. - from toolkit.civitai import get_model_path_from_url - model_path = get_model_path_from_url(self.model_config.name_or_path) - - load_args = {} - if self.noise_scheduler: - load_args['scheduler'] = self.noise_scheduler - - if self.model_config.vae_path is not None: - load_args['vae'] = load_vae(self.model_config.vae_path, dtype) - if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega: - if self.custom_pipeline is not None: - pipln = self.custom_pipeline - else: - pipln = StableDiffusionXLPipeline - # pipln = StableDiffusionKDiffusionXLPipeline - - # see if path exists - if not os.path.exists(model_path) or os.path.isdir(model_path): - # try to load with default diffusers - pipe = pipln.from_pretrained( - model_path, - dtype=dtype, - device=self.device_torch, - # variant="fp16", - use_safetensors=True, - **load_args - ) - else: - pipe = pipln.from_single_file( - model_path, - device=self.device_torch, - torch_dtype=self.torch_dtype, - ) - - if 'vae' in load_args and load_args['vae'] is not None: - pipe.vae = load_args['vae'] - flush() - - text_encoders = [pipe.text_encoder, pipe.text_encoder_2] - tokenizer = [pipe.tokenizer, pipe.tokenizer_2] - for text_encoder in text_encoders: - text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) - text_encoder.requires_grad_(False) - text_encoder.eval() - text_encoder = text_encoders - - pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) - - if self.model_config.experimental_xl: - print("Experimental XL mode enabled") - print("Loading and injecting alt weights") - # load the mismatched weight and force it in - raw_state_dict = load_file(model_path) - replacement_weight = raw_state_dict['conditioner.embedders.1.model.text_projection'].clone() - del raw_state_dict - # get state dict for for 2nd text encoder - te1_state_dict = text_encoders[1].state_dict() - # replace weight with mismatched weight - te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype) - flush() - print("Injecting alt weights") - elif self.model_config.is_v3: - if self.custom_pipeline is not None: - pipln = self.custom_pipeline - else: - pipln = StableDiffusion3Pipeline - - print("Loading SD3 model") - # assume it is the large model - base_model_path = "stabilityai/stable-diffusion-3.5-large" - print("Loading transformer") - subfolder = 'transformer' - transformer_path = model_path - # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set - if os.path.exists(transformer_path): - subfolder = None - transformer_path = os.path.join(transformer_path, 'transformer') - # check if the path is a full checkpoint. - te_folder_path = os.path.join(model_path, 'text_encoder') - # if we have the te, this folder is a full checkpoint, use it as the base - if os.path.exists(te_folder_path): - base_model_path = model_path - else: - # is remote use whatever path we were given - base_model_path = model_path - - transformer = SD3Transformer2DModel.from_pretrained( - transformer_path, - subfolder=subfolder, - torch_dtype=dtype, - ) - if not self.low_vram: - # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu - transformer.to(torch.device(self.quantize_device), dtype=dtype) - flush() - - if self.model_config.lora_path is not None: - raise ValueError("LoRA is not supported for SD3 models currently") - - if self.model_config.quantize: - quantization_type = qfloat8 - print("Quantizing transformer") - quantize(transformer, weights=quantization_type) - freeze(transformer) - transformer.to(self.device_torch) - else: - transformer.to(self.device_torch, dtype=dtype) - - scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") - print("Loading vae") - vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) - flush() - - print("Loading t5") - tokenizer_3 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_3", torch_dtype=dtype) - text_encoder_3 = T5EncoderModel.from_pretrained( - base_model_path, - subfolder="text_encoder_3", - torch_dtype=dtype - ) - - text_encoder_3.to(self.device_torch, dtype=dtype) - flush() - - if self.model_config.quantize: - print("Quantizing T5") - quantize(text_encoder_3, weights=qfloat8) - freeze(text_encoder_3) - flush() - - - # see if path exists - if not os.path.exists(model_path) or os.path.isdir(model_path): - try: - # try to load with default diffusers - pipe = pipln.from_pretrained( - base_model_path, - dtype=dtype, - device=self.device_torch, - tokenizer_3=tokenizer_3, - text_encoder_3=text_encoder_3, - transformer=transformer, - # variant="fp16", - use_safetensors=True, - repo_type="model", - ignore_patterns=["*.md", "*..gitattributes"], - **load_args - ) - except Exception as e: - print(f"Error loading from pretrained: {e}") - raise e - - else: - pipe = pipln.from_single_file( - model_path, - transformer=transformer, - device=self.device_torch, - torch_dtype=self.torch_dtype, - tokenizer_3=tokenizer_3, - text_encoder_3=text_encoder_3, - **load_args - ) - - flush() - - text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3] - tokenizer = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3] - # replace the to function with a no-op since it throws an error instead of a warning - # text_encoders[2].to = lambda *args, **kwargs: None - for text_encoder in text_encoders: - text_encoder.to(self.device_torch, dtype=dtype) - text_encoder.requires_grad_(False) - text_encoder.eval() - text_encoder = text_encoders - - - elif self.model_config.is_pixart: - te_kwargs = {} - # handle quantization of TE - te_is_quantized = False - if self.model_config.text_encoder_bits == 8: - te_kwargs['load_in_8bit'] = True - te_kwargs['device_map'] = "auto" - te_is_quantized = True - elif self.model_config.text_encoder_bits == 4: - te_kwargs['load_in_4bit'] = True - te_kwargs['device_map'] = "auto" - te_is_quantized = True - - main_model_path = "PixArt-alpha/PixArt-XL-2-1024-MS" - if self.model_config.is_pixart_sigma: - main_model_path = "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" - - main_model_path = model_path - - # load the TE in 8bit mode - text_encoder = T5EncoderModel.from_pretrained( - main_model_path, - subfolder="text_encoder", - torch_dtype=self.torch_dtype, - **te_kwargs - ) - - # load the transformer - subfolder = "transformer" - # check if it is just the unet - if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)): - subfolder = None - - if te_is_quantized: - # replace the to function with a no-op since it throws an error instead of a warning - text_encoder.to = lambda *args, **kwargs: None - - text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) - - if self.model_config.is_pixart_sigma: - # load the transformer only from the save - transformer = Transformer2DModel.from_pretrained( - model_path if self.model_config.unet_path is None else self.model_config.unet_path, - torch_dtype=self.torch_dtype, - subfolder='transformer' - ) - pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained( - main_model_path, - transformer=transformer, - text_encoder=text_encoder, - dtype=dtype, - device=self.device_torch, - **load_args - ) - - else: - - # load the transformer only from the save - transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype, - subfolder=subfolder) - pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained( - main_model_path, - transformer=transformer, - text_encoder=text_encoder, - dtype=dtype, - device=self.device_torch, - **load_args - ).to(self.device_torch) - - if self.model_config.unet_sample_size is not None: - pipe.transformer.config.sample_size = self.model_config.unet_sample_size - pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) - - flush() - # text_encoder = pipe.text_encoder - # text_encoder.to(self.device_torch, dtype=dtype) - text_encoder.requires_grad_(False) - text_encoder.eval() - pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) - tokenizer = pipe.tokenizer - - pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) - if self.noise_scheduler is None: - self.noise_scheduler = pipe.scheduler - - - elif self.model_config.is_auraflow: - te_kwargs = {} - # handle quantization of TE - te_is_quantized = False - if self.model_config.text_encoder_bits == 8: - te_kwargs['load_in_8bit'] = True - te_kwargs['device_map'] = "auto" - te_is_quantized = True - elif self.model_config.text_encoder_bits == 4: - te_kwargs['load_in_4bit'] = True - te_kwargs['device_map'] = "auto" - te_is_quantized = True - - main_model_path = model_path - - # load the TE in 8bit mode - text_encoder = UMT5EncoderModel.from_pretrained( - main_model_path, - subfolder="text_encoder", - torch_dtype=self.torch_dtype, - **te_kwargs - ) - - # load the transformer - subfolder = "transformer" - # check if it is just the unet - if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)): - subfolder = None - - if te_is_quantized: - # replace the to function with a no-op since it throws an error instead of a warning - text_encoder.to = lambda *args, **kwargs: None - - # load the transformer only from the save - transformer = AuraFlowTransformer2DModel.from_pretrained( - model_path if self.model_config.unet_path is None else self.model_config.unet_path, - torch_dtype=self.torch_dtype, - subfolder='transformer' - ) - pipe: AuraFlowPipeline = AuraFlowPipeline.from_pretrained( - main_model_path, - transformer=transformer, - text_encoder=text_encoder, - dtype=dtype, - device=self.device_torch, - **load_args - ) - - pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) - - # patch auraflow so it can handle other aspect ratios - # patch_auraflow_pos_embed(pipe.transformer.pos_embed) - - flush() - # text_encoder = pipe.text_encoder - # text_encoder.to(self.device_torch, dtype=dtype) - text_encoder.requires_grad_(False) - text_encoder.eval() - pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) - tokenizer = pipe.tokenizer - - elif self.model_config.is_flux: - print("Loading Flux model") - base_model_path = "black-forest-labs/FLUX.1-schnell" - print("Loading transformer") - subfolder = 'transformer' - transformer_path = model_path - local_files_only = False - # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set - if os.path.exists(transformer_path): - subfolder = None - transformer_path = os.path.join(transformer_path, 'transformer') - # check if the path is a full checkpoint. - te_folder_path = os.path.join(model_path, 'text_encoder') - # if we have the te, this folder is a full checkpoint, use it as the base - if os.path.exists(te_folder_path): - base_model_path = model_path - - transformer = FluxTransformer2DModel.from_pretrained( - transformer_path, - subfolder=subfolder, - torch_dtype=dtype, - # low_cpu_mem_usage=False, - # device_map=None - ) - if not self.low_vram: - # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu - transformer.to(torch.device(self.quantize_device), dtype=dtype) - flush() - - if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: - if self.model_config.inference_lora_path is not None and self.model_config.assistant_lora_path is not None: - raise ValueError("Cannot load both assistant lora and inference lora at the same time") - - if self.model_config.lora_path: - raise ValueError("Cannot load both assistant lora and lora at the same time") - - if not self.is_flux: - raise ValueError("Assistant/ inference lora is only supported for flux models currently") - - load_lora_path = self.model_config.inference_lora_path - if load_lora_path is None: - load_lora_path = self.model_config.assistant_lora_path - - if os.path.isdir(load_lora_path): - load_lora_path = os.path.join( - load_lora_path, "pytorch_lora_weights.safetensors" - ) - elif not os.path.exists(load_lora_path): - print(f"Grabbing lora from the hub: {load_lora_path}") - new_lora_path = hf_hub_download( - load_lora_path, - filename="pytorch_lora_weights.safetensors" - ) - # replace the path - load_lora_path = new_lora_path - - if self.model_config.inference_lora_path is not None: - self.model_config.inference_lora_path = new_lora_path - if self.model_config.assistant_lora_path is not None: - self.model_config.assistant_lora_path = new_lora_path - - if self.model_config.assistant_lora_path is not None: - # for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on - # quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps - # it is better to merge it in now, and sample slowly later, otherwise training is slowed in half - # so we will merge in now and sample with -1 weight later - self.invert_assistant_lora = True - # trigger it to get merged in - self.model_config.lora_path = self.model_config.assistant_lora_path - - if self.model_config.lora_path is not None: - print("Fusing in LoRA") - # need the pipe for peft - pipe: FluxPipeline = FluxPipeline( - scheduler=None, - text_encoder=None, - tokenizer=None, - text_encoder_2=None, - tokenizer_2=None, - vae=None, - transformer=transformer, - ) - if self.low_vram: - # we cannot fuse the loras all at once without ooming in lowvram mode, so we have to do it in parts - # we can do it on the cpu but it takes about 5-10 mins vs seconds on the gpu - # we are going to separate it into the two transformer blocks one at a time - - lora_state_dict = load_file(self.model_config.lora_path) - single_transformer_lora = {} - single_block_key = "transformer.single_transformer_blocks." - double_transformer_lora = {} - double_block_key = "transformer.transformer_blocks." - for key, value in lora_state_dict.items(): - if single_block_key in key: - single_transformer_lora[key] = value - elif double_block_key in key: - double_transformer_lora[key] = value - else: - raise ValueError(f"Unknown lora key: {key}. Cannot load this lora in low vram mode") - - # double blocks - transformer.transformer_blocks = transformer.transformer_blocks.to( - torch.device(self.quantize_device), dtype=dtype - ) - pipe.load_lora_weights(double_transformer_lora, adapter_name=f"lora1_double") - pipe.fuse_lora() - pipe.unload_lora_weights() - transformer.transformer_blocks = transformer.transformer_blocks.to( - 'cpu', dtype=dtype - ) - - # single blocks - transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( - torch.device(self.quantize_device), dtype=dtype - ) - pipe.load_lora_weights(single_transformer_lora, adapter_name=f"lora1_single") - pipe.fuse_lora() - pipe.unload_lora_weights() - transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( - 'cpu', dtype=dtype - ) - - # cleanup - del single_transformer_lora - del double_transformer_lora - del lora_state_dict - flush() - - else: - # need the pipe to do this unfortunately for now - # we have to fuse in the weights before quantizing - pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") - pipe.fuse_lora() - # unfortunately, not an easier way with peft - pipe.unload_lora_weights() - flush() - - if self.model_config.quantize: - quantization_type = qfloat8 - print("Quantizing transformer") - quantize(transformer, weights=quantization_type) - freeze(transformer) - transformer.to(self.device_torch) - else: - transformer.to(self.device_torch, dtype=dtype) - - flush() - - scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") - print("Loading vae") - vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) - flush() - - print("Loading t5") - tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) - text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", - torch_dtype=dtype) - - text_encoder_2.to(self.device_torch, dtype=dtype) - flush() - - print("Quantizing T5") - quantize(text_encoder_2, weights=qfloat8) - freeze(text_encoder_2) - flush() - - print("Loading clip") - text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) - tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) - text_encoder.to(self.device_torch, dtype=dtype) - - print("making pipe") - pipe: FluxPipeline = FluxPipeline( - scheduler=scheduler, - text_encoder=text_encoder, - tokenizer=tokenizer, - text_encoder_2=None, - tokenizer_2=tokenizer_2, - vae=vae, - transformer=None, - ) - pipe.text_encoder_2 = text_encoder_2 - pipe.transformer = transformer - - print("preparing") - - text_encoder = [pipe.text_encoder, pipe.text_encoder_2] - tokenizer = [pipe.tokenizer, pipe.tokenizer_2] - - pipe.transformer = pipe.transformer.to(self.device_torch) - - flush() - text_encoder[0].to(self.device_torch) - text_encoder[0].requires_grad_(False) - text_encoder[0].eval() - text_encoder[1].to(self.device_torch) - text_encoder[1].requires_grad_(False) - text_encoder[1].eval() - pipe.transformer = pipe.transformer.to(self.device_torch) - flush() - else: - if self.custom_pipeline is not None: - pipln = self.custom_pipeline - else: - pipln = StableDiffusionPipeline - - if self.model_config.text_encoder_bits < 16: - # this is only supported for T5 models for now - te_kwargs = {} - # handle quantization of TE - te_is_quantized = False - if self.model_config.text_encoder_bits == 8: - te_kwargs['load_in_8bit'] = True - te_kwargs['device_map'] = "auto" - te_is_quantized = True - elif self.model_config.text_encoder_bits == 4: - te_kwargs['load_in_4bit'] = True - te_kwargs['device_map'] = "auto" - te_is_quantized = True - - text_encoder = T5EncoderModel.from_pretrained( - model_path, - subfolder="text_encoder", - torch_dtype=self.te_torch_dtype, - **te_kwargs - ) - # replace the to function with a no-op since it throws an error instead of a warning - text_encoder.to = lambda *args, **kwargs: None - - load_args['text_encoder'] = text_encoder - - # see if path exists - if not os.path.exists(model_path) or os.path.isdir(model_path): - # try to load with default diffusers - pipe = pipln.from_pretrained( - model_path, - dtype=dtype, - device=self.device_torch, - load_safety_checker=False, - requires_safety_checker=False, - safety_checker=None, - # variant="fp16", - trust_remote_code=True, - **load_args - ) - else: - pipe = pipln.from_single_file( - model_path, - dtype=dtype, - device=self.device_torch, - load_safety_checker=False, - requires_safety_checker=False, - torch_dtype=self.torch_dtype, - safety_checker=None, - trust_remote_code=True, - **load_args - ) - flush() - - pipe.register_to_config(requires_safety_checker=False) - text_encoder = pipe.text_encoder - text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) - text_encoder.requires_grad_(False) - text_encoder.eval() - tokenizer = pipe.tokenizer - - # scheduler doesn't get set sometimes, so we set it here - pipe.scheduler = self.noise_scheduler - - # add hacks to unet to help training - # pipe.unet = prepare_unet_for_training(pipe.unet) - - if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: - # pixart and sd3 dont use a unet - self.unet = pipe.transformer - else: - self.unet: 'UNet2DConditionModel' = pipe.unet - self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) - self.vae.eval() - self.vae.requires_grad_(False) - VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) - self.vae_scale_factor = VAE_SCALE_FACTOR - self.unet.to(self.device_torch, dtype=dtype) - self.unet.requires_grad_(False) - self.unet.eval() - - # load any loras we have - if self.model_config.lora_path is not None and not self.is_flux: - pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") - pipe.fuse_lora() - # unfortunately, not an easier way with peft - pipe.unload_lora_weights() - - self.tokenizer = tokenizer - self.text_encoder = text_encoder - self.pipeline = pipe - self.load_refiner() - self.is_loaded = True - - if self.model_config.assistant_lora_path is not None: - print("Loading assistant lora") - self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( - self.model_config.assistant_lora_path, self) - - if self.invert_assistant_lora: - # invert and disable during training - self.assistant_lora.multiplier = -1.0 - self.assistant_lora.is_active = False - - if self.model_config.inference_lora_path is not None: - print("Loading inference lora") - self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( - self.model_config.inference_lora_path, self) - # disable during training - self.assistant_lora.is_active = False - - if self.is_pixart and self.vae_scale_factor == 16: - # TODO make our own pipeline? - # we generate an image 2x larger, so we need to copy the sizes from larger ones down - # ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN - for key in ASPECT_RATIO_256_BIN.keys(): - ASPECT_RATIO_256_BIN[key] = [ASPECT_RATIO_256_BIN[key][0] * 2, ASPECT_RATIO_256_BIN[key][1] * 2] - for key in ASPECT_RATIO_512_BIN.keys(): - ASPECT_RATIO_512_BIN[key] = [ASPECT_RATIO_512_BIN[key][0] * 2, ASPECT_RATIO_512_BIN[key][1] * 2] - for key in ASPECT_RATIO_1024_BIN.keys(): - ASPECT_RATIO_1024_BIN[key] = [ASPECT_RATIO_1024_BIN[key][0] * 2, ASPECT_RATIO_1024_BIN[key][1] * 2] - for key in ASPECT_RATIO_2048_BIN.keys(): - ASPECT_RATIO_2048_BIN[key] = [ASPECT_RATIO_2048_BIN[key][0] * 2, ASPECT_RATIO_2048_BIN[key][1] * 2] - - def te_train(self): - if isinstance(self.text_encoder, list): - for te in self.text_encoder: - te.train() - else: - self.text_encoder.train() - - def te_eval(self): - if isinstance(self.text_encoder, list): - for te in self.text_encoder: - te.eval() - else: - self.text_encoder.eval() - - def load_refiner(self): - # for now, we are just going to rely on the TE from the base model - # which is TE2 for SDXL and TE for SD (no refiner currently) - # and completely ignore a TE that may or may not be packaged with the refiner - if self.model_config.refiner_name_or_path is not None: - refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml') - # load the refiner model - dtype = get_torch_dtype(self.dtype) - model_path = self.model_config.refiner_name_or_path - if not os.path.exists(model_path) or os.path.isdir(model_path): - # TODO only load unet?? - refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( - model_path, - dtype=dtype, - device=self.device_torch, - # variant="fp16", - use_safetensors=True, - ).to(self.device_torch) - else: - refiner = StableDiffusionXLImg2ImgPipeline.from_single_file( - model_path, - dtype=dtype, - device=self.device_torch, - torch_dtype=self.torch_dtype, - original_config_file=refiner_config_path, - ).to(self.device_torch) - - self.refiner_unet = refiner.unet - del refiner - flush() - - @torch.no_grad() - def generate_images( - self, - image_configs: List[GenerateImageConfig], - sampler=None, - pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None, - ): - merge_multiplier = 1.0 - flush() - # if using assistant, unfuse it - if self.model_config.assistant_lora_path is not None: - print("Unloading assistant lora") - if self.invert_assistant_lora: - self.assistant_lora.is_active = True - # move weights on to the device - self.assistant_lora.force_to(self.device_torch, self.torch_dtype) - else: - self.assistant_lora.is_active = False - - if self.model_config.inference_lora_path is not None: - print("Loading inference lora") - self.assistant_lora.is_active = True - # move weights on to the device - self.assistant_lora.force_to(self.device_torch, self.torch_dtype) - - if self.network is not None: - self.network.eval() - network = self.network - # check if we have the same network weight for all samples. If we do, we can merge in th - # the network to drastically speed up inference - unique_network_weights = set([x.network_multiplier for x in image_configs]) - if len(unique_network_weights) == 1 and self.network.can_merge_in: - can_merge_in = True - merge_multiplier = unique_network_weights.pop() - network.merge_in(merge_weight=merge_multiplier) - else: - network = BlankNetwork() - - self.save_device_state() - self.set_device_state_preset('generate') - - # save current seed state for training - rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None - - if pipeline is None: - noise_scheduler = self.noise_scheduler - if sampler is not None: - if sampler.startswith("sample_"): # sample_dpmpp_2m - # using ksampler - noise_scheduler = get_sampler( - 'lms', { - "prediction_type": self.prediction_type, - }) - else: - noise_scheduler = get_sampler( - sampler, - { - "prediction_type": self.prediction_type, - }, - 'sd' if not self.is_pixart else 'pixart' - ) - - try: - noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype) - except: - pass - - if sampler.startswith("sample_") and self.is_xl: - # using kdiffusion - Pipe = StableDiffusionKDiffusionXLPipeline - elif self.is_xl: - Pipe = StableDiffusionXLPipeline - elif self.is_v3: - Pipe = StableDiffusion3Pipeline - else: - Pipe = StableDiffusionPipeline - - extra_args = {} - if self.adapter is not None: - if isinstance(self.adapter, T2IAdapter): - if self.is_xl: - Pipe = StableDiffusionXLAdapterPipeline - else: - Pipe = StableDiffusionAdapterPipeline - extra_args['adapter'] = self.adapter - elif isinstance(self.adapter, ControlNetModel): - if self.is_xl: - Pipe = StableDiffusionXLControlNetPipeline - else: - Pipe = StableDiffusionControlNetPipeline - extra_args['controlnet'] = self.adapter - elif isinstance(self.adapter, ReferenceAdapter): - # pass the noise scheduler to the adapter - self.adapter.noise_scheduler = noise_scheduler - else: - if self.is_xl: - extra_args['add_watermarker'] = False - - # TODO add clip skip - if self.is_xl: - pipeline = Pipe( - vae=self.vae, - unet=self.unet, - text_encoder=self.text_encoder[0], - text_encoder_2=self.text_encoder[1], - tokenizer=self.tokenizer[0], - tokenizer_2=self.tokenizer[1], - scheduler=noise_scheduler, - **extra_args - ).to(self.device_torch) - pipeline.watermark = None - elif self.is_flux: - if self.model_config.use_flux_cfg: - pipeline = FluxWithCFGPipeline( - vae=self.vae, - transformer=self.unet, - text_encoder=self.text_encoder[0], - text_encoder_2=self.text_encoder[1], - tokenizer=self.tokenizer[0], - tokenizer_2=self.tokenizer[1], - scheduler=noise_scheduler, - **extra_args - ) - - else: - pipeline = FluxPipeline( - vae=self.vae, - transformer=self.unet, - text_encoder=self.text_encoder[0], - text_encoder_2=self.text_encoder[1], - tokenizer=self.tokenizer[0], - tokenizer_2=self.tokenizer[1], - scheduler=noise_scheduler, - **extra_args - ) - pipeline.watermark = None - elif self.is_v3: - pipeline = Pipe( - vae=self.vae, - transformer=self.unet, - text_encoder=self.text_encoder[0], - text_encoder_2=self.text_encoder[1], - text_encoder_3=self.text_encoder[2], - tokenizer=self.tokenizer[0], - tokenizer_2=self.tokenizer[1], - tokenizer_3=self.tokenizer[2], - scheduler=noise_scheduler, - **extra_args - ) - elif self.is_pixart: - pipeline = PixArtSigmaPipeline( - vae=self.vae, - transformer=self.unet, - text_encoder=self.text_encoder, - tokenizer=self.tokenizer, - scheduler=noise_scheduler, - **extra_args - ) - - elif self.is_auraflow: - pipeline = AuraFlowPipeline( - vae=self.vae, - transformer=self.unet, - text_encoder=self.text_encoder, - tokenizer=self.tokenizer, - scheduler=noise_scheduler, - **extra_args - ) - - else: - pipeline = Pipe( - vae=self.vae, - unet=self.unet, - text_encoder=self.text_encoder, - tokenizer=self.tokenizer, - scheduler=noise_scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - **extra_args - ) - flush() - # disable progress bar - pipeline.set_progress_bar_config(disable=True) - - if sampler.startswith("sample_"): - pipeline.set_scheduler(sampler) - - refiner_pipeline = None - if self.refiner_unet: - # build refiner pipeline - refiner_pipeline = StableDiffusionXLImg2ImgPipeline( - vae=pipeline.vae, - unet=self.refiner_unet, - text_encoder=None, - text_encoder_2=pipeline.text_encoder_2, - tokenizer=None, - tokenizer_2=pipeline.tokenizer_2, - scheduler=pipeline.scheduler, - add_watermarker=False, - requires_aesthetics_score=True, - ).to(self.device_torch) - # refiner_pipeline.register_to_config(requires_aesthetics_score=False) - refiner_pipeline.watermark = None - refiner_pipeline.set_progress_bar_config(disable=True) - flush() - - start_multiplier = 1.0 - if self.network is not None: - start_multiplier = self.network.multiplier - - # pipeline.to(self.device_torch) - - with network: - with torch.no_grad(): - if self.network is not None: - assert self.network.is_active - - for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): - gen_config = image_configs[i] - - extra = {} - validation_image = None - if self.adapter is not None and gen_config.adapter_image_path is not None: - validation_image = Image.open(gen_config.adapter_image_path).convert("RGB") - if isinstance(self.adapter, T2IAdapter): - # not sure why this is double?? - validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2)) - extra['image'] = validation_image - extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale - if isinstance(self.adapter, ControlNetModel): - validation_image = validation_image.resize((gen_config.width, gen_config.height)) - extra['image'] = validation_image - extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale - if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): - transform = transforms.Compose([ - transforms.ToTensor(), - ]) - validation_image = transform(validation_image) - if isinstance(self.adapter, CustomAdapter): - # todo allow loading multiple - transform = transforms.Compose([ - transforms.ToTensor(), - ]) - validation_image = transform(validation_image) - self.adapter.num_images = 1 - if isinstance(self.adapter, ReferenceAdapter): - # need -1 to 1 - validation_image = transforms.ToTensor()(validation_image) - validation_image = validation_image * 2.0 - 1.0 - validation_image = validation_image.unsqueeze(0) - self.adapter.set_reference_images(validation_image) - - if self.network is not None: - self.network.multiplier = gen_config.network_multiplier - torch.manual_seed(gen_config.seed) - torch.cuda.manual_seed(gen_config.seed) - - if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \ - and gen_config.adapter_image_path is not None: - # run through the adapter to saturate the embeds - conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) - self.adapter(conditional_clip_embeds) - - if self.adapter is not None and isinstance(self.adapter, CustomAdapter): - # handle condition the prompts - gen_config.prompt = self.adapter.condition_prompt( - gen_config.prompt, - is_unconditional=False, - ) - gen_config.prompt_2 = gen_config.prompt - gen_config.negative_prompt = self.adapter.condition_prompt( - gen_config.negative_prompt, - is_unconditional=True, - ) - gen_config.negative_prompt_2 = gen_config.negative_prompt - - if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: - self.adapter.trigger_pre_te( - tensors_0_1=validation_image, - is_training=False, - has_been_preprocessed=False, - quad_count=4 - ) - - # encode the prompt ourselves so we can do fun stuff with embeddings - if isinstance(self.adapter, CustomAdapter): - self.adapter.is_unconditional_run = False - conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True) - - if isinstance(self.adapter, CustomAdapter): - self.adapter.is_unconditional_run = True - unconditional_embeds = self.encode_prompt( - gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True - ) - if isinstance(self.adapter, CustomAdapter): - self.adapter.is_unconditional_run = False - - # allow any manipulations to take place to embeddings - gen_config.post_process_embeddings( - conditional_embeds, - unconditional_embeds, - ) - - if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ - and gen_config.adapter_image_path is not None: - # apply the image projection - conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) - unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, - True) - conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False) - unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) - - if self.adapter is not None and isinstance(self.adapter, - CustomAdapter) and validation_image is not None: - conditional_embeds = self.adapter.condition_encoded_embeds( - tensors_0_1=validation_image, - prompt_embeds=conditional_embeds, - is_training=False, - has_been_preprocessed=False, - is_generating_samples=True, - ) - unconditional_embeds = self.adapter.condition_encoded_embeds( - tensors_0_1=validation_image, - prompt_embeds=unconditional_embeds, - is_training=False, - has_been_preprocessed=False, - is_unconditional=True, - is_generating_samples=True, - ) - - if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len( - gen_config.extra_values) > 0: - extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, - dtype=self.torch_dtype) - # apply extra values to the embeddings - self.adapter.add_extra_values(extra_values, is_unconditional=False) - self.adapter.add_extra_values(torch.zeros_like(extra_values), is_unconditional=True) - pass # todo remove, for debugging - - if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: - # if we have a refiner loaded, set the denoising end at the refiner start - extra['denoising_end'] = gen_config.refiner_start_at - extra['output_type'] = 'latent' - if not self.is_xl: - raise ValueError("Refiner is only supported for XL models") - - conditional_embeds = conditional_embeds.to(self.device_torch, dtype=self.unet.dtype) - unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=self.unet.dtype) - - if self.is_xl: - # fix guidance rescale for sdxl - # was trained on 0.7 (I believe) - - grs = gen_config.guidance_rescale - # if grs is None or grs < 0.00001: - # grs = 0.7 - # grs = 0.0 - - if sampler.startswith("sample_"): - extra['use_karras_sigmas'] = True - extra = { - **extra, - **gen_config.extra_kwargs, - } - - img = pipeline( - # prompt=gen_config.prompt, - # prompt_2=gen_config.prompt_2, - prompt_embeds=conditional_embeds.text_embeds, - pooled_prompt_embeds=conditional_embeds.pooled_embeds, - negative_prompt_embeds=unconditional_embeds.text_embeds, - negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, - # negative_prompt=gen_config.negative_prompt, - # negative_prompt_2=gen_config.negative_prompt_2, - height=gen_config.height, - width=gen_config.width, - num_inference_steps=gen_config.num_inference_steps, - guidance_scale=gen_config.guidance_scale, - guidance_rescale=grs, - latents=gen_config.latents, - **extra - ).images[0] - elif self.is_v3: - img = pipeline( - prompt_embeds=conditional_embeds.text_embeds, - pooled_prompt_embeds=conditional_embeds.pooled_embeds, - negative_prompt_embeds=unconditional_embeds.text_embeds, - negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, - height=gen_config.height, - width=gen_config.width, - num_inference_steps=gen_config.num_inference_steps, - guidance_scale=gen_config.guidance_scale, - latents=gen_config.latents, - **extra - ).images[0] - elif self.is_flux: - if self.model_config.use_flux_cfg: - img = pipeline( - prompt_embeds=conditional_embeds.text_embeds, - pooled_prompt_embeds=conditional_embeds.pooled_embeds, - negative_prompt_embeds=unconditional_embeds.text_embeds, - negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, - height=gen_config.height, - width=gen_config.width, - num_inference_steps=gen_config.num_inference_steps, - guidance_scale=gen_config.guidance_scale, - latents=gen_config.latents, - **extra - ).images[0] - else: - img = pipeline( - prompt_embeds=conditional_embeds.text_embeds, - pooled_prompt_embeds=conditional_embeds.pooled_embeds, - # negative_prompt_embeds=unconditional_embeds.text_embeds, - # negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, - height=gen_config.height, - width=gen_config.width, - num_inference_steps=gen_config.num_inference_steps, - guidance_scale=gen_config.guidance_scale, - latents=gen_config.latents, - **extra - ).images[0] - elif self.is_pixart: - # needs attention masks for some reason - img = pipeline( - prompt=None, - prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), - prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, - dtype=self.unet.dtype), - negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, - dtype=self.unet.dtype), - negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, - dtype=self.unet.dtype), - negative_prompt=None, - # negative_prompt=gen_config.negative_prompt, - height=gen_config.height, - width=gen_config.width, - num_inference_steps=gen_config.num_inference_steps, - guidance_scale=gen_config.guidance_scale, - latents=gen_config.latents, - **extra - ).images[0] - elif self.is_auraflow: - pipeline: AuraFlowPipeline = pipeline - - img = pipeline( - prompt=None, - prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), - prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, - dtype=self.unet.dtype), - negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, - dtype=self.unet.dtype), - negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, - dtype=self.unet.dtype), - negative_prompt=None, - # negative_prompt=gen_config.negative_prompt, - height=gen_config.height, - width=gen_config.width, - num_inference_steps=gen_config.num_inference_steps, - guidance_scale=gen_config.guidance_scale, - latents=gen_config.latents, - **extra - ).images[0] - else: - img = pipeline( - # prompt=gen_config.prompt, - prompt_embeds=conditional_embeds.text_embeds, - negative_prompt_embeds=unconditional_embeds.text_embeds, - # negative_prompt=gen_config.negative_prompt, - height=gen_config.height, - width=gen_config.width, - num_inference_steps=gen_config.num_inference_steps, - guidance_scale=gen_config.guidance_scale, - latents=gen_config.latents, - **extra - ).images[0] - - if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: - # slide off just the last 1280 on the last dim as refiner does not use first text encoder - # todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ - refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:] - refiner_unconditional_text_embeds = unconditional_embeds.text_embeds[:, :, -1280:] - # run through refiner - img = refiner_pipeline( - # prompt=gen_config.prompt, - # prompt_2=gen_config.prompt_2, - - # slice these as it does not use both text encoders - # height=gen_config.height, - # width=gen_config.width, - prompt_embeds=refiner_text_embeds, - pooled_prompt_embeds=conditional_embeds.pooled_embeds, - negative_prompt_embeds=refiner_unconditional_text_embeds, - negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, - num_inference_steps=gen_config.num_inference_steps, - guidance_scale=gen_config.guidance_scale, - guidance_rescale=grs, - denoising_start=gen_config.refiner_start_at, - denoising_end=gen_config.num_inference_steps, - image=img.unsqueeze(0) - ).images[0] - - gen_config.save_image(img, i) - gen_config.log_image(img, i) - - if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): - self.adapter.clear_memory() - - # clear pipeline and cache to reduce vram usage - del pipeline - if refiner_pipeline is not None: - del refiner_pipeline - torch.cuda.empty_cache() - - # restore training state - torch.set_rng_state(rng_state) - if cuda_rng_state is not None: - torch.cuda.set_rng_state(cuda_rng_state) - - self.restore_device_state() - if self.network is not None: - self.network.train() - self.network.multiplier = start_multiplier - - self.unet.to(self.device_torch, dtype=self.torch_dtype) - if network.is_merged_in: - network.merge_out(merge_multiplier) - # self.tokenizer.to(original_device_dict['tokenizer']) - - # refuse loras - if self.model_config.assistant_lora_path is not None: - print("Loading assistant lora") - if self.invert_assistant_lora: - self.assistant_lora.is_active = False - # move weights off the device - self.assistant_lora.force_to('cpu', self.torch_dtype) - else: - self.assistant_lora.is_active = True - - if self.model_config.inference_lora_path is not None: - print("Unloading inference lora") - self.assistant_lora.is_active = False - # move weights off the device - self.assistant_lora.force_to('cpu', self.torch_dtype) - - flush() - - def get_latent_noise( - self, - height=None, - width=None, - pixel_height=None, - pixel_width=None, - batch_size=1, - noise_offset=0.0, - ): - VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) - if height is None and pixel_height is None: - raise ValueError("height or pixel_height must be specified") - if width is None and pixel_width is None: - raise ValueError("width or pixel_width must be specified") - if height is None: - height = pixel_height // VAE_SCALE_FACTOR - if width is None: - width = pixel_width // VAE_SCALE_FACTOR - - num_channels = self.unet.config['in_channels'] - if self.is_flux: - # has 64 channels in for some reason - num_channels = 16 - noise = torch.randn( - ( - batch_size, - num_channels, - height, - width, - ), - device=self.unet.device, - ) - noise = apply_noise_offset(noise, noise_offset) - return noise - - def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False): - VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) - if self.is_xl: - bs, ch, h, w = list(latents.shape) - - height = h * VAE_SCALE_FACTOR - width = w * VAE_SCALE_FACTOR - - dtype = latents.dtype - # just do it without any cropping nonsense - target_size = (height, width) - original_size = (height, width) - crops_coords_top_left = (0, 0) - if requires_aesthetic_score: - # refiner - # https://huggingface.co/papers/2307.01952 - aesthetic_score = 6.0 # simulate one - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids]) - add_time_ids = add_time_ids.to(latents.device, dtype=dtype) - - batch_time_ids = torch.cat( - [add_time_ids for _ in range(bs)] - ) - return batch_time_ids - else: - return None - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor - ) -> torch.FloatTensor: - original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0) - noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) - timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) - - if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): - timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) - - noisy_latents_chunks = [] - - for idx in range(original_samples.shape[0]): - noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], - timesteps_chunks[idx]) - noisy_latents_chunks.append(noisy_latents) - - noisy_latents = torch.cat(noisy_latents_chunks, dim=0) - return noisy_latents - - def predict_noise( - self, - latents: torch.Tensor, - text_embeddings: Union[PromptEmbeds, None] = None, - timestep: Union[int, torch.Tensor] = 1, - guidance_scale=7.5, - guidance_rescale=0, - add_time_ids=None, - conditional_embeddings: Union[PromptEmbeds, None] = None, - unconditional_embeddings: Union[PromptEmbeds, None] = None, - is_input_scaled=False, - detach_unconditional=False, - rescale_cfg=None, - return_conditional_pred=False, - guidance_embedding_scale=1.0, - **kwargs, - ): - conditional_pred = None - # get the embeddings - if text_embeddings is None and conditional_embeddings is None: - raise ValueError("Either text_embeddings or conditional_embeddings must be specified") - if text_embeddings is None and unconditional_embeddings is not None: - text_embeddings = concat_prompt_embeds([ - unconditional_embeddings, # negative embedding - conditional_embeddings, # positive embedding - ]) - elif text_embeddings is None and conditional_embeddings is not None: - # not doing cfg - text_embeddings = conditional_embeddings - - # CFG is comparing neg and positive, if we have concatenated embeddings - # then we are doing it, otherwise we are not and takes half the time. - do_classifier_free_guidance = True - - # check if batch size of embeddings matches batch size of latents - if latents.shape[0] == text_embeddings.text_embeds.shape[0]: - do_classifier_free_guidance = False - elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]: - raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings") - latents = latents.to(self.device_torch) - text_embeddings = text_embeddings.to(self.device_torch) - timestep = timestep.to(self.device_torch) - - # if timestep is zero dim, unsqueeze it - if len(timestep.shape) == 0: - timestep = timestep.unsqueeze(0) - - # if we only have 1 timestep, we can just use the same timestep for all - if timestep.shape[0] == 1 and latents.shape[0] > 1: - # check if it is rank 1 or 2 - if len(timestep.shape) == 1: - timestep = timestep.repeat(latents.shape[0]) - else: - timestep = timestep.repeat(latents.shape[0], 0) - - # handle t2i adapters - if 'down_intrablock_additional_residuals' in kwargs: - # go through each item and concat if doing cfg and it doesnt have the same shape - for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']): - if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: - kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0) - - # handle controlnet - if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs: - # go through each item and concat if doing cfg and it doesnt have the same shape - for idx, item in enumerate(kwargs['down_block_additional_residuals']): - if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: - kwargs['down_block_additional_residuals'][idx] = torch.cat([item] * 2, dim=0) - for idx, item in enumerate(kwargs['mid_block_additional_residual']): - if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: - kwargs['mid_block_additional_residual'][idx] = torch.cat([item] * 2, dim=0) - - def scale_model_input(model_input, timestep_tensor): - if is_input_scaled: - return model_input - mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) - timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) - out_chunks = [] - # unsqueeze if timestep is zero dim - for idx in range(model_input.shape[0]): - # if scheduler has step_index - if hasattr(self.noise_scheduler, '_step_index'): - self.noise_scheduler._step_index = None - out_chunks.append( - self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_chunks[idx]) - ) - return torch.cat(out_chunks, dim=0) - - if self.is_xl: - with torch.no_grad(): - # 16, 6 for bs of 4 - if add_time_ids is None: - add_time_ids = self.get_time_ids_from_latents(latents) - - if do_classifier_free_guidance: - # todo check this with larget batches - add_time_ids = torch.cat([add_time_ids] * 2) - - if do_classifier_free_guidance: - latent_model_input = torch.cat([latents] * 2) - timestep = torch.cat([timestep] * 2) - else: - latent_model_input = latents - - latent_model_input = scale_model_input(latent_model_input, timestep) - - added_cond_kwargs = { - # todo can we zero here the second text encoder? or match a blank string? - "text_embeds": text_embeddings.pooled_embeds, - "time_ids": add_time_ids, - } - - if self.model_config.refiner_name_or_path is not None: - # we have the refiner on the second half of everything. Do Both - if do_classifier_free_guidance: - raise ValueError("Refiner is not supported with classifier free guidance") - - if self.unet.training: - input_chunks = torch.chunk(latent_model_input, 2, dim=0) - timestep_chunks = torch.chunk(timestep, 2, dim=0) - added_cond_kwargs_chunked = { - "text_embeds": torch.chunk(text_embeddings.pooled_embeds, 2, dim=0), - "time_ids": torch.chunk(add_time_ids, 2, dim=0), - } - text_embeds_chunks = torch.chunk(text_embeddings.text_embeds, 2, dim=0) - - # predict the noise residual - base_pred = self.unet( - input_chunks[0], - timestep_chunks[0], - encoder_hidden_states=text_embeds_chunks[0], - added_cond_kwargs={ - "text_embeds": added_cond_kwargs_chunked['text_embeds'][0], - "time_ids": added_cond_kwargs_chunked['time_ids'][0], - }, - **kwargs, - ).sample - - refiner_pred = self.refiner_unet( - input_chunks[1], - timestep_chunks[1], - encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], - # just use the first second text encoder - added_cond_kwargs={ - "text_embeds": added_cond_kwargs_chunked['text_embeds'][1], - # "time_ids": added_cond_kwargs_chunked['time_ids'][1], - "time_ids": self.get_time_ids_from_latents(input_chunks[1], requires_aesthetic_score=True), - }, - **kwargs, - ).sample - - noise_pred = torch.cat([base_pred, refiner_pred], dim=0) - else: - noise_pred = self.refiner_unet( - latent_model_input, - timestep, - encoder_hidden_states=text_embeddings.text_embeds[:, :, -1280:], - # just use the first second text encoder - added_cond_kwargs={ - "text_embeds": text_embeddings.pooled_embeds, - "time_ids": self.get_time_ids_from_latents(latent_model_input, - requires_aesthetic_score=True), - }, - **kwargs, - ).sample - - else: - - # predict the noise residual - noise_pred = self.unet( - latent_model_input.to(self.device_torch, self.torch_dtype), - timestep, - encoder_hidden_states=text_embeddings.text_embeds, - added_cond_kwargs=added_cond_kwargs, - **kwargs, - ).sample - - conditional_pred = noise_pred - - if do_classifier_free_guidance: - # perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - conditional_pred = noise_pred_text - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - - # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 - if guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - - else: - with torch.no_grad(): - if do_classifier_free_guidance: - # if we are doing classifier free guidance, need to double up - latent_model_input = torch.cat([latents] * 2, dim=0) - timestep = torch.cat([timestep] * 2) - else: - latent_model_input = latents - - latent_model_input = scale_model_input(latent_model_input, timestep) - - # check if we need to concat timesteps - if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1: - ts_bs = timestep.shape[0] - if ts_bs != latent_model_input.shape[0]: - if ts_bs == 1: - timestep = torch.cat([timestep] * latent_model_input.shape[0]) - elif ts_bs * 2 == latent_model_input.shape[0]: - timestep = torch.cat([timestep] * 2, dim=0) - else: - raise ValueError( - f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") - - # predict the noise residual - if self.is_pixart: - VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) - batch_size, ch, h, w = list(latents.shape) - - height = h * VAE_SCALE_FACTOR - width = w * VAE_SCALE_FACTOR - - if self.pipeline.transformer.config.sample_size == 256: - aspect_ratio_bin = ASPECT_RATIO_2048_BIN - elif self.pipeline.transformer.config.sample_size == 128: - aspect_ratio_bin = ASPECT_RATIO_1024_BIN - elif self.pipeline.transformer.config.sample_size == 64: - aspect_ratio_bin = ASPECT_RATIO_512_BIN - elif self.pipeline.transformer.config.sample_size == 32: - aspect_ratio_bin = ASPECT_RATIO_256_BIN - else: - raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}") - orig_height, orig_width = height, width - height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, - ratios=aspect_ratio_bin) - - added_cond_kwargs = {"resolution": None, "aspect_ratio": None} - if self.unet.config.sample_size == 128 or ( - self.vae_scale_factor == 16 and self.unet.config.sample_size == 64): - resolution = torch.tensor([height, width]).repeat(batch_size, 1) - aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1) - resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) - aspect_ratio = aspect_ratio.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) - - if do_classifier_free_guidance: - resolution = torch.cat([resolution, resolution], dim=0) - aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) - - added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} - - noise_pred = self.unet( - latent_model_input.to(self.device_torch, self.torch_dtype), - encoder_hidden_states=text_embeddings.text_embeds, - encoder_attention_mask=text_embeddings.attention_mask, - timestep=timestep, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - **kwargs - )[0] - - # learned sigma - if self.unet.config.out_channels // 2 == self.unet.config.in_channels: - noise_pred = noise_pred.chunk(2, dim=1)[0] - else: - noise_pred = noise_pred - else: - if self.unet.device != self.device_torch: - self.unet.to(self.device_torch) - if self.unet.dtype != self.torch_dtype: - self.unet = self.unet.to(dtype=self.torch_dtype) - if self.is_flux: - with torch.no_grad(): - - bs, c, h, w = latent_model_input.shape - latent_model_input_packed = rearrange( - latent_model_input, - "b c (h ph) (w pw) -> b (h w) (c ph pw)", - ph=2, - pw=2 - ) - - img_ids = torch.zeros(h // 2, w // 2, 3) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs).to(self.device_torch) - - txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) - - # # handle guidance - if self.unet.config.guidance_embeds: - if isinstance(guidance_scale, list): - guidance = torch.tensor(guidance_scale, device=self.device_torch) - else: - guidance = torch.tensor([guidance_scale], device=self.device_torch) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None - - cast_dtype = self.unet.dtype - # with torch.amp.autocast(device_type='cuda', dtype=cast_dtype): - noise_pred = self.unet( - hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64] - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) - # todo make sure this doesnt change - timestep=timestep / 1000, # timestep is 1000 scale - encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), - # [1, 512, 4096] - pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768] - txt_ids=txt_ids, # [1, 512, 3] - img_ids=img_ids, # [1, 4096, 3] - guidance=guidance, - return_dict=False, - **kwargs, - )[0] - - if isinstance(noise_pred, QTensor): - noise_pred = noise_pred.dequantize() - - noise_pred = rearrange( - noise_pred, - "b (h w) (c ph pw) -> b c (h ph) (w pw)", - h=latent_model_input.shape[2] // 2, - w=latent_model_input.shape[3] // 2, - ph=2, - pw=2, - c=latent_model_input.shape[1], - ) - elif self.is_v3: - noise_pred = self.unet( - hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), - timestep=timestep, - encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), - pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), - **kwargs, - ).sample - if isinstance(noise_pred, QTensor): - noise_pred = noise_pred.dequantize() - elif self.is_auraflow: - # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - t = torch.tensor([timestep / 1000]).expand(latent_model_input.shape[0]) - t = t.to(self.device_torch, self.torch_dtype) - - noise_pred = self.unet( - latent_model_input, - encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), - timestep=t, - return_dict=False, - )[0] - else: - noise_pred = self.unet( - latent_model_input.to(self.device_torch, self.torch_dtype), - timestep=timestep, - encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), - **kwargs, - ).sample - - conditional_pred = noise_pred - - if do_classifier_free_guidance: - # perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0) - conditional_pred = noise_pred_text - if detach_unconditional: - noise_pred_uncond = noise_pred_uncond.detach() - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - if rescale_cfg is not None and rescale_cfg != guidance_scale: - with torch.no_grad(): - # do cfg at the target rescale so we can match it - target_pred_mean_std = noise_pred_uncond + rescale_cfg * ( - noise_pred_text - noise_pred_uncond - ) - target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach() - target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach() - - pred_mean = noise_pred.mean([1, 2, 3], keepdim=True).detach() - pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach() - - # match the mean and std - noise_pred = (noise_pred - pred_mean) / pred_std - noise_pred = (noise_pred * target_std) + target_mean - - # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 - if guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - - if return_conditional_pred: - return noise_pred, conditional_pred - return noise_pred - - def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None): - if noise_scheduler is None: - noise_scheduler = self.noise_scheduler - # // sometimes they are on the wrong device, no idea why - if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler): - try: - noise_scheduler.betas = noise_scheduler.betas.to(self.device_torch) - noise_scheduler.alphas = noise_scheduler.alphas.to(self.device_torch) - noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(self.device_torch) - except Exception as e: - pass - - mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) - latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0) - timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) - out_chunks = [] - if len(timestep_chunks) == 1 and len(mi_chunks) > 1: - # expand timestep to match - timestep_chunks = timestep_chunks * len(mi_chunks) - - for idx in range(model_input.shape[0]): - # Reset it so it is unique for the - if hasattr(noise_scheduler, '_step_index'): - noise_scheduler._step_index = None - if hasattr(noise_scheduler, 'is_scale_input_called'): - noise_scheduler.is_scale_input_called = True - out_chunks.append( - noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[ - 0] - ) - return torch.cat(out_chunks, dim=0) - - # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 - def diffuse_some_steps( - self, - latents: torch.FloatTensor, - text_embeddings: PromptEmbeds, - total_timesteps: int = 1000, - start_timesteps=0, - guidance_scale=1, - add_time_ids=None, - bleed_ratio: float = 0.5, - bleed_latents: torch.FloatTensor = None, - is_input_scaled=False, - return_first_prediction=False, - **kwargs, - ): - timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps] - - first_prediction = None - - for timestep in tqdm(timesteps_to_run, leave=False): - timestep = timestep.unsqueeze_(0) - noise_pred, conditional_pred = self.predict_noise( - latents, - text_embeddings, - timestep, - guidance_scale=guidance_scale, - add_time_ids=add_time_ids, - is_input_scaled=is_input_scaled, - return_conditional_pred=True, - **kwargs, - ) - # some schedulers need to run separately, so do that. (euler for example) - - if return_first_prediction and first_prediction is None: - first_prediction = conditional_pred - - latents = self.step_scheduler(noise_pred, latents, timestep) - - # if not last step, and bleeding, bleed in some latents - if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]: - latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio) - - # only skip first scaling - is_input_scaled = False - - # return latents_steps - if return_first_prediction: - return latents, first_prediction - return latents - - def encode_prompt( - self, - prompt, - prompt2=None, - num_images_per_prompt=1, - force_all=False, - long_prompts=False, - max_length=None, - dropout_prob=0.0, - ) -> PromptEmbeds: - # sd1.5 embeddings are (bs, 77, 768) - prompt = prompt - # if it is not a list, make it one - if not isinstance(prompt, list): - prompt = [prompt] - - if prompt2 is not None and not isinstance(prompt2, list): - prompt2 = [prompt2] - if self.is_xl: - # todo make this a config - # 50% chance to use an encoder anyway even if it is disabled - # allows the other TE to compensate for the disabled one - # use_encoder_1 = self.use_text_encoder_1 or force_all or random.random() > 0.5 - # use_encoder_2 = self.use_text_encoder_2 or force_all or random.random() > 0.5 - use_encoder_1 = True - use_encoder_2 = True - - return PromptEmbeds( - train_tools.encode_prompts_xl( - self.tokenizer, - self.text_encoder, - prompt, - prompt2, - num_images_per_prompt=num_images_per_prompt, - use_text_encoder_1=use_encoder_1, - use_text_encoder_2=use_encoder_2, - truncate=not long_prompts, - max_length=max_length, - dropout_prob=dropout_prob, - ) - ) - if self.is_v3: - return PromptEmbeds( - train_tools.encode_prompts_sd3( - self.tokenizer, - self.text_encoder, - prompt, - num_images_per_prompt=num_images_per_prompt, - truncate=not long_prompts, - max_length=max_length, - dropout_prob=dropout_prob, - pipeline=self.pipeline, - ) - ) - elif self.is_pixart: - embeds, attention_mask = train_tools.encode_prompts_pixart( - self.tokenizer, - self.text_encoder, - prompt, - truncate=not long_prompts, - max_length=300 if self.model_config.is_pixart_sigma else 120, - dropout_prob=dropout_prob - ) - return PromptEmbeds( - embeds, - attention_mask=attention_mask, - ) - elif self.is_auraflow: - embeds, attention_mask = train_tools.encode_prompts_auraflow( - self.tokenizer, - self.text_encoder, - prompt, - truncate=not long_prompts, - max_length=256, - dropout_prob=dropout_prob - ) - return PromptEmbeds( - embeds, - attention_mask=attention_mask, # not used - ) - elif self.is_flux: - prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux( - self.tokenizer, # list - self.text_encoder, # list - prompt, - truncate=not long_prompts, - max_length=512, - dropout_prob=dropout_prob, - attn_mask=self.model_config.attn_masking - ) - pe = PromptEmbeds( - prompt_embeds - ) - pe.pooled_embeds = pooled_prompt_embeds - return pe - - - elif isinstance(self.text_encoder, T5EncoderModel): - embeds, attention_mask = train_tools.encode_prompts_pixart( - self.tokenizer, - self.text_encoder, - prompt, - truncate=not long_prompts, - max_length=256, - dropout_prob=dropout_prob - ) - - # just mask the attention mask - prompt_attention_mask = attention_mask.unsqueeze(-1).expand(embeds.shape) - embeds = embeds * prompt_attention_mask.to(dtype=embeds.dtype, device=embeds.device) - return PromptEmbeds( - embeds, - - # do we want attn mask here? - # attention_mask=attention_mask, - ) - else: - return PromptEmbeds( - train_tools.encode_prompts( - self.tokenizer, - self.text_encoder, - prompt, - truncate=not long_prompts, - max_length=max_length, - dropout_prob=dropout_prob - ) - ) - - @torch.no_grad() - def encode_images( - self, - image_list: List[torch.Tensor], - device=None, - dtype=None - ): - if device is None: - device = self.vae_device_torch - if dtype is None: - dtype = self.vae_torch_dtype - - latent_list = [] - # Move to vae to device if on cpu - if self.vae.device == 'cpu': - self.vae.to(device) - self.vae.eval() - self.vae.requires_grad_(False) - # move to device and dtype - image_list = [image.to(device, dtype=dtype) for image in image_list] - - VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) - - # resize images if not divisible by 8 - for i in range(len(image_list)): - image = image_list[i] - if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0: - image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, - image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image) - - images = torch.stack(image_list) - if isinstance(self.vae, AutoencoderTiny): - latents = self.vae.encode(images, return_dict=False)[0] - else: - latents = self.vae.encode(images).latent_dist.sample() - shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 - - # flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303 - # z = self.scale_factor * (z - self.shift_factor) - latents = self.vae.config['scaling_factor'] * (latents - shift) - latents = latents.to(device, dtype=dtype) - - return latents - - def decode_latents( - self, - latents: torch.Tensor, - device=None, - dtype=None - ): - if device is None: - device = self.device - if dtype is None: - dtype = self.torch_dtype - - # Move to vae to device if on cpu - if self.vae.device == 'cpu': - self.vae.to(self.device) - latents = latents.to(device, dtype=dtype) - latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] - images = self.vae.decode(latents).sample - images = images.to(device, dtype=dtype) - - return images - - def encode_image_prompt_pairs( - self, - prompt_list: List[str], - image_list: List[torch.Tensor], - device=None, - dtype=None - ): - # todo check image types and expand and rescale as needed - # device and dtype are for outputs - if device is None: - device = self.device - if dtype is None: - dtype = self.torch_dtype - - embedding_list = [] - latent_list = [] - # embed the prompts - for prompt in prompt_list: - embedding = self.encode_prompt(prompt).to(self.device_torch, dtype=dtype) - embedding_list.append(embedding) - - return embedding_list, latent_list - - def get_weight_by_name(self, name): - # weights begin with te{te_num}_ for text encoder - # weights begin with unet_ for unet_ - if name.startswith('te'): - key = name[4:] - # text encoder - te_num = int(name[2]) - if isinstance(self.text_encoder, list): - return self.text_encoder[te_num].state_dict()[key] - else: - return self.text_encoder.state_dict()[key] - elif name.startswith('unet'): - key = name[5:] - # unet - return self.unet.state_dict()[key] - - raise ValueError(f"Unknown weight name: {name}") - - def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=False): - return inject_trigger_into_prompt( - prompt, - trigger=trigger, - to_replace_list=to_replace_list, - add_if_not_present=add_if_not_present, - ) - - def state_dict(self, vae=True, text_encoder=True, unet=True): - state_dict = OrderedDict() - if vae: - for k, v in self.vae.state_dict().items(): - new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" - state_dict[new_key] = v - if text_encoder: - if isinstance(self.text_encoder, list): - for i, encoder in enumerate(self.text_encoder): - for k, v in encoder.state_dict().items(): - new_key = k if k.startswith( - f"{SD_PREFIX_TEXT_ENCODER}{i}_") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}" - state_dict[new_key] = v - else: - for k, v in self.text_encoder.state_dict().items(): - new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}_") else f"{SD_PREFIX_TEXT_ENCODER}_{k}" - state_dict[new_key] = v - if unet: - for k, v in self.unet.state_dict().items(): - new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" - state_dict[new_key] = v - return state_dict - - def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \ - OrderedDict[ - str, Parameter]: - named_params: OrderedDict[str, Parameter] = OrderedDict() - if vae: - for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"): - named_params[name] = param - if text_encoder: - if isinstance(self.text_encoder, list): - for i, encoder in enumerate(self.text_encoder): - if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0: - # dont add these params - continue - if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1: - # dont add these params - continue - - for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"): - named_params[name] = param - else: - for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"): - named_params[name] = param - if unet: - if self.is_flux: - # Just train the middle 2 blocks of each transformer block - # block_list = [] - # num_transformer_blocks = 2 - # start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2) - # for i in range(num_transformer_blocks): - # block_list.append(self.unet.transformer_blocks[start_block + i]) - # - # num_single_transformer_blocks = 4 - # start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2) - # for i in range(num_single_transformer_blocks): - # block_list.append(self.unet.single_transformer_blocks[start_block + i]) - # - # for block in block_list: - # for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): - # named_params[name] = param - - # train the guidance embedding - # if self.unet.config.guidance_embeds: - # transformer: FluxTransformer2DModel = self.unet - # for name, param in transformer.time_text_embed.named_parameters(recurse=True, - # prefix=f"{SD_PREFIX_UNET}"): - # named_params[name] = param - - for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, - prefix=f"{SD_PREFIX_UNET}"): - named_params[name] = param - for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, - prefix=f"{SD_PREFIX_UNET}"): - named_params[name] = param - else: - for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): - named_params[name] = param - - if refiner: - for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"): - named_params[name] = param - - # convert to state dict keys, jsut replace . with _ on keys - if state_dict_keys: - new_named_params = OrderedDict() - for k, v in named_params.items(): - # replace only the first . with an _ - new_key = k.replace('.', '_', 1) - new_named_params[new_key] = v - named_params = new_named_params - - return named_params - - def save_refiner(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16')): - - # load the full refiner since we only train unet - if self.model_config.refiner_name_or_path is None: - raise ValueError("Refiner must be specified to save it") - refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml') - # load the refiner model - dtype = get_torch_dtype(self.dtype) - model_path = self.model_config._original_refiner_name_or_path - if not os.path.exists(model_path) or os.path.isdir(model_path): - # TODO only load unet?? - refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( - model_path, - dtype=dtype, - device='cpu', - # variant="fp16", - use_safetensors=True, - ) - else: - refiner = StableDiffusionXLImg2ImgPipeline.from_single_file( - model_path, - dtype=dtype, - device='cpu', - torch_dtype=self.torch_dtype, - original_config_file=refiner_config_path, - ) - # replace original unet - refiner.unet = self.refiner_unet - flush() - - diffusers_state_dict = OrderedDict() - for k, v in refiner.vae.state_dict().items(): - new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" - diffusers_state_dict[new_key] = v - for k, v in refiner.text_encoder_2.state_dict().items(): - new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}" - diffusers_state_dict[new_key] = v - for k, v in refiner.unet.state_dict().items(): - new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" - diffusers_state_dict[new_key] = v - - converted_state_dict = get_ldm_state_dict_from_diffusers( - diffusers_state_dict, - 'sdxl_refiner', - device='cpu', - dtype=save_dtype - ) - - # make sure parent folder exists - os.makedirs(os.path.dirname(output_file), exist_ok=True) - save_file(converted_state_dict, output_file, metadata=meta) - - if self.config_file is not None: - output_path_no_ext = os.path.splitext(output_file)[0] - output_config_path = f"{output_path_no_ext}.yaml" - shutil.copyfile(self.config_file, output_config_path) - - def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): - version_string = '1' - if self.is_v2: - version_string = '2' - if self.is_xl: - version_string = 'sdxl' - if self.is_ssd: - # overwrite sdxl because both wil be true here - version_string = 'ssd' - if self.is_ssd and self.is_vega: - version_string = 'vega' - # if output file does not end in .safetensors, then it is a directory and we are - # saving in diffusers format - if not output_file.endswith('.safetensors'): - # diffusers - # if self.is_pixart: - # self.unet.save_pretrained( - # save_directory=output_file, - # safe_serialization=True, - # ) - # else: - if self.is_flux: - # only save the unet - transformer: FluxTransformer2DModel = self.unet - transformer.save_pretrained( - save_directory=os.path.join(output_file, 'transformer'), - safe_serialization=True, - ) - else: - - self.pipeline.save_pretrained( - save_directory=output_file, - safe_serialization=True, - ) - # save out meta config - meta_path = os.path.join(output_file, 'aitk_meta.yaml') - with open(meta_path, 'w') as f: - yaml.dump(meta, f) - - else: - save_ldm_model_from_diffusers( - sd=self, - output_file=output_file, - meta=meta, - save_dtype=save_dtype, - sd_version=version_string, - ) - if self.config_file is not None: - output_path_no_ext = os.path.splitext(output_file)[0] - output_config_path = f"{output_path_no_ext}.yaml" - shutil.copyfile(self.config_file, output_config_path) - - def prepare_optimizer_params( - self, - unet=False, - text_encoder=False, - text_encoder_lr=None, - unet_lr=None, - refiner_lr=None, - refiner=False, - default_lr=1e-6, - ): - # todo maybe only get locon ones? - # not all items are saved, to make it match, we need to match out save mappings - # and not train anything not mapped. Also add learning rate - version = 'sd1' - if self.is_xl: - version = 'sdxl' - if self.is_v2: - version = 'sd2' - mapping_filename = f"stable_diffusion_{version}.json" - mapping_path = os.path.join(KEYMAPS_ROOT, mapping_filename) - with open(mapping_path, 'r') as f: - mapping = json.load(f) - ldm_diffusers_keymap = mapping['ldm_diffusers_keymap'] - - trainable_parameters = [] - - # we use state dict to find params - - if unet: - named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True) - unet_lr = unet_lr if unet_lr is not None else default_lr - params = [] - if self.is_pixart or self.is_auraflow or self.is_flux: - for param in named_params.values(): - if param.requires_grad: - params.append(param) - else: - for key, diffusers_key in ldm_diffusers_keymap.items(): - if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: - if named_params[diffusers_key].requires_grad: - params.append(named_params[diffusers_key]) - param_data = {"params": params, "lr": unet_lr} - trainable_parameters.append(param_data) - print(f"Found {len(params)} trainable parameter in unet") - - if text_encoder: - named_params = self.named_parameters(vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True) - text_encoder_lr = text_encoder_lr if text_encoder_lr is not None else default_lr - params = [] - for key, diffusers_key in ldm_diffusers_keymap.items(): - if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: - if named_params[diffusers_key].requires_grad: - params.append(named_params[diffusers_key]) - param_data = {"params": params, "lr": text_encoder_lr} - trainable_parameters.append(param_data) - - print(f"Found {len(params)} trainable parameter in text encoder") - - if refiner: - named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, - state_dict_keys=True) - refiner_lr = refiner_lr if refiner_lr is not None else default_lr - params = [] - for key, diffusers_key in ldm_diffusers_keymap.items(): - diffusers_key = f"refiner_{diffusers_key}" - if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: - if named_params[diffusers_key].requires_grad: - params.append(named_params[diffusers_key]) - param_data = {"params": params, "lr": refiner_lr} - trainable_parameters.append(param_data) - - print(f"Found {len(params)} trainable parameter in refiner") - - return trainable_parameters - - def save_device_state(self): - # saves the current device state for all modules - # this is useful for when we want to alter the state and restore it - if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: - unet_has_grad = self.unet.proj_out.weight.requires_grad - else: - unet_has_grad = self.unet.conv_in.weight.requires_grad - - self.device_state = { - **empty_preset, - 'vae': { - 'training': self.vae.training, - 'device': self.vae.device, - }, - 'unet': { - 'training': self.unet.training, - 'device': self.unet.device, - 'requires_grad': unet_has_grad, - }, - } - if isinstance(self.text_encoder, list): - self.device_state['text_encoder']: List[dict] = [] - for encoder in self.text_encoder: - try: - te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad - except: - te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad - self.device_state['text_encoder'].append({ - 'training': encoder.training, - 'device': encoder.device, - # todo there has to be a better way to do this - 'requires_grad': te_has_grad - }) - else: - if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel): - te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad - else: - te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad - - self.device_state['text_encoder'] = { - 'training': self.text_encoder.training, - 'device': self.text_encoder.device, - 'requires_grad': te_has_grad - } - if self.adapter is not None: - if isinstance(self.adapter, IPAdapter): - requires_grad = self.adapter.image_proj_model.training - adapter_device = self.unet.device - elif isinstance(self.adapter, T2IAdapter): - requires_grad = self.adapter.adapter.conv_in.weight.requires_grad - adapter_device = self.adapter.device - elif isinstance(self.adapter, ControlNetModel): - requires_grad = self.adapter.conv_in.training - adapter_device = self.adapter.device - elif isinstance(self.adapter, ClipVisionAdapter): - requires_grad = self.adapter.embedder.training - adapter_device = self.adapter.device - elif isinstance(self.adapter, CustomAdapter): - requires_grad = self.adapter.training - adapter_device = self.adapter.device - elif isinstance(self.adapter, ReferenceAdapter): - # todo update this!! - requires_grad = True - adapter_device = self.adapter.device - else: - raise ValueError(f"Unknown adapter type: {type(self.adapter)}") - self.device_state['adapter'] = { - 'training': self.adapter.training, - 'device': adapter_device, - 'requires_grad': requires_grad, - } - - if self.refiner_unet is not None: - self.device_state['refiner_unet'] = { - 'training': self.refiner_unet.training, - 'device': self.refiner_unet.device, - 'requires_grad': self.refiner_unet.conv_in.weight.requires_grad, - } - - def restore_device_state(self): - # restores the device state for all modules - # this is useful for when we want to alter the state and restore it - if self.device_state is None: - return - self.set_device_state(self.device_state) - self.device_state = None - - def set_device_state(self, state): - if state['vae']['training']: - self.vae.train() - else: - self.vae.eval() - self.vae.to(state['vae']['device']) - if state['unet']['training']: - self.unet.train() - else: - self.unet.eval() - self.unet.to(state['unet']['device']) - if state['unet']['requires_grad']: - self.unet.requires_grad_(True) - else: - self.unet.requires_grad_(False) - if isinstance(self.text_encoder, list): - for i, encoder in enumerate(self.text_encoder): - if isinstance(state['text_encoder'], list): - if state['text_encoder'][i]['training']: - encoder.train() - else: - encoder.eval() - encoder.to(state['text_encoder'][i]['device']) - encoder.requires_grad_(state['text_encoder'][i]['requires_grad']) - else: - if state['text_encoder']['training']: - encoder.train() - else: - encoder.eval() - encoder.to(state['text_encoder']['device']) - encoder.requires_grad_(state['text_encoder']['requires_grad']) - else: - if state['text_encoder']['training']: - self.text_encoder.train() - else: - self.text_encoder.eval() - self.text_encoder.to(state['text_encoder']['device']) - self.text_encoder.requires_grad_(state['text_encoder']['requires_grad']) - - if self.adapter is not None: - self.adapter.to(state['adapter']['device']) - self.adapter.requires_grad_(state['adapter']['requires_grad']) - if state['adapter']['training']: - self.adapter.train() - else: - self.adapter.eval() - - if self.refiner_unet is not None: - self.refiner_unet.to(state['refiner_unet']['device']) - self.refiner_unet.requires_grad_(state['refiner_unet']['requires_grad']) - if state['refiner_unet']['training']: - self.refiner_unet.train() - else: - self.refiner_unet.eval() - flush() - - def set_device_state_preset(self, device_state_preset: DeviceStatePreset): - # sets a preset for device state - - # save current state first - self.save_device_state() - - active_modules = [] - training_modules = [] - if device_state_preset in ['cache_latents']: - active_modules = ['vae'] - if device_state_preset in ['cache_clip']: - active_modules = ['clip'] - if device_state_preset in ['generate']: - active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet'] - - state = copy.deepcopy(empty_preset) - # vae - state['vae'] = { - 'training': 'vae' in training_modules, - 'device': self.vae_device_torch if 'vae' in active_modules else 'cpu', - 'requires_grad': 'vae' in training_modules, - } - - # unet - state['unet'] = { - 'training': 'unet' in training_modules, - 'device': self.device_torch if 'unet' in active_modules else 'cpu', - 'requires_grad': 'unet' in training_modules, - } - - if self.refiner_unet is not None: - state['refiner_unet'] = { - 'training': 'refiner_unet' in training_modules, - 'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu', - 'requires_grad': 'refiner_unet' in training_modules, - } - - # text encoder - if isinstance(self.text_encoder, list): - state['text_encoder'] = [] - for i, encoder in enumerate(self.text_encoder): - state['text_encoder'].append({ - 'training': 'text_encoder' in training_modules, - 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', - 'requires_grad': 'text_encoder' in training_modules, - }) - else: - state['text_encoder'] = { - 'training': 'text_encoder' in training_modules, - 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', - 'requires_grad': 'text_encoder' in training_modules, - } - - if self.adapter is not None: - state['adapter'] = { - 'training': 'adapter' in training_modules, - 'device': self.device_torch if 'adapter' in active_modules else 'cpu', - 'requires_grad': 'adapter' in training_modules, - } - - self.set_device_state(state) - - def text_encoder_to(self, *args, **kwargs): - if isinstance(self.text_encoder, list): - for encoder in self.text_encoder: - encoder.to(*args, **kwargs) - else: - self.text_encoder.to(*args, **kwargs)