Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import inspect | |
from typing import Any, Callable, Dict, List, Optional, Union | |
import os | |
import torch | |
from torch import Tensor | |
import torch.nn.functional as F | |
from diffusers.pipelines import FluxPipeline | |
from diffusers.utils import logging | |
from diffusers.loaders import TextualInversionLoaderMixin | |
from diffusers.pipelines.flux.pipeline_flux import FluxLoraLoaderMixin | |
from diffusers.models.transformers.transformer_flux import ( | |
USE_PEFT_BACKEND, | |
scale_lora_layers, | |
unscale_lora_layers, | |
logger, | |
) | |
from torchvision.transforms import ToPILImage | |
from peft.tuners.tuners_utils import BaseTunerLayer | |
# from optimum.quanto import ( | |
# freeze, quantize, QTensor, qfloat8, qint8, qint4, qint2, | |
# ) | |
import re | |
import safetensors | |
from src.adapters.mod_adapters import CLIPModAdapter | |
from peft import LoraConfig, set_peft_model_state_dict | |
from transformers import CLIPProcessor, CLIPModel, CLIPVisionModelWithProjection, CLIPVisionModel | |
def encode_vae_images(pipeline: FluxPipeline, images: Tensor): | |
images = pipeline.image_processor.preprocess(images) | |
images = images.to(pipeline.device).to(pipeline.dtype) | |
images = pipeline.vae.encode(images).latent_dist.sample() | |
images = ( | |
images - pipeline.vae.config.shift_factor | |
) * pipeline.vae.config.scaling_factor | |
images_tokens = pipeline._pack_latents(images, *images.shape) | |
images_ids = pipeline._prepare_latent_image_ids( | |
images.shape[0], | |
images.shape[2], | |
images.shape[3], | |
pipeline.device, | |
pipeline.dtype, | |
) | |
if images_tokens.shape[1] != images_ids.shape[0]: | |
images_ids = pipeline._prepare_latent_image_ids( | |
images.shape[0], | |
images.shape[2] // 2, | |
images.shape[3] // 2, | |
pipeline.device, | |
pipeline.dtype, | |
) | |
return images_tokens, images_ids | |
def decode_vae_images(pipeline: FluxPipeline, latents: Tensor, height, width, output_type: Optional[str] = "pil"): | |
latents = pipeline._unpack_latents(latents, height, width, pipeline.vae_scale_factor) | |
latents = (latents / pipeline.vae.config.scaling_factor) + pipeline.vae.config.shift_factor | |
image = pipeline.vae.decode(latents, return_dict=False)[0] | |
return pipeline.image_processor.postprocess(image, output_type=output_type) | |
def _get_clip_prompt_embeds( | |
self, | |
prompt: Union[str, List[str]], | |
num_images_per_prompt: int = 1, | |
device: Optional[torch.device] = None, | |
): | |
device = device or self._execution_device | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
batch_size = len(prompt) | |
if isinstance(self, TextualInversionLoaderMixin): | |
prompt = self.maybe_convert_prompt(prompt, self.tokenizer) | |
text_inputs = self.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=self.tokenizer_max_length, | |
truncation=True, | |
return_overflowing_tokens=False, | |
return_length=False, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) | |
# Use pooled output of CLIPTextModel | |
prompt_embeds = prompt_embeds.pooler_output | |
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) | |
# duplicate text embeddings for each generation per prompt, using mps friendly method | |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) | |
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) | |
return prompt_embeds | |
def encode_prompt_with_clip_t5( | |
self, | |
prompt: Union[str, List[str]], | |
prompt_2: Union[str, List[str]], | |
device: Optional[torch.device] = None, | |
num_images_per_prompt: int = 1, | |
prompt_embeds: Optional[torch.FloatTensor] = None, | |
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
max_sequence_length: int = 512, | |
lora_scale: Optional[float] = None, | |
): | |
r""" | |
Args: | |
prompt (`str` or `List[str]`, *optional*): | |
prompt to be encoded | |
prompt_2 (`str` or `List[str]`, *optional*): | |
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is | |
used in all text-encoders | |
device: (`torch.device`): | |
torch device | |
num_images_per_prompt (`int`): | |
number of images that should be generated per prompt | |
prompt_embeds (`torch.FloatTensor`, *optional*): | |
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
provided, text embeddings will be generated from `prompt` input argument. | |
pooled_prompt_embeds (`torch.FloatTensor`, *optional*): | |
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. | |
If not provided, pooled text embeddings will be generated from `prompt` input argument. | |
lora_scale (`float`, *optional*): | |
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. | |
""" | |
device = device or self._execution_device | |
# set lora scale so that monkey patched LoRA | |
# function of text encoder can correctly access it | |
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): | |
self._lora_scale = lora_scale | |
# dynamically adjust the LoRA scale | |
if self.text_encoder is not None and USE_PEFT_BACKEND: | |
scale_lora_layers(self.text_encoder, lora_scale) | |
if self.text_encoder_2 is not None and USE_PEFT_BACKEND: | |
scale_lora_layers(self.text_encoder_2, lora_scale) | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
if prompt_embeds is None: | |
prompt_2 = prompt_2 or prompt | |
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 | |
# We only use the pooled prompt output from the CLIPTextModel | |
pooled_prompt_embeds = _get_clip_prompt_embeds( | |
self=self, | |
prompt=prompt, | |
device=device, | |
num_images_per_prompt=num_images_per_prompt, | |
) | |
if self.text_encoder_2 is not None: | |
prompt_embeds = self._get_t5_prompt_embeds( | |
prompt=prompt_2, | |
num_images_per_prompt=num_images_per_prompt, | |
max_sequence_length=max_sequence_length, | |
device=device, | |
) | |
if self.text_encoder is not None: | |
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: | |
# Retrieve the original scale by scaling back the LoRA layers | |
unscale_lora_layers(self.text_encoder, lora_scale) | |
if self.text_encoder_2 is not None: | |
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: | |
# Retrieve the original scale by scaling back the LoRA layers | |
unscale_lora_layers(self.text_encoder_2, lora_scale) | |
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype | |
if self.text_encoder_2 is not None: | |
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) | |
else: | |
text_ids = None | |
return prompt_embeds, pooled_prompt_embeds, text_ids | |
def prepare_text_input( | |
pipeline: FluxPipeline, | |
prompts, | |
max_sequence_length=512, | |
): | |
# Turn off warnings (CLIP overflow) | |
logger.setLevel(logging.ERROR) | |
( | |
t5_prompt_embeds, | |
pooled_prompt_embeds, | |
text_ids, | |
) = encode_prompt_with_clip_t5( | |
self=pipeline, | |
prompt=prompts, | |
prompt_2=None, | |
prompt_embeds=None, | |
pooled_prompt_embeds=None, | |
device=pipeline.device, | |
num_images_per_prompt=1, | |
max_sequence_length=max_sequence_length, | |
lora_scale=None, | |
) | |
# Turn on warnings | |
logger.setLevel(logging.WARNING) | |
return t5_prompt_embeds, pooled_prompt_embeds, text_ids | |
def prepare_t5_input( | |
pipeline: FluxPipeline, | |
prompts, | |
max_sequence_length=512, | |
): | |
# Turn off warnings (CLIP overflow) | |
logger.setLevel(logging.ERROR) | |
( | |
t5_prompt_embeds, | |
pooled_prompt_embeds, | |
text_ids, | |
) = encode_prompt_with_clip_t5( | |
self=pipeline, | |
prompt=prompts, | |
prompt_2=None, | |
prompt_embeds=None, | |
pooled_prompt_embeds=None, | |
device=pipeline.device, | |
num_images_per_prompt=1, | |
max_sequence_length=max_sequence_length, | |
lora_scale=None, | |
) | |
# Turn on warnings | |
logger.setLevel(logging.WARNING) | |
return t5_prompt_embeds, pooled_prompt_embeds, text_ids | |
def tokenize_t5_prompt(pipe, input_prompt, max_length, **kargs): | |
return pipe.tokenizer_2( | |
input_prompt, | |
padding="max_length", | |
max_length=max_length, | |
truncation=True, | |
return_length=False, | |
return_overflowing_tokens=False, | |
return_tensors="pt", | |
**kargs, | |
) | |
def clear_attn_maps(transformer): | |
for i, block in enumerate(transformer.transformer_blocks): | |
if hasattr(block.attn, "attn_maps"): | |
del block.attn.attn_maps | |
del block.attn.timestep | |
for i, block in enumerate(transformer.single_transformer_blocks): | |
if hasattr(block.attn, "cond2latents"): | |
del block.attn.cond2latents | |
def gather_attn_maps(transformer, clear=False): | |
t2i_attn_maps = {} | |
i2t_attn_maps = {} | |
for i, block in enumerate(transformer.transformer_blocks): | |
name = f"block_{i}" | |
if hasattr(block.attn, "attn_maps"): | |
attention_maps = block.attn.attn_maps | |
timesteps = block.attn.timestep # (B,) | |
for (timestep, (t2i_attn_map, i2t_attn_map)) in zip(timesteps, attention_maps): | |
timestep = str(timestep.item()) | |
t2i_attn_maps[timestep] = t2i_attn_maps.get(timestep, dict()) | |
t2i_attn_maps[timestep][name] = t2i_attn_maps[timestep].get(name, []) | |
t2i_attn_maps[timestep][name].append(t2i_attn_map.cpu()) | |
i2t_attn_maps[timestep] = i2t_attn_maps.get(timestep, dict()) | |
i2t_attn_maps[timestep][name] = i2t_attn_maps[timestep].get(name, []) | |
i2t_attn_maps[timestep][name].append(i2t_attn_map.cpu()) | |
if clear: | |
del block.attn.attn_maps | |
for timestep in t2i_attn_maps: | |
for name in t2i_attn_maps[timestep]: | |
t2i_attn_maps[timestep][name] = torch.cat(t2i_attn_maps[timestep][name], dim=0) | |
i2t_attn_maps[timestep][name] = torch.cat(i2t_attn_maps[timestep][name], dim=0) | |
return t2i_attn_maps, i2t_attn_maps | |
def process_token(token, startofword): | |
if '</w>' in token: | |
token = token.replace('</w>', '') | |
if startofword: | |
token = '<' + token + '>' | |
else: | |
token = '-' + token + '>' | |
startofword = True | |
elif token not in ['<|startoftext|>', '<|endoftext|>']: | |
if startofword: | |
token = '<' + token + '-' | |
startofword = False | |
else: | |
token = '-' + token + '-' | |
return token, startofword | |
def save_attention_image(attn_map, tokens, batch_dir, to_pil): | |
startofword = True | |
for i, (token, a) in enumerate(zip(tokens, attn_map[:len(tokens)])): | |
token, startofword = process_token(token, startofword) | |
token = token.replace("/", "-") | |
if token == '-<pad>-': | |
continue | |
a = a.to(torch.float32) | |
a = a / a.max() * 255 / 256 | |
to_pil(a).save(os.path.join(batch_dir, f'{i}-{token}.png')) | |
def save_attention_maps(attn_maps, pipe, prompts, base_dir='attn_maps'): | |
to_pil = ToPILImage() | |
token_ids = tokenize_t5_prompt(pipe, prompts, 512).input_ids # (B, 512) | |
token_ids = [x for x in token_ids] | |
total_tokens = [pipe.tokenizer_2.convert_ids_to_tokens(token_id) for token_id in token_ids] | |
os.makedirs(base_dir, exist_ok=True) | |
total_attn_map_shape = (256, 256) | |
total_attn_map_number = 0 | |
# (B, 24, H, W, 512) -> (B, H, W, 512) -> (B, 512, H, W) | |
print(attn_maps.keys()) | |
total_attn_map = list(list(attn_maps.values())[0].values())[0].sum(1) | |
total_attn_map = total_attn_map.permute(0, 3, 1, 2) | |
total_attn_map = torch.zeros_like(total_attn_map) | |
total_attn_map = F.interpolate(total_attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False) | |
for timestep, layers in attn_maps.items(): | |
timestep_dir = os.path.join(base_dir, f'{timestep}') | |
os.makedirs(timestep_dir, exist_ok=True) | |
for layer, attn_map in layers.items(): | |
layer_dir = os.path.join(timestep_dir, f'{layer}') | |
os.makedirs(layer_dir, exist_ok=True) | |
attn_map = attn_map.sum(1).squeeze(1).permute(0, 3, 1, 2) | |
resized_attn_map = F.interpolate(attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False) | |
total_attn_map += resized_attn_map | |
total_attn_map_number += 1 | |
for batch, (attn_map, tokens) in enumerate(zip(resized_attn_map, total_tokens)): | |
save_attention_image(attn_map, tokens, layer_dir, to_pil) | |
# for batch, (tokens, attn) in enumerate(zip(total_tokens, attn_map)): | |
# batch_dir = os.path.join(layer_dir, f'batch-{batch}') | |
# os.makedirs(batch_dir, exist_ok=True) | |
# save_attention_image(attn, tokens, batch_dir, to_pil) | |
total_attn_map /= total_attn_map_number | |
for batch, (attn_map, tokens) in enumerate(zip(total_attn_map, total_tokens)): | |
batch_dir = os.path.join(base_dir, f'batch-{batch}') | |
os.makedirs(batch_dir, exist_ok=True) | |
save_attention_image(attn_map, tokens, batch_dir, to_pil) | |
def gather_cond2latents(transformer, clear=False): | |
c2l_attn_maps = {} | |
# for i, block in enumerate(transformer.transformer_blocks): | |
for i, block in enumerate(transformer.single_transformer_blocks): | |
name = f"block_{i}" | |
if hasattr(block.attn, "cond2latents"): | |
attention_maps = block.attn.cond2latents | |
timesteps = block.attn.cond_timesteps # (B,) | |
for (timestep, c2l_attn_map) in zip(timesteps, attention_maps): | |
timestep = str(timestep.item()) | |
c2l_attn_maps[timestep] = c2l_attn_maps.get(timestep, dict()) | |
c2l_attn_maps[timestep][name] = c2l_attn_maps[timestep].get(name, []) | |
c2l_attn_maps[timestep][name].append(c2l_attn_map.cpu()) | |
if clear: | |
# del block.attn.attn_maps | |
del block.attn.cond2latents | |
del block.attn.cond_timesteps | |
for timestep in c2l_attn_maps: | |
for name in c2l_attn_maps[timestep]: | |
c2l_attn_maps[timestep][name] = torch.cat(c2l_attn_maps[timestep][name], dim=0) | |
return c2l_attn_maps | |
def save_cond2latent_image(attn_map, batch_dir, to_pil): | |
for i, a in enumerate(attn_map): # (N, H, W) | |
a = a.to(torch.float32) | |
a = a / a.max() * 255 / 256 | |
to_pil(a).save(os.path.join(batch_dir, f'{i}.png')) | |
def save_cond2latent(attn_maps, base_dir='attn_maps'): | |
to_pil = ToPILImage() | |
os.makedirs(base_dir, exist_ok=True) | |
total_attn_map_shape = (256, 256) | |
total_attn_map_number = 0 | |
# (N, H, W) -> (1, N, H, W) | |
total_attn_map = list(list(attn_maps.values())[0].values())[0].unsqueeze(0) | |
total_attn_map = torch.zeros_like(total_attn_map) | |
total_attn_map = F.interpolate(total_attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False) | |
for timestep, layers in attn_maps.items(): | |
cur_ts_attn_map = torch.zeros_like(total_attn_map) | |
cur_ts_attn_map_number = 0 | |
timestep_dir = os.path.join(base_dir, f'{timestep}') | |
os.makedirs(timestep_dir, exist_ok=True) | |
for layer, attn_map in layers.items(): | |
# layer_dir = os.path.join(timestep_dir, f'{layer}') | |
# os.makedirs(layer_dir, exist_ok=True) | |
attn_map = attn_map.unsqueeze(0) # (1, N, H, W) | |
resized_attn_map = F.interpolate(attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False) | |
cur_ts_attn_map += resized_attn_map | |
cur_ts_attn_map_number += 1 | |
for batch, attn_map in enumerate(cur_ts_attn_map / cur_ts_attn_map_number): | |
save_cond2latent_image(attn_map, timestep_dir, to_pil) | |
total_attn_map += cur_ts_attn_map | |
total_attn_map_number += cur_ts_attn_map_number | |
total_attn_map /= total_attn_map_number | |
for batch, attn_map in enumerate(total_attn_map): | |
batch_dir = os.path.join(base_dir, f'batch-{batch}') | |
os.makedirs(batch_dir, exist_ok=True) | |
save_cond2latent_image(attn_map, batch_dir, to_pil) | |
def quantization(pipe, qtype): | |
if qtype != "None" and qtype != "": | |
if qtype.endswith("quanto"): | |
if qtype == "int2-quanto": | |
quant_level = qint2 | |
elif qtype == "int4-quanto": | |
quant_level = qint4 | |
elif qtype == "int8-quanto": | |
quant_level = qint8 | |
elif qtype == "fp8-quanto": | |
quant_level = qfloat8 | |
else: | |
raise ValueError(f"Invalid quantisation level: {qtype}") | |
extra_quanto_args = {} | |
extra_quanto_args["exclude"] = [ | |
"*.norm", | |
"*.norm1", | |
"*.norm2", | |
"*.norm2_context", | |
"proj_out", | |
"x_embedder", | |
"norm_out", | |
"context_embedder", | |
] | |
try: | |
quantize(pipe.transformer, weights=quant_level, **extra_quanto_args) | |
quantize(pipe.text_encoder_2, weights=quant_level, **extra_quanto_args) | |
print("[Quantization] Start freezing") | |
freeze(pipe.transformer) | |
freeze(pipe.text_encoder_2) | |
print("[Quantization] Finished") | |
except Exception as e: | |
if "out of memory" in str(e).lower(): | |
print( | |
"GPU ran out of memory during quantisation. Use --quantize_via=cpu to use the slower CPU method." | |
) | |
raise e | |
else: | |
assert qtype == "fp8-ao" | |
from torchao.float8 import convert_to_float8_training, Float8LinearConfig | |
def module_filter_fn(mod: torch.nn.Module, fqn: str): | |
# don't convert the output module | |
if fqn == "proj_out": | |
return False | |
# don't convert linear modules with weight dimensions not divisible by 16 | |
if isinstance(mod, torch.nn.Linear): | |
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: | |
return False | |
return True | |
convert_to_float8_training( | |
pipe.transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) | |
) | |
class CustomFluxPipeline: | |
def __init__( | |
self, | |
config, | |
device="cuda", | |
ckpt_root=None, | |
ckpt_root_condition=None, | |
torch_dtype=torch.bfloat16, | |
): | |
print("[CustomFluxPipeline] Loading FLUX Pipeline") | |
if config["model"].get("dit_quant", "None")=="int8-quanto": | |
self.pipe = FluxPipeline.from_pretrained("diffusers/FLUX.1-dev-torchao-int8", | |
torch_dtype=torch_dtype, | |
use_safetensors=False).to(device) | |
else: | |
self.pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", | |
torch_dtype=torch_dtype).to(device) | |
# self.pipe.enable_sequential_cpu_offload() | |
self.config = config | |
self.device = device | |
self.dtype = torch_dtype | |
# if config["model"].get("dit_quant", "None") != "None": | |
# quantization(self.pipe, config["model"]["dit_quant"]) | |
self.modulation_adapters = [] | |
self.pipe.modulation_adapters = [] | |
try: | |
if config["model"]["modulation"]["use_clip"]: | |
load_clip(self, config, torch_dtype, device, None, is_training=False) | |
except Exception as e: | |
print(e) | |
if config["model"]["use_dit_lora"] or config["model"]["use_condition_dblock_lora"] or config["model"]["use_condition_sblock_lora"]: | |
if ckpt_root_condition is None and (config["model"]["use_condition_dblock_lora"] or config["model"]["use_condition_sblock_lora"]): | |
ckpt_root_condition = ckpt_root | |
load_dit_lora(self, self.pipe, config, torch_dtype, device, f"{ckpt_root}", f"{ckpt_root_condition}", is_training=False) | |
def add_modulation_adapter(self, modulation_adapter): | |
self.modulation_adapters.append(modulation_adapter) | |
self.pipe.modulation_adapters.append(modulation_adapter) | |
def clear_modulation_adapters(self): | |
self.modulation_adapters = [] | |
self.pipe.modulation_adapters = [] | |
torch.cuda.empty_cache() | |
def load_clip(self, config, torch_dtype, device, ckpt_dir=None, is_training=False): | |
model_path = os.getenv("CLIP_MODEL_PATH", "openai/clip-vit-large-patch14") | |
clip_model = CLIPVisionModelWithProjection.from_pretrained(model_path).to(device, dtype=torch_dtype) | |
clip_processor = CLIPProcessor.from_pretrained(model_path) | |
self.pipe.clip_model = clip_model | |
self.pipe.clip_processor = clip_processor | |
def load_dit_lora(self, pipe, config, torch_dtype, device, ckpt_dir=None, condition_ckpt_dir=None, is_training=False): | |
if not config["model"]["use_condition_dblock_lora"] and not config["model"]["use_condition_sblock_lora"] and not config["model"]["use_dit_lora"]: | |
print("[load_dit_lora] no dit lora, no condition lora") | |
return [] | |
adapter_names = ["default", "condition"] | |
if condition_ckpt_dir is None: | |
condition_ckpt_dir = ckpt_dir | |
if not config["model"]["use_condition_dblock_lora"] and not config["model"]["use_condition_sblock_lora"]: | |
print("[load_dit_lora] no condition lora") | |
adapter_names.pop(1) | |
elif condition_ckpt_dir is not None and os.path.exists(os.path.join(condition_ckpt_dir, "pytorch_lora_weights_condition.safetensors")): | |
assert "condition" in adapter_names | |
print(f"[load_dit_lora] load condition lora from {condition_ckpt_dir}") | |
pipe.transformer.load_lora_adapter(condition_ckpt_dir, use_safetensors=True, adapter_name="condition", weight_name="pytorch_lora_weights_condition.safetensors") # TODO: check if they are trainable | |
else: | |
assert is_training | |
assert "condition" in adapter_names | |
print("[load_dit_lora] init new condition lora") | |
pipe.transformer.add_adapter(LoraConfig(**config["model"]["condition_lora_config"]), adapter_name="condition") | |
if not config["model"]["use_dit_lora"]: | |
print("[load_dit_lora] no dit lora") | |
adapter_names.pop(0) | |
elif ckpt_dir is not None and os.path.exists(os.path.join(ckpt_dir, "pytorch_lora_weights.safetensors")): | |
assert "default" in adapter_names | |
print(f"[load_dit_lora] load dit lora from {ckpt_dir}") | |
lora_file = os.path.join(ckpt_dir, "pytorch_lora_weights.safetensors") | |
lora_state_dict = safetensors.torch.load_file(lora_file, device="cpu") | |
single_lora_pattern = "(.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)" | |
latent_lora_pattern = "(.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2)" | |
use_pretrained_dit_single_lora = config["model"].get("use_pretrained_dit_single_lora", True) | |
use_pretrained_dit_latent_lora = config["model"].get("use_pretrained_dit_latent_lora", True) | |
if not use_pretrained_dit_single_lora or not use_pretrained_dit_latent_lora: | |
lora_state_dict_keys = list(lora_state_dict.keys()) | |
for layer_name in lora_state_dict_keys: | |
if not use_pretrained_dit_single_lora: | |
if re.search(single_lora_pattern, layer_name): | |
del lora_state_dict[layer_name] | |
if not use_pretrained_dit_latent_lora: | |
if re.search(latent_lora_pattern, layer_name): | |
del lora_state_dict[layer_name] | |
pipe.transformer.add_adapter(LoraConfig(**config["model"]["dit_lora_config"]), adapter_name="default") | |
set_peft_model_state_dict(pipe.transformer, lora_state_dict, adapter_name="default") | |
else: | |
pipe.transformer.load_lora_adapter(ckpt_dir, use_safetensors=True, adapter_name="default", weight_name="pytorch_lora_weights.safetensors") # TODO: check if they are trainable | |
else: | |
assert is_training | |
assert "default" in adapter_names | |
print("[load_dit_lora] init new dit lora") | |
pipe.transformer.add_adapter(LoraConfig(**config["model"]["dit_lora_config"]), adapter_name="default") | |
assert len(adapter_names) <= 2 and len(adapter_names) > 0 | |
for name, module in pipe.transformer.named_modules(): | |
if isinstance(module, BaseTunerLayer): | |
module.set_adapter(adapter_names) | |
if "default" in adapter_names: assert config["model"]["use_dit_lora"] | |
if "condition" in adapter_names: assert config["model"]["use_condition_dblock_lora"] or config["model"]["use_condition_sblock_lora"] | |
lora_layers = list(filter( | |
lambda p: p[1].requires_grad, pipe.transformer.named_parameters() | |
)) | |
lora_layers = [l[1] for l in lora_layers] | |
return lora_layers | |
def load_modulation_adapter(self, config, torch_dtype, device, ckpt_dir=None, is_training=False): | |
adapter_type = config["model"]["modulation"]["adapter_type"] | |
if ckpt_dir is not None and os.path.exists(ckpt_dir): | |
print(f"loading modulation adapter from {ckpt_dir}") | |
modulation_adapter = CLIPModAdapter.from_pretrained( | |
ckpt_dir, subfolder="modulation_adapter", strict=False, | |
low_cpu_mem_usage=False, device_map=None, | |
).to(device) | |
else: | |
print(f"Init new modulation adapter") | |
adapter_layers = config["model"]["modulation"]["adapter_layers"] | |
adapter_width = config["model"]["modulation"]["adapter_width"] | |
pblock_adapter_layers = config["model"]["modulation"]["per_block_adapter_layers"] | |
pblock_adapter_width = config["model"]["modulation"]["per_block_adapter_width"] | |
pblock_adapter_single_blocks = config["model"]["modulation"]["per_block_adapter_single_blocks"] | |
use_text_mod = config["model"]["modulation"]["use_text_mod"] | |
use_img_mod = config["model"]["modulation"]["use_img_mod"] | |
out_dim = config["model"]["modulation"]["out_dim"] | |
if adapter_type == "clip_adapter": | |
modulation_adapter = CLIPModAdapter( | |
out_dim=out_dim, | |
width=adapter_width, | |
pblock_width=pblock_adapter_width, | |
layers=adapter_layers, | |
pblock_layers=pblock_adapter_layers, | |
heads=8, | |
input_text_dim=4096, | |
input_image_dim=1024, | |
pblock_single_blocks=pblock_adapter_single_blocks, | |
) | |
else: | |
raise NotImplementedError() | |
if is_training: | |
modulation_adapter.train() | |
try: | |
modulation_adapter.enable_gradient_checkpointing() | |
except Exception as e: | |
print(e) | |
if not config["model"]["modulation"]["use_perblock_adapter"]: | |
try: | |
modulation_adapter.net2.requires_grad_(False) | |
except Exception as e: | |
print(e) | |
else: | |
modulation_adapter.requires_grad_(False) | |
modulation_adapter.to(device, dtype=torch_dtype) | |
return modulation_adapter | |
def load_ckpt(self, ckpt_dir, is_training=False): | |
if self.config["model"]["use_dit_lora"]: | |
self.pipe.transformer.delete_adapters(["subject"]) | |
lora_path = f"{ckpt_dir}/pytorch_lora_weights.safetensors" | |
print(f"Loading DIT Lora from {lora_path}") | |
self.pipe.load_lora_weights(lora_path, adapter_name="subject") | |