Spaces:
Runtime error
Runtime error
WIP
Browse files- .gitignore +2 -0
- app.py +113 -11
- lora_diffusion/FOR-cloneofsimo-LoRA +6 -0
- lora_diffusion/__init__.py +5 -0
- lora_diffusion/cli_lora_add.py +187 -0
- lora_diffusion/cli_lora_pti.py +1040 -0
- lora_diffusion/cli_pt_to_safetensors.py +85 -0
- lora_diffusion/cli_svd.py +146 -0
- lora_diffusion/dataset.py +311 -0
- lora_diffusion/lora.py +1110 -0
- lora_diffusion/lora_manager.py +144 -0
- lora_diffusion/preprocess_files.py +327 -0
- lora_diffusion/safe_open.py +68 -0
- lora_diffusion/to_ckpt_v2.py +232 -0
- lora_diffusion/utils.py +214 -0
- lora_diffusion/xformers_utils.py +70 -0
- requirements.txt +3 -0
- train_dreambooth_cloneofsimo_lora.py +1008 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*/__pycache__/
|
| 2 |
+
*/*.pyc
|
app.py
CHANGED
|
@@ -1,16 +1,118 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
| 2 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
def
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
iface = gr.Interface(
|
| 9 |
-
fn=load_csv,
|
| 10 |
-
inputs="file",
|
| 11 |
-
outputs="dataframe",
|
| 12 |
-
title="CSV Loader",
|
| 13 |
-
description="Load a CSV file and display its contents.",
|
| 14 |
-
)
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import shutil
|
| 3 |
+
import zipfile
|
| 4 |
+
import tensorflow as tf
|
| 5 |
import pandas as pd
|
| 6 |
+
import pathlib
|
| 7 |
+
import PIL.Image
|
| 8 |
+
import os
|
| 9 |
+
import subprocess
|
| 10 |
|
| 11 |
+
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
|
| 12 |
+
w, h = image.size
|
| 13 |
+
if w == h:
|
| 14 |
+
return image
|
| 15 |
+
elif w > h:
|
| 16 |
+
new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
|
| 17 |
+
new_image.paste(image, (0, (w - h) // 2))
|
| 18 |
+
return new_image
|
| 19 |
+
else:
|
| 20 |
+
new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
|
| 21 |
+
new_image.paste(image, ((h - w) // 2, 0))
|
| 22 |
+
return new_image
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
class ModelTrainer:
|
| 26 |
+
def __init__(self):
|
| 27 |
+
self.training_pictures = []
|
| 28 |
+
self.training_model = None
|
| 29 |
+
|
| 30 |
+
def unzip_file(self, zip_file_path):
|
| 31 |
+
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
|
| 32 |
+
extracted_path = zip_file_path.replace('.zip', '')
|
| 33 |
+
zip_ref.extractall(extracted_path)
|
| 34 |
+
file_names = zip_ref.namelist()
|
| 35 |
+
for file_name in file_names:
|
| 36 |
+
if file_name.endswith(('.jpeg', '.jpg', '.png')):
|
| 37 |
+
self.training_pictures.append(f'{extracted_path}/{file_name}')
|
| 38 |
+
|
| 39 |
+
def train(self, pretrained_model_name_or_path: str, instance_images: list | None):
|
| 40 |
+
output_model_name = 'a-xyz-model'
|
| 41 |
+
resolution = 512
|
| 42 |
+
repo_dir = pathlib.Path(__file__).parent
|
| 43 |
+
subdirs = ['train-instance', 'train-class', 'experiments']
|
| 44 |
+
dir_paths = []
|
| 45 |
+
|
| 46 |
+
for subdir in subdirs:
|
| 47 |
+
dir_path = repo_dir / subdir / output_model_name
|
| 48 |
+
dir_paths.append(dir_path)
|
| 49 |
+
shutil.rmtree(dir_path, ignore_errors=True)
|
| 50 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
instance_data_dir, class_data_dir, output_dir = dir_paths
|
| 53 |
+
|
| 54 |
+
for i, temp_path in enumerate(instance_images):
|
| 55 |
+
image = PIL.Image.open(temp_path.name)
|
| 56 |
+
image = pad_image(image)
|
| 57 |
+
image = image.resize((resolution, resolution))
|
| 58 |
+
image = image.convert('RGB')
|
| 59 |
+
out_path = instance_data_dir / f'{i:03d}.jpg'
|
| 60 |
+
image.save(out_path, format='JPEG', quality=100)
|
| 61 |
+
|
| 62 |
+
command = [
|
| 63 |
+
'python', '-u',
|
| 64 |
+
'train_dreambooth_cloneofsimo_lora.py',
|
| 65 |
+
'--pretrained_model_name_or_path', pretrained_model_name_or_path,
|
| 66 |
+
'--instance_data_dir', instance_data_dir,
|
| 67 |
+
'--class_data_dir', class_data_dir,
|
| 68 |
+
'--resolution', '768',
|
| 69 |
+
'--output_dir', output_dir,
|
| 70 |
+
'--instance_prompt', 'a photo of a pwsm dog',
|
| 71 |
+
'--with_prior_preservation',
|
| 72 |
+
'--class_prompt', 'a dog',
|
| 73 |
+
'--prior_loss_weight', '1.0',
|
| 74 |
+
'--num_class_images', '100',
|
| 75 |
+
'--learning_rate', '0.0004',
|
| 76 |
+
'--train_batch_size', '1',
|
| 77 |
+
'--sample_batch_size', '1',
|
| 78 |
+
'--max_train_steps', '400',
|
| 79 |
+
'--gradient_accumulation_steps', '1',
|
| 80 |
+
'--gradient_checkpointing',
|
| 81 |
+
'--train_text_encoder',
|
| 82 |
+
'--learning_rate_text', '5e-6',
|
| 83 |
+
'--save_steps', '100',
|
| 84 |
+
'--seed', '1337',
|
| 85 |
+
'--lr_scheduler', 'constant',
|
| 86 |
+
'--lr_warmup_steps', '0'
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
result = subprocess.run(command)
|
| 90 |
+
return result
|
| 91 |
+
|
| 92 |
+
def generate_picture(self, row):
|
| 93 |
+
num_of_training_steps, learning_rate, checkpoint_steps, abc = row
|
| 94 |
+
return f'Picture generated for num_of_training_steps: {num_of_training_steps}, learning_rate: {learning_rate}, checkpoint_steps: {checkpoint_steps}'
|
| 95 |
+
|
| 96 |
+
def generate_pictures(self, csv_input):
|
| 97 |
+
csv = pd.read_csv(csv_input.name)
|
| 98 |
+
result = []
|
| 99 |
+
for index, row in csv.iterrows():
|
| 100 |
+
result.append(self.generate_picture(row))
|
| 101 |
+
return "\n".join(str(item) for item in result)
|
| 102 |
+
|
| 103 |
+
loader = ModelTrainer()
|
| 104 |
+
|
| 105 |
+
with gr.Blocks() as demo:
|
| 106 |
+
with gr.Box():
|
| 107 |
+
instance_images = gr.Files(label='Instance images')
|
| 108 |
+
pretrained_model_name_or_path = gr.inputs.Textbox(lines=1, label='pretrained_model_name_or_path', default='stabilityai/stable-diffusion-2-1')
|
| 109 |
+
output_message = gr.Markdown()
|
| 110 |
+
train_button = gr.Button('Train')
|
| 111 |
+
train_button.click(fn=loader.train, inputs=[pretrained_model_name_or_path, instance_images], outputs=[output_message])
|
| 112 |
+
with gr.Box():
|
| 113 |
+
csv_input = gr.inputs.File(label='CSV File')
|
| 114 |
+
output_message2 = gr.Markdown()
|
| 115 |
+
generate_button = gr.Button('Generate Pictures from CSV')
|
| 116 |
+
generate_button.click(fn=loader.generate_pictures, inputs=[csv_input], outputs=[output_message2])
|
| 117 |
+
|
| 118 |
+
demo.launch()
|
lora_diffusion/FOR-cloneofsimo-LoRA
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
This 'lora_diffusion' library in this subdirectory is required by
|
| 2 |
+
'train_dreambooth_cloneofsimo_lora.py' script and is the underlying library in the
|
| 3 |
+
https://github.com/cloneofsimo/lora project.
|
| 4 |
+
|
| 5 |
+
The 'train_dreambooth_cloneofsimo_lora.py' script, in turn, is merely a renamed copy
|
| 6 |
+
of 'traning_scripts/train_lora_dreambooth.py' from that same project.
|
lora_diffusion/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .lora import *
|
| 2 |
+
from .dataset import *
|
| 3 |
+
from .utils import *
|
| 4 |
+
from .preprocess_files import *
|
| 5 |
+
from .lora_manager import *
|
lora_diffusion/cli_lora_add.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal, Union, Dict
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import fire
|
| 5 |
+
from diffusers import StableDiffusionPipeline
|
| 6 |
+
from safetensors.torch import safe_open, save_file
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from .lora import (
|
| 10 |
+
tune_lora_scale,
|
| 11 |
+
patch_pipe,
|
| 12 |
+
collapse_lora,
|
| 13 |
+
monkeypatch_remove_lora,
|
| 14 |
+
)
|
| 15 |
+
from .lora_manager import lora_join
|
| 16 |
+
from .to_ckpt_v2 import convert_to_ckpt
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _text_lora_path(path: str) -> str:
|
| 20 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
| 21 |
+
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def add(
|
| 25 |
+
path_1: str,
|
| 26 |
+
path_2: str,
|
| 27 |
+
output_path: str,
|
| 28 |
+
alpha_1: float = 0.5,
|
| 29 |
+
alpha_2: float = 0.5,
|
| 30 |
+
mode: Literal[
|
| 31 |
+
"lpl",
|
| 32 |
+
"upl",
|
| 33 |
+
"upl-ckpt-v2",
|
| 34 |
+
] = "lpl",
|
| 35 |
+
with_text_lora: bool = False,
|
| 36 |
+
):
|
| 37 |
+
print("Lora Add, mode " + mode)
|
| 38 |
+
if mode == "lpl":
|
| 39 |
+
if path_1.endswith(".pt") and path_2.endswith(".pt"):
|
| 40 |
+
for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + (
|
| 41 |
+
[(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")]
|
| 42 |
+
if with_text_lora
|
| 43 |
+
else []
|
| 44 |
+
):
|
| 45 |
+
print("Loading", _path_1, _path_2)
|
| 46 |
+
out_list = []
|
| 47 |
+
if opt == "text_encoder":
|
| 48 |
+
if not os.path.exists(_path_1):
|
| 49 |
+
print(f"No text encoder found in {_path_1}, skipping...")
|
| 50 |
+
continue
|
| 51 |
+
if not os.path.exists(_path_2):
|
| 52 |
+
print(f"No text encoder found in {_path_1}, skipping...")
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
l1 = torch.load(_path_1)
|
| 56 |
+
l2 = torch.load(_path_2)
|
| 57 |
+
|
| 58 |
+
l1pairs = zip(l1[::2], l1[1::2])
|
| 59 |
+
l2pairs = zip(l2[::2], l2[1::2])
|
| 60 |
+
|
| 61 |
+
for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
|
| 62 |
+
# print("Merging", x1.shape, y1.shape, x2.shape, y2.shape)
|
| 63 |
+
x1.data = alpha_1 * x1.data + alpha_2 * x2.data
|
| 64 |
+
y1.data = alpha_1 * y1.data + alpha_2 * y2.data
|
| 65 |
+
|
| 66 |
+
out_list.append(x1)
|
| 67 |
+
out_list.append(y1)
|
| 68 |
+
|
| 69 |
+
if opt == "unet":
|
| 70 |
+
|
| 71 |
+
print("Saving merged UNET to", output_path)
|
| 72 |
+
torch.save(out_list, output_path)
|
| 73 |
+
|
| 74 |
+
elif opt == "text_encoder":
|
| 75 |
+
print("Saving merged text encoder to", _text_lora_path(output_path))
|
| 76 |
+
torch.save(
|
| 77 |
+
out_list,
|
| 78 |
+
_text_lora_path(output_path),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
elif path_1.endswith(".safetensors") and path_2.endswith(".safetensors"):
|
| 82 |
+
safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
|
| 83 |
+
safeloras_2 = safe_open(path_2, framework="pt", device="cpu")
|
| 84 |
+
|
| 85 |
+
metadata = dict(safeloras_1.metadata())
|
| 86 |
+
metadata.update(dict(safeloras_2.metadata()))
|
| 87 |
+
|
| 88 |
+
ret_tensor = {}
|
| 89 |
+
|
| 90 |
+
for keys in set(list(safeloras_1.keys()) + list(safeloras_2.keys())):
|
| 91 |
+
if keys.startswith("text_encoder") or keys.startswith("unet"):
|
| 92 |
+
|
| 93 |
+
tens1 = safeloras_1.get_tensor(keys)
|
| 94 |
+
tens2 = safeloras_2.get_tensor(keys)
|
| 95 |
+
|
| 96 |
+
tens = alpha_1 * tens1 + alpha_2 * tens2
|
| 97 |
+
ret_tensor[keys] = tens
|
| 98 |
+
else:
|
| 99 |
+
if keys in safeloras_1.keys():
|
| 100 |
+
|
| 101 |
+
tens1 = safeloras_1.get_tensor(keys)
|
| 102 |
+
else:
|
| 103 |
+
tens1 = safeloras_2.get_tensor(keys)
|
| 104 |
+
|
| 105 |
+
ret_tensor[keys] = tens1
|
| 106 |
+
|
| 107 |
+
save_file(ret_tensor, output_path, metadata)
|
| 108 |
+
|
| 109 |
+
elif mode == "upl":
|
| 110 |
+
|
| 111 |
+
print(
|
| 112 |
+
f"Merging UNET/CLIP from {path_1} with LoRA from {path_2} to {output_path}. Merging ratio : {alpha_1}."
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
loaded_pipeline = StableDiffusionPipeline.from_pretrained(
|
| 116 |
+
path_1,
|
| 117 |
+
).to("cpu")
|
| 118 |
+
|
| 119 |
+
patch_pipe(loaded_pipeline, path_2)
|
| 120 |
+
|
| 121 |
+
collapse_lora(loaded_pipeline.unet, alpha_1)
|
| 122 |
+
collapse_lora(loaded_pipeline.text_encoder, alpha_1)
|
| 123 |
+
|
| 124 |
+
monkeypatch_remove_lora(loaded_pipeline.unet)
|
| 125 |
+
monkeypatch_remove_lora(loaded_pipeline.text_encoder)
|
| 126 |
+
|
| 127 |
+
loaded_pipeline.save_pretrained(output_path)
|
| 128 |
+
|
| 129 |
+
elif mode == "upl-ckpt-v2":
|
| 130 |
+
|
| 131 |
+
assert output_path.endswith(".ckpt"), "Only .ckpt files are supported"
|
| 132 |
+
name = os.path.basename(output_path)[0:-5]
|
| 133 |
+
|
| 134 |
+
print(
|
| 135 |
+
f"You will be using {name} as the token in A1111 webui. Make sure {name} is unique enough token."
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
loaded_pipeline = StableDiffusionPipeline.from_pretrained(
|
| 139 |
+
path_1,
|
| 140 |
+
).to("cpu")
|
| 141 |
+
|
| 142 |
+
tok_dict = patch_pipe(loaded_pipeline, path_2, patch_ti=False)
|
| 143 |
+
|
| 144 |
+
collapse_lora(loaded_pipeline.unet, alpha_1)
|
| 145 |
+
collapse_lora(loaded_pipeline.text_encoder, alpha_1)
|
| 146 |
+
|
| 147 |
+
monkeypatch_remove_lora(loaded_pipeline.unet)
|
| 148 |
+
monkeypatch_remove_lora(loaded_pipeline.text_encoder)
|
| 149 |
+
|
| 150 |
+
_tmp_output = output_path + ".tmp"
|
| 151 |
+
|
| 152 |
+
loaded_pipeline.save_pretrained(_tmp_output)
|
| 153 |
+
convert_to_ckpt(_tmp_output, output_path, as_half=True)
|
| 154 |
+
# remove the tmp_output folder
|
| 155 |
+
shutil.rmtree(_tmp_output)
|
| 156 |
+
|
| 157 |
+
keys = sorted(tok_dict.keys())
|
| 158 |
+
tok_catted = torch.stack([tok_dict[k] for k in keys])
|
| 159 |
+
ret = {
|
| 160 |
+
"string_to_token": {"*": torch.tensor(265)},
|
| 161 |
+
"string_to_param": {"*": tok_catted},
|
| 162 |
+
"name": name,
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
torch.save(ret, output_path[:-5] + ".pt")
|
| 166 |
+
print(
|
| 167 |
+
f"Textual embedding saved as {output_path[:-5]}.pt, put it in the embedding folder and use it as {name} in A1111 repo, "
|
| 168 |
+
)
|
| 169 |
+
elif mode == "ljl":
|
| 170 |
+
print("Using Join mode : alpha will not have an effect here.")
|
| 171 |
+
assert path_1.endswith(".safetensors") and path_2.endswith(
|
| 172 |
+
".safetensors"
|
| 173 |
+
), "Only .safetensors files are supported"
|
| 174 |
+
|
| 175 |
+
safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
|
| 176 |
+
safeloras_2 = safe_open(path_2, framework="pt", device="cpu")
|
| 177 |
+
|
| 178 |
+
total_tensor, total_metadata, _, _ = lora_join([safeloras_1, safeloras_2])
|
| 179 |
+
save_file(total_tensor, output_path, total_metadata)
|
| 180 |
+
|
| 181 |
+
else:
|
| 182 |
+
print("Unknown mode", mode)
|
| 183 |
+
raise ValueError(f"Unknown mode {mode}")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def main():
|
| 187 |
+
fire.Fire(add)
|
lora_diffusion/cli_lora_pti.py
ADDED
|
@@ -0,0 +1,1040 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Bootstrapped from:
|
| 2 |
+
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import hashlib
|
| 6 |
+
import inspect
|
| 7 |
+
import itertools
|
| 8 |
+
import math
|
| 9 |
+
import os
|
| 10 |
+
import random
|
| 11 |
+
import re
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Optional, List, Literal
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import torch.optim as optim
|
| 18 |
+
import torch.utils.checkpoint
|
| 19 |
+
from diffusers import (
|
| 20 |
+
AutoencoderKL,
|
| 21 |
+
DDPMScheduler,
|
| 22 |
+
StableDiffusionPipeline,
|
| 23 |
+
UNet2DConditionModel,
|
| 24 |
+
)
|
| 25 |
+
from diffusers.optimization import get_scheduler
|
| 26 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
| 27 |
+
from PIL import Image
|
| 28 |
+
from torch.utils.data import Dataset
|
| 29 |
+
from torchvision import transforms
|
| 30 |
+
from tqdm.auto import tqdm
|
| 31 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 32 |
+
import wandb
|
| 33 |
+
import fire
|
| 34 |
+
|
| 35 |
+
from lora_diffusion import (
|
| 36 |
+
PivotalTuningDatasetCapation,
|
| 37 |
+
extract_lora_ups_down,
|
| 38 |
+
inject_trainable_lora,
|
| 39 |
+
inject_trainable_lora_extended,
|
| 40 |
+
inspect_lora,
|
| 41 |
+
save_lora_weight,
|
| 42 |
+
save_all,
|
| 43 |
+
prepare_clip_model_sets,
|
| 44 |
+
evaluate_pipe,
|
| 45 |
+
UNET_EXTENDED_TARGET_REPLACE,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_models(
|
| 50 |
+
pretrained_model_name_or_path,
|
| 51 |
+
pretrained_vae_name_or_path,
|
| 52 |
+
revision,
|
| 53 |
+
placeholder_tokens: List[str],
|
| 54 |
+
initializer_tokens: List[str],
|
| 55 |
+
device="cuda:0",
|
| 56 |
+
):
|
| 57 |
+
|
| 58 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
| 59 |
+
pretrained_model_name_or_path,
|
| 60 |
+
subfolder="tokenizer",
|
| 61 |
+
revision=revision,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
| 65 |
+
pretrained_model_name_or_path,
|
| 66 |
+
subfolder="text_encoder",
|
| 67 |
+
revision=revision,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
placeholder_token_ids = []
|
| 71 |
+
|
| 72 |
+
for token, init_tok in zip(placeholder_tokens, initializer_tokens):
|
| 73 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
| 74 |
+
if num_added_tokens == 0:
|
| 75 |
+
raise ValueError(
|
| 76 |
+
f"The tokenizer already contains the token {token}. Please pass a different"
|
| 77 |
+
" `placeholder_token` that is not already in the tokenizer."
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
placeholder_token_id = tokenizer.convert_tokens_to_ids(token)
|
| 81 |
+
|
| 82 |
+
placeholder_token_ids.append(placeholder_token_id)
|
| 83 |
+
|
| 84 |
+
# Load models and create wrapper for stable diffusion
|
| 85 |
+
|
| 86 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
| 87 |
+
token_embeds = text_encoder.get_input_embeddings().weight.data
|
| 88 |
+
if init_tok.startswith("<rand"):
|
| 89 |
+
# <rand-"sigma">, e.g. <rand-0.5>
|
| 90 |
+
sigma_val = float(re.findall(r"<rand-(.*)>", init_tok)[0])
|
| 91 |
+
|
| 92 |
+
token_embeds[placeholder_token_id] = (
|
| 93 |
+
torch.randn_like(token_embeds[0]) * sigma_val
|
| 94 |
+
)
|
| 95 |
+
print(
|
| 96 |
+
f"Initialized {token} with random noise (sigma={sigma_val}), empirically {token_embeds[placeholder_token_id].mean().item():.3f} +- {token_embeds[placeholder_token_id].std().item():.3f}"
|
| 97 |
+
)
|
| 98 |
+
print(f"Norm : {token_embeds[placeholder_token_id].norm():.4f}")
|
| 99 |
+
|
| 100 |
+
elif init_tok == "<zero>":
|
| 101 |
+
token_embeds[placeholder_token_id] = torch.zeros_like(token_embeds[0])
|
| 102 |
+
else:
|
| 103 |
+
token_ids = tokenizer.encode(init_tok, add_special_tokens=False)
|
| 104 |
+
# Check if initializer_token is a single token or a sequence of tokens
|
| 105 |
+
if len(token_ids) > 1:
|
| 106 |
+
raise ValueError("The initializer token must be a single token.")
|
| 107 |
+
|
| 108 |
+
initializer_token_id = token_ids[0]
|
| 109 |
+
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
| 110 |
+
|
| 111 |
+
vae = AutoencoderKL.from_pretrained(
|
| 112 |
+
pretrained_vae_name_or_path or pretrained_model_name_or_path,
|
| 113 |
+
subfolder=None if pretrained_vae_name_or_path else "vae",
|
| 114 |
+
revision=None if pretrained_vae_name_or_path else revision,
|
| 115 |
+
)
|
| 116 |
+
unet = UNet2DConditionModel.from_pretrained(
|
| 117 |
+
pretrained_model_name_or_path,
|
| 118 |
+
subfolder="unet",
|
| 119 |
+
revision=revision,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
return (
|
| 123 |
+
text_encoder.to(device),
|
| 124 |
+
vae.to(device),
|
| 125 |
+
unet.to(device),
|
| 126 |
+
tokenizer,
|
| 127 |
+
placeholder_token_ids,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@torch.no_grad()
|
| 132 |
+
def text2img_dataloader(
|
| 133 |
+
train_dataset,
|
| 134 |
+
train_batch_size,
|
| 135 |
+
tokenizer,
|
| 136 |
+
vae,
|
| 137 |
+
text_encoder,
|
| 138 |
+
cached_latents: bool = False,
|
| 139 |
+
):
|
| 140 |
+
|
| 141 |
+
if cached_latents:
|
| 142 |
+
cached_latents_dataset = []
|
| 143 |
+
for idx in tqdm(range(len(train_dataset))):
|
| 144 |
+
batch = train_dataset[idx]
|
| 145 |
+
# rint(batch)
|
| 146 |
+
latents = vae.encode(
|
| 147 |
+
batch["instance_images"].unsqueeze(0).to(dtype=vae.dtype).to(vae.device)
|
| 148 |
+
).latent_dist.sample()
|
| 149 |
+
latents = latents * 0.18215
|
| 150 |
+
batch["instance_images"] = latents.squeeze(0)
|
| 151 |
+
cached_latents_dataset.append(batch)
|
| 152 |
+
|
| 153 |
+
def collate_fn(examples):
|
| 154 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
| 155 |
+
pixel_values = [example["instance_images"] for example in examples]
|
| 156 |
+
pixel_values = torch.stack(pixel_values)
|
| 157 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 158 |
+
|
| 159 |
+
input_ids = tokenizer.pad(
|
| 160 |
+
{"input_ids": input_ids},
|
| 161 |
+
padding="max_length",
|
| 162 |
+
max_length=tokenizer.model_max_length,
|
| 163 |
+
return_tensors="pt",
|
| 164 |
+
).input_ids
|
| 165 |
+
|
| 166 |
+
batch = {
|
| 167 |
+
"input_ids": input_ids,
|
| 168 |
+
"pixel_values": pixel_values,
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
if examples[0].get("mask", None) is not None:
|
| 172 |
+
batch["mask"] = torch.stack([example["mask"] for example in examples])
|
| 173 |
+
|
| 174 |
+
return batch
|
| 175 |
+
|
| 176 |
+
if cached_latents:
|
| 177 |
+
|
| 178 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 179 |
+
cached_latents_dataset,
|
| 180 |
+
batch_size=train_batch_size,
|
| 181 |
+
shuffle=True,
|
| 182 |
+
collate_fn=collate_fn,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
print("PTI : Using cached latent.")
|
| 186 |
+
|
| 187 |
+
else:
|
| 188 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 189 |
+
train_dataset,
|
| 190 |
+
batch_size=train_batch_size,
|
| 191 |
+
shuffle=True,
|
| 192 |
+
collate_fn=collate_fn,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
return train_dataloader
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def inpainting_dataloader(
|
| 199 |
+
train_dataset, train_batch_size, tokenizer, vae, text_encoder
|
| 200 |
+
):
|
| 201 |
+
def collate_fn(examples):
|
| 202 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
| 203 |
+
pixel_values = [example["instance_images"] for example in examples]
|
| 204 |
+
mask_values = [example["instance_masks"] for example in examples]
|
| 205 |
+
masked_image_values = [
|
| 206 |
+
example["instance_masked_images"] for example in examples
|
| 207 |
+
]
|
| 208 |
+
|
| 209 |
+
# Concat class and instance examples for prior preservation.
|
| 210 |
+
# We do this to avoid doing two forward passes.
|
| 211 |
+
if examples[0].get("class_prompt_ids", None) is not None:
|
| 212 |
+
input_ids += [example["class_prompt_ids"] for example in examples]
|
| 213 |
+
pixel_values += [example["class_images"] for example in examples]
|
| 214 |
+
mask_values += [example["class_masks"] for example in examples]
|
| 215 |
+
masked_image_values += [
|
| 216 |
+
example["class_masked_images"] for example in examples
|
| 217 |
+
]
|
| 218 |
+
|
| 219 |
+
pixel_values = (
|
| 220 |
+
torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
|
| 221 |
+
)
|
| 222 |
+
mask_values = (
|
| 223 |
+
torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
|
| 224 |
+
)
|
| 225 |
+
masked_image_values = (
|
| 226 |
+
torch.stack(masked_image_values)
|
| 227 |
+
.to(memory_format=torch.contiguous_format)
|
| 228 |
+
.float()
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
input_ids = tokenizer.pad(
|
| 232 |
+
{"input_ids": input_ids},
|
| 233 |
+
padding="max_length",
|
| 234 |
+
max_length=tokenizer.model_max_length,
|
| 235 |
+
return_tensors="pt",
|
| 236 |
+
).input_ids
|
| 237 |
+
|
| 238 |
+
batch = {
|
| 239 |
+
"input_ids": input_ids,
|
| 240 |
+
"pixel_values": pixel_values,
|
| 241 |
+
"mask_values": mask_values,
|
| 242 |
+
"masked_image_values": masked_image_values,
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
if examples[0].get("mask", None) is not None:
|
| 246 |
+
batch["mask"] = torch.stack([example["mask"] for example in examples])
|
| 247 |
+
|
| 248 |
+
return batch
|
| 249 |
+
|
| 250 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 251 |
+
train_dataset,
|
| 252 |
+
batch_size=train_batch_size,
|
| 253 |
+
shuffle=True,
|
| 254 |
+
collate_fn=collate_fn,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
return train_dataloader
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def loss_step(
|
| 261 |
+
batch,
|
| 262 |
+
unet,
|
| 263 |
+
vae,
|
| 264 |
+
text_encoder,
|
| 265 |
+
scheduler,
|
| 266 |
+
train_inpainting=False,
|
| 267 |
+
t_mutliplier=1.0,
|
| 268 |
+
mixed_precision=False,
|
| 269 |
+
mask_temperature=1.0,
|
| 270 |
+
cached_latents: bool = False,
|
| 271 |
+
):
|
| 272 |
+
weight_dtype = torch.float32
|
| 273 |
+
if not cached_latents:
|
| 274 |
+
latents = vae.encode(
|
| 275 |
+
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
|
| 276 |
+
).latent_dist.sample()
|
| 277 |
+
latents = latents * 0.18215
|
| 278 |
+
|
| 279 |
+
if train_inpainting:
|
| 280 |
+
masked_image_latents = vae.encode(
|
| 281 |
+
batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
|
| 282 |
+
).latent_dist.sample()
|
| 283 |
+
masked_image_latents = masked_image_latents * 0.18215
|
| 284 |
+
mask = F.interpolate(
|
| 285 |
+
batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
|
| 286 |
+
scale_factor=1 / 8,
|
| 287 |
+
)
|
| 288 |
+
else:
|
| 289 |
+
latents = batch["pixel_values"]
|
| 290 |
+
|
| 291 |
+
if train_inpainting:
|
| 292 |
+
masked_image_latents = batch["masked_image_latents"]
|
| 293 |
+
mask = batch["mask_values"]
|
| 294 |
+
|
| 295 |
+
noise = torch.randn_like(latents)
|
| 296 |
+
bsz = latents.shape[0]
|
| 297 |
+
|
| 298 |
+
timesteps = torch.randint(
|
| 299 |
+
0,
|
| 300 |
+
int(scheduler.config.num_train_timesteps * t_mutliplier),
|
| 301 |
+
(bsz,),
|
| 302 |
+
device=latents.device,
|
| 303 |
+
)
|
| 304 |
+
timesteps = timesteps.long()
|
| 305 |
+
|
| 306 |
+
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
| 307 |
+
|
| 308 |
+
if train_inpainting:
|
| 309 |
+
latent_model_input = torch.cat(
|
| 310 |
+
[noisy_latents, mask, masked_image_latents], dim=1
|
| 311 |
+
)
|
| 312 |
+
else:
|
| 313 |
+
latent_model_input = noisy_latents
|
| 314 |
+
|
| 315 |
+
if mixed_precision:
|
| 316 |
+
with torch.cuda.amp.autocast():
|
| 317 |
+
|
| 318 |
+
encoder_hidden_states = text_encoder(
|
| 319 |
+
batch["input_ids"].to(text_encoder.device)
|
| 320 |
+
)[0]
|
| 321 |
+
|
| 322 |
+
model_pred = unet(
|
| 323 |
+
latent_model_input, timesteps, encoder_hidden_states
|
| 324 |
+
).sample
|
| 325 |
+
else:
|
| 326 |
+
|
| 327 |
+
encoder_hidden_states = text_encoder(
|
| 328 |
+
batch["input_ids"].to(text_encoder.device)
|
| 329 |
+
)[0]
|
| 330 |
+
|
| 331 |
+
model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
|
| 332 |
+
|
| 333 |
+
if scheduler.config.prediction_type == "epsilon":
|
| 334 |
+
target = noise
|
| 335 |
+
elif scheduler.config.prediction_type == "v_prediction":
|
| 336 |
+
target = scheduler.get_velocity(latents, noise, timesteps)
|
| 337 |
+
else:
|
| 338 |
+
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
|
| 339 |
+
|
| 340 |
+
if batch.get("mask", None) is not None:
|
| 341 |
+
|
| 342 |
+
mask = (
|
| 343 |
+
batch["mask"]
|
| 344 |
+
.to(model_pred.device)
|
| 345 |
+
.reshape(
|
| 346 |
+
model_pred.shape[0], 1, model_pred.shape[2] * 8, model_pred.shape[3] * 8
|
| 347 |
+
)
|
| 348 |
+
)
|
| 349 |
+
# resize to match model_pred
|
| 350 |
+
mask = F.interpolate(
|
| 351 |
+
mask.float(),
|
| 352 |
+
size=model_pred.shape[-2:],
|
| 353 |
+
mode="nearest",
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
mask = (mask + 0.01).pow(mask_temperature)
|
| 357 |
+
|
| 358 |
+
mask = mask / mask.max()
|
| 359 |
+
|
| 360 |
+
model_pred = model_pred * mask
|
| 361 |
+
|
| 362 |
+
target = target * mask
|
| 363 |
+
|
| 364 |
+
loss = (
|
| 365 |
+
F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
| 366 |
+
.mean([1, 2, 3])
|
| 367 |
+
.mean()
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
return loss
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def train_inversion(
|
| 374 |
+
unet,
|
| 375 |
+
vae,
|
| 376 |
+
text_encoder,
|
| 377 |
+
dataloader,
|
| 378 |
+
num_steps: int,
|
| 379 |
+
scheduler,
|
| 380 |
+
index_no_updates,
|
| 381 |
+
optimizer,
|
| 382 |
+
save_steps: int,
|
| 383 |
+
placeholder_token_ids,
|
| 384 |
+
placeholder_tokens,
|
| 385 |
+
save_path: str,
|
| 386 |
+
tokenizer,
|
| 387 |
+
lr_scheduler,
|
| 388 |
+
test_image_path: str,
|
| 389 |
+
cached_latents: bool,
|
| 390 |
+
accum_iter: int = 1,
|
| 391 |
+
log_wandb: bool = False,
|
| 392 |
+
wandb_log_prompt_cnt: int = 10,
|
| 393 |
+
class_token: str = "person",
|
| 394 |
+
train_inpainting: bool = False,
|
| 395 |
+
mixed_precision: bool = False,
|
| 396 |
+
clip_ti_decay: bool = True,
|
| 397 |
+
):
|
| 398 |
+
|
| 399 |
+
progress_bar = tqdm(range(num_steps))
|
| 400 |
+
progress_bar.set_description("Steps")
|
| 401 |
+
global_step = 0
|
| 402 |
+
|
| 403 |
+
# Original Emb for TI
|
| 404 |
+
orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone()
|
| 405 |
+
|
| 406 |
+
if log_wandb:
|
| 407 |
+
preped_clip = prepare_clip_model_sets()
|
| 408 |
+
|
| 409 |
+
index_updates = ~index_no_updates
|
| 410 |
+
loss_sum = 0.0
|
| 411 |
+
|
| 412 |
+
for epoch in range(math.ceil(num_steps / len(dataloader))):
|
| 413 |
+
unet.eval()
|
| 414 |
+
text_encoder.train()
|
| 415 |
+
for batch in dataloader:
|
| 416 |
+
|
| 417 |
+
lr_scheduler.step()
|
| 418 |
+
|
| 419 |
+
with torch.set_grad_enabled(True):
|
| 420 |
+
loss = (
|
| 421 |
+
loss_step(
|
| 422 |
+
batch,
|
| 423 |
+
unet,
|
| 424 |
+
vae,
|
| 425 |
+
text_encoder,
|
| 426 |
+
scheduler,
|
| 427 |
+
train_inpainting=train_inpainting,
|
| 428 |
+
mixed_precision=mixed_precision,
|
| 429 |
+
cached_latents=cached_latents,
|
| 430 |
+
)
|
| 431 |
+
/ accum_iter
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
loss.backward()
|
| 435 |
+
loss_sum += loss.detach().item()
|
| 436 |
+
|
| 437 |
+
if global_step % accum_iter == 0:
|
| 438 |
+
# print gradient of text encoder embedding
|
| 439 |
+
print(
|
| 440 |
+
text_encoder.get_input_embeddings()
|
| 441 |
+
.weight.grad[index_updates, :]
|
| 442 |
+
.norm(dim=-1)
|
| 443 |
+
.mean()
|
| 444 |
+
)
|
| 445 |
+
optimizer.step()
|
| 446 |
+
optimizer.zero_grad()
|
| 447 |
+
|
| 448 |
+
with torch.no_grad():
|
| 449 |
+
|
| 450 |
+
# normalize embeddings
|
| 451 |
+
if clip_ti_decay:
|
| 452 |
+
pre_norm = (
|
| 453 |
+
text_encoder.get_input_embeddings()
|
| 454 |
+
.weight[index_updates, :]
|
| 455 |
+
.norm(dim=-1, keepdim=True)
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0])
|
| 459 |
+
text_encoder.get_input_embeddings().weight[
|
| 460 |
+
index_updates
|
| 461 |
+
] = F.normalize(
|
| 462 |
+
text_encoder.get_input_embeddings().weight[
|
| 463 |
+
index_updates, :
|
| 464 |
+
],
|
| 465 |
+
dim=-1,
|
| 466 |
+
) * (
|
| 467 |
+
pre_norm + lambda_ * (0.4 - pre_norm)
|
| 468 |
+
)
|
| 469 |
+
print(pre_norm)
|
| 470 |
+
|
| 471 |
+
current_norm = (
|
| 472 |
+
text_encoder.get_input_embeddings()
|
| 473 |
+
.weight[index_updates, :]
|
| 474 |
+
.norm(dim=-1)
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
text_encoder.get_input_embeddings().weight[
|
| 478 |
+
index_no_updates
|
| 479 |
+
] = orig_embeds_params[index_no_updates]
|
| 480 |
+
|
| 481 |
+
print(f"Current Norm : {current_norm}")
|
| 482 |
+
|
| 483 |
+
global_step += 1
|
| 484 |
+
progress_bar.update(1)
|
| 485 |
+
|
| 486 |
+
logs = {
|
| 487 |
+
"loss": loss.detach().item(),
|
| 488 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 489 |
+
}
|
| 490 |
+
progress_bar.set_postfix(**logs)
|
| 491 |
+
|
| 492 |
+
if global_step % save_steps == 0:
|
| 493 |
+
save_all(
|
| 494 |
+
unet=unet,
|
| 495 |
+
text_encoder=text_encoder,
|
| 496 |
+
placeholder_token_ids=placeholder_token_ids,
|
| 497 |
+
placeholder_tokens=placeholder_tokens,
|
| 498 |
+
save_path=os.path.join(
|
| 499 |
+
save_path, f"step_inv_{global_step}.safetensors"
|
| 500 |
+
),
|
| 501 |
+
save_lora=False,
|
| 502 |
+
)
|
| 503 |
+
if log_wandb:
|
| 504 |
+
with torch.no_grad():
|
| 505 |
+
pipe = StableDiffusionPipeline(
|
| 506 |
+
vae=vae,
|
| 507 |
+
text_encoder=text_encoder,
|
| 508 |
+
tokenizer=tokenizer,
|
| 509 |
+
unet=unet,
|
| 510 |
+
scheduler=scheduler,
|
| 511 |
+
safety_checker=None,
|
| 512 |
+
feature_extractor=None,
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
# open all images in test_image_path
|
| 516 |
+
images = []
|
| 517 |
+
for file in os.listdir(test_image_path):
|
| 518 |
+
if (
|
| 519 |
+
file.lower().endswith(".png")
|
| 520 |
+
or file.lower().endswith(".jpg")
|
| 521 |
+
or file.lower().endswith(".jpeg")
|
| 522 |
+
):
|
| 523 |
+
images.append(
|
| 524 |
+
Image.open(os.path.join(test_image_path, file))
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
wandb.log({"loss": loss_sum / save_steps})
|
| 528 |
+
loss_sum = 0.0
|
| 529 |
+
wandb.log(
|
| 530 |
+
evaluate_pipe(
|
| 531 |
+
pipe,
|
| 532 |
+
target_images=images,
|
| 533 |
+
class_token=class_token,
|
| 534 |
+
learnt_token="".join(placeholder_tokens),
|
| 535 |
+
n_test=wandb_log_prompt_cnt,
|
| 536 |
+
n_step=50,
|
| 537 |
+
clip_model_sets=preped_clip,
|
| 538 |
+
)
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
if global_step >= num_steps:
|
| 542 |
+
return
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def perform_tuning(
|
| 546 |
+
unet,
|
| 547 |
+
vae,
|
| 548 |
+
text_encoder,
|
| 549 |
+
dataloader,
|
| 550 |
+
num_steps,
|
| 551 |
+
scheduler,
|
| 552 |
+
optimizer,
|
| 553 |
+
save_steps: int,
|
| 554 |
+
placeholder_token_ids,
|
| 555 |
+
placeholder_tokens,
|
| 556 |
+
save_path,
|
| 557 |
+
lr_scheduler_lora,
|
| 558 |
+
lora_unet_target_modules,
|
| 559 |
+
lora_clip_target_modules,
|
| 560 |
+
mask_temperature,
|
| 561 |
+
out_name: str,
|
| 562 |
+
tokenizer,
|
| 563 |
+
test_image_path: str,
|
| 564 |
+
cached_latents: bool,
|
| 565 |
+
log_wandb: bool = False,
|
| 566 |
+
wandb_log_prompt_cnt: int = 10,
|
| 567 |
+
class_token: str = "person",
|
| 568 |
+
train_inpainting: bool = False,
|
| 569 |
+
):
|
| 570 |
+
|
| 571 |
+
progress_bar = tqdm(range(num_steps))
|
| 572 |
+
progress_bar.set_description("Steps")
|
| 573 |
+
global_step = 0
|
| 574 |
+
|
| 575 |
+
weight_dtype = torch.float16
|
| 576 |
+
|
| 577 |
+
unet.train()
|
| 578 |
+
text_encoder.train()
|
| 579 |
+
|
| 580 |
+
if log_wandb:
|
| 581 |
+
preped_clip = prepare_clip_model_sets()
|
| 582 |
+
|
| 583 |
+
loss_sum = 0.0
|
| 584 |
+
|
| 585 |
+
for epoch in range(math.ceil(num_steps / len(dataloader))):
|
| 586 |
+
for batch in dataloader:
|
| 587 |
+
lr_scheduler_lora.step()
|
| 588 |
+
|
| 589 |
+
optimizer.zero_grad()
|
| 590 |
+
|
| 591 |
+
loss = loss_step(
|
| 592 |
+
batch,
|
| 593 |
+
unet,
|
| 594 |
+
vae,
|
| 595 |
+
text_encoder,
|
| 596 |
+
scheduler,
|
| 597 |
+
train_inpainting=train_inpainting,
|
| 598 |
+
t_mutliplier=0.8,
|
| 599 |
+
mixed_precision=True,
|
| 600 |
+
mask_temperature=mask_temperature,
|
| 601 |
+
cached_latents=cached_latents,
|
| 602 |
+
)
|
| 603 |
+
loss_sum += loss.detach().item()
|
| 604 |
+
|
| 605 |
+
loss.backward()
|
| 606 |
+
torch.nn.utils.clip_grad_norm_(
|
| 607 |
+
itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0
|
| 608 |
+
)
|
| 609 |
+
optimizer.step()
|
| 610 |
+
progress_bar.update(1)
|
| 611 |
+
logs = {
|
| 612 |
+
"loss": loss.detach().item(),
|
| 613 |
+
"lr": lr_scheduler_lora.get_last_lr()[0],
|
| 614 |
+
}
|
| 615 |
+
progress_bar.set_postfix(**logs)
|
| 616 |
+
|
| 617 |
+
global_step += 1
|
| 618 |
+
|
| 619 |
+
if global_step % save_steps == 0:
|
| 620 |
+
save_all(
|
| 621 |
+
unet,
|
| 622 |
+
text_encoder,
|
| 623 |
+
placeholder_token_ids=placeholder_token_ids,
|
| 624 |
+
placeholder_tokens=placeholder_tokens,
|
| 625 |
+
save_path=os.path.join(
|
| 626 |
+
save_path, f"step_{global_step}.safetensors"
|
| 627 |
+
),
|
| 628 |
+
target_replace_module_text=lora_clip_target_modules,
|
| 629 |
+
target_replace_module_unet=lora_unet_target_modules,
|
| 630 |
+
)
|
| 631 |
+
moved = (
|
| 632 |
+
torch.tensor(list(itertools.chain(*inspect_lora(unet).values())))
|
| 633 |
+
.mean()
|
| 634 |
+
.item()
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
print("LORA Unet Moved", moved)
|
| 638 |
+
moved = (
|
| 639 |
+
torch.tensor(
|
| 640 |
+
list(itertools.chain(*inspect_lora(text_encoder).values()))
|
| 641 |
+
)
|
| 642 |
+
.mean()
|
| 643 |
+
.item()
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
print("LORA CLIP Moved", moved)
|
| 647 |
+
|
| 648 |
+
if log_wandb:
|
| 649 |
+
with torch.no_grad():
|
| 650 |
+
pipe = StableDiffusionPipeline(
|
| 651 |
+
vae=vae,
|
| 652 |
+
text_encoder=text_encoder,
|
| 653 |
+
tokenizer=tokenizer,
|
| 654 |
+
unet=unet,
|
| 655 |
+
scheduler=scheduler,
|
| 656 |
+
safety_checker=None,
|
| 657 |
+
feature_extractor=None,
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
# open all images in test_image_path
|
| 661 |
+
images = []
|
| 662 |
+
for file in os.listdir(test_image_path):
|
| 663 |
+
if file.endswith(".png") or file.endswith(".jpg"):
|
| 664 |
+
images.append(
|
| 665 |
+
Image.open(os.path.join(test_image_path, file))
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
wandb.log({"loss": loss_sum / save_steps})
|
| 669 |
+
loss_sum = 0.0
|
| 670 |
+
wandb.log(
|
| 671 |
+
evaluate_pipe(
|
| 672 |
+
pipe,
|
| 673 |
+
target_images=images,
|
| 674 |
+
class_token=class_token,
|
| 675 |
+
learnt_token="".join(placeholder_tokens),
|
| 676 |
+
n_test=wandb_log_prompt_cnt,
|
| 677 |
+
n_step=50,
|
| 678 |
+
clip_model_sets=preped_clip,
|
| 679 |
+
)
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
if global_step >= num_steps:
|
| 683 |
+
break
|
| 684 |
+
|
| 685 |
+
save_all(
|
| 686 |
+
unet,
|
| 687 |
+
text_encoder,
|
| 688 |
+
placeholder_token_ids=placeholder_token_ids,
|
| 689 |
+
placeholder_tokens=placeholder_tokens,
|
| 690 |
+
save_path=os.path.join(save_path, f"{out_name}.safetensors"),
|
| 691 |
+
target_replace_module_text=lora_clip_target_modules,
|
| 692 |
+
target_replace_module_unet=lora_unet_target_modules,
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def train(
|
| 697 |
+
instance_data_dir: str,
|
| 698 |
+
pretrained_model_name_or_path: str,
|
| 699 |
+
output_dir: str,
|
| 700 |
+
train_text_encoder: bool = True,
|
| 701 |
+
pretrained_vae_name_or_path: str = None,
|
| 702 |
+
revision: Optional[str] = None,
|
| 703 |
+
perform_inversion: bool = True,
|
| 704 |
+
use_template: Literal[None, "object", "style"] = None,
|
| 705 |
+
train_inpainting: bool = False,
|
| 706 |
+
placeholder_tokens: str = "",
|
| 707 |
+
placeholder_token_at_data: Optional[str] = None,
|
| 708 |
+
initializer_tokens: Optional[str] = None,
|
| 709 |
+
seed: int = 42,
|
| 710 |
+
resolution: int = 512,
|
| 711 |
+
color_jitter: bool = True,
|
| 712 |
+
train_batch_size: int = 1,
|
| 713 |
+
sample_batch_size: int = 1,
|
| 714 |
+
max_train_steps_tuning: int = 1000,
|
| 715 |
+
max_train_steps_ti: int = 1000,
|
| 716 |
+
save_steps: int = 100,
|
| 717 |
+
gradient_accumulation_steps: int = 4,
|
| 718 |
+
gradient_checkpointing: bool = False,
|
| 719 |
+
lora_rank: int = 4,
|
| 720 |
+
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
|
| 721 |
+
lora_clip_target_modules={"CLIPAttention"},
|
| 722 |
+
lora_dropout_p: float = 0.0,
|
| 723 |
+
lora_scale: float = 1.0,
|
| 724 |
+
use_extended_lora: bool = False,
|
| 725 |
+
clip_ti_decay: bool = True,
|
| 726 |
+
learning_rate_unet: float = 1e-4,
|
| 727 |
+
learning_rate_text: float = 1e-5,
|
| 728 |
+
learning_rate_ti: float = 5e-4,
|
| 729 |
+
continue_inversion: bool = False,
|
| 730 |
+
continue_inversion_lr: Optional[float] = None,
|
| 731 |
+
use_face_segmentation_condition: bool = False,
|
| 732 |
+
cached_latents: bool = True,
|
| 733 |
+
use_mask_captioned_data: bool = False,
|
| 734 |
+
mask_temperature: float = 1.0,
|
| 735 |
+
scale_lr: bool = False,
|
| 736 |
+
lr_scheduler: str = "linear",
|
| 737 |
+
lr_warmup_steps: int = 0,
|
| 738 |
+
lr_scheduler_lora: str = "linear",
|
| 739 |
+
lr_warmup_steps_lora: int = 0,
|
| 740 |
+
weight_decay_ti: float = 0.00,
|
| 741 |
+
weight_decay_lora: float = 0.001,
|
| 742 |
+
use_8bit_adam: bool = False,
|
| 743 |
+
device="cuda:0",
|
| 744 |
+
extra_args: Optional[dict] = None,
|
| 745 |
+
log_wandb: bool = False,
|
| 746 |
+
wandb_log_prompt_cnt: int = 10,
|
| 747 |
+
wandb_project_name: str = "new_pti_project",
|
| 748 |
+
wandb_entity: str = "new_pti_entity",
|
| 749 |
+
proxy_token: str = "person",
|
| 750 |
+
enable_xformers_memory_efficient_attention: bool = False,
|
| 751 |
+
out_name: str = "final_lora",
|
| 752 |
+
):
|
| 753 |
+
torch.manual_seed(seed)
|
| 754 |
+
|
| 755 |
+
if log_wandb:
|
| 756 |
+
wandb.init(
|
| 757 |
+
project=wandb_project_name,
|
| 758 |
+
entity=wandb_entity,
|
| 759 |
+
name=f"steps_{max_train_steps_ti}_lr_{learning_rate_ti}_{instance_data_dir.split('/')[-1]}",
|
| 760 |
+
reinit=True,
|
| 761 |
+
config={
|
| 762 |
+
**(extra_args if extra_args is not None else {}),
|
| 763 |
+
},
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
if output_dir is not None:
|
| 767 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 768 |
+
# print(placeholder_tokens, initializer_tokens)
|
| 769 |
+
if len(placeholder_tokens) == 0:
|
| 770 |
+
placeholder_tokens = []
|
| 771 |
+
print("PTI : Placeholder Tokens not given, using null token")
|
| 772 |
+
else:
|
| 773 |
+
placeholder_tokens = placeholder_tokens.split("|")
|
| 774 |
+
|
| 775 |
+
assert (
|
| 776 |
+
sorted(placeholder_tokens) == placeholder_tokens
|
| 777 |
+
), f"Placeholder tokens should be sorted. Use something like {'|'.join(sorted(placeholder_tokens))}'"
|
| 778 |
+
|
| 779 |
+
if initializer_tokens is None:
|
| 780 |
+
print("PTI : Initializer Tokens not given, doing random inits")
|
| 781 |
+
initializer_tokens = ["<rand-0.017>"] * len(placeholder_tokens)
|
| 782 |
+
else:
|
| 783 |
+
initializer_tokens = initializer_tokens.split("|")
|
| 784 |
+
|
| 785 |
+
assert len(initializer_tokens) == len(
|
| 786 |
+
placeholder_tokens
|
| 787 |
+
), "Unequal Initializer token for Placeholder tokens."
|
| 788 |
+
|
| 789 |
+
if proxy_token is not None:
|
| 790 |
+
class_token = proxy_token
|
| 791 |
+
class_token = "".join(initializer_tokens)
|
| 792 |
+
|
| 793 |
+
if placeholder_token_at_data is not None:
|
| 794 |
+
tok, pat = placeholder_token_at_data.split("|")
|
| 795 |
+
token_map = {tok: pat}
|
| 796 |
+
|
| 797 |
+
else:
|
| 798 |
+
token_map = {"DUMMY": "".join(placeholder_tokens)}
|
| 799 |
+
|
| 800 |
+
print("PTI : Placeholder Tokens", placeholder_tokens)
|
| 801 |
+
print("PTI : Initializer Tokens", initializer_tokens)
|
| 802 |
+
|
| 803 |
+
# get the models
|
| 804 |
+
text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models(
|
| 805 |
+
pretrained_model_name_or_path,
|
| 806 |
+
pretrained_vae_name_or_path,
|
| 807 |
+
revision,
|
| 808 |
+
placeholder_tokens,
|
| 809 |
+
initializer_tokens,
|
| 810 |
+
device=device,
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
noise_scheduler = DDPMScheduler.from_config(
|
| 814 |
+
pretrained_model_name_or_path, subfolder="scheduler"
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
if gradient_checkpointing:
|
| 818 |
+
unet.enable_gradient_checkpointing()
|
| 819 |
+
|
| 820 |
+
if enable_xformers_memory_efficient_attention:
|
| 821 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 822 |
+
|
| 823 |
+
if is_xformers_available():
|
| 824 |
+
unet.enable_xformers_memory_efficient_attention()
|
| 825 |
+
else:
|
| 826 |
+
raise ValueError(
|
| 827 |
+
"xformers is not available. Make sure it is installed correctly"
|
| 828 |
+
)
|
| 829 |
+
|
| 830 |
+
if scale_lr:
|
| 831 |
+
unet_lr = learning_rate_unet * gradient_accumulation_steps * train_batch_size
|
| 832 |
+
text_encoder_lr = (
|
| 833 |
+
learning_rate_text * gradient_accumulation_steps * train_batch_size
|
| 834 |
+
)
|
| 835 |
+
ti_lr = learning_rate_ti * gradient_accumulation_steps * train_batch_size
|
| 836 |
+
else:
|
| 837 |
+
unet_lr = learning_rate_unet
|
| 838 |
+
text_encoder_lr = learning_rate_text
|
| 839 |
+
ti_lr = learning_rate_ti
|
| 840 |
+
|
| 841 |
+
train_dataset = PivotalTuningDatasetCapation(
|
| 842 |
+
instance_data_root=instance_data_dir,
|
| 843 |
+
token_map=token_map,
|
| 844 |
+
use_template=use_template,
|
| 845 |
+
tokenizer=tokenizer,
|
| 846 |
+
size=resolution,
|
| 847 |
+
color_jitter=color_jitter,
|
| 848 |
+
use_face_segmentation_condition=use_face_segmentation_condition,
|
| 849 |
+
use_mask_captioned_data=use_mask_captioned_data,
|
| 850 |
+
train_inpainting=train_inpainting,
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
train_dataset.blur_amount = 200
|
| 854 |
+
|
| 855 |
+
if train_inpainting:
|
| 856 |
+
assert not cached_latents, "Cached latents not supported for inpainting"
|
| 857 |
+
|
| 858 |
+
train_dataloader = inpainting_dataloader(
|
| 859 |
+
train_dataset, train_batch_size, tokenizer, vae, text_encoder
|
| 860 |
+
)
|
| 861 |
+
else:
|
| 862 |
+
train_dataloader = text2img_dataloader(
|
| 863 |
+
train_dataset,
|
| 864 |
+
train_batch_size,
|
| 865 |
+
tokenizer,
|
| 866 |
+
vae,
|
| 867 |
+
text_encoder,
|
| 868 |
+
cached_latents=cached_latents,
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
index_no_updates = torch.arange(len(tokenizer)) != -1
|
| 872 |
+
|
| 873 |
+
for tok_id in placeholder_token_ids:
|
| 874 |
+
index_no_updates[tok_id] = False
|
| 875 |
+
|
| 876 |
+
unet.requires_grad_(False)
|
| 877 |
+
vae.requires_grad_(False)
|
| 878 |
+
|
| 879 |
+
params_to_freeze = itertools.chain(
|
| 880 |
+
text_encoder.text_model.encoder.parameters(),
|
| 881 |
+
text_encoder.text_model.final_layer_norm.parameters(),
|
| 882 |
+
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
| 883 |
+
)
|
| 884 |
+
for param in params_to_freeze:
|
| 885 |
+
param.requires_grad = False
|
| 886 |
+
|
| 887 |
+
if cached_latents:
|
| 888 |
+
vae = None
|
| 889 |
+
# STEP 1 : Perform Inversion
|
| 890 |
+
if perform_inversion:
|
| 891 |
+
ti_optimizer = optim.AdamW(
|
| 892 |
+
text_encoder.get_input_embeddings().parameters(),
|
| 893 |
+
lr=ti_lr,
|
| 894 |
+
betas=(0.9, 0.999),
|
| 895 |
+
eps=1e-08,
|
| 896 |
+
weight_decay=weight_decay_ti,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
lr_scheduler = get_scheduler(
|
| 900 |
+
lr_scheduler,
|
| 901 |
+
optimizer=ti_optimizer,
|
| 902 |
+
num_warmup_steps=lr_warmup_steps,
|
| 903 |
+
num_training_steps=max_train_steps_ti,
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
train_inversion(
|
| 907 |
+
unet,
|
| 908 |
+
vae,
|
| 909 |
+
text_encoder,
|
| 910 |
+
train_dataloader,
|
| 911 |
+
max_train_steps_ti,
|
| 912 |
+
cached_latents=cached_latents,
|
| 913 |
+
accum_iter=gradient_accumulation_steps,
|
| 914 |
+
scheduler=noise_scheduler,
|
| 915 |
+
index_no_updates=index_no_updates,
|
| 916 |
+
optimizer=ti_optimizer,
|
| 917 |
+
lr_scheduler=lr_scheduler,
|
| 918 |
+
save_steps=save_steps,
|
| 919 |
+
placeholder_tokens=placeholder_tokens,
|
| 920 |
+
placeholder_token_ids=placeholder_token_ids,
|
| 921 |
+
save_path=output_dir,
|
| 922 |
+
test_image_path=instance_data_dir,
|
| 923 |
+
log_wandb=log_wandb,
|
| 924 |
+
wandb_log_prompt_cnt=wandb_log_prompt_cnt,
|
| 925 |
+
class_token=class_token,
|
| 926 |
+
train_inpainting=train_inpainting,
|
| 927 |
+
mixed_precision=False,
|
| 928 |
+
tokenizer=tokenizer,
|
| 929 |
+
clip_ti_decay=clip_ti_decay,
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
del ti_optimizer
|
| 933 |
+
|
| 934 |
+
# Next perform Tuning with LoRA:
|
| 935 |
+
if not use_extended_lora:
|
| 936 |
+
unet_lora_params, _ = inject_trainable_lora(
|
| 937 |
+
unet,
|
| 938 |
+
r=lora_rank,
|
| 939 |
+
target_replace_module=lora_unet_target_modules,
|
| 940 |
+
dropout_p=lora_dropout_p,
|
| 941 |
+
scale=lora_scale,
|
| 942 |
+
)
|
| 943 |
+
else:
|
| 944 |
+
print("PTI : USING EXTENDED UNET!!!")
|
| 945 |
+
lora_unet_target_modules = (
|
| 946 |
+
lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE
|
| 947 |
+
)
|
| 948 |
+
print("PTI : Will replace modules: ", lora_unet_target_modules)
|
| 949 |
+
|
| 950 |
+
unet_lora_params, _ = inject_trainable_lora_extended(
|
| 951 |
+
unet, r=lora_rank, target_replace_module=lora_unet_target_modules
|
| 952 |
+
)
|
| 953 |
+
print(f"PTI : has {len(unet_lora_params)} lora")
|
| 954 |
+
|
| 955 |
+
print("PTI : Before training:")
|
| 956 |
+
inspect_lora(unet)
|
| 957 |
+
|
| 958 |
+
params_to_optimize = [
|
| 959 |
+
{"params": itertools.chain(*unet_lora_params), "lr": unet_lr},
|
| 960 |
+
]
|
| 961 |
+
|
| 962 |
+
text_encoder.requires_grad_(False)
|
| 963 |
+
|
| 964 |
+
if continue_inversion:
|
| 965 |
+
params_to_optimize += [
|
| 966 |
+
{
|
| 967 |
+
"params": text_encoder.get_input_embeddings().parameters(),
|
| 968 |
+
"lr": continue_inversion_lr
|
| 969 |
+
if continue_inversion_lr is not None
|
| 970 |
+
else ti_lr,
|
| 971 |
+
}
|
| 972 |
+
]
|
| 973 |
+
text_encoder.requires_grad_(True)
|
| 974 |
+
params_to_freeze = itertools.chain(
|
| 975 |
+
text_encoder.text_model.encoder.parameters(),
|
| 976 |
+
text_encoder.text_model.final_layer_norm.parameters(),
|
| 977 |
+
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
| 978 |
+
)
|
| 979 |
+
for param in params_to_freeze:
|
| 980 |
+
param.requires_grad = False
|
| 981 |
+
else:
|
| 982 |
+
text_encoder.requires_grad_(False)
|
| 983 |
+
if train_text_encoder:
|
| 984 |
+
text_encoder_lora_params, _ = inject_trainable_lora(
|
| 985 |
+
text_encoder,
|
| 986 |
+
target_replace_module=lora_clip_target_modules,
|
| 987 |
+
r=lora_rank,
|
| 988 |
+
)
|
| 989 |
+
params_to_optimize += [
|
| 990 |
+
{
|
| 991 |
+
"params": itertools.chain(*text_encoder_lora_params),
|
| 992 |
+
"lr": text_encoder_lr,
|
| 993 |
+
}
|
| 994 |
+
]
|
| 995 |
+
inspect_lora(text_encoder)
|
| 996 |
+
|
| 997 |
+
lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora)
|
| 998 |
+
|
| 999 |
+
unet.train()
|
| 1000 |
+
if train_text_encoder:
|
| 1001 |
+
text_encoder.train()
|
| 1002 |
+
|
| 1003 |
+
train_dataset.blur_amount = 70
|
| 1004 |
+
|
| 1005 |
+
lr_scheduler_lora = get_scheduler(
|
| 1006 |
+
lr_scheduler_lora,
|
| 1007 |
+
optimizer=lora_optimizers,
|
| 1008 |
+
num_warmup_steps=lr_warmup_steps_lora,
|
| 1009 |
+
num_training_steps=max_train_steps_tuning,
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
perform_tuning(
|
| 1013 |
+
unet,
|
| 1014 |
+
vae,
|
| 1015 |
+
text_encoder,
|
| 1016 |
+
train_dataloader,
|
| 1017 |
+
max_train_steps_tuning,
|
| 1018 |
+
cached_latents=cached_latents,
|
| 1019 |
+
scheduler=noise_scheduler,
|
| 1020 |
+
optimizer=lora_optimizers,
|
| 1021 |
+
save_steps=save_steps,
|
| 1022 |
+
placeholder_tokens=placeholder_tokens,
|
| 1023 |
+
placeholder_token_ids=placeholder_token_ids,
|
| 1024 |
+
save_path=output_dir,
|
| 1025 |
+
lr_scheduler_lora=lr_scheduler_lora,
|
| 1026 |
+
lora_unet_target_modules=lora_unet_target_modules,
|
| 1027 |
+
lora_clip_target_modules=lora_clip_target_modules,
|
| 1028 |
+
mask_temperature=mask_temperature,
|
| 1029 |
+
tokenizer=tokenizer,
|
| 1030 |
+
out_name=out_name,
|
| 1031 |
+
test_image_path=instance_data_dir,
|
| 1032 |
+
log_wandb=log_wandb,
|
| 1033 |
+
wandb_log_prompt_cnt=wandb_log_prompt_cnt,
|
| 1034 |
+
class_token=class_token,
|
| 1035 |
+
train_inpainting=train_inpainting,
|
| 1036 |
+
)
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
def main():
|
| 1040 |
+
fire.Fire(train)
|
lora_diffusion/cli_pt_to_safetensors.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import fire
|
| 4 |
+
import torch
|
| 5 |
+
from lora_diffusion import (
|
| 6 |
+
DEFAULT_TARGET_REPLACE,
|
| 7 |
+
TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
| 8 |
+
UNET_DEFAULT_TARGET_REPLACE,
|
| 9 |
+
convert_loras_to_safeloras_with_embeds,
|
| 10 |
+
safetensors_available,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
_target_by_name = {
|
| 14 |
+
"unet": UNET_DEFAULT_TARGET_REPLACE,
|
| 15 |
+
"text_encoder": TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def convert(*paths, outpath, overwrite=False, **settings):
|
| 20 |
+
"""
|
| 21 |
+
Converts one or more pytorch Lora and/or Textual Embedding pytorch files
|
| 22 |
+
into a safetensor file.
|
| 23 |
+
|
| 24 |
+
Pass all the input paths as arguments. Whether they are Textual Embedding
|
| 25 |
+
or Lora models will be auto-detected.
|
| 26 |
+
|
| 27 |
+
For Lora models, their name will be taken from the path, i.e.
|
| 28 |
+
"lora_weight.pt" => unet
|
| 29 |
+
"lora_weight.text_encoder.pt" => text_encoder
|
| 30 |
+
|
| 31 |
+
You can also set target_modules and/or rank by providing an argument prefixed
|
| 32 |
+
by the name.
|
| 33 |
+
|
| 34 |
+
So a complete example might be something like:
|
| 35 |
+
|
| 36 |
+
```
|
| 37 |
+
python -m lora_diffusion.cli_pt_to_safetensors lora_weight.* --outpath lora_weight.safetensor --unet.rank 8
|
| 38 |
+
```
|
| 39 |
+
"""
|
| 40 |
+
modelmap = {}
|
| 41 |
+
embeds = {}
|
| 42 |
+
|
| 43 |
+
if os.path.exists(outpath) and not overwrite:
|
| 44 |
+
raise ValueError(
|
| 45 |
+
f"Output path {outpath} already exists, and overwrite is not True"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
for path in paths:
|
| 49 |
+
data = torch.load(path)
|
| 50 |
+
|
| 51 |
+
if isinstance(data, dict):
|
| 52 |
+
print(f"Loading textual inversion embeds {data.keys()} from {path}")
|
| 53 |
+
embeds.update(data)
|
| 54 |
+
|
| 55 |
+
else:
|
| 56 |
+
name_parts = os.path.split(path)[1].split(".")
|
| 57 |
+
name = name_parts[-2] if len(name_parts) > 2 else "unet"
|
| 58 |
+
|
| 59 |
+
model_settings = {
|
| 60 |
+
"target_modules": _target_by_name.get(name, DEFAULT_TARGET_REPLACE),
|
| 61 |
+
"rank": 4,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
prefix = f"{name}."
|
| 65 |
+
|
| 66 |
+
arg_settings = { k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix) }
|
| 67 |
+
model_settings = { **model_settings, **arg_settings }
|
| 68 |
+
|
| 69 |
+
print(f"Loading Lora for {name} from {path} with settings {model_settings}")
|
| 70 |
+
|
| 71 |
+
modelmap[name] = (
|
| 72 |
+
path,
|
| 73 |
+
model_settings["target_modules"],
|
| 74 |
+
model_settings["rank"],
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
convert_loras_to_safeloras_with_embeds(modelmap, embeds, outpath)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def main():
|
| 81 |
+
fire.Fire(convert)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
main()
|
lora_diffusion/cli_svd.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fire
|
| 2 |
+
from diffusers import StableDiffusionPipeline
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from .lora import (
|
| 7 |
+
save_all,
|
| 8 |
+
_find_modules,
|
| 9 |
+
LoraInjectedConv2d,
|
| 10 |
+
LoraInjectedLinear,
|
| 11 |
+
inject_trainable_lora,
|
| 12 |
+
inject_trainable_lora_extended,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _iter_lora(model):
|
| 17 |
+
for module in model.modules():
|
| 18 |
+
if isinstance(module, LoraInjectedConv2d) or isinstance(
|
| 19 |
+
module, LoraInjectedLinear
|
| 20 |
+
):
|
| 21 |
+
yield module
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def overwrite_base(base_model, tuned_model, rank, clamp_quantile):
|
| 25 |
+
device = base_model.device
|
| 26 |
+
dtype = base_model.dtype
|
| 27 |
+
|
| 28 |
+
for lor_base, lor_tune in zip(_iter_lora(base_model), _iter_lora(tuned_model)):
|
| 29 |
+
|
| 30 |
+
if isinstance(lor_base, LoraInjectedLinear):
|
| 31 |
+
residual = lor_tune.linear.weight.data - lor_base.linear.weight.data
|
| 32 |
+
# SVD on residual
|
| 33 |
+
print("Distill Linear shape ", residual.shape)
|
| 34 |
+
residual = residual.float()
|
| 35 |
+
U, S, Vh = torch.linalg.svd(residual)
|
| 36 |
+
U = U[:, :rank]
|
| 37 |
+
S = S[:rank]
|
| 38 |
+
U = U @ torch.diag(S)
|
| 39 |
+
|
| 40 |
+
Vh = Vh[:rank, :]
|
| 41 |
+
|
| 42 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
| 43 |
+
hi_val = torch.quantile(dist, clamp_quantile)
|
| 44 |
+
low_val = -hi_val
|
| 45 |
+
|
| 46 |
+
U = U.clamp(low_val, hi_val)
|
| 47 |
+
Vh = Vh.clamp(low_val, hi_val)
|
| 48 |
+
|
| 49 |
+
assert lor_base.lora_up.weight.shape == U.shape
|
| 50 |
+
assert lor_base.lora_down.weight.shape == Vh.shape
|
| 51 |
+
|
| 52 |
+
lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype)
|
| 53 |
+
lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype)
|
| 54 |
+
|
| 55 |
+
if isinstance(lor_base, LoraInjectedConv2d):
|
| 56 |
+
residual = lor_tune.conv.weight.data - lor_base.conv.weight.data
|
| 57 |
+
print("Distill Conv shape ", residual.shape)
|
| 58 |
+
|
| 59 |
+
residual = residual.float()
|
| 60 |
+
residual = residual.flatten(start_dim=1)
|
| 61 |
+
|
| 62 |
+
# SVD on residual
|
| 63 |
+
U, S, Vh = torch.linalg.svd(residual)
|
| 64 |
+
U = U[:, :rank]
|
| 65 |
+
S = S[:rank]
|
| 66 |
+
U = U @ torch.diag(S)
|
| 67 |
+
|
| 68 |
+
Vh = Vh[:rank, :]
|
| 69 |
+
|
| 70 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
| 71 |
+
hi_val = torch.quantile(dist, clamp_quantile)
|
| 72 |
+
low_val = -hi_val
|
| 73 |
+
|
| 74 |
+
U = U.clamp(low_val, hi_val)
|
| 75 |
+
Vh = Vh.clamp(low_val, hi_val)
|
| 76 |
+
|
| 77 |
+
# U is (out_channels, rank) with 1x1 conv. So,
|
| 78 |
+
U = U.reshape(U.shape[0], U.shape[1], 1, 1)
|
| 79 |
+
# V is (rank, in_channels * kernel_size1 * kernel_size2)
|
| 80 |
+
# now reshape:
|
| 81 |
+
Vh = Vh.reshape(
|
| 82 |
+
Vh.shape[0],
|
| 83 |
+
lor_base.conv.in_channels,
|
| 84 |
+
lor_base.conv.kernel_size[0],
|
| 85 |
+
lor_base.conv.kernel_size[1],
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
assert lor_base.lora_up.weight.shape == U.shape
|
| 89 |
+
assert lor_base.lora_down.weight.shape == Vh.shape
|
| 90 |
+
|
| 91 |
+
lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype)
|
| 92 |
+
lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def svd_distill(
|
| 96 |
+
target_model: str,
|
| 97 |
+
base_model: str,
|
| 98 |
+
rank: int = 4,
|
| 99 |
+
clamp_quantile: float = 0.99,
|
| 100 |
+
device: str = "cuda:0",
|
| 101 |
+
save_path: str = "svd_distill.safetensors",
|
| 102 |
+
):
|
| 103 |
+
pipe_base = StableDiffusionPipeline.from_pretrained(
|
| 104 |
+
base_model, torch_dtype=torch.float16
|
| 105 |
+
).to(device)
|
| 106 |
+
|
| 107 |
+
pipe_tuned = StableDiffusionPipeline.from_pretrained(
|
| 108 |
+
target_model, torch_dtype=torch.float16
|
| 109 |
+
).to(device)
|
| 110 |
+
|
| 111 |
+
# Inject unet
|
| 112 |
+
_ = inject_trainable_lora_extended(pipe_base.unet, r=rank)
|
| 113 |
+
_ = inject_trainable_lora_extended(pipe_tuned.unet, r=rank)
|
| 114 |
+
|
| 115 |
+
overwrite_base(
|
| 116 |
+
pipe_base.unet, pipe_tuned.unet, rank=rank, clamp_quantile=clamp_quantile
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Inject text encoder
|
| 120 |
+
_ = inject_trainable_lora(
|
| 121 |
+
pipe_base.text_encoder, r=rank, target_replace_module={"CLIPAttention"}
|
| 122 |
+
)
|
| 123 |
+
_ = inject_trainable_lora(
|
| 124 |
+
pipe_tuned.text_encoder, r=rank, target_replace_module={"CLIPAttention"}
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
overwrite_base(
|
| 128 |
+
pipe_base.text_encoder,
|
| 129 |
+
pipe_tuned.text_encoder,
|
| 130 |
+
rank=rank,
|
| 131 |
+
clamp_quantile=clamp_quantile,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
save_all(
|
| 135 |
+
unet=pipe_base.unet,
|
| 136 |
+
text_encoder=pipe_base.text_encoder,
|
| 137 |
+
placeholder_token_ids=None,
|
| 138 |
+
placeholder_tokens=None,
|
| 139 |
+
save_path=save_path,
|
| 140 |
+
save_lora=True,
|
| 141 |
+
save_ti=False,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def main():
|
| 146 |
+
fire.Fire(svd_distill)
|
lora_diffusion/dataset.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torch import zeros_like
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
import glob
|
| 10 |
+
from .preprocess_files import face_mask_google_mediapipe
|
| 11 |
+
|
| 12 |
+
OBJECT_TEMPLATE = [
|
| 13 |
+
"a photo of a {}",
|
| 14 |
+
"a rendering of a {}",
|
| 15 |
+
"a cropped photo of the {}",
|
| 16 |
+
"the photo of a {}",
|
| 17 |
+
"a photo of a clean {}",
|
| 18 |
+
"a photo of a dirty {}",
|
| 19 |
+
"a dark photo of the {}",
|
| 20 |
+
"a photo of my {}",
|
| 21 |
+
"a photo of the cool {}",
|
| 22 |
+
"a close-up photo of a {}",
|
| 23 |
+
"a bright photo of the {}",
|
| 24 |
+
"a cropped photo of a {}",
|
| 25 |
+
"a photo of the {}",
|
| 26 |
+
"a good photo of the {}",
|
| 27 |
+
"a photo of one {}",
|
| 28 |
+
"a close-up photo of the {}",
|
| 29 |
+
"a rendition of the {}",
|
| 30 |
+
"a photo of the clean {}",
|
| 31 |
+
"a rendition of a {}",
|
| 32 |
+
"a photo of a nice {}",
|
| 33 |
+
"a good photo of a {}",
|
| 34 |
+
"a photo of the nice {}",
|
| 35 |
+
"a photo of the small {}",
|
| 36 |
+
"a photo of the weird {}",
|
| 37 |
+
"a photo of the large {}",
|
| 38 |
+
"a photo of a cool {}",
|
| 39 |
+
"a photo of a small {}",
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
STYLE_TEMPLATE = [
|
| 43 |
+
"a painting in the style of {}",
|
| 44 |
+
"a rendering in the style of {}",
|
| 45 |
+
"a cropped painting in the style of {}",
|
| 46 |
+
"the painting in the style of {}",
|
| 47 |
+
"a clean painting in the style of {}",
|
| 48 |
+
"a dirty painting in the style of {}",
|
| 49 |
+
"a dark painting in the style of {}",
|
| 50 |
+
"a picture in the style of {}",
|
| 51 |
+
"a cool painting in the style of {}",
|
| 52 |
+
"a close-up painting in the style of {}",
|
| 53 |
+
"a bright painting in the style of {}",
|
| 54 |
+
"a cropped painting in the style of {}",
|
| 55 |
+
"a good painting in the style of {}",
|
| 56 |
+
"a close-up painting in the style of {}",
|
| 57 |
+
"a rendition in the style of {}",
|
| 58 |
+
"a nice painting in the style of {}",
|
| 59 |
+
"a small painting in the style of {}",
|
| 60 |
+
"a weird painting in the style of {}",
|
| 61 |
+
"a large painting in the style of {}",
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
NULL_TEMPLATE = ["{}"]
|
| 65 |
+
|
| 66 |
+
TEMPLATE_MAP = {
|
| 67 |
+
"object": OBJECT_TEMPLATE,
|
| 68 |
+
"style": STYLE_TEMPLATE,
|
| 69 |
+
"null": NULL_TEMPLATE,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _randomset(lis):
|
| 74 |
+
ret = []
|
| 75 |
+
for i in range(len(lis)):
|
| 76 |
+
if random.random() < 0.5:
|
| 77 |
+
ret.append(lis[i])
|
| 78 |
+
return ret
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _shuffle(lis):
|
| 82 |
+
|
| 83 |
+
return random.sample(lis, len(lis))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _get_cutout_holes(
|
| 87 |
+
height,
|
| 88 |
+
width,
|
| 89 |
+
min_holes=8,
|
| 90 |
+
max_holes=32,
|
| 91 |
+
min_height=16,
|
| 92 |
+
max_height=128,
|
| 93 |
+
min_width=16,
|
| 94 |
+
max_width=128,
|
| 95 |
+
):
|
| 96 |
+
holes = []
|
| 97 |
+
for _n in range(random.randint(min_holes, max_holes)):
|
| 98 |
+
hole_height = random.randint(min_height, max_height)
|
| 99 |
+
hole_width = random.randint(min_width, max_width)
|
| 100 |
+
y1 = random.randint(0, height - hole_height)
|
| 101 |
+
x1 = random.randint(0, width - hole_width)
|
| 102 |
+
y2 = y1 + hole_height
|
| 103 |
+
x2 = x1 + hole_width
|
| 104 |
+
holes.append((x1, y1, x2, y2))
|
| 105 |
+
return holes
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _generate_random_mask(image):
|
| 109 |
+
mask = zeros_like(image[:1])
|
| 110 |
+
holes = _get_cutout_holes(mask.shape[1], mask.shape[2])
|
| 111 |
+
for (x1, y1, x2, y2) in holes:
|
| 112 |
+
mask[:, y1:y2, x1:x2] = 1.0
|
| 113 |
+
if random.uniform(0, 1) < 0.25:
|
| 114 |
+
mask.fill_(1.0)
|
| 115 |
+
masked_image = image * (mask < 0.5)
|
| 116 |
+
return mask, masked_image
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class PivotalTuningDatasetCapation(Dataset):
|
| 120 |
+
"""
|
| 121 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
| 122 |
+
It pre-processes the images and the tokenizes prompts.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
instance_data_root,
|
| 128 |
+
tokenizer,
|
| 129 |
+
token_map: Optional[dict] = None,
|
| 130 |
+
use_template: Optional[str] = None,
|
| 131 |
+
size=512,
|
| 132 |
+
h_flip=True,
|
| 133 |
+
color_jitter=False,
|
| 134 |
+
resize=True,
|
| 135 |
+
use_mask_captioned_data=False,
|
| 136 |
+
use_face_segmentation_condition=False,
|
| 137 |
+
train_inpainting=False,
|
| 138 |
+
blur_amount: int = 70,
|
| 139 |
+
):
|
| 140 |
+
self.size = size
|
| 141 |
+
self.tokenizer = tokenizer
|
| 142 |
+
self.resize = resize
|
| 143 |
+
self.train_inpainting = train_inpainting
|
| 144 |
+
|
| 145 |
+
instance_data_root = Path(instance_data_root)
|
| 146 |
+
if not instance_data_root.exists():
|
| 147 |
+
raise ValueError("Instance images root doesn't exists.")
|
| 148 |
+
|
| 149 |
+
self.instance_images_path = []
|
| 150 |
+
self.mask_path = []
|
| 151 |
+
|
| 152 |
+
assert not (
|
| 153 |
+
use_mask_captioned_data and use_template
|
| 154 |
+
), "Can't use both mask caption data and template."
|
| 155 |
+
|
| 156 |
+
# Prepare the instance images
|
| 157 |
+
if use_mask_captioned_data:
|
| 158 |
+
src_imgs = glob.glob(str(instance_data_root) + "/*src.jpg")
|
| 159 |
+
for f in src_imgs:
|
| 160 |
+
idx = int(str(Path(f).stem).split(".")[0])
|
| 161 |
+
mask_path = f"{instance_data_root}/{idx}.mask.png"
|
| 162 |
+
|
| 163 |
+
if Path(mask_path).exists():
|
| 164 |
+
self.instance_images_path.append(f)
|
| 165 |
+
self.mask_path.append(mask_path)
|
| 166 |
+
else:
|
| 167 |
+
print(f"Mask not found for {f}")
|
| 168 |
+
|
| 169 |
+
self.captions = open(f"{instance_data_root}/caption.txt").readlines()
|
| 170 |
+
|
| 171 |
+
else:
|
| 172 |
+
possibily_src_images = (
|
| 173 |
+
glob.glob(str(instance_data_root) + "/*.jpg")
|
| 174 |
+
+ glob.glob(str(instance_data_root) + "/*.png")
|
| 175 |
+
+ glob.glob(str(instance_data_root) + "/*.jpeg")
|
| 176 |
+
)
|
| 177 |
+
possibily_src_images = (
|
| 178 |
+
set(possibily_src_images)
|
| 179 |
+
- set(glob.glob(str(instance_data_root) + "/*mask.png"))
|
| 180 |
+
- set([str(instance_data_root) + "/caption.txt"])
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
self.instance_images_path = list(set(possibily_src_images))
|
| 184 |
+
self.captions = [
|
| 185 |
+
x.split("/")[-1].split(".")[0] for x in self.instance_images_path
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
assert (
|
| 189 |
+
len(self.instance_images_path) > 0
|
| 190 |
+
), "No images found in the instance data root."
|
| 191 |
+
|
| 192 |
+
self.instance_images_path = sorted(self.instance_images_path)
|
| 193 |
+
|
| 194 |
+
self.use_mask = use_face_segmentation_condition or use_mask_captioned_data
|
| 195 |
+
self.use_mask_captioned_data = use_mask_captioned_data
|
| 196 |
+
|
| 197 |
+
if use_face_segmentation_condition:
|
| 198 |
+
|
| 199 |
+
for idx in range(len(self.instance_images_path)):
|
| 200 |
+
targ = f"{instance_data_root}/{idx}.mask.png"
|
| 201 |
+
# see if the mask exists
|
| 202 |
+
if not Path(targ).exists():
|
| 203 |
+
print(f"Mask not found for {targ}")
|
| 204 |
+
|
| 205 |
+
print(
|
| 206 |
+
"Warning : this will pre-process all the images in the instance data root."
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
if len(self.mask_path) > 0:
|
| 210 |
+
print(
|
| 211 |
+
"Warning : masks already exists, but will be overwritten."
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
masks = face_mask_google_mediapipe(
|
| 215 |
+
[
|
| 216 |
+
Image.open(f).convert("RGB")
|
| 217 |
+
for f in self.instance_images_path
|
| 218 |
+
]
|
| 219 |
+
)
|
| 220 |
+
for idx, mask in enumerate(masks):
|
| 221 |
+
mask.save(f"{instance_data_root}/{idx}.mask.png")
|
| 222 |
+
|
| 223 |
+
break
|
| 224 |
+
|
| 225 |
+
for idx in range(len(self.instance_images_path)):
|
| 226 |
+
self.mask_path.append(f"{instance_data_root}/{idx}.mask.png")
|
| 227 |
+
|
| 228 |
+
self.num_instance_images = len(self.instance_images_path)
|
| 229 |
+
self.token_map = token_map
|
| 230 |
+
|
| 231 |
+
self.use_template = use_template
|
| 232 |
+
if use_template is not None:
|
| 233 |
+
self.templates = TEMPLATE_MAP[use_template]
|
| 234 |
+
|
| 235 |
+
self._length = self.num_instance_images
|
| 236 |
+
|
| 237 |
+
self.h_flip = h_flip
|
| 238 |
+
self.image_transforms = transforms.Compose(
|
| 239 |
+
[
|
| 240 |
+
transforms.Resize(
|
| 241 |
+
size, interpolation=transforms.InterpolationMode.BILINEAR
|
| 242 |
+
)
|
| 243 |
+
if resize
|
| 244 |
+
else transforms.Lambda(lambda x: x),
|
| 245 |
+
transforms.ColorJitter(0.1, 0.1)
|
| 246 |
+
if color_jitter
|
| 247 |
+
else transforms.Lambda(lambda x: x),
|
| 248 |
+
transforms.CenterCrop(size),
|
| 249 |
+
transforms.ToTensor(),
|
| 250 |
+
transforms.Normalize([0.5], [0.5]),
|
| 251 |
+
]
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
self.blur_amount = blur_amount
|
| 255 |
+
|
| 256 |
+
def __len__(self):
|
| 257 |
+
return self._length
|
| 258 |
+
|
| 259 |
+
def __getitem__(self, index):
|
| 260 |
+
example = {}
|
| 261 |
+
instance_image = Image.open(
|
| 262 |
+
self.instance_images_path[index % self.num_instance_images]
|
| 263 |
+
)
|
| 264 |
+
if not instance_image.mode == "RGB":
|
| 265 |
+
instance_image = instance_image.convert("RGB")
|
| 266 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
| 267 |
+
|
| 268 |
+
if self.train_inpainting:
|
| 269 |
+
(
|
| 270 |
+
example["instance_masks"],
|
| 271 |
+
example["instance_masked_images"],
|
| 272 |
+
) = _generate_random_mask(example["instance_images"])
|
| 273 |
+
|
| 274 |
+
if self.use_template:
|
| 275 |
+
assert self.token_map is not None
|
| 276 |
+
input_tok = list(self.token_map.values())[0]
|
| 277 |
+
|
| 278 |
+
text = random.choice(self.templates).format(input_tok)
|
| 279 |
+
else:
|
| 280 |
+
text = self.captions[index % self.num_instance_images].strip()
|
| 281 |
+
|
| 282 |
+
if self.token_map is not None:
|
| 283 |
+
for token, value in self.token_map.items():
|
| 284 |
+
text = text.replace(token, value)
|
| 285 |
+
|
| 286 |
+
print(text)
|
| 287 |
+
|
| 288 |
+
if self.use_mask:
|
| 289 |
+
example["mask"] = (
|
| 290 |
+
self.image_transforms(
|
| 291 |
+
Image.open(self.mask_path[index % self.num_instance_images])
|
| 292 |
+
)
|
| 293 |
+
* 0.5
|
| 294 |
+
+ 1.0
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if self.h_flip and random.random() > 0.5:
|
| 298 |
+
hflip = transforms.RandomHorizontalFlip(p=1)
|
| 299 |
+
|
| 300 |
+
example["instance_images"] = hflip(example["instance_images"])
|
| 301 |
+
if self.use_mask:
|
| 302 |
+
example["mask"] = hflip(example["mask"])
|
| 303 |
+
|
| 304 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
| 305 |
+
text,
|
| 306 |
+
padding="do_not_pad",
|
| 307 |
+
truncation=True,
|
| 308 |
+
max_length=self.tokenizer.model_max_length,
|
| 309 |
+
).input_ids
|
| 310 |
+
|
| 311 |
+
return example
|
lora_diffusion/lora.py
ADDED
|
@@ -0,0 +1,1110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import math
|
| 3 |
+
from itertools import groupby
|
| 4 |
+
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import PIL
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from safetensors.torch import safe_open
|
| 14 |
+
from safetensors.torch import save_file as safe_save
|
| 15 |
+
|
| 16 |
+
safetensors_available = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
from .safe_open import safe_open
|
| 19 |
+
|
| 20 |
+
def safe_save(
|
| 21 |
+
tensors: Dict[str, torch.Tensor],
|
| 22 |
+
filename: str,
|
| 23 |
+
metadata: Optional[Dict[str, str]] = None,
|
| 24 |
+
) -> None:
|
| 25 |
+
raise EnvironmentError(
|
| 26 |
+
"Saving safetensors requires the safetensors library. Please install with pip or similar."
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
safetensors_available = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class LoraInjectedLinear(nn.Module):
|
| 33 |
+
def __init__(
|
| 34 |
+
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
if r > min(in_features, out_features):
|
| 39 |
+
raise ValueError(
|
| 40 |
+
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
|
| 41 |
+
)
|
| 42 |
+
self.r = r
|
| 43 |
+
self.linear = nn.Linear(in_features, out_features, bias)
|
| 44 |
+
self.lora_down = nn.Linear(in_features, r, bias=False)
|
| 45 |
+
self.dropout = nn.Dropout(dropout_p)
|
| 46 |
+
self.lora_up = nn.Linear(r, out_features, bias=False)
|
| 47 |
+
self.scale = scale
|
| 48 |
+
self.selector = nn.Identity()
|
| 49 |
+
|
| 50 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
| 51 |
+
nn.init.zeros_(self.lora_up.weight)
|
| 52 |
+
|
| 53 |
+
def forward(self, input):
|
| 54 |
+
return (
|
| 55 |
+
self.linear(input)
|
| 56 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
| 57 |
+
* self.scale
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def realize_as_lora(self):
|
| 61 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
| 62 |
+
|
| 63 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
| 64 |
+
# diag is a 1D tensor of size (r,)
|
| 65 |
+
assert diag.shape == (self.r,)
|
| 66 |
+
self.selector = nn.Linear(self.r, self.r, bias=False)
|
| 67 |
+
self.selector.weight.data = torch.diag(diag)
|
| 68 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
| 69 |
+
self.lora_up.weight.device
|
| 70 |
+
).to(self.lora_up.weight.dtype)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class LoraInjectedConv2d(nn.Module):
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
in_channels: int,
|
| 77 |
+
out_channels: int,
|
| 78 |
+
kernel_size,
|
| 79 |
+
stride=1,
|
| 80 |
+
padding=0,
|
| 81 |
+
dilation=1,
|
| 82 |
+
groups: int = 1,
|
| 83 |
+
bias: bool = True,
|
| 84 |
+
r: int = 4,
|
| 85 |
+
dropout_p: float = 0.1,
|
| 86 |
+
scale: float = 1.0,
|
| 87 |
+
):
|
| 88 |
+
super().__init__()
|
| 89 |
+
if r > min(in_channels, out_channels):
|
| 90 |
+
raise ValueError(
|
| 91 |
+
f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}"
|
| 92 |
+
)
|
| 93 |
+
self.r = r
|
| 94 |
+
self.conv = nn.Conv2d(
|
| 95 |
+
in_channels=in_channels,
|
| 96 |
+
out_channels=out_channels,
|
| 97 |
+
kernel_size=kernel_size,
|
| 98 |
+
stride=stride,
|
| 99 |
+
padding=padding,
|
| 100 |
+
dilation=dilation,
|
| 101 |
+
groups=groups,
|
| 102 |
+
bias=bias,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self.lora_down = nn.Conv2d(
|
| 106 |
+
in_channels=in_channels,
|
| 107 |
+
out_channels=r,
|
| 108 |
+
kernel_size=kernel_size,
|
| 109 |
+
stride=stride,
|
| 110 |
+
padding=padding,
|
| 111 |
+
dilation=dilation,
|
| 112 |
+
groups=groups,
|
| 113 |
+
bias=False,
|
| 114 |
+
)
|
| 115 |
+
self.dropout = nn.Dropout(dropout_p)
|
| 116 |
+
self.lora_up = nn.Conv2d(
|
| 117 |
+
in_channels=r,
|
| 118 |
+
out_channels=out_channels,
|
| 119 |
+
kernel_size=1,
|
| 120 |
+
stride=1,
|
| 121 |
+
padding=0,
|
| 122 |
+
bias=False,
|
| 123 |
+
)
|
| 124 |
+
self.selector = nn.Identity()
|
| 125 |
+
self.scale = scale
|
| 126 |
+
|
| 127 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
| 128 |
+
nn.init.zeros_(self.lora_up.weight)
|
| 129 |
+
|
| 130 |
+
def forward(self, input):
|
| 131 |
+
return (
|
| 132 |
+
self.conv(input)
|
| 133 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
| 134 |
+
* self.scale
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def realize_as_lora(self):
|
| 138 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
| 139 |
+
|
| 140 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
| 141 |
+
# diag is a 1D tensor of size (r,)
|
| 142 |
+
assert diag.shape == (self.r,)
|
| 143 |
+
self.selector = nn.Conv2d(
|
| 144 |
+
in_channels=self.r,
|
| 145 |
+
out_channels=self.r,
|
| 146 |
+
kernel_size=1,
|
| 147 |
+
stride=1,
|
| 148 |
+
padding=0,
|
| 149 |
+
bias=False,
|
| 150 |
+
)
|
| 151 |
+
self.selector.weight.data = torch.diag(diag)
|
| 152 |
+
|
| 153 |
+
# same device + dtype as lora_up
|
| 154 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
| 155 |
+
self.lora_up.weight.device
|
| 156 |
+
).to(self.lora_up.weight.dtype)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
|
| 160 |
+
|
| 161 |
+
UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"}
|
| 162 |
+
|
| 163 |
+
TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
|
| 164 |
+
|
| 165 |
+
TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
|
| 166 |
+
|
| 167 |
+
DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
|
| 168 |
+
|
| 169 |
+
EMBED_FLAG = "<embed>"
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _find_children(
|
| 173 |
+
model,
|
| 174 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
| 175 |
+
):
|
| 176 |
+
"""
|
| 177 |
+
Find all modules of a certain class (or union of classes).
|
| 178 |
+
|
| 179 |
+
Returns all matching modules, along with the parent of those moduless and the
|
| 180 |
+
names they are referenced by.
|
| 181 |
+
"""
|
| 182 |
+
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
| 183 |
+
for parent in model.modules():
|
| 184 |
+
for name, module in parent.named_children():
|
| 185 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
| 186 |
+
yield parent, name, module
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _find_modules_v2(
|
| 190 |
+
model,
|
| 191 |
+
ancestor_class: Optional[Set[str]] = None,
|
| 192 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
| 193 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = [
|
| 194 |
+
LoraInjectedLinear,
|
| 195 |
+
LoraInjectedConv2d,
|
| 196 |
+
],
|
| 197 |
+
):
|
| 198 |
+
"""
|
| 199 |
+
Find all modules of a certain class (or union of classes) that are direct or
|
| 200 |
+
indirect descendants of other modules of a certain class (or union of classes).
|
| 201 |
+
|
| 202 |
+
Returns all matching modules, along with the parent of those moduless and the
|
| 203 |
+
names they are referenced by.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
# Get the targets we should replace all linears under
|
| 207 |
+
if ancestor_class is not None:
|
| 208 |
+
ancestors = (
|
| 209 |
+
module
|
| 210 |
+
for module in model.modules()
|
| 211 |
+
if module.__class__.__name__ in ancestor_class
|
| 212 |
+
)
|
| 213 |
+
else:
|
| 214 |
+
# this, incase you want to naively iterate over all modules.
|
| 215 |
+
ancestors = [module for module in model.modules()]
|
| 216 |
+
|
| 217 |
+
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
| 218 |
+
for ancestor in ancestors:
|
| 219 |
+
for fullname, module in ancestor.named_modules():
|
| 220 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
| 221 |
+
# Find the direct parent if this is a descendant, not a child, of target
|
| 222 |
+
*path, name = fullname.split(".")
|
| 223 |
+
parent = ancestor
|
| 224 |
+
while path:
|
| 225 |
+
parent = parent.get_submodule(path.pop(0))
|
| 226 |
+
# Skip this linear if it's a child of a LoraInjectedLinear
|
| 227 |
+
if exclude_children_of and any(
|
| 228 |
+
[isinstance(parent, _class) for _class in exclude_children_of]
|
| 229 |
+
):
|
| 230 |
+
continue
|
| 231 |
+
# Otherwise, yield it
|
| 232 |
+
yield parent, name, module
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _find_modules_old(
|
| 236 |
+
model,
|
| 237 |
+
ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
|
| 238 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
| 239 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
|
| 240 |
+
):
|
| 241 |
+
ret = []
|
| 242 |
+
for _module in model.modules():
|
| 243 |
+
if _module.__class__.__name__ in ancestor_class:
|
| 244 |
+
|
| 245 |
+
for name, _child_module in _module.named_modules():
|
| 246 |
+
if _child_module.__class__ in search_class:
|
| 247 |
+
ret.append((_module, name, _child_module))
|
| 248 |
+
print(ret)
|
| 249 |
+
return ret
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
_find_modules = _find_modules_v2
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def inject_trainable_lora(
|
| 256 |
+
model: nn.Module,
|
| 257 |
+
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
|
| 258 |
+
r: int = 4,
|
| 259 |
+
loras=None, # path to lora .pt
|
| 260 |
+
verbose: bool = False,
|
| 261 |
+
dropout_p: float = 0.0,
|
| 262 |
+
scale: float = 1.0,
|
| 263 |
+
):
|
| 264 |
+
"""
|
| 265 |
+
inject lora into model, and returns lora parameter groups.
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
require_grad_params = []
|
| 269 |
+
names = []
|
| 270 |
+
|
| 271 |
+
if loras != None:
|
| 272 |
+
loras = torch.load(loras)
|
| 273 |
+
|
| 274 |
+
for _module, name, _child_module in _find_modules(
|
| 275 |
+
model, target_replace_module, search_class=[nn.Linear]
|
| 276 |
+
):
|
| 277 |
+
weight = _child_module.weight
|
| 278 |
+
bias = _child_module.bias
|
| 279 |
+
if verbose:
|
| 280 |
+
print("LoRA Injection : injecting lora into ", name)
|
| 281 |
+
print("LoRA Injection : weight shape", weight.shape)
|
| 282 |
+
_tmp = LoraInjectedLinear(
|
| 283 |
+
_child_module.in_features,
|
| 284 |
+
_child_module.out_features,
|
| 285 |
+
_child_module.bias is not None,
|
| 286 |
+
r=r,
|
| 287 |
+
dropout_p=dropout_p,
|
| 288 |
+
scale=scale,
|
| 289 |
+
)
|
| 290 |
+
_tmp.linear.weight = weight
|
| 291 |
+
if bias is not None:
|
| 292 |
+
_tmp.linear.bias = bias
|
| 293 |
+
|
| 294 |
+
# switch the module
|
| 295 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
| 296 |
+
_module._modules[name] = _tmp
|
| 297 |
+
|
| 298 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
| 299 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
| 300 |
+
|
| 301 |
+
if loras != None:
|
| 302 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
| 303 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
| 304 |
+
|
| 305 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
| 306 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
| 307 |
+
names.append(name)
|
| 308 |
+
|
| 309 |
+
return require_grad_params, names
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def inject_trainable_lora_extended(
|
| 313 |
+
model: nn.Module,
|
| 314 |
+
target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE,
|
| 315 |
+
r: int = 4,
|
| 316 |
+
loras=None, # path to lora .pt
|
| 317 |
+
):
|
| 318 |
+
"""
|
| 319 |
+
inject lora into model, and returns lora parameter groups.
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
require_grad_params = []
|
| 323 |
+
names = []
|
| 324 |
+
|
| 325 |
+
if loras != None:
|
| 326 |
+
loras = torch.load(loras)
|
| 327 |
+
|
| 328 |
+
for _module, name, _child_module in _find_modules(
|
| 329 |
+
model, target_replace_module, search_class=[nn.Linear, nn.Conv2d]
|
| 330 |
+
):
|
| 331 |
+
if _child_module.__class__ == nn.Linear:
|
| 332 |
+
weight = _child_module.weight
|
| 333 |
+
bias = _child_module.bias
|
| 334 |
+
_tmp = LoraInjectedLinear(
|
| 335 |
+
_child_module.in_features,
|
| 336 |
+
_child_module.out_features,
|
| 337 |
+
_child_module.bias is not None,
|
| 338 |
+
r=r,
|
| 339 |
+
)
|
| 340 |
+
_tmp.linear.weight = weight
|
| 341 |
+
if bias is not None:
|
| 342 |
+
_tmp.linear.bias = bias
|
| 343 |
+
elif _child_module.__class__ == nn.Conv2d:
|
| 344 |
+
weight = _child_module.weight
|
| 345 |
+
bias = _child_module.bias
|
| 346 |
+
_tmp = LoraInjectedConv2d(
|
| 347 |
+
_child_module.in_channels,
|
| 348 |
+
_child_module.out_channels,
|
| 349 |
+
_child_module.kernel_size,
|
| 350 |
+
_child_module.stride,
|
| 351 |
+
_child_module.padding,
|
| 352 |
+
_child_module.dilation,
|
| 353 |
+
_child_module.groups,
|
| 354 |
+
_child_module.bias is not None,
|
| 355 |
+
r=r,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
_tmp.conv.weight = weight
|
| 359 |
+
if bias is not None:
|
| 360 |
+
_tmp.conv.bias = bias
|
| 361 |
+
|
| 362 |
+
# switch the module
|
| 363 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
| 364 |
+
if bias is not None:
|
| 365 |
+
_tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
|
| 366 |
+
|
| 367 |
+
_module._modules[name] = _tmp
|
| 368 |
+
|
| 369 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
| 370 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
| 371 |
+
|
| 372 |
+
if loras != None:
|
| 373 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
| 374 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
| 375 |
+
|
| 376 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
| 377 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
| 378 |
+
names.append(name)
|
| 379 |
+
|
| 380 |
+
return require_grad_params, names
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
|
| 384 |
+
|
| 385 |
+
loras = []
|
| 386 |
+
|
| 387 |
+
for _m, _n, _child_module in _find_modules(
|
| 388 |
+
model,
|
| 389 |
+
target_replace_module,
|
| 390 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
|
| 391 |
+
):
|
| 392 |
+
loras.append((_child_module.lora_up, _child_module.lora_down))
|
| 393 |
+
|
| 394 |
+
if len(loras) == 0:
|
| 395 |
+
raise ValueError("No lora injected.")
|
| 396 |
+
|
| 397 |
+
return loras
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def extract_lora_as_tensor(
|
| 401 |
+
model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
|
| 402 |
+
):
|
| 403 |
+
|
| 404 |
+
loras = []
|
| 405 |
+
|
| 406 |
+
for _m, _n, _child_module in _find_modules(
|
| 407 |
+
model,
|
| 408 |
+
target_replace_module,
|
| 409 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
|
| 410 |
+
):
|
| 411 |
+
up, down = _child_module.realize_as_lora()
|
| 412 |
+
if as_fp16:
|
| 413 |
+
up = up.to(torch.float16)
|
| 414 |
+
down = down.to(torch.float16)
|
| 415 |
+
|
| 416 |
+
loras.append((up, down))
|
| 417 |
+
|
| 418 |
+
if len(loras) == 0:
|
| 419 |
+
raise ValueError("No lora injected.")
|
| 420 |
+
|
| 421 |
+
return loras
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def save_lora_weight(
|
| 425 |
+
model,
|
| 426 |
+
path="./lora.pt",
|
| 427 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
| 428 |
+
):
|
| 429 |
+
weights = []
|
| 430 |
+
for _up, _down in extract_lora_ups_down(
|
| 431 |
+
model, target_replace_module=target_replace_module
|
| 432 |
+
):
|
| 433 |
+
weights.append(_up.weight.to("cpu").to(torch.float16))
|
| 434 |
+
weights.append(_down.weight.to("cpu").to(torch.float16))
|
| 435 |
+
|
| 436 |
+
torch.save(weights, path)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def save_lora_as_json(model, path="./lora.json"):
|
| 440 |
+
weights = []
|
| 441 |
+
for _up, _down in extract_lora_ups_down(model):
|
| 442 |
+
weights.append(_up.weight.detach().cpu().numpy().tolist())
|
| 443 |
+
weights.append(_down.weight.detach().cpu().numpy().tolist())
|
| 444 |
+
|
| 445 |
+
import json
|
| 446 |
+
|
| 447 |
+
with open(path, "w") as f:
|
| 448 |
+
json.dump(weights, f)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def save_safeloras_with_embeds(
|
| 452 |
+
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
|
| 453 |
+
embeds: Dict[str, torch.Tensor] = {},
|
| 454 |
+
outpath="./lora.safetensors",
|
| 455 |
+
):
|
| 456 |
+
"""
|
| 457 |
+
Saves the Lora from multiple modules in a single safetensor file.
|
| 458 |
+
|
| 459 |
+
modelmap is a dictionary of {
|
| 460 |
+
"module name": (module, target_replace_module)
|
| 461 |
+
}
|
| 462 |
+
"""
|
| 463 |
+
weights = {}
|
| 464 |
+
metadata = {}
|
| 465 |
+
|
| 466 |
+
for name, (model, target_replace_module) in modelmap.items():
|
| 467 |
+
metadata[name] = json.dumps(list(target_replace_module))
|
| 468 |
+
|
| 469 |
+
for i, (_up, _down) in enumerate(
|
| 470 |
+
extract_lora_as_tensor(model, target_replace_module)
|
| 471 |
+
):
|
| 472 |
+
rank = _down.shape[0]
|
| 473 |
+
|
| 474 |
+
metadata[f"{name}:{i}:rank"] = str(rank)
|
| 475 |
+
weights[f"{name}:{i}:up"] = _up
|
| 476 |
+
weights[f"{name}:{i}:down"] = _down
|
| 477 |
+
|
| 478 |
+
for token, tensor in embeds.items():
|
| 479 |
+
metadata[token] = EMBED_FLAG
|
| 480 |
+
weights[token] = tensor
|
| 481 |
+
|
| 482 |
+
print(f"Saving weights to {outpath}")
|
| 483 |
+
safe_save(weights, outpath, metadata)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def save_safeloras(
|
| 487 |
+
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
|
| 488 |
+
outpath="./lora.safetensors",
|
| 489 |
+
):
|
| 490 |
+
return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def convert_loras_to_safeloras_with_embeds(
|
| 494 |
+
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
|
| 495 |
+
embeds: Dict[str, torch.Tensor] = {},
|
| 496 |
+
outpath="./lora.safetensors",
|
| 497 |
+
):
|
| 498 |
+
"""
|
| 499 |
+
Converts the Lora from multiple pytorch .pt files into a single safetensor file.
|
| 500 |
+
|
| 501 |
+
modelmap is a dictionary of {
|
| 502 |
+
"module name": (pytorch_model_path, target_replace_module, rank)
|
| 503 |
+
}
|
| 504 |
+
"""
|
| 505 |
+
|
| 506 |
+
weights = {}
|
| 507 |
+
metadata = {}
|
| 508 |
+
|
| 509 |
+
for name, (path, target_replace_module, r) in modelmap.items():
|
| 510 |
+
metadata[name] = json.dumps(list(target_replace_module))
|
| 511 |
+
|
| 512 |
+
lora = torch.load(path)
|
| 513 |
+
for i, weight in enumerate(lora):
|
| 514 |
+
is_up = i % 2 == 0
|
| 515 |
+
i = i // 2
|
| 516 |
+
|
| 517 |
+
if is_up:
|
| 518 |
+
metadata[f"{name}:{i}:rank"] = str(r)
|
| 519 |
+
weights[f"{name}:{i}:up"] = weight
|
| 520 |
+
else:
|
| 521 |
+
weights[f"{name}:{i}:down"] = weight
|
| 522 |
+
|
| 523 |
+
for token, tensor in embeds.items():
|
| 524 |
+
metadata[token] = EMBED_FLAG
|
| 525 |
+
weights[token] = tensor
|
| 526 |
+
|
| 527 |
+
print(f"Saving weights to {outpath}")
|
| 528 |
+
safe_save(weights, outpath, metadata)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def convert_loras_to_safeloras(
|
| 532 |
+
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
|
| 533 |
+
outpath="./lora.safetensors",
|
| 534 |
+
):
|
| 535 |
+
convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def parse_safeloras(
|
| 539 |
+
safeloras,
|
| 540 |
+
) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
|
| 541 |
+
"""
|
| 542 |
+
Converts a loaded safetensor file that contains a set of module Loras
|
| 543 |
+
into Parameters and other information
|
| 544 |
+
|
| 545 |
+
Output is a dictionary of {
|
| 546 |
+
"module name": (
|
| 547 |
+
[list of weights],
|
| 548 |
+
[list of ranks],
|
| 549 |
+
target_replacement_modules
|
| 550 |
+
)
|
| 551 |
+
}
|
| 552 |
+
"""
|
| 553 |
+
loras = {}
|
| 554 |
+
metadata = safeloras.metadata()
|
| 555 |
+
|
| 556 |
+
get_name = lambda k: k.split(":")[0]
|
| 557 |
+
|
| 558 |
+
keys = list(safeloras.keys())
|
| 559 |
+
keys.sort(key=get_name)
|
| 560 |
+
|
| 561 |
+
for name, module_keys in groupby(keys, get_name):
|
| 562 |
+
info = metadata.get(name)
|
| 563 |
+
|
| 564 |
+
if not info:
|
| 565 |
+
raise ValueError(
|
| 566 |
+
f"Tensor {name} has no metadata - is this a Lora safetensor?"
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
# Skip Textual Inversion embeds
|
| 570 |
+
if info == EMBED_FLAG:
|
| 571 |
+
continue
|
| 572 |
+
|
| 573 |
+
# Handle Loras
|
| 574 |
+
# Extract the targets
|
| 575 |
+
target = json.loads(info)
|
| 576 |
+
|
| 577 |
+
# Build the result lists - Python needs us to preallocate lists to insert into them
|
| 578 |
+
module_keys = list(module_keys)
|
| 579 |
+
ranks = [4] * (len(module_keys) // 2)
|
| 580 |
+
weights = [None] * len(module_keys)
|
| 581 |
+
|
| 582 |
+
for key in module_keys:
|
| 583 |
+
# Split the model name and index out of the key
|
| 584 |
+
_, idx, direction = key.split(":")
|
| 585 |
+
idx = int(idx)
|
| 586 |
+
|
| 587 |
+
# Add the rank
|
| 588 |
+
ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
|
| 589 |
+
|
| 590 |
+
# Insert the weight into the list
|
| 591 |
+
idx = idx * 2 + (1 if direction == "down" else 0)
|
| 592 |
+
weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
|
| 593 |
+
|
| 594 |
+
loras[name] = (weights, ranks, target)
|
| 595 |
+
|
| 596 |
+
return loras
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def parse_safeloras_embeds(
|
| 600 |
+
safeloras,
|
| 601 |
+
) -> Dict[str, torch.Tensor]:
|
| 602 |
+
"""
|
| 603 |
+
Converts a loaded safetensor file that contains Textual Inversion embeds into
|
| 604 |
+
a dictionary of embed_token: Tensor
|
| 605 |
+
"""
|
| 606 |
+
embeds = {}
|
| 607 |
+
metadata = safeloras.metadata()
|
| 608 |
+
|
| 609 |
+
for key in safeloras.keys():
|
| 610 |
+
# Only handle Textual Inversion embeds
|
| 611 |
+
meta = metadata.get(key)
|
| 612 |
+
if not meta or meta != EMBED_FLAG:
|
| 613 |
+
continue
|
| 614 |
+
|
| 615 |
+
embeds[key] = safeloras.get_tensor(key)
|
| 616 |
+
|
| 617 |
+
return embeds
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def load_safeloras(path, device="cpu"):
|
| 621 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
| 622 |
+
return parse_safeloras(safeloras)
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def load_safeloras_embeds(path, device="cpu"):
|
| 626 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
| 627 |
+
return parse_safeloras_embeds(safeloras)
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
def load_safeloras_both(path, device="cpu"):
|
| 631 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
| 632 |
+
return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def collapse_lora(model, alpha=1.0):
|
| 636 |
+
|
| 637 |
+
for _module, name, _child_module in _find_modules(
|
| 638 |
+
model,
|
| 639 |
+
UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE,
|
| 640 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
|
| 641 |
+
):
|
| 642 |
+
|
| 643 |
+
if isinstance(_child_module, LoraInjectedLinear):
|
| 644 |
+
print("Collapsing Lin Lora in", name)
|
| 645 |
+
|
| 646 |
+
_child_module.linear.weight = nn.Parameter(
|
| 647 |
+
_child_module.linear.weight.data
|
| 648 |
+
+ alpha
|
| 649 |
+
* (
|
| 650 |
+
_child_module.lora_up.weight.data
|
| 651 |
+
@ _child_module.lora_down.weight.data
|
| 652 |
+
)
|
| 653 |
+
.type(_child_module.linear.weight.dtype)
|
| 654 |
+
.to(_child_module.linear.weight.device)
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
else:
|
| 658 |
+
print("Collapsing Conv Lora in", name)
|
| 659 |
+
_child_module.conv.weight = nn.Parameter(
|
| 660 |
+
_child_module.conv.weight.data
|
| 661 |
+
+ alpha
|
| 662 |
+
* (
|
| 663 |
+
_child_module.lora_up.weight.data.flatten(start_dim=1)
|
| 664 |
+
@ _child_module.lora_down.weight.data.flatten(start_dim=1)
|
| 665 |
+
)
|
| 666 |
+
.reshape(_child_module.conv.weight.data.shape)
|
| 667 |
+
.type(_child_module.conv.weight.dtype)
|
| 668 |
+
.to(_child_module.conv.weight.device)
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
def monkeypatch_or_replace_lora(
|
| 673 |
+
model,
|
| 674 |
+
loras,
|
| 675 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
| 676 |
+
r: Union[int, List[int]] = 4,
|
| 677 |
+
):
|
| 678 |
+
for _module, name, _child_module in _find_modules(
|
| 679 |
+
model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
|
| 680 |
+
):
|
| 681 |
+
_source = (
|
| 682 |
+
_child_module.linear
|
| 683 |
+
if isinstance(_child_module, LoraInjectedLinear)
|
| 684 |
+
else _child_module
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
weight = _source.weight
|
| 688 |
+
bias = _source.bias
|
| 689 |
+
_tmp = LoraInjectedLinear(
|
| 690 |
+
_source.in_features,
|
| 691 |
+
_source.out_features,
|
| 692 |
+
_source.bias is not None,
|
| 693 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
| 694 |
+
)
|
| 695 |
+
_tmp.linear.weight = weight
|
| 696 |
+
|
| 697 |
+
if bias is not None:
|
| 698 |
+
_tmp.linear.bias = bias
|
| 699 |
+
|
| 700 |
+
# switch the module
|
| 701 |
+
_module._modules[name] = _tmp
|
| 702 |
+
|
| 703 |
+
up_weight = loras.pop(0)
|
| 704 |
+
down_weight = loras.pop(0)
|
| 705 |
+
|
| 706 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
| 707 |
+
up_weight.type(weight.dtype)
|
| 708 |
+
)
|
| 709 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
| 710 |
+
down_weight.type(weight.dtype)
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
_module._modules[name].to(weight.device)
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def monkeypatch_or_replace_lora_extended(
|
| 717 |
+
model,
|
| 718 |
+
loras,
|
| 719 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
| 720 |
+
r: Union[int, List[int]] = 4,
|
| 721 |
+
):
|
| 722 |
+
for _module, name, _child_module in _find_modules(
|
| 723 |
+
model,
|
| 724 |
+
target_replace_module,
|
| 725 |
+
search_class=[nn.Linear, LoraInjectedLinear, nn.Conv2d, LoraInjectedConv2d],
|
| 726 |
+
):
|
| 727 |
+
|
| 728 |
+
if (_child_module.__class__ == nn.Linear) or (
|
| 729 |
+
_child_module.__class__ == LoraInjectedLinear
|
| 730 |
+
):
|
| 731 |
+
if len(loras[0].shape) != 2:
|
| 732 |
+
continue
|
| 733 |
+
|
| 734 |
+
_source = (
|
| 735 |
+
_child_module.linear
|
| 736 |
+
if isinstance(_child_module, LoraInjectedLinear)
|
| 737 |
+
else _child_module
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
weight = _source.weight
|
| 741 |
+
bias = _source.bias
|
| 742 |
+
_tmp = LoraInjectedLinear(
|
| 743 |
+
_source.in_features,
|
| 744 |
+
_source.out_features,
|
| 745 |
+
_source.bias is not None,
|
| 746 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
| 747 |
+
)
|
| 748 |
+
_tmp.linear.weight = weight
|
| 749 |
+
|
| 750 |
+
if bias is not None:
|
| 751 |
+
_tmp.linear.bias = bias
|
| 752 |
+
|
| 753 |
+
elif (_child_module.__class__ == nn.Conv2d) or (
|
| 754 |
+
_child_module.__class__ == LoraInjectedConv2d
|
| 755 |
+
):
|
| 756 |
+
if len(loras[0].shape) != 4:
|
| 757 |
+
continue
|
| 758 |
+
_source = (
|
| 759 |
+
_child_module.conv
|
| 760 |
+
if isinstance(_child_module, LoraInjectedConv2d)
|
| 761 |
+
else _child_module
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
weight = _source.weight
|
| 765 |
+
bias = _source.bias
|
| 766 |
+
_tmp = LoraInjectedConv2d(
|
| 767 |
+
_source.in_channels,
|
| 768 |
+
_source.out_channels,
|
| 769 |
+
_source.kernel_size,
|
| 770 |
+
_source.stride,
|
| 771 |
+
_source.padding,
|
| 772 |
+
_source.dilation,
|
| 773 |
+
_source.groups,
|
| 774 |
+
_source.bias is not None,
|
| 775 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
_tmp.conv.weight = weight
|
| 779 |
+
|
| 780 |
+
if bias is not None:
|
| 781 |
+
_tmp.conv.bias = bias
|
| 782 |
+
|
| 783 |
+
# switch the module
|
| 784 |
+
_module._modules[name] = _tmp
|
| 785 |
+
|
| 786 |
+
up_weight = loras.pop(0)
|
| 787 |
+
down_weight = loras.pop(0)
|
| 788 |
+
|
| 789 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
| 790 |
+
up_weight.type(weight.dtype)
|
| 791 |
+
)
|
| 792 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
| 793 |
+
down_weight.type(weight.dtype)
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
_module._modules[name].to(weight.device)
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
def monkeypatch_or_replace_safeloras(models, safeloras):
|
| 800 |
+
loras = parse_safeloras(safeloras)
|
| 801 |
+
|
| 802 |
+
for name, (lora, ranks, target) in loras.items():
|
| 803 |
+
model = getattr(models, name, None)
|
| 804 |
+
|
| 805 |
+
if not model:
|
| 806 |
+
print(f"No model provided for {name}, contained in Lora")
|
| 807 |
+
continue
|
| 808 |
+
|
| 809 |
+
monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
def monkeypatch_remove_lora(model):
|
| 813 |
+
for _module, name, _child_module in _find_modules(
|
| 814 |
+
model, search_class=[LoraInjectedLinear, LoraInjectedConv2d]
|
| 815 |
+
):
|
| 816 |
+
if isinstance(_child_module, LoraInjectedLinear):
|
| 817 |
+
_source = _child_module.linear
|
| 818 |
+
weight, bias = _source.weight, _source.bias
|
| 819 |
+
|
| 820 |
+
_tmp = nn.Linear(
|
| 821 |
+
_source.in_features, _source.out_features, bias is not None
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
_tmp.weight = weight
|
| 825 |
+
if bias is not None:
|
| 826 |
+
_tmp.bias = bias
|
| 827 |
+
|
| 828 |
+
else:
|
| 829 |
+
_source = _child_module.conv
|
| 830 |
+
weight, bias = _source.weight, _source.bias
|
| 831 |
+
|
| 832 |
+
_tmp = nn.Conv2d(
|
| 833 |
+
in_channels=_source.in_channels,
|
| 834 |
+
out_channels=_source.out_channels,
|
| 835 |
+
kernel_size=_source.kernel_size,
|
| 836 |
+
stride=_source.stride,
|
| 837 |
+
padding=_source.padding,
|
| 838 |
+
dilation=_source.dilation,
|
| 839 |
+
groups=_source.groups,
|
| 840 |
+
bias=bias is not None,
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
_tmp.weight = weight
|
| 844 |
+
if bias is not None:
|
| 845 |
+
_tmp.bias = bias
|
| 846 |
+
|
| 847 |
+
_module._modules[name] = _tmp
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
def monkeypatch_add_lora(
|
| 851 |
+
model,
|
| 852 |
+
loras,
|
| 853 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
| 854 |
+
alpha: float = 1.0,
|
| 855 |
+
beta: float = 1.0,
|
| 856 |
+
):
|
| 857 |
+
for _module, name, _child_module in _find_modules(
|
| 858 |
+
model, target_replace_module, search_class=[LoraInjectedLinear]
|
| 859 |
+
):
|
| 860 |
+
weight = _child_module.linear.weight
|
| 861 |
+
|
| 862 |
+
up_weight = loras.pop(0)
|
| 863 |
+
down_weight = loras.pop(0)
|
| 864 |
+
|
| 865 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
| 866 |
+
up_weight.type(weight.dtype).to(weight.device) * alpha
|
| 867 |
+
+ _module._modules[name].lora_up.weight.to(weight.device) * beta
|
| 868 |
+
)
|
| 869 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
| 870 |
+
down_weight.type(weight.dtype).to(weight.device) * alpha
|
| 871 |
+
+ _module._modules[name].lora_down.weight.to(weight.device) * beta
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
_module._modules[name].to(weight.device)
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
def tune_lora_scale(model, alpha: float = 1.0):
|
| 878 |
+
for _module in model.modules():
|
| 879 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
|
| 880 |
+
_module.scale = alpha
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
def set_lora_diag(model, diag: torch.Tensor):
|
| 884 |
+
for _module in model.modules():
|
| 885 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
|
| 886 |
+
_module.set_selector_from_diag(diag)
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
def _text_lora_path(path: str) -> str:
|
| 890 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
| 891 |
+
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def _ti_lora_path(path: str) -> str:
|
| 895 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
| 896 |
+
return ".".join(path.split(".")[:-1] + ["ti", "pt"])
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
def apply_learned_embed_in_clip(
|
| 900 |
+
learned_embeds,
|
| 901 |
+
text_encoder,
|
| 902 |
+
tokenizer,
|
| 903 |
+
token: Optional[Union[str, List[str]]] = None,
|
| 904 |
+
idempotent=False,
|
| 905 |
+
):
|
| 906 |
+
if isinstance(token, str):
|
| 907 |
+
trained_tokens = [token]
|
| 908 |
+
elif isinstance(token, list):
|
| 909 |
+
assert len(learned_embeds.keys()) == len(
|
| 910 |
+
token
|
| 911 |
+
), "The number of tokens and the number of embeds should be the same"
|
| 912 |
+
trained_tokens = token
|
| 913 |
+
else:
|
| 914 |
+
trained_tokens = list(learned_embeds.keys())
|
| 915 |
+
|
| 916 |
+
for token in trained_tokens:
|
| 917 |
+
print(token)
|
| 918 |
+
embeds = learned_embeds[token]
|
| 919 |
+
|
| 920 |
+
# cast to dtype of text_encoder
|
| 921 |
+
dtype = text_encoder.get_input_embeddings().weight.dtype
|
| 922 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
| 923 |
+
|
| 924 |
+
i = 1
|
| 925 |
+
if not idempotent:
|
| 926 |
+
while num_added_tokens == 0:
|
| 927 |
+
print(f"The tokenizer already contains the token {token}.")
|
| 928 |
+
token = f"{token[:-1]}-{i}>"
|
| 929 |
+
print(f"Attempting to add the token {token}.")
|
| 930 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
| 931 |
+
i += 1
|
| 932 |
+
elif num_added_tokens == 0 and idempotent:
|
| 933 |
+
print(f"The tokenizer already contains the token {token}.")
|
| 934 |
+
print(f"Replacing {token} embedding.")
|
| 935 |
+
|
| 936 |
+
# resize the token embeddings
|
| 937 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
| 938 |
+
|
| 939 |
+
# get the id for the token and assign the embeds
|
| 940 |
+
token_id = tokenizer.convert_tokens_to_ids(token)
|
| 941 |
+
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
| 942 |
+
return token
|
| 943 |
+
|
| 944 |
+
|
| 945 |
+
def load_learned_embed_in_clip(
|
| 946 |
+
learned_embeds_path,
|
| 947 |
+
text_encoder,
|
| 948 |
+
tokenizer,
|
| 949 |
+
token: Optional[Union[str, List[str]]] = None,
|
| 950 |
+
idempotent=False,
|
| 951 |
+
):
|
| 952 |
+
learned_embeds = torch.load(learned_embeds_path)
|
| 953 |
+
apply_learned_embed_in_clip(
|
| 954 |
+
learned_embeds, text_encoder, tokenizer, token, idempotent
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
|
| 958 |
+
def patch_pipe(
|
| 959 |
+
pipe,
|
| 960 |
+
maybe_unet_path,
|
| 961 |
+
token: Optional[str] = None,
|
| 962 |
+
r: int = 4,
|
| 963 |
+
patch_unet=True,
|
| 964 |
+
patch_text=True,
|
| 965 |
+
patch_ti=True,
|
| 966 |
+
idempotent_token=True,
|
| 967 |
+
unet_target_replace_module=DEFAULT_TARGET_REPLACE,
|
| 968 |
+
text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
| 969 |
+
):
|
| 970 |
+
if maybe_unet_path.endswith(".pt"):
|
| 971 |
+
# torch format
|
| 972 |
+
|
| 973 |
+
if maybe_unet_path.endswith(".ti.pt"):
|
| 974 |
+
unet_path = maybe_unet_path[:-6] + ".pt"
|
| 975 |
+
elif maybe_unet_path.endswith(".text_encoder.pt"):
|
| 976 |
+
unet_path = maybe_unet_path[:-16] + ".pt"
|
| 977 |
+
else:
|
| 978 |
+
unet_path = maybe_unet_path
|
| 979 |
+
|
| 980 |
+
ti_path = _ti_lora_path(unet_path)
|
| 981 |
+
text_path = _text_lora_path(unet_path)
|
| 982 |
+
|
| 983 |
+
if patch_unet:
|
| 984 |
+
print("LoRA : Patching Unet")
|
| 985 |
+
monkeypatch_or_replace_lora(
|
| 986 |
+
pipe.unet,
|
| 987 |
+
torch.load(unet_path),
|
| 988 |
+
r=r,
|
| 989 |
+
target_replace_module=unet_target_replace_module,
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
if patch_text:
|
| 993 |
+
print("LoRA : Patching text encoder")
|
| 994 |
+
monkeypatch_or_replace_lora(
|
| 995 |
+
pipe.text_encoder,
|
| 996 |
+
torch.load(text_path),
|
| 997 |
+
target_replace_module=text_target_replace_module,
|
| 998 |
+
r=r,
|
| 999 |
+
)
|
| 1000 |
+
if patch_ti:
|
| 1001 |
+
print("LoRA : Patching token input")
|
| 1002 |
+
token = load_learned_embed_in_clip(
|
| 1003 |
+
ti_path,
|
| 1004 |
+
pipe.text_encoder,
|
| 1005 |
+
pipe.tokenizer,
|
| 1006 |
+
token=token,
|
| 1007 |
+
idempotent=idempotent_token,
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
elif maybe_unet_path.endswith(".safetensors"):
|
| 1011 |
+
safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
|
| 1012 |
+
monkeypatch_or_replace_safeloras(pipe, safeloras)
|
| 1013 |
+
tok_dict = parse_safeloras_embeds(safeloras)
|
| 1014 |
+
if patch_ti:
|
| 1015 |
+
apply_learned_embed_in_clip(
|
| 1016 |
+
tok_dict,
|
| 1017 |
+
pipe.text_encoder,
|
| 1018 |
+
pipe.tokenizer,
|
| 1019 |
+
token=token,
|
| 1020 |
+
idempotent=idempotent_token,
|
| 1021 |
+
)
|
| 1022 |
+
return tok_dict
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
@torch.no_grad()
|
| 1026 |
+
def inspect_lora(model):
|
| 1027 |
+
moved = {}
|
| 1028 |
+
|
| 1029 |
+
for name, _module in model.named_modules():
|
| 1030 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
|
| 1031 |
+
ups = _module.lora_up.weight.data.clone()
|
| 1032 |
+
downs = _module.lora_down.weight.data.clone()
|
| 1033 |
+
|
| 1034 |
+
wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1)
|
| 1035 |
+
|
| 1036 |
+
dist = wght.flatten().abs().mean().item()
|
| 1037 |
+
if name in moved:
|
| 1038 |
+
moved[name].append(dist)
|
| 1039 |
+
else:
|
| 1040 |
+
moved[name] = [dist]
|
| 1041 |
+
|
| 1042 |
+
return moved
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
def save_all(
|
| 1046 |
+
unet,
|
| 1047 |
+
text_encoder,
|
| 1048 |
+
save_path,
|
| 1049 |
+
placeholder_token_ids=None,
|
| 1050 |
+
placeholder_tokens=None,
|
| 1051 |
+
save_lora=True,
|
| 1052 |
+
save_ti=True,
|
| 1053 |
+
target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
| 1054 |
+
target_replace_module_unet=DEFAULT_TARGET_REPLACE,
|
| 1055 |
+
safe_form=True,
|
| 1056 |
+
):
|
| 1057 |
+
if not safe_form:
|
| 1058 |
+
# save ti
|
| 1059 |
+
if save_ti:
|
| 1060 |
+
ti_path = _ti_lora_path(save_path)
|
| 1061 |
+
learned_embeds_dict = {}
|
| 1062 |
+
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
|
| 1063 |
+
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
|
| 1064 |
+
print(
|
| 1065 |
+
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
|
| 1066 |
+
learned_embeds[:4],
|
| 1067 |
+
)
|
| 1068 |
+
learned_embeds_dict[tok] = learned_embeds.detach().cpu()
|
| 1069 |
+
|
| 1070 |
+
torch.save(learned_embeds_dict, ti_path)
|
| 1071 |
+
print("Ti saved to ", ti_path)
|
| 1072 |
+
|
| 1073 |
+
# save text encoder
|
| 1074 |
+
if save_lora:
|
| 1075 |
+
|
| 1076 |
+
save_lora_weight(
|
| 1077 |
+
unet, save_path, target_replace_module=target_replace_module_unet
|
| 1078 |
+
)
|
| 1079 |
+
print("Unet saved to ", save_path)
|
| 1080 |
+
|
| 1081 |
+
save_lora_weight(
|
| 1082 |
+
text_encoder,
|
| 1083 |
+
_text_lora_path(save_path),
|
| 1084 |
+
target_replace_module=target_replace_module_text,
|
| 1085 |
+
)
|
| 1086 |
+
print("Text Encoder saved to ", _text_lora_path(save_path))
|
| 1087 |
+
|
| 1088 |
+
else:
|
| 1089 |
+
assert save_path.endswith(
|
| 1090 |
+
".safetensors"
|
| 1091 |
+
), f"Save path : {save_path} should end with .safetensors"
|
| 1092 |
+
|
| 1093 |
+
loras = {}
|
| 1094 |
+
embeds = {}
|
| 1095 |
+
|
| 1096 |
+
if save_lora:
|
| 1097 |
+
|
| 1098 |
+
loras["unet"] = (unet, target_replace_module_unet)
|
| 1099 |
+
loras["text_encoder"] = (text_encoder, target_replace_module_text)
|
| 1100 |
+
|
| 1101 |
+
if save_ti:
|
| 1102 |
+
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
|
| 1103 |
+
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
|
| 1104 |
+
print(
|
| 1105 |
+
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
|
| 1106 |
+
learned_embeds[:4],
|
| 1107 |
+
)
|
| 1108 |
+
embeds[tok] = learned_embeds.detach().cpu()
|
| 1109 |
+
|
| 1110 |
+
save_safeloras_with_embeds(loras, embeds, save_path)
|
lora_diffusion/lora_manager.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
import torch
|
| 3 |
+
from safetensors import safe_open
|
| 4 |
+
from diffusers import StableDiffusionPipeline
|
| 5 |
+
from .lora import (
|
| 6 |
+
monkeypatch_or_replace_safeloras,
|
| 7 |
+
apply_learned_embed_in_clip,
|
| 8 |
+
set_lora_diag,
|
| 9 |
+
parse_safeloras_embeds,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def lora_join(lora_safetenors: list):
|
| 14 |
+
metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors]
|
| 15 |
+
_total_metadata = {}
|
| 16 |
+
total_metadata = {}
|
| 17 |
+
total_tensor = {}
|
| 18 |
+
total_rank = 0
|
| 19 |
+
ranklist = []
|
| 20 |
+
for _metadata in metadatas:
|
| 21 |
+
rankset = []
|
| 22 |
+
for k, v in _metadata.items():
|
| 23 |
+
if k.endswith("rank"):
|
| 24 |
+
rankset.append(int(v))
|
| 25 |
+
|
| 26 |
+
assert len(set(rankset)) <= 1, "Rank should be the same per model"
|
| 27 |
+
if len(rankset) == 0:
|
| 28 |
+
rankset = [0]
|
| 29 |
+
|
| 30 |
+
total_rank += rankset[0]
|
| 31 |
+
_total_metadata.update(_metadata)
|
| 32 |
+
ranklist.append(rankset[0])
|
| 33 |
+
|
| 34 |
+
# remove metadata about tokens
|
| 35 |
+
for k, v in _total_metadata.items():
|
| 36 |
+
if v != "<embed>":
|
| 37 |
+
total_metadata[k] = v
|
| 38 |
+
|
| 39 |
+
tensorkeys = set()
|
| 40 |
+
for safelora in lora_safetenors:
|
| 41 |
+
tensorkeys.update(safelora.keys())
|
| 42 |
+
|
| 43 |
+
for keys in tensorkeys:
|
| 44 |
+
if keys.startswith("text_encoder") or keys.startswith("unet"):
|
| 45 |
+
tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors]
|
| 46 |
+
|
| 47 |
+
is_down = keys.endswith("down")
|
| 48 |
+
|
| 49 |
+
if is_down:
|
| 50 |
+
_tensor = torch.cat(tensorset, dim=0)
|
| 51 |
+
assert _tensor.shape[0] == total_rank
|
| 52 |
+
else:
|
| 53 |
+
_tensor = torch.cat(tensorset, dim=1)
|
| 54 |
+
assert _tensor.shape[1] == total_rank
|
| 55 |
+
|
| 56 |
+
total_tensor[keys] = _tensor
|
| 57 |
+
keys_rank = ":".join(keys.split(":")[:-1]) + ":rank"
|
| 58 |
+
total_metadata[keys_rank] = str(total_rank)
|
| 59 |
+
token_size_list = []
|
| 60 |
+
for idx, safelora in enumerate(lora_safetenors):
|
| 61 |
+
tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"]
|
| 62 |
+
for jdx, token in enumerate(sorted(tokens)):
|
| 63 |
+
|
| 64 |
+
total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token)
|
| 65 |
+
total_metadata[f"<s{idx}-{jdx}>"] = "<embed>"
|
| 66 |
+
|
| 67 |
+
print(f"Embedding {token} replaced to <s{idx}-{jdx}>")
|
| 68 |
+
|
| 69 |
+
token_size_list.append(len(tokens))
|
| 70 |
+
|
| 71 |
+
return total_tensor, total_metadata, ranklist, token_size_list
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class DummySafeTensorObject:
|
| 75 |
+
def __init__(self, tensor: dict, metadata):
|
| 76 |
+
self.tensor = tensor
|
| 77 |
+
self._metadata = metadata
|
| 78 |
+
|
| 79 |
+
def keys(self):
|
| 80 |
+
return self.tensor.keys()
|
| 81 |
+
|
| 82 |
+
def metadata(self):
|
| 83 |
+
return self._metadata
|
| 84 |
+
|
| 85 |
+
def get_tensor(self, key):
|
| 86 |
+
return self.tensor[key]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class LoRAManager:
|
| 90 |
+
def __init__(self, lora_paths_list: List[str], pipe: StableDiffusionPipeline):
|
| 91 |
+
|
| 92 |
+
self.lora_paths_list = lora_paths_list
|
| 93 |
+
self.pipe = pipe
|
| 94 |
+
self._setup()
|
| 95 |
+
|
| 96 |
+
def _setup(self):
|
| 97 |
+
|
| 98 |
+
self._lora_safetenors = [
|
| 99 |
+
safe_open(path, framework="pt", device="cpu")
|
| 100 |
+
for path in self.lora_paths_list
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
(
|
| 104 |
+
total_tensor,
|
| 105 |
+
total_metadata,
|
| 106 |
+
self.ranklist,
|
| 107 |
+
self.token_size_list,
|
| 108 |
+
) = lora_join(self._lora_safetenors)
|
| 109 |
+
|
| 110 |
+
self.total_safelora = DummySafeTensorObject(total_tensor, total_metadata)
|
| 111 |
+
|
| 112 |
+
monkeypatch_or_replace_safeloras(self.pipe, self.total_safelora)
|
| 113 |
+
tok_dict = parse_safeloras_embeds(self.total_safelora)
|
| 114 |
+
|
| 115 |
+
apply_learned_embed_in_clip(
|
| 116 |
+
tok_dict,
|
| 117 |
+
self.pipe.text_encoder,
|
| 118 |
+
self.pipe.tokenizer,
|
| 119 |
+
token=None,
|
| 120 |
+
idempotent=True,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def tune(self, scales):
|
| 124 |
+
|
| 125 |
+
assert len(scales) == len(
|
| 126 |
+
self.ranklist
|
| 127 |
+
), "Scale list should be the same length as ranklist"
|
| 128 |
+
|
| 129 |
+
diags = []
|
| 130 |
+
for scale, rank in zip(scales, self.ranklist):
|
| 131 |
+
diags = diags + [scale] * rank
|
| 132 |
+
|
| 133 |
+
set_lora_diag(self.pipe.unet, torch.tensor(diags))
|
| 134 |
+
|
| 135 |
+
def prompt(self, prompt):
|
| 136 |
+
if prompt is not None:
|
| 137 |
+
for idx, tok_size in enumerate(self.token_size_list):
|
| 138 |
+
prompt = prompt.replace(
|
| 139 |
+
f"<{idx + 1}>",
|
| 140 |
+
"".join([f"<s{idx}-{jdx}>" for jdx in range(tok_size)]),
|
| 141 |
+
)
|
| 142 |
+
# TODO : Rescale LoRA + Text inputs based on prompt scale params
|
| 143 |
+
|
| 144 |
+
return prompt
|
lora_diffusion/preprocess_files.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Have SwinIR upsample
|
| 2 |
+
# Have BLIP auto caption
|
| 3 |
+
# Have CLIPSeg auto mask concept
|
| 4 |
+
|
| 5 |
+
from typing import List, Literal, Union, Optional, Tuple
|
| 6 |
+
import os
|
| 7 |
+
from PIL import Image, ImageFilter
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
import fire
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import glob
|
| 13 |
+
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@torch.no_grad()
|
| 17 |
+
def swin_ir_sr(
|
| 18 |
+
images: List[Image.Image],
|
| 19 |
+
model_id: Literal[
|
| 20 |
+
"caidas/swin2SR-classical-sr-x2-64", "caidas/swin2SR-classical-sr-x4-48"
|
| 21 |
+
] = "caidas/swin2SR-classical-sr-x2-64",
|
| 22 |
+
target_size: Optional[Tuple[int, int]] = None,
|
| 23 |
+
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
|
| 24 |
+
**kwargs,
|
| 25 |
+
) -> List[Image.Image]:
|
| 26 |
+
"""
|
| 27 |
+
Upscales images using SwinIR. Returns a list of PIL images.
|
| 28 |
+
"""
|
| 29 |
+
# So this is currently in main branch, so this can be used in the future I guess?
|
| 30 |
+
from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor
|
| 31 |
+
|
| 32 |
+
model = Swin2SRForImageSuperResolution.from_pretrained(
|
| 33 |
+
model_id,
|
| 34 |
+
).to(device)
|
| 35 |
+
processor = Swin2SRImageProcessor()
|
| 36 |
+
|
| 37 |
+
out_images = []
|
| 38 |
+
|
| 39 |
+
for image in tqdm(images):
|
| 40 |
+
|
| 41 |
+
ori_w, ori_h = image.size
|
| 42 |
+
if target_size is not None:
|
| 43 |
+
if ori_w >= target_size[0] and ori_h >= target_size[1]:
|
| 44 |
+
out_images.append(image)
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
inputs = processor(image, return_tensors="pt").to(device)
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
outputs = model(**inputs)
|
| 50 |
+
|
| 51 |
+
output = (
|
| 52 |
+
outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
| 53 |
+
)
|
| 54 |
+
output = np.moveaxis(output, source=0, destination=-1)
|
| 55 |
+
output = (output * 255.0).round().astype(np.uint8)
|
| 56 |
+
output = Image.fromarray(output)
|
| 57 |
+
|
| 58 |
+
out_images.append(output)
|
| 59 |
+
|
| 60 |
+
return out_images
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@torch.no_grad()
|
| 64 |
+
def clipseg_mask_generator(
|
| 65 |
+
images: List[Image.Image],
|
| 66 |
+
target_prompts: Union[List[str], str],
|
| 67 |
+
model_id: Literal[
|
| 68 |
+
"CIDAS/clipseg-rd64-refined", "CIDAS/clipseg-rd16"
|
| 69 |
+
] = "CIDAS/clipseg-rd64-refined",
|
| 70 |
+
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
|
| 71 |
+
bias: float = 0.01,
|
| 72 |
+
temp: float = 1.0,
|
| 73 |
+
**kwargs,
|
| 74 |
+
) -> List[Image.Image]:
|
| 75 |
+
"""
|
| 76 |
+
Returns a greyscale mask for each image, where the mask is the probability of the target prompt being present in the image
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
if isinstance(target_prompts, str):
|
| 80 |
+
print(
|
| 81 |
+
f'Warning: only one target prompt "{target_prompts}" was given, so it will be used for all images'
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
target_prompts = [target_prompts] * len(images)
|
| 85 |
+
|
| 86 |
+
processor = CLIPSegProcessor.from_pretrained(model_id)
|
| 87 |
+
model = CLIPSegForImageSegmentation.from_pretrained(model_id).to(device)
|
| 88 |
+
|
| 89 |
+
masks = []
|
| 90 |
+
|
| 91 |
+
for image, prompt in tqdm(zip(images, target_prompts)):
|
| 92 |
+
|
| 93 |
+
original_size = image.size
|
| 94 |
+
|
| 95 |
+
inputs = processor(
|
| 96 |
+
text=[prompt, ""],
|
| 97 |
+
images=[image] * 2,
|
| 98 |
+
padding="max_length",
|
| 99 |
+
truncation=True,
|
| 100 |
+
return_tensors="pt",
|
| 101 |
+
).to(device)
|
| 102 |
+
|
| 103 |
+
outputs = model(**inputs)
|
| 104 |
+
|
| 105 |
+
logits = outputs.logits
|
| 106 |
+
probs = torch.nn.functional.softmax(logits / temp, dim=0)[0]
|
| 107 |
+
probs = (probs + bias).clamp_(0, 1)
|
| 108 |
+
probs = 255 * probs / probs.max()
|
| 109 |
+
|
| 110 |
+
# make mask greyscale
|
| 111 |
+
mask = Image.fromarray(probs.cpu().numpy()).convert("L")
|
| 112 |
+
|
| 113 |
+
# resize mask to original size
|
| 114 |
+
mask = mask.resize(original_size)
|
| 115 |
+
|
| 116 |
+
masks.append(mask)
|
| 117 |
+
|
| 118 |
+
return masks
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@torch.no_grad()
|
| 122 |
+
def blip_captioning_dataset(
|
| 123 |
+
images: List[Image.Image],
|
| 124 |
+
text: Optional[str] = None,
|
| 125 |
+
model_id: Literal[
|
| 126 |
+
"Salesforce/blip-image-captioning-large",
|
| 127 |
+
"Salesforce/blip-image-captioning-base",
|
| 128 |
+
] = "Salesforce/blip-image-captioning-large",
|
| 129 |
+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
| 130 |
+
**kwargs,
|
| 131 |
+
) -> List[str]:
|
| 132 |
+
"""
|
| 133 |
+
Returns a list of captions for the given images
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
| 137 |
+
|
| 138 |
+
processor = BlipProcessor.from_pretrained(model_id)
|
| 139 |
+
model = BlipForConditionalGeneration.from_pretrained(model_id).to(device)
|
| 140 |
+
captions = []
|
| 141 |
+
|
| 142 |
+
for image in tqdm(images):
|
| 143 |
+
inputs = processor(image, text=text, return_tensors="pt").to("cuda")
|
| 144 |
+
out = model.generate(
|
| 145 |
+
**inputs, max_length=150, do_sample=True, top_k=50, temperature=0.7
|
| 146 |
+
)
|
| 147 |
+
caption = processor.decode(out[0], skip_special_tokens=True)
|
| 148 |
+
|
| 149 |
+
captions.append(caption)
|
| 150 |
+
|
| 151 |
+
return captions
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def face_mask_google_mediapipe(
|
| 155 |
+
images: List[Image.Image], blur_amount: float = 80.0, bias: float = 0.05
|
| 156 |
+
) -> List[Image.Image]:
|
| 157 |
+
"""
|
| 158 |
+
Returns a list of images with mask on the face parts.
|
| 159 |
+
"""
|
| 160 |
+
import mediapipe as mp
|
| 161 |
+
|
| 162 |
+
mp_face_detection = mp.solutions.face_detection
|
| 163 |
+
|
| 164 |
+
face_detection = mp_face_detection.FaceDetection(
|
| 165 |
+
model_selection=1, min_detection_confidence=0.5
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
masks = []
|
| 169 |
+
for image in tqdm(images):
|
| 170 |
+
|
| 171 |
+
image = np.array(image)
|
| 172 |
+
|
| 173 |
+
results = face_detection.process(image)
|
| 174 |
+
black_image = np.ones((image.shape[0], image.shape[1]), dtype=np.uint8)
|
| 175 |
+
|
| 176 |
+
if results.detections:
|
| 177 |
+
|
| 178 |
+
for detection in results.detections:
|
| 179 |
+
|
| 180 |
+
x_min = int(
|
| 181 |
+
detection.location_data.relative_bounding_box.xmin * image.shape[1]
|
| 182 |
+
)
|
| 183 |
+
y_min = int(
|
| 184 |
+
detection.location_data.relative_bounding_box.ymin * image.shape[0]
|
| 185 |
+
)
|
| 186 |
+
width = int(
|
| 187 |
+
detection.location_data.relative_bounding_box.width * image.shape[1]
|
| 188 |
+
)
|
| 189 |
+
height = int(
|
| 190 |
+
detection.location_data.relative_bounding_box.height
|
| 191 |
+
* image.shape[0]
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# draw the colored rectangle
|
| 195 |
+
black_image[y_min : y_min + height, x_min : x_min + width] = 255
|
| 196 |
+
|
| 197 |
+
black_image = Image.fromarray(black_image)
|
| 198 |
+
masks.append(black_image)
|
| 199 |
+
|
| 200 |
+
return masks
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _crop_to_square(
|
| 204 |
+
image: Image.Image, com: List[Tuple[int, int]], resize_to: Optional[int] = None
|
| 205 |
+
):
|
| 206 |
+
cx, cy = com
|
| 207 |
+
width, height = image.size
|
| 208 |
+
if width > height:
|
| 209 |
+
left_possible = max(cx - height / 2, 0)
|
| 210 |
+
left = min(left_possible, width - height)
|
| 211 |
+
right = left + height
|
| 212 |
+
top = 0
|
| 213 |
+
bottom = height
|
| 214 |
+
else:
|
| 215 |
+
left = 0
|
| 216 |
+
right = width
|
| 217 |
+
top_possible = max(cy - width / 2, 0)
|
| 218 |
+
top = min(top_possible, height - width)
|
| 219 |
+
bottom = top + width
|
| 220 |
+
|
| 221 |
+
image = image.crop((left, top, right, bottom))
|
| 222 |
+
|
| 223 |
+
if resize_to:
|
| 224 |
+
image = image.resize((resize_to, resize_to), Image.Resampling.LANCZOS)
|
| 225 |
+
|
| 226 |
+
return image
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _center_of_mass(mask: Image.Image):
|
| 230 |
+
"""
|
| 231 |
+
Returns the center of mass of the mask
|
| 232 |
+
"""
|
| 233 |
+
x, y = np.meshgrid(np.arange(mask.size[0]), np.arange(mask.size[1]))
|
| 234 |
+
|
| 235 |
+
x_ = x * np.array(mask)
|
| 236 |
+
y_ = y * np.array(mask)
|
| 237 |
+
|
| 238 |
+
x = np.sum(x_) / np.sum(mask)
|
| 239 |
+
y = np.sum(y_) / np.sum(mask)
|
| 240 |
+
|
| 241 |
+
return x, y
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def load_and_save_masks_and_captions(
|
| 245 |
+
files: Union[str, List[str]],
|
| 246 |
+
output_dir: str,
|
| 247 |
+
caption_text: Optional[str] = None,
|
| 248 |
+
target_prompts: Optional[Union[List[str], str]] = None,
|
| 249 |
+
target_size: int = 512,
|
| 250 |
+
crop_based_on_salience: bool = True,
|
| 251 |
+
use_face_detection_instead: bool = False,
|
| 252 |
+
temp: float = 1.0,
|
| 253 |
+
n_length: int = -1,
|
| 254 |
+
):
|
| 255 |
+
"""
|
| 256 |
+
Loads images from the given files, generates masks for them, and saves the masks and captions and upscale images
|
| 257 |
+
to output dir.
|
| 258 |
+
"""
|
| 259 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 260 |
+
|
| 261 |
+
# load images
|
| 262 |
+
if isinstance(files, str):
|
| 263 |
+
# check if it is a directory
|
| 264 |
+
if os.path.isdir(files):
|
| 265 |
+
# get all the .png .jpg in the directory
|
| 266 |
+
files = glob.glob(os.path.join(files, "*.png")) + glob.glob(
|
| 267 |
+
os.path.join(files, "*.jpg")
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
if len(files) == 0:
|
| 271 |
+
raise Exception(
|
| 272 |
+
f"No files found in {files}. Either {files} is not a directory or it does not contain any .png or .jpg files."
|
| 273 |
+
)
|
| 274 |
+
if n_length == -1:
|
| 275 |
+
n_length = len(files)
|
| 276 |
+
files = sorted(files)[:n_length]
|
| 277 |
+
|
| 278 |
+
images = [Image.open(file) for file in files]
|
| 279 |
+
|
| 280 |
+
# captions
|
| 281 |
+
print(f"Generating {len(images)} captions...")
|
| 282 |
+
captions = blip_captioning_dataset(images, text=caption_text)
|
| 283 |
+
|
| 284 |
+
if target_prompts is None:
|
| 285 |
+
target_prompts = captions
|
| 286 |
+
|
| 287 |
+
print(f"Generating {len(images)} masks...")
|
| 288 |
+
if not use_face_detection_instead:
|
| 289 |
+
seg_masks = clipseg_mask_generator(
|
| 290 |
+
images=images, target_prompts=target_prompts, temp=temp
|
| 291 |
+
)
|
| 292 |
+
else:
|
| 293 |
+
seg_masks = face_mask_google_mediapipe(images=images)
|
| 294 |
+
|
| 295 |
+
# find the center of mass of the mask
|
| 296 |
+
if crop_based_on_salience:
|
| 297 |
+
coms = [_center_of_mass(mask) for mask in seg_masks]
|
| 298 |
+
else:
|
| 299 |
+
coms = [(image.size[0] / 2, image.size[1] / 2) for image in images]
|
| 300 |
+
# based on the center of mass, crop the image to a square
|
| 301 |
+
images = [
|
| 302 |
+
_crop_to_square(image, com, resize_to=None) for image, com in zip(images, coms)
|
| 303 |
+
]
|
| 304 |
+
|
| 305 |
+
print(f"Upscaling {len(images)} images...")
|
| 306 |
+
# upscale images anyways
|
| 307 |
+
images = swin_ir_sr(images, target_size=(target_size, target_size))
|
| 308 |
+
images = [
|
| 309 |
+
image.resize((target_size, target_size), Image.Resampling.LANCZOS)
|
| 310 |
+
for image in images
|
| 311 |
+
]
|
| 312 |
+
|
| 313 |
+
seg_masks = [
|
| 314 |
+
_crop_to_square(mask, com, resize_to=target_size)
|
| 315 |
+
for mask, com in zip(seg_masks, coms)
|
| 316 |
+
]
|
| 317 |
+
with open(os.path.join(output_dir, "caption.txt"), "w") as f:
|
| 318 |
+
# save images and masks
|
| 319 |
+
for idx, (image, mask, caption) in enumerate(zip(images, seg_masks, captions)):
|
| 320 |
+
image.save(os.path.join(output_dir, f"{idx}.src.jpg"), quality=99)
|
| 321 |
+
mask.save(os.path.join(output_dir, f"{idx}.mask.png"))
|
| 322 |
+
|
| 323 |
+
f.write(caption + "\n")
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def main():
|
| 327 |
+
fire.Fire(load_and_save_masks_and_captions)
|
lora_diffusion/safe_open.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pure python version of Safetensors safe_open
|
| 3 |
+
From https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import mmap
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SafetensorsWrapper:
|
| 14 |
+
def __init__(self, metadata, tensors):
|
| 15 |
+
self._metadata = metadata
|
| 16 |
+
self._tensors = tensors
|
| 17 |
+
|
| 18 |
+
def metadata(self):
|
| 19 |
+
return self._metadata
|
| 20 |
+
|
| 21 |
+
def keys(self):
|
| 22 |
+
return self._tensors.keys()
|
| 23 |
+
|
| 24 |
+
def get_tensor(self, k):
|
| 25 |
+
return self._tensors[k]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
DTYPES = {
|
| 29 |
+
"F32": torch.float32,
|
| 30 |
+
"F16": torch.float16,
|
| 31 |
+
"BF16": torch.bfloat16,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def create_tensor(storage, info, offset):
|
| 36 |
+
dtype = DTYPES[info["dtype"]]
|
| 37 |
+
shape = info["shape"]
|
| 38 |
+
start, stop = info["data_offsets"]
|
| 39 |
+
return (
|
| 40 |
+
torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8)
|
| 41 |
+
.view(dtype=dtype)
|
| 42 |
+
.reshape(shape)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def safe_open(filename, framework="pt", device="cpu"):
|
| 47 |
+
if framework != "pt":
|
| 48 |
+
raise ValueError("`framework` must be 'pt'")
|
| 49 |
+
|
| 50 |
+
with open(filename, mode="r", encoding="utf8") as file_obj:
|
| 51 |
+
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
|
| 52 |
+
header = m.read(8)
|
| 53 |
+
n = int.from_bytes(header, "little")
|
| 54 |
+
metadata_bytes = m.read(n)
|
| 55 |
+
metadata = json.loads(metadata_bytes)
|
| 56 |
+
|
| 57 |
+
size = os.stat(filename).st_size
|
| 58 |
+
storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped()
|
| 59 |
+
offset = n + 8
|
| 60 |
+
|
| 61 |
+
return SafetensorsWrapper(
|
| 62 |
+
metadata=metadata.get("__metadata__", {}),
|
| 63 |
+
tensors={
|
| 64 |
+
name: create_tensor(storage, info, offset).to(device)
|
| 65 |
+
for name, info in metadata.items()
|
| 66 |
+
if name != "__metadata__"
|
| 67 |
+
},
|
| 68 |
+
)
|
lora_diffusion/to_ckpt_v2.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05
|
| 2 |
+
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
|
| 3 |
+
# *Only* converts the UNet, VAE, and Text Encoder.
|
| 4 |
+
# Does not convert optimizer state or any other thing.
|
| 5 |
+
# Written by jachiam
|
| 6 |
+
import argparse
|
| 7 |
+
import os.path as osp
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# =================#
|
| 13 |
+
# UNet Conversion #
|
| 14 |
+
# =================#
|
| 15 |
+
|
| 16 |
+
unet_conversion_map = [
|
| 17 |
+
# (stable-diffusion, HF Diffusers)
|
| 18 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
| 19 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
| 20 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
| 21 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
| 22 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
| 23 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
| 24 |
+
("out.0.weight", "conv_norm_out.weight"),
|
| 25 |
+
("out.0.bias", "conv_norm_out.bias"),
|
| 26 |
+
("out.2.weight", "conv_out.weight"),
|
| 27 |
+
("out.2.bias", "conv_out.bias"),
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
unet_conversion_map_resnet = [
|
| 31 |
+
# (stable-diffusion, HF Diffusers)
|
| 32 |
+
("in_layers.0", "norm1"),
|
| 33 |
+
("in_layers.2", "conv1"),
|
| 34 |
+
("out_layers.0", "norm2"),
|
| 35 |
+
("out_layers.3", "conv2"),
|
| 36 |
+
("emb_layers.1", "time_emb_proj"),
|
| 37 |
+
("skip_connection", "conv_shortcut"),
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
unet_conversion_map_layer = []
|
| 41 |
+
# hardcoded number of downblocks and resnets/attentions...
|
| 42 |
+
# would need smarter logic for other networks.
|
| 43 |
+
for i in range(4):
|
| 44 |
+
# loop over downblocks/upblocks
|
| 45 |
+
|
| 46 |
+
for j in range(2):
|
| 47 |
+
# loop over resnets/attentions for downblocks
|
| 48 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
| 49 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
| 50 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
| 51 |
+
|
| 52 |
+
if i < 3:
|
| 53 |
+
# no attention layers in down_blocks.3
|
| 54 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
| 55 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
| 56 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
| 57 |
+
|
| 58 |
+
for j in range(3):
|
| 59 |
+
# loop over resnets/attentions for upblocks
|
| 60 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
| 61 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
| 62 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
| 63 |
+
|
| 64 |
+
if i > 0:
|
| 65 |
+
# no attention layers in up_blocks.0
|
| 66 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
| 67 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
| 68 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
| 69 |
+
|
| 70 |
+
if i < 3:
|
| 71 |
+
# no downsample in down_blocks.3
|
| 72 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
| 73 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
| 74 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
| 75 |
+
|
| 76 |
+
# no upsample in up_blocks.3
|
| 77 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
| 78 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
| 79 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
| 80 |
+
|
| 81 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
| 82 |
+
sd_mid_atn_prefix = "middle_block.1."
|
| 83 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
| 84 |
+
|
| 85 |
+
for j in range(2):
|
| 86 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
| 87 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
| 88 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def convert_unet_state_dict(unet_state_dict):
|
| 92 |
+
# buyer beware: this is a *brittle* function,
|
| 93 |
+
# and correct output requires that all of these pieces interact in
|
| 94 |
+
# the exact order in which I have arranged them.
|
| 95 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
| 96 |
+
for sd_name, hf_name in unet_conversion_map:
|
| 97 |
+
mapping[hf_name] = sd_name
|
| 98 |
+
for k, v in mapping.items():
|
| 99 |
+
if "resnets" in k:
|
| 100 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
| 101 |
+
v = v.replace(hf_part, sd_part)
|
| 102 |
+
mapping[k] = v
|
| 103 |
+
for k, v in mapping.items():
|
| 104 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
| 105 |
+
v = v.replace(hf_part, sd_part)
|
| 106 |
+
mapping[k] = v
|
| 107 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
| 108 |
+
return new_state_dict
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ================#
|
| 112 |
+
# VAE Conversion #
|
| 113 |
+
# ================#
|
| 114 |
+
|
| 115 |
+
vae_conversion_map = [
|
| 116 |
+
# (stable-diffusion, HF Diffusers)
|
| 117 |
+
("nin_shortcut", "conv_shortcut"),
|
| 118 |
+
("norm_out", "conv_norm_out"),
|
| 119 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
for i in range(4):
|
| 123 |
+
# down_blocks have two resnets
|
| 124 |
+
for j in range(2):
|
| 125 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
| 126 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
| 127 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
| 128 |
+
|
| 129 |
+
if i < 3:
|
| 130 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
| 131 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
| 132 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
| 133 |
+
|
| 134 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
| 135 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
| 136 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
| 137 |
+
|
| 138 |
+
# up_blocks have three resnets
|
| 139 |
+
# also, up blocks in hf are numbered in reverse from sd
|
| 140 |
+
for j in range(3):
|
| 141 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
| 142 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
| 143 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
| 144 |
+
|
| 145 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
| 146 |
+
for i in range(2):
|
| 147 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
| 148 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
| 149 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
vae_conversion_map_attn = [
|
| 153 |
+
# (stable-diffusion, HF Diffusers)
|
| 154 |
+
("norm.", "group_norm."),
|
| 155 |
+
("q.", "query."),
|
| 156 |
+
("k.", "key."),
|
| 157 |
+
("v.", "value."),
|
| 158 |
+
("proj_out.", "proj_attn."),
|
| 159 |
+
]
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def reshape_weight_for_sd(w):
|
| 163 |
+
# convert HF linear weights to SD conv2d weights
|
| 164 |
+
return w.reshape(*w.shape, 1, 1)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def convert_vae_state_dict(vae_state_dict):
|
| 168 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
| 169 |
+
for k, v in mapping.items():
|
| 170 |
+
for sd_part, hf_part in vae_conversion_map:
|
| 171 |
+
v = v.replace(hf_part, sd_part)
|
| 172 |
+
mapping[k] = v
|
| 173 |
+
for k, v in mapping.items():
|
| 174 |
+
if "attentions" in k:
|
| 175 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
| 176 |
+
v = v.replace(hf_part, sd_part)
|
| 177 |
+
mapping[k] = v
|
| 178 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
| 179 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
| 180 |
+
for k, v in new_state_dict.items():
|
| 181 |
+
for weight_name in weights_to_convert:
|
| 182 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
| 183 |
+
print(f"Reshaping {k} for SD format")
|
| 184 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
| 185 |
+
return new_state_dict
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# =========================#
|
| 189 |
+
# Text Encoder Conversion #
|
| 190 |
+
# =========================#
|
| 191 |
+
# pretty much a no-op
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def convert_text_enc_state_dict(text_enc_dict):
|
| 195 |
+
return text_enc_dict
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def convert_to_ckpt(model_path, checkpoint_path, as_half):
|
| 199 |
+
|
| 200 |
+
assert model_path is not None, "Must provide a model path!"
|
| 201 |
+
|
| 202 |
+
assert checkpoint_path is not None, "Must provide a checkpoint path!"
|
| 203 |
+
|
| 204 |
+
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
| 205 |
+
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
| 206 |
+
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
| 207 |
+
|
| 208 |
+
# Convert the UNet model
|
| 209 |
+
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
| 210 |
+
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
| 211 |
+
unet_state_dict = {
|
| 212 |
+
"model.diffusion_model." + k: v for k, v in unet_state_dict.items()
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
# Convert the VAE model
|
| 216 |
+
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
| 217 |
+
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
| 218 |
+
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
| 219 |
+
|
| 220 |
+
# Convert the text encoder model
|
| 221 |
+
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
| 222 |
+
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
| 223 |
+
text_enc_dict = {
|
| 224 |
+
"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
# Put together new checkpoint
|
| 228 |
+
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
| 229 |
+
if as_half:
|
| 230 |
+
state_dict = {k: v.half() for k, v in state_dict.items()}
|
| 231 |
+
state_dict = {"state_dict": state_dict}
|
| 232 |
+
torch.save(state_dict, checkpoint_path)
|
lora_diffusion/utils.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from transformers import (
|
| 6 |
+
CLIPProcessor,
|
| 7 |
+
CLIPTextModelWithProjection,
|
| 8 |
+
CLIPTokenizer,
|
| 9 |
+
CLIPVisionModelWithProjection,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
from diffusers import StableDiffusionPipeline
|
| 13 |
+
from .lora import patch_pipe, tune_lora_scale, _text_lora_path, _ti_lora_path
|
| 14 |
+
import os
|
| 15 |
+
import glob
|
| 16 |
+
import math
|
| 17 |
+
|
| 18 |
+
EXAMPLE_PROMPTS = [
|
| 19 |
+
"<obj> swimming in a pool",
|
| 20 |
+
"<obj> at a beach with a view of seashore",
|
| 21 |
+
"<obj> in times square",
|
| 22 |
+
"<obj> wearing sunglasses",
|
| 23 |
+
"<obj> in a construction outfit",
|
| 24 |
+
"<obj> playing with a ball",
|
| 25 |
+
"<obj> wearing headphones",
|
| 26 |
+
"<obj> oil painting ghibli inspired",
|
| 27 |
+
"<obj> working on the laptop",
|
| 28 |
+
"<obj> with mountains and sunset in background",
|
| 29 |
+
"Painting of <obj> at a beach by artist claude monet",
|
| 30 |
+
"<obj> digital painting 3d render geometric style",
|
| 31 |
+
"A screaming <obj>",
|
| 32 |
+
"A depressed <obj>",
|
| 33 |
+
"A sleeping <obj>",
|
| 34 |
+
"A sad <obj>",
|
| 35 |
+
"A joyous <obj>",
|
| 36 |
+
"A frowning <obj>",
|
| 37 |
+
"A sculpture of <obj>",
|
| 38 |
+
"<obj> near a pool",
|
| 39 |
+
"<obj> at a beach with a view of seashore",
|
| 40 |
+
"<obj> in a garden",
|
| 41 |
+
"<obj> in grand canyon",
|
| 42 |
+
"<obj> floating in ocean",
|
| 43 |
+
"<obj> and an armchair",
|
| 44 |
+
"A maple tree on the side of <obj>",
|
| 45 |
+
"<obj> and an orange sofa",
|
| 46 |
+
"<obj> with chocolate cake on it",
|
| 47 |
+
"<obj> with a vase of rose flowers on it",
|
| 48 |
+
"A digital illustration of <obj>",
|
| 49 |
+
"Georgia O'Keeffe style <obj> painting",
|
| 50 |
+
"A watercolor painting of <obj> on a beach",
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def image_grid(_imgs, rows=None, cols=None):
|
| 55 |
+
|
| 56 |
+
if rows is None and cols is None:
|
| 57 |
+
rows = cols = math.ceil(len(_imgs) ** 0.5)
|
| 58 |
+
|
| 59 |
+
if rows is None:
|
| 60 |
+
rows = math.ceil(len(_imgs) / cols)
|
| 61 |
+
if cols is None:
|
| 62 |
+
cols = math.ceil(len(_imgs) / rows)
|
| 63 |
+
|
| 64 |
+
w, h = _imgs[0].size
|
| 65 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
| 66 |
+
grid_w, grid_h = grid.size
|
| 67 |
+
|
| 68 |
+
for i, img in enumerate(_imgs):
|
| 69 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
| 70 |
+
return grid
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def text_img_alignment(img_embeds, text_embeds, target_img_embeds):
|
| 74 |
+
# evaluation inspired from textual inversion paper
|
| 75 |
+
# https://arxiv.org/abs/2208.01618
|
| 76 |
+
|
| 77 |
+
# text alignment
|
| 78 |
+
assert img_embeds.shape[0] == text_embeds.shape[0]
|
| 79 |
+
text_img_sim = (img_embeds * text_embeds).sum(dim=-1) / (
|
| 80 |
+
img_embeds.norm(dim=-1) * text_embeds.norm(dim=-1)
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# image alignment
|
| 84 |
+
img_embed_normalized = img_embeds / img_embeds.norm(dim=-1, keepdim=True)
|
| 85 |
+
|
| 86 |
+
avg_target_img_embed = (
|
| 87 |
+
(target_img_embeds / target_img_embeds.norm(dim=-1, keepdim=True))
|
| 88 |
+
.mean(dim=0)
|
| 89 |
+
.unsqueeze(0)
|
| 90 |
+
.repeat(img_embeds.shape[0], 1)
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
img_img_sim = (img_embed_normalized * avg_target_img_embed).sum(dim=-1)
|
| 94 |
+
|
| 95 |
+
return {
|
| 96 |
+
"text_alignment_avg": text_img_sim.mean().item(),
|
| 97 |
+
"image_alignment_avg": img_img_sim.mean().item(),
|
| 98 |
+
"text_alignment_all": text_img_sim.tolist(),
|
| 99 |
+
"image_alignment_all": img_img_sim.tolist(),
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def prepare_clip_model_sets(eval_clip_id: str = "openai/clip-vit-large-patch14"):
|
| 104 |
+
text_model = CLIPTextModelWithProjection.from_pretrained(eval_clip_id)
|
| 105 |
+
tokenizer = CLIPTokenizer.from_pretrained(eval_clip_id)
|
| 106 |
+
vis_model = CLIPVisionModelWithProjection.from_pretrained(eval_clip_id)
|
| 107 |
+
processor = CLIPProcessor.from_pretrained(eval_clip_id)
|
| 108 |
+
|
| 109 |
+
return text_model, tokenizer, vis_model, processor
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def evaluate_pipe(
|
| 113 |
+
pipe,
|
| 114 |
+
target_images: List[Image.Image],
|
| 115 |
+
class_token: str = "",
|
| 116 |
+
learnt_token: str = "",
|
| 117 |
+
guidance_scale: float = 5.0,
|
| 118 |
+
seed=0,
|
| 119 |
+
clip_model_sets=None,
|
| 120 |
+
eval_clip_id: str = "openai/clip-vit-large-patch14",
|
| 121 |
+
n_test: int = 10,
|
| 122 |
+
n_step: int = 50,
|
| 123 |
+
):
|
| 124 |
+
|
| 125 |
+
if clip_model_sets is not None:
|
| 126 |
+
text_model, tokenizer, vis_model, processor = clip_model_sets
|
| 127 |
+
else:
|
| 128 |
+
text_model, tokenizer, vis_model, processor = prepare_clip_model_sets(
|
| 129 |
+
eval_clip_id
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
images = []
|
| 133 |
+
img_embeds = []
|
| 134 |
+
text_embeds = []
|
| 135 |
+
for prompt in EXAMPLE_PROMPTS[:n_test]:
|
| 136 |
+
prompt = prompt.replace("<obj>", learnt_token)
|
| 137 |
+
torch.manual_seed(seed)
|
| 138 |
+
with torch.autocast("cuda"):
|
| 139 |
+
img = pipe(
|
| 140 |
+
prompt, num_inference_steps=n_step, guidance_scale=guidance_scale
|
| 141 |
+
).images[0]
|
| 142 |
+
images.append(img)
|
| 143 |
+
|
| 144 |
+
# image
|
| 145 |
+
inputs = processor(images=img, return_tensors="pt")
|
| 146 |
+
img_embed = vis_model(**inputs).image_embeds
|
| 147 |
+
img_embeds.append(img_embed)
|
| 148 |
+
|
| 149 |
+
prompt = prompt.replace(learnt_token, class_token)
|
| 150 |
+
# prompts
|
| 151 |
+
inputs = tokenizer([prompt], padding=True, return_tensors="pt")
|
| 152 |
+
outputs = text_model(**inputs)
|
| 153 |
+
text_embed = outputs.text_embeds
|
| 154 |
+
text_embeds.append(text_embed)
|
| 155 |
+
|
| 156 |
+
# target images
|
| 157 |
+
inputs = processor(images=target_images, return_tensors="pt")
|
| 158 |
+
target_img_embeds = vis_model(**inputs).image_embeds
|
| 159 |
+
|
| 160 |
+
img_embeds = torch.cat(img_embeds, dim=0)
|
| 161 |
+
text_embeds = torch.cat(text_embeds, dim=0)
|
| 162 |
+
|
| 163 |
+
return text_img_alignment(img_embeds, text_embeds, target_img_embeds)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def visualize_progress(
|
| 167 |
+
path_alls: Union[str, List[str]],
|
| 168 |
+
prompt: str,
|
| 169 |
+
model_id: str = "runwayml/stable-diffusion-v1-5",
|
| 170 |
+
device="cuda:0",
|
| 171 |
+
patch_unet=True,
|
| 172 |
+
patch_text=True,
|
| 173 |
+
patch_ti=True,
|
| 174 |
+
unet_scale=1.0,
|
| 175 |
+
text_sclae=1.0,
|
| 176 |
+
num_inference_steps=50,
|
| 177 |
+
guidance_scale=5.0,
|
| 178 |
+
offset: int = 0,
|
| 179 |
+
limit: int = 10,
|
| 180 |
+
seed: int = 0,
|
| 181 |
+
):
|
| 182 |
+
|
| 183 |
+
imgs = []
|
| 184 |
+
if isinstance(path_alls, str):
|
| 185 |
+
alls = list(set(glob.glob(path_alls)))
|
| 186 |
+
|
| 187 |
+
alls.sort(key=os.path.getmtime)
|
| 188 |
+
else:
|
| 189 |
+
alls = path_alls
|
| 190 |
+
|
| 191 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
| 192 |
+
model_id, torch_dtype=torch.float16
|
| 193 |
+
).to(device)
|
| 194 |
+
|
| 195 |
+
print(f"Found {len(alls)} checkpoints")
|
| 196 |
+
for path in alls[offset:limit]:
|
| 197 |
+
print(path)
|
| 198 |
+
|
| 199 |
+
patch_pipe(
|
| 200 |
+
pipe, path, patch_unet=patch_unet, patch_text=patch_text, patch_ti=patch_ti
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
tune_lora_scale(pipe.unet, unet_scale)
|
| 204 |
+
tune_lora_scale(pipe.text_encoder, text_sclae)
|
| 205 |
+
|
| 206 |
+
torch.manual_seed(seed)
|
| 207 |
+
image = pipe(
|
| 208 |
+
prompt,
|
| 209 |
+
num_inference_steps=num_inference_steps,
|
| 210 |
+
guidance_scale=guidance_scale,
|
| 211 |
+
).images[0]
|
| 212 |
+
imgs.append(image)
|
| 213 |
+
|
| 214 |
+
return imgs
|
lora_diffusion/xformers_utils.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from diffusers.models.attention import BasicTransformerBlock
|
| 5 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 6 |
+
|
| 7 |
+
from .lora import LoraInjectedLinear
|
| 8 |
+
|
| 9 |
+
if is_xformers_available():
|
| 10 |
+
import xformers
|
| 11 |
+
import xformers.ops
|
| 12 |
+
else:
|
| 13 |
+
xformers = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@functools.cache
|
| 17 |
+
def test_xformers_backwards(size):
|
| 18 |
+
@torch.enable_grad()
|
| 19 |
+
def _grad(size):
|
| 20 |
+
q = torch.randn((1, 4, size), device="cuda")
|
| 21 |
+
k = torch.randn((1, 4, size), device="cuda")
|
| 22 |
+
v = torch.randn((1, 4, size), device="cuda")
|
| 23 |
+
|
| 24 |
+
q = q.detach().requires_grad_()
|
| 25 |
+
k = k.detach().requires_grad_()
|
| 26 |
+
v = v.detach().requires_grad_()
|
| 27 |
+
|
| 28 |
+
out = xformers.ops.memory_efficient_attention(q, k, v)
|
| 29 |
+
loss = out.sum(2).mean(0).sum()
|
| 30 |
+
|
| 31 |
+
return torch.autograd.grad(loss, v)
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
_grad(size)
|
| 35 |
+
print(size, "pass")
|
| 36 |
+
return True
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(size, "fail")
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def set_use_memory_efficient_attention_xformers(
|
| 43 |
+
module: torch.nn.Module, valid: bool
|
| 44 |
+
) -> None:
|
| 45 |
+
def fn_test_dim_head(module: torch.nn.Module):
|
| 46 |
+
if isinstance(module, BasicTransformerBlock):
|
| 47 |
+
# dim_head isn't stored anywhere, so back-calculate
|
| 48 |
+
source = module.attn1.to_v
|
| 49 |
+
if isinstance(source, LoraInjectedLinear):
|
| 50 |
+
source = source.linear
|
| 51 |
+
|
| 52 |
+
dim_head = source.out_features // module.attn1.heads
|
| 53 |
+
|
| 54 |
+
result = test_xformers_backwards(dim_head)
|
| 55 |
+
|
| 56 |
+
# If dim_head > dim_head_max, turn xformers off
|
| 57 |
+
if not result:
|
| 58 |
+
module.set_use_memory_efficient_attention_xformers(False)
|
| 59 |
+
|
| 60 |
+
for child in module.children():
|
| 61 |
+
fn_test_dim_head(child)
|
| 62 |
+
|
| 63 |
+
if not is_xformers_available() and valid:
|
| 64 |
+
print("XFormers is not available. Skipping.")
|
| 65 |
+
return
|
| 66 |
+
|
| 67 |
+
module.set_use_memory_efficient_attention_xformers(valid)
|
| 68 |
+
|
| 69 |
+
if valid:
|
| 70 |
+
fn_test_dim_head(module)
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diffusers
|
| 2 |
+
accelerate
|
| 3 |
+
transformers>=4.25.1
|
train_dreambooth_cloneofsimo_lora.py
ADDED
|
@@ -0,0 +1,1008 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Bootstrapped from:
|
| 2 |
+
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import hashlib
|
| 6 |
+
import itertools
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
import inspect
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import torch.utils.checkpoint
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
from accelerate import Accelerator
|
| 19 |
+
from accelerate.logging import get_logger
|
| 20 |
+
from accelerate.utils import set_seed
|
| 21 |
+
from diffusers import (
|
| 22 |
+
AutoencoderKL,
|
| 23 |
+
DDPMScheduler,
|
| 24 |
+
StableDiffusionPipeline,
|
| 25 |
+
UNet2DConditionModel,
|
| 26 |
+
)
|
| 27 |
+
from diffusers.optimization import get_scheduler
|
| 28 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
| 29 |
+
|
| 30 |
+
from tqdm.auto import tqdm
|
| 31 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 32 |
+
|
| 33 |
+
from lora_diffusion import (
|
| 34 |
+
extract_lora_ups_down,
|
| 35 |
+
inject_trainable_lora,
|
| 36 |
+
safetensors_available,
|
| 37 |
+
save_lora_weight,
|
| 38 |
+
save_safeloras,
|
| 39 |
+
)
|
| 40 |
+
from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers
|
| 41 |
+
from PIL import Image
|
| 42 |
+
from torch.utils.data import Dataset
|
| 43 |
+
from torchvision import transforms
|
| 44 |
+
|
| 45 |
+
from pathlib import Path
|
| 46 |
+
|
| 47 |
+
import random
|
| 48 |
+
import re
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class DreamBoothDataset(Dataset):
|
| 52 |
+
"""
|
| 53 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
| 54 |
+
It pre-processes the images and the tokenizes prompts.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
instance_data_root,
|
| 60 |
+
instance_prompt,
|
| 61 |
+
tokenizer,
|
| 62 |
+
class_data_root=None,
|
| 63 |
+
class_prompt=None,
|
| 64 |
+
size=512,
|
| 65 |
+
center_crop=False,
|
| 66 |
+
color_jitter=False,
|
| 67 |
+
h_flip=False,
|
| 68 |
+
resize=False,
|
| 69 |
+
):
|
| 70 |
+
self.size = size
|
| 71 |
+
self.center_crop = center_crop
|
| 72 |
+
self.tokenizer = tokenizer
|
| 73 |
+
self.resize = resize
|
| 74 |
+
|
| 75 |
+
self.instance_data_root = Path(instance_data_root)
|
| 76 |
+
if not self.instance_data_root.exists():
|
| 77 |
+
raise ValueError("Instance images root doesn't exists.")
|
| 78 |
+
|
| 79 |
+
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
| 80 |
+
self.num_instance_images = len(self.instance_images_path)
|
| 81 |
+
self.instance_prompt = instance_prompt
|
| 82 |
+
self._length = self.num_instance_images
|
| 83 |
+
|
| 84 |
+
if class_data_root is not None:
|
| 85 |
+
self.class_data_root = Path(class_data_root)
|
| 86 |
+
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
| 87 |
+
self.class_images_path = list(self.class_data_root.iterdir())
|
| 88 |
+
self.num_class_images = len(self.class_images_path)
|
| 89 |
+
self._length = max(self.num_class_images, self.num_instance_images)
|
| 90 |
+
self.class_prompt = class_prompt
|
| 91 |
+
else:
|
| 92 |
+
self.class_data_root = None
|
| 93 |
+
|
| 94 |
+
img_transforms = []
|
| 95 |
+
|
| 96 |
+
if resize:
|
| 97 |
+
img_transforms.append(
|
| 98 |
+
transforms.Resize(
|
| 99 |
+
size, interpolation=transforms.InterpolationMode.BILINEAR
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
if center_crop:
|
| 103 |
+
img_transforms.append(transforms.CenterCrop(size))
|
| 104 |
+
if color_jitter:
|
| 105 |
+
img_transforms.append(transforms.ColorJitter(0.2, 0.1))
|
| 106 |
+
if h_flip:
|
| 107 |
+
img_transforms.append(transforms.RandomHorizontalFlip())
|
| 108 |
+
|
| 109 |
+
self.image_transforms = transforms.Compose(
|
| 110 |
+
[*img_transforms, transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def __len__(self):
|
| 114 |
+
return self._length
|
| 115 |
+
|
| 116 |
+
def __getitem__(self, index):
|
| 117 |
+
example = {}
|
| 118 |
+
instance_image = Image.open(
|
| 119 |
+
self.instance_images_path[index % self.num_instance_images]
|
| 120 |
+
)
|
| 121 |
+
if not instance_image.mode == "RGB":
|
| 122 |
+
instance_image = instance_image.convert("RGB")
|
| 123 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
| 124 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
| 125 |
+
self.instance_prompt,
|
| 126 |
+
padding="do_not_pad",
|
| 127 |
+
truncation=True,
|
| 128 |
+
max_length=self.tokenizer.model_max_length,
|
| 129 |
+
).input_ids
|
| 130 |
+
|
| 131 |
+
if self.class_data_root:
|
| 132 |
+
class_image = Image.open(
|
| 133 |
+
self.class_images_path[index % self.num_class_images]
|
| 134 |
+
)
|
| 135 |
+
if not class_image.mode == "RGB":
|
| 136 |
+
class_image = class_image.convert("RGB")
|
| 137 |
+
example["class_images"] = self.image_transforms(class_image)
|
| 138 |
+
example["class_prompt_ids"] = self.tokenizer(
|
| 139 |
+
self.class_prompt,
|
| 140 |
+
padding="do_not_pad",
|
| 141 |
+
truncation=True,
|
| 142 |
+
max_length=self.tokenizer.model_max_length,
|
| 143 |
+
).input_ids
|
| 144 |
+
|
| 145 |
+
return example
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class PromptDataset(Dataset):
|
| 149 |
+
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
| 150 |
+
|
| 151 |
+
def __init__(self, prompt, num_samples):
|
| 152 |
+
self.prompt = prompt
|
| 153 |
+
self.num_samples = num_samples
|
| 154 |
+
|
| 155 |
+
def __len__(self):
|
| 156 |
+
return self.num_samples
|
| 157 |
+
|
| 158 |
+
def __getitem__(self, index):
|
| 159 |
+
example = {}
|
| 160 |
+
example["prompt"] = self.prompt
|
| 161 |
+
example["index"] = index
|
| 162 |
+
return example
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
logger = get_logger(__name__)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def parse_args(input_args=None):
|
| 169 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--pretrained_model_name_or_path",
|
| 172 |
+
type=str,
|
| 173 |
+
default=None,
|
| 174 |
+
required=True,
|
| 175 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 176 |
+
)
|
| 177 |
+
parser.add_argument(
|
| 178 |
+
"--pretrained_vae_name_or_path",
|
| 179 |
+
type=str,
|
| 180 |
+
default=None,
|
| 181 |
+
help="Path to pretrained vae or vae identifier from huggingface.co/models.",
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--revision",
|
| 185 |
+
type=str,
|
| 186 |
+
default=None,
|
| 187 |
+
required=False,
|
| 188 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--tokenizer_name",
|
| 192 |
+
type=str,
|
| 193 |
+
default=None,
|
| 194 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
| 195 |
+
)
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"--instance_data_dir",
|
| 198 |
+
type=str,
|
| 199 |
+
default=None,
|
| 200 |
+
required=True,
|
| 201 |
+
help="A folder containing the training data of instance images.",
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--class_data_dir",
|
| 205 |
+
type=str,
|
| 206 |
+
default=None,
|
| 207 |
+
required=False,
|
| 208 |
+
help="A folder containing the training data of class images.",
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--instance_prompt",
|
| 212 |
+
type=str,
|
| 213 |
+
default=None,
|
| 214 |
+
required=True,
|
| 215 |
+
help="The prompt with identifier specifying the instance",
|
| 216 |
+
)
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"--class_prompt",
|
| 219 |
+
type=str,
|
| 220 |
+
default=None,
|
| 221 |
+
help="The prompt to specify images in the same class as provided instance images.",
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--with_prior_preservation",
|
| 225 |
+
default=False,
|
| 226 |
+
action="store_true",
|
| 227 |
+
help="Flag to add prior preservation loss.",
|
| 228 |
+
)
|
| 229 |
+
parser.add_argument(
|
| 230 |
+
"--prior_loss_weight",
|
| 231 |
+
type=float,
|
| 232 |
+
default=1.0,
|
| 233 |
+
help="The weight of prior preservation loss.",
|
| 234 |
+
)
|
| 235 |
+
parser.add_argument(
|
| 236 |
+
"--num_class_images",
|
| 237 |
+
type=int,
|
| 238 |
+
default=100,
|
| 239 |
+
help=(
|
| 240 |
+
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
|
| 241 |
+
" sampled with class_prompt."
|
| 242 |
+
),
|
| 243 |
+
)
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--output_dir",
|
| 246 |
+
type=str,
|
| 247 |
+
default="text-inversion-model",
|
| 248 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 249 |
+
)
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--output_format",
|
| 252 |
+
type=str,
|
| 253 |
+
choices=["pt", "safe", "both"],
|
| 254 |
+
default="both",
|
| 255 |
+
help="The output format of the model predicitions and checkpoints.",
|
| 256 |
+
)
|
| 257 |
+
parser.add_argument(
|
| 258 |
+
"--seed", type=int, default=None, help="A seed for reproducible training."
|
| 259 |
+
)
|
| 260 |
+
parser.add_argument(
|
| 261 |
+
"--resolution",
|
| 262 |
+
type=int,
|
| 263 |
+
default=512,
|
| 264 |
+
help=(
|
| 265 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 266 |
+
" resolution"
|
| 267 |
+
),
|
| 268 |
+
)
|
| 269 |
+
parser.add_argument(
|
| 270 |
+
"--center_crop",
|
| 271 |
+
action="store_true",
|
| 272 |
+
help="Whether to center crop images before resizing to resolution",
|
| 273 |
+
)
|
| 274 |
+
parser.add_argument(
|
| 275 |
+
"--color_jitter",
|
| 276 |
+
action="store_true",
|
| 277 |
+
help="Whether to apply color jitter to images",
|
| 278 |
+
)
|
| 279 |
+
parser.add_argument(
|
| 280 |
+
"--train_text_encoder",
|
| 281 |
+
action="store_true",
|
| 282 |
+
help="Whether to train the text encoder",
|
| 283 |
+
)
|
| 284 |
+
parser.add_argument(
|
| 285 |
+
"--train_batch_size",
|
| 286 |
+
type=int,
|
| 287 |
+
default=4,
|
| 288 |
+
help="Batch size (per device) for the training dataloader.",
|
| 289 |
+
)
|
| 290 |
+
parser.add_argument(
|
| 291 |
+
"--sample_batch_size",
|
| 292 |
+
type=int,
|
| 293 |
+
default=4,
|
| 294 |
+
help="Batch size (per device) for sampling images.",
|
| 295 |
+
)
|
| 296 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
| 297 |
+
parser.add_argument(
|
| 298 |
+
"--max_train_steps",
|
| 299 |
+
type=int,
|
| 300 |
+
default=None,
|
| 301 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 302 |
+
)
|
| 303 |
+
parser.add_argument(
|
| 304 |
+
"--save_steps",
|
| 305 |
+
type=int,
|
| 306 |
+
default=500,
|
| 307 |
+
help="Save checkpoint every X updates steps.",
|
| 308 |
+
)
|
| 309 |
+
parser.add_argument(
|
| 310 |
+
"--gradient_accumulation_steps",
|
| 311 |
+
type=int,
|
| 312 |
+
default=1,
|
| 313 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 314 |
+
)
|
| 315 |
+
parser.add_argument(
|
| 316 |
+
"--gradient_checkpointing",
|
| 317 |
+
action="store_true",
|
| 318 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 319 |
+
)
|
| 320 |
+
parser.add_argument(
|
| 321 |
+
"--lora_rank",
|
| 322 |
+
type=int,
|
| 323 |
+
default=4,
|
| 324 |
+
help="Rank of LoRA approximation.",
|
| 325 |
+
)
|
| 326 |
+
parser.add_argument(
|
| 327 |
+
"--learning_rate",
|
| 328 |
+
type=float,
|
| 329 |
+
default=None,
|
| 330 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 331 |
+
)
|
| 332 |
+
parser.add_argument(
|
| 333 |
+
"--learning_rate_text",
|
| 334 |
+
type=float,
|
| 335 |
+
default=5e-6,
|
| 336 |
+
help="Initial learning rate for text encoder (after the potential warmup period) to use.",
|
| 337 |
+
)
|
| 338 |
+
parser.add_argument(
|
| 339 |
+
"--scale_lr",
|
| 340 |
+
action="store_true",
|
| 341 |
+
default=False,
|
| 342 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 343 |
+
)
|
| 344 |
+
parser.add_argument(
|
| 345 |
+
"--lr_scheduler",
|
| 346 |
+
type=str,
|
| 347 |
+
default="constant",
|
| 348 |
+
help=(
|
| 349 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 350 |
+
' "constant", "constant_with_warmup"]'
|
| 351 |
+
),
|
| 352 |
+
)
|
| 353 |
+
parser.add_argument(
|
| 354 |
+
"--lr_warmup_steps",
|
| 355 |
+
type=int,
|
| 356 |
+
default=500,
|
| 357 |
+
help="Number of steps for the warmup in the lr scheduler.",
|
| 358 |
+
)
|
| 359 |
+
parser.add_argument(
|
| 360 |
+
"--use_8bit_adam",
|
| 361 |
+
action="store_true",
|
| 362 |
+
help="Whether or not to use 8-bit Adam from bitsandbytes.",
|
| 363 |
+
)
|
| 364 |
+
parser.add_argument(
|
| 365 |
+
"--adam_beta1",
|
| 366 |
+
type=float,
|
| 367 |
+
default=0.9,
|
| 368 |
+
help="The beta1 parameter for the Adam optimizer.",
|
| 369 |
+
)
|
| 370 |
+
parser.add_argument(
|
| 371 |
+
"--adam_beta2",
|
| 372 |
+
type=float,
|
| 373 |
+
default=0.999,
|
| 374 |
+
help="The beta2 parameter for the Adam optimizer.",
|
| 375 |
+
)
|
| 376 |
+
parser.add_argument(
|
| 377 |
+
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
|
| 378 |
+
)
|
| 379 |
+
parser.add_argument(
|
| 380 |
+
"--adam_epsilon",
|
| 381 |
+
type=float,
|
| 382 |
+
default=1e-08,
|
| 383 |
+
help="Epsilon value for the Adam optimizer",
|
| 384 |
+
)
|
| 385 |
+
parser.add_argument(
|
| 386 |
+
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
|
| 387 |
+
)
|
| 388 |
+
parser.add_argument(
|
| 389 |
+
"--push_to_hub",
|
| 390 |
+
action="store_true",
|
| 391 |
+
help="Whether or not to push the model to the Hub.",
|
| 392 |
+
)
|
| 393 |
+
parser.add_argument(
|
| 394 |
+
"--hub_token",
|
| 395 |
+
type=str,
|
| 396 |
+
default=None,
|
| 397 |
+
help="The token to use to push to the Model Hub.",
|
| 398 |
+
)
|
| 399 |
+
parser.add_argument(
|
| 400 |
+
"--logging_dir",
|
| 401 |
+
type=str,
|
| 402 |
+
default="logs",
|
| 403 |
+
help=(
|
| 404 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 405 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 406 |
+
),
|
| 407 |
+
)
|
| 408 |
+
parser.add_argument(
|
| 409 |
+
"--mixed_precision",
|
| 410 |
+
type=str,
|
| 411 |
+
default=None,
|
| 412 |
+
choices=["no", "fp16", "bf16"],
|
| 413 |
+
help=(
|
| 414 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 415 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 416 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 417 |
+
),
|
| 418 |
+
)
|
| 419 |
+
parser.add_argument(
|
| 420 |
+
"--local_rank",
|
| 421 |
+
type=int,
|
| 422 |
+
default=-1,
|
| 423 |
+
help="For distributed training: local_rank",
|
| 424 |
+
)
|
| 425 |
+
parser.add_argument(
|
| 426 |
+
"--resume_unet",
|
| 427 |
+
type=str,
|
| 428 |
+
default=None,
|
| 429 |
+
help=("File path for unet lora to resume training."),
|
| 430 |
+
)
|
| 431 |
+
parser.add_argument(
|
| 432 |
+
"--resume_text_encoder",
|
| 433 |
+
type=str,
|
| 434 |
+
default=None,
|
| 435 |
+
help=("File path for text encoder lora to resume training."),
|
| 436 |
+
)
|
| 437 |
+
parser.add_argument(
|
| 438 |
+
"--resize",
|
| 439 |
+
type=bool,
|
| 440 |
+
default=True,
|
| 441 |
+
required=False,
|
| 442 |
+
help="Should images be resized to --resolution before training?",
|
| 443 |
+
)
|
| 444 |
+
parser.add_argument(
|
| 445 |
+
"--use_xformers", action="store_true", help="Whether or not to use xformers"
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
if input_args is not None:
|
| 449 |
+
args = parser.parse_args(input_args)
|
| 450 |
+
else:
|
| 451 |
+
args = parser.parse_args()
|
| 452 |
+
|
| 453 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 454 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 455 |
+
args.local_rank = env_local_rank
|
| 456 |
+
|
| 457 |
+
if args.with_prior_preservation:
|
| 458 |
+
if args.class_data_dir is None:
|
| 459 |
+
raise ValueError("You must specify a data directory for class images.")
|
| 460 |
+
if args.class_prompt is None:
|
| 461 |
+
raise ValueError("You must specify prompt for class images.")
|
| 462 |
+
else:
|
| 463 |
+
if args.class_data_dir is not None:
|
| 464 |
+
logger.warning(
|
| 465 |
+
"You need not use --class_data_dir without --with_prior_preservation."
|
| 466 |
+
)
|
| 467 |
+
if args.class_prompt is not None:
|
| 468 |
+
logger.warning(
|
| 469 |
+
"You need not use --class_prompt without --with_prior_preservation."
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
if not safetensors_available:
|
| 473 |
+
if args.output_format == "both":
|
| 474 |
+
print(
|
| 475 |
+
"Safetensors is not available - changing output format to just output PyTorch files"
|
| 476 |
+
)
|
| 477 |
+
args.output_format = "pt"
|
| 478 |
+
elif args.output_format == "safe":
|
| 479 |
+
raise ValueError(
|
| 480 |
+
"Safetensors is not available - either install it, or change output_format."
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
return args
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def main(args):
|
| 487 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 488 |
+
|
| 489 |
+
accelerator = Accelerator(
|
| 490 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 491 |
+
mixed_precision=args.mixed_precision,
|
| 492 |
+
log_with="tensorboard",
|
| 493 |
+
logging_dir=logging_dir,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
| 497 |
+
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
| 498 |
+
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
| 499 |
+
if (
|
| 500 |
+
args.train_text_encoder
|
| 501 |
+
and args.gradient_accumulation_steps > 1
|
| 502 |
+
and accelerator.num_processes > 1
|
| 503 |
+
):
|
| 504 |
+
raise ValueError(
|
| 505 |
+
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
| 506 |
+
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
if args.seed is not None:
|
| 510 |
+
set_seed(args.seed)
|
| 511 |
+
|
| 512 |
+
if args.with_prior_preservation:
|
| 513 |
+
class_images_dir = Path(args.class_data_dir)
|
| 514 |
+
if not class_images_dir.exists():
|
| 515 |
+
class_images_dir.mkdir(parents=True)
|
| 516 |
+
cur_class_images = len(list(class_images_dir.iterdir()))
|
| 517 |
+
|
| 518 |
+
if cur_class_images < args.num_class_images:
|
| 519 |
+
torch_dtype = (
|
| 520 |
+
torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
| 521 |
+
)
|
| 522 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
| 523 |
+
args.pretrained_model_name_or_path,
|
| 524 |
+
torch_dtype=torch_dtype,
|
| 525 |
+
safety_checker=None,
|
| 526 |
+
revision=args.revision,
|
| 527 |
+
)
|
| 528 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 529 |
+
|
| 530 |
+
num_new_images = args.num_class_images - cur_class_images
|
| 531 |
+
logger.info(f"Number of class images to sample: {num_new_images}.")
|
| 532 |
+
|
| 533 |
+
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
| 534 |
+
sample_dataloader = torch.utils.data.DataLoader(
|
| 535 |
+
sample_dataset, batch_size=args.sample_batch_size
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
sample_dataloader = accelerator.prepare(sample_dataloader)
|
| 539 |
+
pipeline.to(accelerator.device)
|
| 540 |
+
|
| 541 |
+
for example in tqdm(
|
| 542 |
+
sample_dataloader,
|
| 543 |
+
desc="Generating class images",
|
| 544 |
+
disable=not accelerator.is_local_main_process,
|
| 545 |
+
):
|
| 546 |
+
images = pipeline(example["prompt"]).images
|
| 547 |
+
|
| 548 |
+
for i, image in enumerate(images):
|
| 549 |
+
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
| 550 |
+
image_filename = (
|
| 551 |
+
class_images_dir
|
| 552 |
+
/ f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
| 553 |
+
)
|
| 554 |
+
image.save(image_filename)
|
| 555 |
+
|
| 556 |
+
del pipeline
|
| 557 |
+
if torch.cuda.is_available():
|
| 558 |
+
torch.cuda.empty_cache()
|
| 559 |
+
|
| 560 |
+
# Handle the repository creation
|
| 561 |
+
if accelerator.is_main_process:
|
| 562 |
+
|
| 563 |
+
if args.output_dir is not None:
|
| 564 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 565 |
+
|
| 566 |
+
# Load the tokenizer
|
| 567 |
+
if args.tokenizer_name:
|
| 568 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
| 569 |
+
args.tokenizer_name,
|
| 570 |
+
revision=args.revision,
|
| 571 |
+
)
|
| 572 |
+
elif args.pretrained_model_name_or_path:
|
| 573 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
| 574 |
+
args.pretrained_model_name_or_path,
|
| 575 |
+
subfolder="tokenizer",
|
| 576 |
+
revision=args.revision,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# Load models and create wrapper for stable diffusion
|
| 580 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
| 581 |
+
args.pretrained_model_name_or_path,
|
| 582 |
+
subfolder="text_encoder",
|
| 583 |
+
revision=args.revision,
|
| 584 |
+
)
|
| 585 |
+
vae = AutoencoderKL.from_pretrained(
|
| 586 |
+
args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
|
| 587 |
+
subfolder=None if args.pretrained_vae_name_or_path else "vae",
|
| 588 |
+
revision=None if args.pretrained_vae_name_or_path else args.revision,
|
| 589 |
+
)
|
| 590 |
+
unet = UNet2DConditionModel.from_pretrained(
|
| 591 |
+
args.pretrained_model_name_or_path,
|
| 592 |
+
subfolder="unet",
|
| 593 |
+
revision=args.revision,
|
| 594 |
+
)
|
| 595 |
+
unet.requires_grad_(False)
|
| 596 |
+
unet_lora_params, _ = inject_trainable_lora(
|
| 597 |
+
unet, r=args.lora_rank, loras=args.resume_unet
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
for _up, _down in extract_lora_ups_down(unet):
|
| 601 |
+
print("Before training: Unet First Layer lora up", _up.weight.data)
|
| 602 |
+
print("Before training: Unet First Layer lora down", _down.weight.data)
|
| 603 |
+
break
|
| 604 |
+
|
| 605 |
+
vae.requires_grad_(False)
|
| 606 |
+
text_encoder.requires_grad_(False)
|
| 607 |
+
|
| 608 |
+
if args.train_text_encoder:
|
| 609 |
+
text_encoder_lora_params, _ = inject_trainable_lora(
|
| 610 |
+
text_encoder,
|
| 611 |
+
target_replace_module=["CLIPAttention"],
|
| 612 |
+
r=args.lora_rank,
|
| 613 |
+
)
|
| 614 |
+
for _up, _down in extract_lora_ups_down(
|
| 615 |
+
text_encoder, target_replace_module=["CLIPAttention"]
|
| 616 |
+
):
|
| 617 |
+
print("Before training: text encoder First Layer lora up", _up.weight.data)
|
| 618 |
+
print(
|
| 619 |
+
"Before training: text encoder First Layer lora down", _down.weight.data
|
| 620 |
+
)
|
| 621 |
+
break
|
| 622 |
+
|
| 623 |
+
if args.use_xformers:
|
| 624 |
+
set_use_memory_efficient_attention_xformers(unet, True)
|
| 625 |
+
set_use_memory_efficient_attention_xformers(vae, True)
|
| 626 |
+
|
| 627 |
+
if args.gradient_checkpointing:
|
| 628 |
+
unet.enable_gradient_checkpointing()
|
| 629 |
+
if args.train_text_encoder:
|
| 630 |
+
text_encoder.gradient_checkpointing_enable()
|
| 631 |
+
|
| 632 |
+
if args.scale_lr:
|
| 633 |
+
args.learning_rate = (
|
| 634 |
+
args.learning_rate
|
| 635 |
+
* args.gradient_accumulation_steps
|
| 636 |
+
* args.train_batch_size
|
| 637 |
+
* accelerator.num_processes
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
| 641 |
+
if args.use_8bit_adam:
|
| 642 |
+
try:
|
| 643 |
+
import bitsandbytes as bnb
|
| 644 |
+
except ImportError:
|
| 645 |
+
raise ImportError(
|
| 646 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 650 |
+
else:
|
| 651 |
+
optimizer_class = torch.optim.AdamW
|
| 652 |
+
|
| 653 |
+
text_lr = (
|
| 654 |
+
args.learning_rate
|
| 655 |
+
if args.learning_rate_text is None
|
| 656 |
+
else args.learning_rate_text
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
params_to_optimize = (
|
| 660 |
+
[
|
| 661 |
+
{"params": itertools.chain(*unet_lora_params), "lr": args.learning_rate},
|
| 662 |
+
{
|
| 663 |
+
"params": itertools.chain(*text_encoder_lora_params),
|
| 664 |
+
"lr": text_lr,
|
| 665 |
+
},
|
| 666 |
+
]
|
| 667 |
+
if args.train_text_encoder
|
| 668 |
+
else itertools.chain(*unet_lora_params)
|
| 669 |
+
)
|
| 670 |
+
optimizer = optimizer_class(
|
| 671 |
+
params_to_optimize,
|
| 672 |
+
lr=args.learning_rate,
|
| 673 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 674 |
+
weight_decay=args.adam_weight_decay,
|
| 675 |
+
eps=args.adam_epsilon,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
noise_scheduler = DDPMScheduler.from_config(
|
| 679 |
+
args.pretrained_model_name_or_path, subfolder="scheduler"
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
train_dataset = DreamBoothDataset(
|
| 683 |
+
instance_data_root=args.instance_data_dir,
|
| 684 |
+
instance_prompt=args.instance_prompt,
|
| 685 |
+
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
| 686 |
+
class_prompt=args.class_prompt,
|
| 687 |
+
tokenizer=tokenizer,
|
| 688 |
+
size=args.resolution,
|
| 689 |
+
center_crop=args.center_crop,
|
| 690 |
+
color_jitter=args.color_jitter,
|
| 691 |
+
resize=args.resize,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
def collate_fn(examples):
|
| 695 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
| 696 |
+
pixel_values = [example["instance_images"] for example in examples]
|
| 697 |
+
|
| 698 |
+
# Concat class and instance examples for prior preservation.
|
| 699 |
+
# We do this to avoid doing two forward passes.
|
| 700 |
+
if args.with_prior_preservation:
|
| 701 |
+
input_ids += [example["class_prompt_ids"] for example in examples]
|
| 702 |
+
pixel_values += [example["class_images"] for example in examples]
|
| 703 |
+
|
| 704 |
+
pixel_values = torch.stack(pixel_values)
|
| 705 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 706 |
+
|
| 707 |
+
input_ids = tokenizer.pad(
|
| 708 |
+
{"input_ids": input_ids},
|
| 709 |
+
padding="max_length",
|
| 710 |
+
max_length=tokenizer.model_max_length,
|
| 711 |
+
return_tensors="pt",
|
| 712 |
+
).input_ids
|
| 713 |
+
|
| 714 |
+
batch = {
|
| 715 |
+
"input_ids": input_ids,
|
| 716 |
+
"pixel_values": pixel_values,
|
| 717 |
+
}
|
| 718 |
+
return batch
|
| 719 |
+
|
| 720 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 721 |
+
train_dataset,
|
| 722 |
+
batch_size=args.train_batch_size,
|
| 723 |
+
shuffle=True,
|
| 724 |
+
collate_fn=collate_fn,
|
| 725 |
+
num_workers=1,
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
# Scheduler and math around the number of training steps.
|
| 729 |
+
overrode_max_train_steps = False
|
| 730 |
+
num_update_steps_per_epoch = math.ceil(
|
| 731 |
+
len(train_dataloader) / args.gradient_accumulation_steps
|
| 732 |
+
)
|
| 733 |
+
if args.max_train_steps is None:
|
| 734 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 735 |
+
overrode_max_train_steps = True
|
| 736 |
+
|
| 737 |
+
lr_scheduler = get_scheduler(
|
| 738 |
+
args.lr_scheduler,
|
| 739 |
+
optimizer=optimizer,
|
| 740 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
| 741 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
if args.train_text_encoder:
|
| 745 |
+
(
|
| 746 |
+
unet,
|
| 747 |
+
text_encoder,
|
| 748 |
+
optimizer,
|
| 749 |
+
train_dataloader,
|
| 750 |
+
lr_scheduler,
|
| 751 |
+
) = accelerator.prepare(
|
| 752 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
| 753 |
+
)
|
| 754 |
+
else:
|
| 755 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 756 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
weight_dtype = torch.float32
|
| 760 |
+
if accelerator.mixed_precision == "fp16":
|
| 761 |
+
weight_dtype = torch.float16
|
| 762 |
+
elif accelerator.mixed_precision == "bf16":
|
| 763 |
+
weight_dtype = torch.bfloat16
|
| 764 |
+
|
| 765 |
+
# Move text_encode and vae to gpu.
|
| 766 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
| 767 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
| 768 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 769 |
+
if not args.train_text_encoder:
|
| 770 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 771 |
+
|
| 772 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 773 |
+
num_update_steps_per_epoch = math.ceil(
|
| 774 |
+
len(train_dataloader) / args.gradient_accumulation_steps
|
| 775 |
+
)
|
| 776 |
+
if overrode_max_train_steps:
|
| 777 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 778 |
+
# Afterwards we recalculate our number of training epochs
|
| 779 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 780 |
+
|
| 781 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 782 |
+
# The trackers initializes automatically on the main process.
|
| 783 |
+
if accelerator.is_main_process:
|
| 784 |
+
accelerator.init_trackers("dreambooth", config=vars(args))
|
| 785 |
+
|
| 786 |
+
# Train!
|
| 787 |
+
total_batch_size = (
|
| 788 |
+
args.train_batch_size
|
| 789 |
+
* accelerator.num_processes
|
| 790 |
+
* args.gradient_accumulation_steps
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
print("***** Running training *****")
|
| 794 |
+
print(f" Num examples = {len(train_dataset)}")
|
| 795 |
+
print(f" Num batches each epoch = {len(train_dataloader)}")
|
| 796 |
+
print(f" Num Epochs = {args.num_train_epochs}")
|
| 797 |
+
print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 798 |
+
print(
|
| 799 |
+
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
| 800 |
+
)
|
| 801 |
+
print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 802 |
+
print(f" Total optimization steps = {args.max_train_steps}")
|
| 803 |
+
# Only show the progress bar once on each machine.
|
| 804 |
+
progress_bar = tqdm(
|
| 805 |
+
range(args.max_train_steps), disable=not accelerator.is_local_main_process
|
| 806 |
+
)
|
| 807 |
+
progress_bar.set_description("Steps")
|
| 808 |
+
global_step = 0
|
| 809 |
+
last_save = 0
|
| 810 |
+
|
| 811 |
+
for epoch in range(args.num_train_epochs):
|
| 812 |
+
unet.train()
|
| 813 |
+
if args.train_text_encoder:
|
| 814 |
+
text_encoder.train()
|
| 815 |
+
|
| 816 |
+
for step, batch in enumerate(train_dataloader):
|
| 817 |
+
# Convert images to latent space
|
| 818 |
+
latents = vae.encode(
|
| 819 |
+
batch["pixel_values"].to(dtype=weight_dtype)
|
| 820 |
+
).latent_dist.sample()
|
| 821 |
+
latents = latents * 0.18215
|
| 822 |
+
|
| 823 |
+
# Sample noise that we'll add to the latents
|
| 824 |
+
noise = torch.randn_like(latents)
|
| 825 |
+
bsz = latents.shape[0]
|
| 826 |
+
# Sample a random timestep for each image
|
| 827 |
+
timesteps = torch.randint(
|
| 828 |
+
0,
|
| 829 |
+
noise_scheduler.config.num_train_timesteps,
|
| 830 |
+
(bsz,),
|
| 831 |
+
device=latents.device,
|
| 832 |
+
)
|
| 833 |
+
timesteps = timesteps.long()
|
| 834 |
+
|
| 835 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
| 836 |
+
# (this is the forward diffusion process)
|
| 837 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 838 |
+
|
| 839 |
+
# Get the text embedding for conditioning
|
| 840 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
| 841 |
+
|
| 842 |
+
# Predict the noise residual
|
| 843 |
+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 844 |
+
|
| 845 |
+
# Get the target for loss depending on the prediction type
|
| 846 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 847 |
+
target = noise
|
| 848 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 849 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 850 |
+
else:
|
| 851 |
+
raise ValueError(
|
| 852 |
+
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
if args.with_prior_preservation:
|
| 856 |
+
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
| 857 |
+
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
| 858 |
+
target, target_prior = torch.chunk(target, 2, dim=0)
|
| 859 |
+
|
| 860 |
+
# Compute instance loss
|
| 861 |
+
loss = (
|
| 862 |
+
F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
| 863 |
+
.mean([1, 2, 3])
|
| 864 |
+
.mean()
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
# Compute prior loss
|
| 868 |
+
prior_loss = F.mse_loss(
|
| 869 |
+
model_pred_prior.float(), target_prior.float(), reduction="mean"
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
# Add the prior loss to the instance loss.
|
| 873 |
+
loss = loss + args.prior_loss_weight * prior_loss
|
| 874 |
+
else:
|
| 875 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
| 876 |
+
|
| 877 |
+
accelerator.backward(loss)
|
| 878 |
+
if accelerator.sync_gradients:
|
| 879 |
+
params_to_clip = (
|
| 880 |
+
itertools.chain(unet.parameters(), text_encoder.parameters())
|
| 881 |
+
if args.train_text_encoder
|
| 882 |
+
else unet.parameters()
|
| 883 |
+
)
|
| 884 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 885 |
+
optimizer.step()
|
| 886 |
+
lr_scheduler.step()
|
| 887 |
+
progress_bar.update(1)
|
| 888 |
+
optimizer.zero_grad()
|
| 889 |
+
|
| 890 |
+
global_step += 1
|
| 891 |
+
|
| 892 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 893 |
+
if accelerator.sync_gradients:
|
| 894 |
+
if args.save_steps and global_step - last_save >= args.save_steps:
|
| 895 |
+
if accelerator.is_main_process:
|
| 896 |
+
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
|
| 897 |
+
# it, the models will be unwrapped, and when they are then used for further training,
|
| 898 |
+
# we will crash. pass this, but only to newer versions of accelerate. fixes
|
| 899 |
+
# https://github.com/huggingface/diffusers/issues/1566
|
| 900 |
+
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
|
| 901 |
+
inspect.signature(
|
| 902 |
+
accelerator.unwrap_model
|
| 903 |
+
).parameters.keys()
|
| 904 |
+
)
|
| 905 |
+
extra_args = (
|
| 906 |
+
{"keep_fp32_wrapper": True}
|
| 907 |
+
if accepts_keep_fp32_wrapper
|
| 908 |
+
else {}
|
| 909 |
+
)
|
| 910 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
| 911 |
+
args.pretrained_model_name_or_path,
|
| 912 |
+
unet=accelerator.unwrap_model(unet, **extra_args),
|
| 913 |
+
text_encoder=accelerator.unwrap_model(
|
| 914 |
+
text_encoder, **extra_args
|
| 915 |
+
),
|
| 916 |
+
revision=args.revision,
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
filename_unet = (
|
| 920 |
+
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
|
| 921 |
+
)
|
| 922 |
+
filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
|
| 923 |
+
print(f"save weights {filename_unet}, {filename_text_encoder}")
|
| 924 |
+
save_lora_weight(pipeline.unet, filename_unet)
|
| 925 |
+
if args.train_text_encoder:
|
| 926 |
+
save_lora_weight(
|
| 927 |
+
pipeline.text_encoder,
|
| 928 |
+
filename_text_encoder,
|
| 929 |
+
target_replace_module=["CLIPAttention"],
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
for _up, _down in extract_lora_ups_down(pipeline.unet):
|
| 933 |
+
print(
|
| 934 |
+
"First Unet Layer's Up Weight is now : ",
|
| 935 |
+
_up.weight.data,
|
| 936 |
+
)
|
| 937 |
+
print(
|
| 938 |
+
"First Unet Layer's Down Weight is now : ",
|
| 939 |
+
_down.weight.data,
|
| 940 |
+
)
|
| 941 |
+
break
|
| 942 |
+
if args.train_text_encoder:
|
| 943 |
+
for _up, _down in extract_lora_ups_down(
|
| 944 |
+
pipeline.text_encoder,
|
| 945 |
+
target_replace_module=["CLIPAttention"],
|
| 946 |
+
):
|
| 947 |
+
print(
|
| 948 |
+
"First Text Encoder Layer's Up Weight is now : ",
|
| 949 |
+
_up.weight.data,
|
| 950 |
+
)
|
| 951 |
+
print(
|
| 952 |
+
"First Text Encoder Layer's Down Weight is now : ",
|
| 953 |
+
_down.weight.data,
|
| 954 |
+
)
|
| 955 |
+
break
|
| 956 |
+
|
| 957 |
+
last_save = global_step
|
| 958 |
+
|
| 959 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 960 |
+
progress_bar.set_postfix(**logs)
|
| 961 |
+
accelerator.log(logs, step=global_step)
|
| 962 |
+
|
| 963 |
+
if global_step >= args.max_train_steps:
|
| 964 |
+
break
|
| 965 |
+
|
| 966 |
+
accelerator.wait_for_everyone()
|
| 967 |
+
|
| 968 |
+
# Create the pipeline using using the trained modules and save it.
|
| 969 |
+
if accelerator.is_main_process:
|
| 970 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
| 971 |
+
args.pretrained_model_name_or_path,
|
| 972 |
+
unet=accelerator.unwrap_model(unet),
|
| 973 |
+
text_encoder=accelerator.unwrap_model(text_encoder),
|
| 974 |
+
revision=args.revision,
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
print("\n\nLora TRAINING DONE!\n\n")
|
| 978 |
+
|
| 979 |
+
if args.output_format == "pt" or args.output_format == "both":
|
| 980 |
+
save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
|
| 981 |
+
if args.train_text_encoder:
|
| 982 |
+
save_lora_weight(
|
| 983 |
+
pipeline.text_encoder,
|
| 984 |
+
args.output_dir + "/lora_weight.text_encoder.pt",
|
| 985 |
+
target_replace_module=["CLIPAttention"],
|
| 986 |
+
)
|
| 987 |
+
|
| 988 |
+
if args.output_format == "safe" or args.output_format == "both":
|
| 989 |
+
loras = {}
|
| 990 |
+
loras["unet"] = (pipeline.unet, {"CrossAttention", "Attention", "GEGLU"})
|
| 991 |
+
if args.train_text_encoder:
|
| 992 |
+
loras["text_encoder"] = (pipeline.text_encoder, {"CLIPAttention"})
|
| 993 |
+
|
| 994 |
+
save_safeloras(loras, args.output_dir + "/lora_weight.safetensors")
|
| 995 |
+
|
| 996 |
+
if args.push_to_hub:
|
| 997 |
+
repo.push_to_hub(
|
| 998 |
+
commit_message="End of training",
|
| 999 |
+
blocking=False,
|
| 1000 |
+
auto_lfs_prune=True,
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
accelerator.end_training()
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
if __name__ == "__main__":
|
| 1007 |
+
args = parse_args()
|
| 1008 |
+
main(args)
|