Spaces:
Configuration error
Configuration error
Duplicate from lora-library/Low-rank-Adaptation
Browse filesCo-authored-by: yuvraj sharma <[email protected]>
- .gitattributes +36 -0
- .gitignore +6 -0
- README.md +14 -0
- README1.md +139 -0
- app.py +121 -0
- contents/alpha_scale.gif +3 -0
- contents/alpha_scale.mp4 +3 -0
- contents/disney_lora.jpg +0 -0
- contents/pop_art.jpg +0 -0
- lora_diffusion/__init__.py +1 -0
- lora_diffusion/cli_lora_add.py +118 -0
- lora_diffusion/lora.py +355 -0
- lora_diffusion/to_ckpt_v2.py +232 -0
- lora_disney.pt +3 -0
- lora_illust.pt +3 -0
- lora_playgroundai_wt.pt +3 -0
- lora_pop.pt +3 -0
- lora_weight.pt +3 -0
- output_example/dummy.txt +0 -0
- requirements.txt +7 -0
- run_lora_db.sh +17 -0
- scripts/make_alpha_gifs.ipynb +0 -0
- scripts/run_inference.ipynb +0 -0
- setup.py +25 -0
- simba1.jpg +0 -0
- simba2.jpg +0 -0
- simba3.jpg +0 -0
- simba4.jpg +0 -0
- train_lora_dreambooth.py +956 -0
- train_lora_dreambooth1.py +964 -0
.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
contents/alpha_scale.gif filter=lfs diff=lfs merge=lfs -text
|
36 |
+
contents/alpha_scale.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data_*
|
2 |
+
output_*
|
3 |
+
__pycache__
|
4 |
+
*.pyc
|
5 |
+
__test*
|
6 |
+
merged_lora*
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: LORA-Low-rank-Adaptation
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: gray
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.12
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
duplicated_from: lora-library/Low-rank-Adaptation
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
README1.md
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Low-rank Adaptation for Fast Text-to-Image Diffusion Fine-tuning
|
2 |
+
|
3 |
+
<!-- #region -->
|
4 |
+
<p align="center">
|
5 |
+
<img src="contents/alpha_scale.gif">
|
6 |
+
</p>
|
7 |
+
<!-- #endregion -->
|
8 |
+
|
9 |
+
> Using LORA to fine tune on illustration dataset : $W = W_0 + \alpha \Delta W$, where $\alpha$ is the merging ratio. Above gif is scaling alpha from 0 to 1. Setting alpha to 0 is same as using the original model, and setting alpha to 1 is same as using the fully fine-tuned model.
|
10 |
+
|
11 |
+
<!-- #region -->
|
12 |
+
<p align="center">
|
13 |
+
<img src="contents/disney_lora.jpg">
|
14 |
+
</p>
|
15 |
+
<!-- #endregion -->
|
16 |
+
|
17 |
+
> "style of sks, baby lion", with disney-style LORA model.
|
18 |
+
|
19 |
+
<!-- #region -->
|
20 |
+
<p align="center">
|
21 |
+
<img src="contents/pop_art.jpg">
|
22 |
+
</p>
|
23 |
+
<!-- #endregion -->
|
24 |
+
|
25 |
+
> "style of sks, superman", with pop-art style LORA model.
|
26 |
+
|
27 |
+
## Main Features
|
28 |
+
|
29 |
+
- Fine-tune Stable diffusion models twice as faster than dreambooth method, by Low-rank Adaptation
|
30 |
+
- Get insanely small end result, easy to share and download.
|
31 |
+
- Easy to use, compatible with diffusers
|
32 |
+
- Sometimes even better performance than full fine-tuning (but left as future work for extensive comparisons)
|
33 |
+
- Merge checkpoints by merging LORA
|
34 |
+
|
35 |
+
# Lengthy Introduction
|
36 |
+
|
37 |
+
Thanks to the generous work of Stability AI and Huggingface, so many people have enjoyed fine-tuning stable diffusion models to fit their needs and generate higher fidelity images. **However, the fine-tuning process is very slow, and it is not easy to find a good balance between the number of steps and the quality of the results.**
|
38 |
+
|
39 |
+
Also, the final results (fully fined-tuned model) is very large. Some people instead works with textual-inversion as an alternative for this. But clearly this is suboptimal: textual inversion only creates a small word-embedding, and the final image is not as good as a fully fine-tuned model.
|
40 |
+
|
41 |
+
Well, what's the alternative? In the domain of LLM, researchers have developed Efficient fine-tuning methods. LORA, especially, tackles the very problem the community currently has: end users with Open-sourced stable-diffusion model want to try various other fine-tuned model that is created by the community, but the model is too large to download and use. LORA instead attempts to fine-tune the "residual" of the model instead of the entire model: i.e., train the $\Delta W$ instead of $W$.
|
42 |
+
|
43 |
+
$$
|
44 |
+
W' = W + \Delta W
|
45 |
+
$$
|
46 |
+
|
47 |
+
Where we can further decompose $\Delta W$ into low-rank matrices : $\Delta W = A B^T $, where $A, \in \mathbb{R}^{n \times d}, B \in \mathbb{R}^{m \times d}, d << n$.
|
48 |
+
This is the key idea of LORA. We can then fine-tune $A$ and $B$ instead of $W$. In the end, you get an insanely small model as $A$ and $B$ are much smaller than $W$.
|
49 |
+
|
50 |
+
Also, not all of the parameters need tuning: they found that often, $Q, K, V, O$ (i.e., attention layer) of the transformer model is enough to tune. (This is also the reason why the end result is so small). This repo will follow the same idea.
|
51 |
+
|
52 |
+
Enough of the lengthy introduction, let's get to the code.
|
53 |
+
|
54 |
+
# Installation
|
55 |
+
|
56 |
+
```bash
|
57 |
+
pip install git+https://github.com/cloneofsimo/lora.git
|
58 |
+
```
|
59 |
+
|
60 |
+
# Getting Started
|
61 |
+
|
62 |
+
## Fine-tuning Stable diffusion with LORA.
|
63 |
+
|
64 |
+
Basic usage is as follows: prepare sets of $A, B$ matrices in an unet model, and fine-tune them.
|
65 |
+
|
66 |
+
```python
|
67 |
+
from lora_diffusion import inject_trainable_lora, extract_lora_up_downs
|
68 |
+
|
69 |
+
...
|
70 |
+
|
71 |
+
unet = UNet2DConditionModel.from_pretrained(
|
72 |
+
pretrained_model_name_or_path,
|
73 |
+
subfolder="unet",
|
74 |
+
)
|
75 |
+
unet.requires_grad_(False)
|
76 |
+
unet_lora_params, train_names = inject_trainable_lora(unet) # This will
|
77 |
+
# turn off all of the gradients of unet, except for the trainable LORA params.
|
78 |
+
optimizer = optim.Adam(
|
79 |
+
itertools.chain(*unet_lora_params, text_encoder.parameters()), lr=1e-4
|
80 |
+
)
|
81 |
+
```
|
82 |
+
|
83 |
+
An example of this can be found in `train_lora_dreambooth.py`. Run this example with
|
84 |
+
|
85 |
+
```bash
|
86 |
+
run_lora_db.sh
|
87 |
+
```
|
88 |
+
|
89 |
+
## Loading, merging, and interpolating trained LORAs.
|
90 |
+
|
91 |
+
We've seen that people have been merging different checkpoints with different ratios, and this seems to be very useful to the community. LORA is extremely easy to merge.
|
92 |
+
|
93 |
+
By the nature of LORA, one can interpolate between different fine-tuned models by adding different $A, B$ matrices.
|
94 |
+
|
95 |
+
Currently, LORA cli has two options : merge unet with LORA, or merge LORA with LORA.
|
96 |
+
|
97 |
+
### Merging unet with LORA
|
98 |
+
|
99 |
+
```bash
|
100 |
+
$ lora_add --path_1 PATH_TO_DIFFUSER_FORMAT_MODEL --path_2 PATH_TO_LORA.PT --mode upl --alpha 1.0 --output_path OUTPUT_PATH
|
101 |
+
```
|
102 |
+
|
103 |
+
`path_1` can be both local path or huggingface model name. When adding LORA to unet, alpha is the constant as below:
|
104 |
+
|
105 |
+
$$
|
106 |
+
W' = W + \alpha \Delta W
|
107 |
+
$$
|
108 |
+
|
109 |
+
So, set alpha to 1.0 to fully add LORA. If the LORA seems to have too much effect (i.e., overfitted), set alpha to lower value. If the LORA seems to have too little effect, set alpha to higher than 1.0. You can tune these values to your needs.
|
110 |
+
|
111 |
+
**Example**
|
112 |
+
|
113 |
+
```bash
|
114 |
+
$ lora_add --path_1 stabilityai/stable-diffusion-2-base --path_2 lora_illust.pt --mode upl --alpha 1.0 --output_path merged_model
|
115 |
+
```
|
116 |
+
|
117 |
+
### Merging LORA with LORA
|
118 |
+
|
119 |
+
```bash
|
120 |
+
$ lora_add --path_1 PATH_TO_LORA.PT --path_2 PATH_TO_LORA.PT --mode lpl --alpha 0.5 --output_path OUTPUT_PATH.PT
|
121 |
+
```
|
122 |
+
|
123 |
+
alpha is the ratio of the first model to the second model. i.e.,
|
124 |
+
|
125 |
+
$$
|
126 |
+
\Delta W = (\alpha A_1 + (1 - \alpha) A_2) (B_1 + (1 - \alpha) B_2)^T
|
127 |
+
$$
|
128 |
+
|
129 |
+
Set alpha to 0.5 to get the average of the two models. Set alpha close to 1.0 to get more effect of the first model, and set alpha close to 0.0 to get more effect of the second model.
|
130 |
+
|
131 |
+
**Example**
|
132 |
+
|
133 |
+
```bash
|
134 |
+
$ lora_add --path_1 lora_illust.pt --path_2 lora_pop.pt --alpha 0.3 --output_path lora_merged.pt
|
135 |
+
```
|
136 |
+
|
137 |
+
### Making Inference with trained LORA
|
138 |
+
|
139 |
+
Checkout `scripts/run_inference.ipynb` for an example of how to make inference with LORA.
|
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import StableDiffusionPipeline
|
2 |
+
from lora_diffusion import monkeypatch_lora, tune_lora_scale
|
3 |
+
import torch
|
4 |
+
import os, shutil
|
5 |
+
import gradio as gr
|
6 |
+
import subprocess
|
7 |
+
|
8 |
+
MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
|
9 |
+
INSTANCE_DIR="./data_example"
|
10 |
+
OUTPUT_DIR="./output_example"
|
11 |
+
|
12 |
+
model_id = "stabilityai/stable-diffusion-2-1-base"
|
13 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
|
14 |
+
#prompt = "style of sks, baby lion"
|
15 |
+
torch.manual_seed(1)
|
16 |
+
#image = pipe(prompt, num_inference_steps=50, guidance_scale= 7).images[0] #no need
|
17 |
+
#image # nice. diffusers are cool. #no need
|
18 |
+
#finetuned_lora_weights = "./lora_weight.pt"
|
19 |
+
|
20 |
+
#global var
|
21 |
+
counter = 0
|
22 |
+
|
23 |
+
#Getting Lora fine-tuned weights
|
24 |
+
def monkeypatching(unet_alpha,texten_alpha, in_prompt, wts): #, prompt, pipe): finetuned_lora_weights
|
25 |
+
print("****** inside monkeypatching *******")
|
26 |
+
print(f"in_prompt is - {str(in_prompt)}")
|
27 |
+
global counter
|
28 |
+
#if model == 'Text-encoder':
|
29 |
+
unet_wt = wts[-2]
|
30 |
+
#else:
|
31 |
+
texten_wt = wts[-1]
|
32 |
+
print(f"UNET weight is = {unet_wt}, Text-encoder weight is = {texten_wt}")
|
33 |
+
if counter == 0 :
|
34 |
+
#if wt == "./lora_playgroundai_wt.pt" :
|
35 |
+
monkeypatch_lora(pipe.unet, torch.load(unet_wt)) #finetuned_lora_weights
|
36 |
+
monkeypatch_lora(pipe.text_encoder, torch.load(texten_wt), target_replace_module=["CLIPAttention"]) #text-encoder #"./lora/lora_kiriko.text_encoder.pt"
|
37 |
+
#tune_lora_scale(pipe.unet, alpha) #1.00)
|
38 |
+
tune_lora_scale(pipe.unet, unet_alpha)
|
39 |
+
tune_lora_scale(pipe.text_encoder, texten_alpha)
|
40 |
+
counter +=1
|
41 |
+
#else:
|
42 |
+
#monkeypatch_lora(pipe.unet, torch.load("./output_example/lora_weight.pt")) #finetuned_lora_weights
|
43 |
+
#tune_lora_scale(pipe.unet, alpha) #1.00)
|
44 |
+
#counter +=1
|
45 |
+
else :
|
46 |
+
tune_lora_scale(pipe.unet, unet_alpha)
|
47 |
+
tune_lora_scale(pipe.text_encoder, texten_alpha)
|
48 |
+
#tune_lora_scale(pipe.unet, alpha) #1.00)
|
49 |
+
prompt = str(in_prompt) #"style of hclu, " + str(in_prompt) #"baby lion"
|
50 |
+
image = pipe(prompt, num_inference_steps=50, guidance_scale=7).images[0]
|
51 |
+
image.save("./illust_lora.jpg") #"./contents/illust_lora.jpg")
|
52 |
+
return image
|
53 |
+
|
54 |
+
#{in_prompt} line68 --pl ignore
|
55 |
+
def accelerate_train_lora(steps, images, in_prompt):
|
56 |
+
print("*********** inside accelerate_train_lora ***********")
|
57 |
+
print(f"images are -- {images}")
|
58 |
+
# path can be retrieved by file_obj.name and original filename can be retrieved with file_obj.orig_name
|
59 |
+
for file in images:
|
60 |
+
print(f"file passed -- {file.name}")
|
61 |
+
os.makedirs(INSTANCE_DIR, exist_ok=True)
|
62 |
+
shutil.copy( file.name, INSTANCE_DIR) #/{file.orig_name}
|
63 |
+
#subprocess.Popen(f'accelerate launch {"./train_lora_dreambooth.py"} \
|
64 |
+
os.system( f"accelerate launch {'./train_lora_dreambooth.py'} --pretrained_model_name_or_path={MODEL_NAME} --instance_data_dir={INSTANCE_DIR} --output_dir={OUTPUT_DIR} --instance_prompt='{in_prompt}' --train_text_encoder --resolution=512 --train_batch_size=1 --gradient_accumulation_steps=1 --learning_rate='1e-4' --learning_rate_text='5e-5' --color_jitter --lr_scheduler='constant' --lr_warmup_steps=0 --max_train_steps={int(steps)}") #10000
|
65 |
+
print("*********** completing accelerate_train_lora ***********")
|
66 |
+
print(f"files in output_dir -- {os.listdir(OUTPUT_DIR)}")
|
67 |
+
#lora_trained_weights = "./output_example/lora_weight.pt"
|
68 |
+
files = os.listdir(OUTPUT_DIR)
|
69 |
+
file_list = []
|
70 |
+
for file in files: #os.listdir(OUTPUT_DIR):
|
71 |
+
if file.endswith(".pt"):
|
72 |
+
print("weight files are -- ",os.path.join(f"{OUTPUT_DIR}", file))
|
73 |
+
file_list.append(os.path.join(f"{OUTPUT_DIR}", file))
|
74 |
+
return file_list #files[1:]
|
75 |
+
#return f"{OUTPUT_DIR}/*.pt"
|
76 |
+
|
77 |
+
with gr.Blocks() as demo:
|
78 |
+
gr.Markdown("""<h1><center><b>LORA</b> (Low-rank Adaptation) for Faster Text-to-Image Diffusion Fine-tuning (UNET+CLIP)</center></h1>
|
79 |
+
""")
|
80 |
+
gr.HTML("<p>You can skip the queue by duplicating this space and upgrading to GPU in settings: <a style='display:inline-block' href='https://huggingface.co/spaces/ysharma/Low-rank-Adaptation?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14' alt='Duplicate Space'></a></p>")
|
81 |
+
#gr.Markdown("""<b>NEW!!</b> : I have fine-tuned the SD model for 15,000 steps using 100 PlaygroundAI images and LORA. You can load this trained model using the example component. Load the weight and start using the Space with the Inference button. Feel free to toggle the Alpha value.""")
|
82 |
+
gr.Markdown(
|
83 |
+
"""**Main Features**<br>- Fine-tune Stable diffusion models twice as faster as Dreambooth method by Low-rank Adaptation.<br>- Get insanely small end results, easy to share and download.<br>- Easy to use, compatible with diffusers.<br>- Sometimes even better performance than full fine-tuning<br><br>Please refer to the GitHub repo this Space is based on, here - <a href = "https://github.com/cloneofsimo/lora">LORA</a>. You can also refer to this tweet by AK to quote/retweet/like here on <a href="https://twitter.com/_akhaliq/status/1601120767009513472">Twitter</a>.This Gradio Space is an attempt to explore this novel LORA approach to fine-tune Stable diffusion models, using the power and flexibility of Gradio! The higher number of steps results in longer training time and better fine-tuned SD models.<br><br><b>To use this Space well:</b><br>- First, upload your set of images (suggested number of images is between 4-9), enter the prompt, enter the number of fine-tuning steps (suggested value is between 2000-4000), and then press the 'Train LORA model' button. This will produce your fine-tuned model weights.<br>- Modify the previous prompt by adding to it (suffix), set the alpha value using the Slider (nearer to 1 implies overfitting to the uploaded images), and then press the 'Inference' button. This will produce an image by the newly fine-tuned UNET and Text-Encoder LORA models.<br><b>Bonus:</b>You can download your fine-tuned model weights from the Gradio file component. The smaller size of LORA models (around 3-4 MB files) is the main highlight of this 'Low-rank Adaptation' approach of fine-tuning.""")
|
84 |
+
|
85 |
+
with gr.Row():
|
86 |
+
in_images = gr.File(label="Upload images to fine-tune for LORA", file_count="multiple")
|
87 |
+
with gr.Column():
|
88 |
+
b1 = gr.Button(value="Train LORA model")
|
89 |
+
in_prompt = gr.Textbox(label="Enter a prompt for fine-tuned LORA model", visible=True)
|
90 |
+
b2 = gr.Button(value="Inference using LORA model")
|
91 |
+
|
92 |
+
with gr.Row():
|
93 |
+
out_image = gr.Image(label="Image generated by LORA model")
|
94 |
+
with gr.Column():
|
95 |
+
with gr.Accordion("Advance settings for Training and Inference", open=False):
|
96 |
+
gr.Markdown("Advance settings for a number of Training Steps and Alpha. Set alpha to 1.0 to fully add LORA. If the LORA seems to have too much effect (i.e., overfitting), set alpha to a lower value. If the LORA seems to have too little effect, set the alpha higher. You can tune these two values to your needs.")
|
97 |
+
in_steps = gr.Number(label="Enter the number of training steps", value = 2000)
|
98 |
+
in_alpha_unet = gr.Slider(0.1,1.0, step=0.01, label="Set UNET Alpha level", value=0.5)
|
99 |
+
in_alpha_texten = gr.Slider(0.1,1.0, step=0.01, label="Set Text-Encoder Alpha level", value=0.5)
|
100 |
+
#in_model = gr.Radio(["Text-encoder", "Unet"], label="Select the fine-tuned model for inference", value="Text-encoder", type="value")
|
101 |
+
out_file = gr.File(label="Lora trained model weights", file_count='multiple' )
|
102 |
+
|
103 |
+
#gr.Examples(
|
104 |
+
# examples=[[0.65, 0.6, "lion", ["./lora_playgroundai_wt.pt","./lora_playgroundai_wt.pt"], ],],
|
105 |
+
# inputs=[in_alpha_unet, in_alpha_texten, in_prompt, out_file ],
|
106 |
+
# outputs=out_image,
|
107 |
+
# fn=monkeypatching,
|
108 |
+
# cache_examples=True,)
|
109 |
+
#gr.Examples(
|
110 |
+
# examples=[[2500, ['./simba1.jpg', './simba2.jpg', './simba3.jpg', './simba4.jpg'], "baby lion in disney style"]],
|
111 |
+
# inputs=[in_steps, in_images, in_prompt],
|
112 |
+
# outputs=out_file,
|
113 |
+
# fn=accelerate_train_lora,
|
114 |
+
# cache_examples=False,
|
115 |
+
# run_on_click=False)
|
116 |
+
|
117 |
+
b1.click(fn = accelerate_train_lora, inputs=[in_steps, in_images, in_prompt] , outputs=out_file)
|
118 |
+
b2.click(fn = monkeypatching, inputs=[in_alpha_unet, in_alpha_texten, in_prompt, out_file,], outputs=out_image)
|
119 |
+
|
120 |
+
demo.queue(concurrency_count=3)
|
121 |
+
demo.launch(debug=True, show_error=True,)
|
contents/alpha_scale.gif
ADDED
![]() |
Git LFS Details
|
contents/alpha_scale.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ad74f5f69d99bfcbeee1d4d2b3900ac1ca7ff83fba5ddf8269ffed8a56c9c6e
|
3 |
+
size 5247140
|
contents/disney_lora.jpg
ADDED
![]() |
contents/pop_art.jpg
ADDED
![]() |
lora_diffusion/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .lora import *
|
lora_diffusion/cli_lora_add.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Union, Dict
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import fire
|
5 |
+
from diffusers import StableDiffusionPipeline
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from .lora import tune_lora_scale, weight_apply_lora
|
9 |
+
from .to_ckpt_v2 import convert_to_ckpt
|
10 |
+
|
11 |
+
|
12 |
+
def _text_lora_path(path: str) -> str:
|
13 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
14 |
+
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
|
15 |
+
|
16 |
+
|
17 |
+
def add(
|
18 |
+
path_1: str,
|
19 |
+
path_2: str,
|
20 |
+
output_path: str,
|
21 |
+
alpha: float = 0.5,
|
22 |
+
mode: Literal[
|
23 |
+
"lpl",
|
24 |
+
"upl",
|
25 |
+
"upl-ckpt-v2",
|
26 |
+
] = "lpl",
|
27 |
+
with_text_lora: bool = False,
|
28 |
+
):
|
29 |
+
print("Lora Add, mode " + mode)
|
30 |
+
if mode == "lpl":
|
31 |
+
for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + (
|
32 |
+
[(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")]
|
33 |
+
if with_text_lora
|
34 |
+
else []
|
35 |
+
):
|
36 |
+
print("Loading", _path_1, _path_2)
|
37 |
+
out_list = []
|
38 |
+
if opt == "text_encoder":
|
39 |
+
if not os.path.exists(_path_1):
|
40 |
+
print(f"No text encoder found in {_path_1}, skipping...")
|
41 |
+
continue
|
42 |
+
if not os.path.exists(_path_2):
|
43 |
+
print(f"No text encoder found in {_path_1}, skipping...")
|
44 |
+
continue
|
45 |
+
|
46 |
+
l1 = torch.load(_path_1)
|
47 |
+
l2 = torch.load(_path_2)
|
48 |
+
|
49 |
+
l1pairs = zip(l1[::2], l1[1::2])
|
50 |
+
l2pairs = zip(l2[::2], l2[1::2])
|
51 |
+
|
52 |
+
for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
|
53 |
+
# print("Merging", x1.shape, y1.shape, x2.shape, y2.shape)
|
54 |
+
x1.data = alpha * x1.data + (1 - alpha) * x2.data
|
55 |
+
y1.data = alpha * y1.data + (1 - alpha) * y2.data
|
56 |
+
|
57 |
+
out_list.append(x1)
|
58 |
+
out_list.append(y1)
|
59 |
+
|
60 |
+
if opt == "unet":
|
61 |
+
|
62 |
+
print("Saving merged UNET to", output_path)
|
63 |
+
torch.save(out_list, output_path)
|
64 |
+
|
65 |
+
elif opt == "text_encoder":
|
66 |
+
print("Saving merged text encoder to", _text_lora_path(output_path))
|
67 |
+
torch.save(
|
68 |
+
out_list,
|
69 |
+
_text_lora_path(output_path),
|
70 |
+
)
|
71 |
+
|
72 |
+
elif mode == "upl":
|
73 |
+
|
74 |
+
loaded_pipeline = StableDiffusionPipeline.from_pretrained(
|
75 |
+
path_1,
|
76 |
+
).to("cpu")
|
77 |
+
|
78 |
+
weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
|
79 |
+
if with_text_lora:
|
80 |
+
|
81 |
+
weight_apply_lora(
|
82 |
+
loaded_pipeline.text_encoder,
|
83 |
+
torch.load(_text_lora_path(path_2)),
|
84 |
+
alpha=alpha,
|
85 |
+
target_replace_module=["CLIPAttention"],
|
86 |
+
)
|
87 |
+
|
88 |
+
loaded_pipeline.save_pretrained(output_path)
|
89 |
+
|
90 |
+
elif mode == "upl-ckpt-v2":
|
91 |
+
|
92 |
+
loaded_pipeline = StableDiffusionPipeline.from_pretrained(
|
93 |
+
path_1,
|
94 |
+
).to("cpu")
|
95 |
+
|
96 |
+
weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
|
97 |
+
if with_text_lora:
|
98 |
+
weight_apply_lora(
|
99 |
+
loaded_pipeline.text_encoder,
|
100 |
+
torch.load(_text_lora_path(path_2)),
|
101 |
+
alpha=alpha,
|
102 |
+
target_replace_module=["CLIPAttention"],
|
103 |
+
)
|
104 |
+
|
105 |
+
_tmp_output = output_path + ".tmp"
|
106 |
+
|
107 |
+
loaded_pipeline.save_pretrained(_tmp_output)
|
108 |
+
convert_to_ckpt(_tmp_output, output_path, as_half=True)
|
109 |
+
# remove the tmp_output folder
|
110 |
+
shutil.rmtree(_tmp_output)
|
111 |
+
|
112 |
+
else:
|
113 |
+
print("Unknown mode", mode)
|
114 |
+
raise ValueError(f"Unknown mode {mode}")
|
115 |
+
|
116 |
+
|
117 |
+
def main():
|
118 |
+
fire.Fire(add)
|
lora_diffusion/lora.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Callable, Dict, List, Optional, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import PIL
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
|
12 |
+
class LoraInjectedLinear(nn.Module):
|
13 |
+
def __init__(self, in_features, out_features, bias=False, r=4):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
if r > min(in_features, out_features):
|
17 |
+
raise ValueError(
|
18 |
+
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
|
19 |
+
)
|
20 |
+
|
21 |
+
self.linear = nn.Linear(in_features, out_features, bias)
|
22 |
+
self.lora_down = nn.Linear(in_features, r, bias=False)
|
23 |
+
self.lora_up = nn.Linear(r, out_features, bias=False)
|
24 |
+
self.scale = 1.0
|
25 |
+
|
26 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r**2)
|
27 |
+
nn.init.zeros_(self.lora_up.weight)
|
28 |
+
|
29 |
+
def forward(self, input):
|
30 |
+
return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale
|
31 |
+
|
32 |
+
|
33 |
+
def inject_trainable_lora(
|
34 |
+
model: nn.Module,
|
35 |
+
target_replace_module: List[str] = ["CrossAttention", "Attention"],
|
36 |
+
r: int = 4,
|
37 |
+
loras=None, # path to lora .pt
|
38 |
+
):
|
39 |
+
"""
|
40 |
+
inject lora into model, and returns lora parameter groups.
|
41 |
+
"""
|
42 |
+
|
43 |
+
require_grad_params = []
|
44 |
+
names = []
|
45 |
+
|
46 |
+
if loras != None:
|
47 |
+
loras = torch.load(loras)
|
48 |
+
|
49 |
+
for _module in model.modules():
|
50 |
+
if _module.__class__.__name__ in target_replace_module:
|
51 |
+
|
52 |
+
for name, _child_module in _module.named_modules():
|
53 |
+
if _child_module.__class__.__name__ == "Linear":
|
54 |
+
|
55 |
+
weight = _child_module.weight
|
56 |
+
bias = _child_module.bias
|
57 |
+
_tmp = LoraInjectedLinear(
|
58 |
+
_child_module.in_features,
|
59 |
+
_child_module.out_features,
|
60 |
+
_child_module.bias is not None,
|
61 |
+
r,
|
62 |
+
)
|
63 |
+
_tmp.linear.weight = weight
|
64 |
+
if bias is not None:
|
65 |
+
_tmp.linear.bias = bias
|
66 |
+
|
67 |
+
# switch the module
|
68 |
+
_module._modules[name] = _tmp
|
69 |
+
|
70 |
+
require_grad_params.append(
|
71 |
+
_module._modules[name].lora_up.parameters()
|
72 |
+
)
|
73 |
+
require_grad_params.append(
|
74 |
+
_module._modules[name].lora_down.parameters()
|
75 |
+
)
|
76 |
+
|
77 |
+
if loras != None:
|
78 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
79 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
80 |
+
|
81 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
82 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
83 |
+
names.append(name)
|
84 |
+
return require_grad_params, names
|
85 |
+
|
86 |
+
|
87 |
+
def extract_lora_ups_down(model, target_replace_module=["CrossAttention", "Attention"]):
|
88 |
+
|
89 |
+
loras = []
|
90 |
+
|
91 |
+
for _module in model.modules():
|
92 |
+
if _module.__class__.__name__ in target_replace_module:
|
93 |
+
for _child_module in _module.modules():
|
94 |
+
if _child_module.__class__.__name__ == "LoraInjectedLinear":
|
95 |
+
loras.append((_child_module.lora_up, _child_module.lora_down))
|
96 |
+
if len(loras) == 0:
|
97 |
+
raise ValueError("No lora injected.")
|
98 |
+
return loras
|
99 |
+
|
100 |
+
|
101 |
+
def save_lora_weight(
|
102 |
+
model, path="./lora.pt", target_replace_module=["CrossAttention", "Attention"]
|
103 |
+
):
|
104 |
+
weights = []
|
105 |
+
for _up, _down in extract_lora_ups_down(
|
106 |
+
model, target_replace_module=target_replace_module
|
107 |
+
):
|
108 |
+
weights.append(_up.weight)
|
109 |
+
weights.append(_down.weight)
|
110 |
+
|
111 |
+
torch.save(weights, path)
|
112 |
+
|
113 |
+
|
114 |
+
def save_lora_as_json(model, path="./lora.json"):
|
115 |
+
weights = []
|
116 |
+
for _up, _down in extract_lora_ups_down(model):
|
117 |
+
weights.append(_up.weight.detach().cpu().numpy().tolist())
|
118 |
+
weights.append(_down.weight.detach().cpu().numpy().tolist())
|
119 |
+
|
120 |
+
import json
|
121 |
+
|
122 |
+
with open(path, "w") as f:
|
123 |
+
json.dump(weights, f)
|
124 |
+
|
125 |
+
|
126 |
+
def weight_apply_lora(
|
127 |
+
model, loras, target_replace_module=["CrossAttention", "Attention"], alpha=1.0
|
128 |
+
):
|
129 |
+
|
130 |
+
for _module in model.modules():
|
131 |
+
if _module.__class__.__name__ in target_replace_module:
|
132 |
+
for _child_module in _module.modules():
|
133 |
+
if _child_module.__class__.__name__ == "Linear":
|
134 |
+
|
135 |
+
weight = _child_module.weight
|
136 |
+
|
137 |
+
up_weight = loras.pop(0).detach().to(weight.device)
|
138 |
+
down_weight = loras.pop(0).detach().to(weight.device)
|
139 |
+
|
140 |
+
# W <- W + U * D
|
141 |
+
weight = weight + alpha * (up_weight @ down_weight).type(
|
142 |
+
weight.dtype
|
143 |
+
)
|
144 |
+
_child_module.weight = nn.Parameter(weight)
|
145 |
+
|
146 |
+
|
147 |
+
def monkeypatch_lora(
|
148 |
+
model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
|
149 |
+
):
|
150 |
+
for _module in model.modules():
|
151 |
+
if _module.__class__.__name__ in target_replace_module:
|
152 |
+
for name, _child_module in _module.named_modules():
|
153 |
+
if _child_module.__class__.__name__ == "Linear":
|
154 |
+
|
155 |
+
weight = _child_module.weight
|
156 |
+
bias = _child_module.bias
|
157 |
+
_tmp = LoraInjectedLinear(
|
158 |
+
_child_module.in_features,
|
159 |
+
_child_module.out_features,
|
160 |
+
_child_module.bias is not None,
|
161 |
+
r=r,
|
162 |
+
)
|
163 |
+
_tmp.linear.weight = weight
|
164 |
+
|
165 |
+
if bias is not None:
|
166 |
+
_tmp.linear.bias = bias
|
167 |
+
|
168 |
+
# switch the module
|
169 |
+
_module._modules[name] = _tmp
|
170 |
+
|
171 |
+
up_weight = loras.pop(0)
|
172 |
+
down_weight = loras.pop(0)
|
173 |
+
|
174 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
175 |
+
up_weight.type(weight.dtype)
|
176 |
+
)
|
177 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
178 |
+
down_weight.type(weight.dtype)
|
179 |
+
)
|
180 |
+
|
181 |
+
_module._modules[name].to(weight.device)
|
182 |
+
|
183 |
+
|
184 |
+
def monkeypatch_replace_lora(
|
185 |
+
model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
|
186 |
+
):
|
187 |
+
for _module in model.modules():
|
188 |
+
if _module.__class__.__name__ in target_replace_module:
|
189 |
+
for name, _child_module in _module.named_modules():
|
190 |
+
if _child_module.__class__.__name__ == "LoraInjectedLinear":
|
191 |
+
|
192 |
+
weight = _child_module.linear.weight
|
193 |
+
bias = _child_module.linear.bias
|
194 |
+
_tmp = LoraInjectedLinear(
|
195 |
+
_child_module.linear.in_features,
|
196 |
+
_child_module.linear.out_features,
|
197 |
+
_child_module.linear.bias is not None,
|
198 |
+
r=r,
|
199 |
+
)
|
200 |
+
_tmp.linear.weight = weight
|
201 |
+
|
202 |
+
if bias is not None:
|
203 |
+
_tmp.linear.bias = bias
|
204 |
+
|
205 |
+
# switch the module
|
206 |
+
_module._modules[name] = _tmp
|
207 |
+
|
208 |
+
up_weight = loras.pop(0)
|
209 |
+
down_weight = loras.pop(0)
|
210 |
+
|
211 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
212 |
+
up_weight.type(weight.dtype)
|
213 |
+
)
|
214 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
215 |
+
down_weight.type(weight.dtype)
|
216 |
+
)
|
217 |
+
|
218 |
+
_module._modules[name].to(weight.device)
|
219 |
+
|
220 |
+
|
221 |
+
def monkeypatch_add_lora(
|
222 |
+
model,
|
223 |
+
loras,
|
224 |
+
target_replace_module=["CrossAttention", "Attention"],
|
225 |
+
alpha: float = 1.0,
|
226 |
+
beta: float = 1.0,
|
227 |
+
):
|
228 |
+
for _module in model.modules():
|
229 |
+
if _module.__class__.__name__ in target_replace_module:
|
230 |
+
for name, _child_module in _module.named_modules():
|
231 |
+
if _child_module.__class__.__name__ == "LoraInjectedLinear":
|
232 |
+
|
233 |
+
weight = _child_module.linear.weight
|
234 |
+
|
235 |
+
up_weight = loras.pop(0)
|
236 |
+
down_weight = loras.pop(0)
|
237 |
+
|
238 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
239 |
+
up_weight.type(weight.dtype).to(weight.device) * alpha
|
240 |
+
+ _module._modules[name].lora_up.weight.to(weight.device) * beta
|
241 |
+
)
|
242 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
243 |
+
down_weight.type(weight.dtype).to(weight.device) * alpha
|
244 |
+
+ _module._modules[name].lora_down.weight.to(weight.device)
|
245 |
+
* beta
|
246 |
+
)
|
247 |
+
|
248 |
+
_module._modules[name].to(weight.device)
|
249 |
+
|
250 |
+
|
251 |
+
def tune_lora_scale(model, alpha: float = 1.0):
|
252 |
+
for _module in model.modules():
|
253 |
+
if _module.__class__.__name__ == "LoraInjectedLinear":
|
254 |
+
_module.scale = alpha
|
255 |
+
|
256 |
+
|
257 |
+
def _text_lora_path(path: str) -> str:
|
258 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
259 |
+
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
|
260 |
+
|
261 |
+
|
262 |
+
def _ti_lora_path(path: str) -> str:
|
263 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
264 |
+
return ".".join(path.split(".")[:-1] + ["ti", "pt"])
|
265 |
+
|
266 |
+
|
267 |
+
def load_learned_embed_in_clip(
|
268 |
+
learned_embeds_path, text_encoder, tokenizer, token=None, idempotent=False
|
269 |
+
):
|
270 |
+
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
|
271 |
+
|
272 |
+
# separate token and the embeds
|
273 |
+
trained_token = list(loaded_learned_embeds.keys())[0]
|
274 |
+
embeds = loaded_learned_embeds[trained_token]
|
275 |
+
|
276 |
+
# cast to dtype of text_encoder
|
277 |
+
dtype = text_encoder.get_input_embeddings().weight.dtype
|
278 |
+
|
279 |
+
# add the token in tokenizer
|
280 |
+
token = token if token is not None else trained_token
|
281 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
282 |
+
i = 1
|
283 |
+
if num_added_tokens == 0 and idempotent:
|
284 |
+
return token
|
285 |
+
|
286 |
+
while num_added_tokens == 0:
|
287 |
+
print(f"The tokenizer already contains the token {token}.")
|
288 |
+
token = f"{token[:-1]}-{i}>"
|
289 |
+
print(f"Attempting to add the token {token}.")
|
290 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
291 |
+
i += 1
|
292 |
+
|
293 |
+
# resize the token embeddings
|
294 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
295 |
+
|
296 |
+
# get the id for the token and assign the embeds
|
297 |
+
token_id = tokenizer.convert_tokens_to_ids(token)
|
298 |
+
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
299 |
+
return token
|
300 |
+
|
301 |
+
|
302 |
+
def patch_pipe(
|
303 |
+
pipe,
|
304 |
+
unet_path,
|
305 |
+
token,
|
306 |
+
alpha: float = 1.0,
|
307 |
+
r: int = 4,
|
308 |
+
patch_text=False,
|
309 |
+
patch_ti=False,
|
310 |
+
idempotent_token=True,
|
311 |
+
):
|
312 |
+
|
313 |
+
ti_path = _ti_lora_path(unet_path)
|
314 |
+
text_path = _text_lora_path(unet_path)
|
315 |
+
|
316 |
+
unet_has_lora = False
|
317 |
+
text_encoder_has_lora = False
|
318 |
+
|
319 |
+
for _module in pipe.unet.modules():
|
320 |
+
if _module.__class__.__name__ == "LoraInjectedLinear":
|
321 |
+
unet_has_lora = True
|
322 |
+
|
323 |
+
for _module in pipe.text_encoder.modules():
|
324 |
+
if _module.__class__.__name__ == "LoraInjectedLinear":
|
325 |
+
text_encoder_has_lora = True
|
326 |
+
|
327 |
+
if not unet_has_lora:
|
328 |
+
monkeypatch_lora(pipe.unet, torch.load(unet_path), r=r)
|
329 |
+
else:
|
330 |
+
monkeypatch_replace_lora(pipe.unet, torch.load(unet_path), r=r)
|
331 |
+
|
332 |
+
if patch_text:
|
333 |
+
if not text_encoder_has_lora:
|
334 |
+
monkeypatch_lora(
|
335 |
+
pipe.text_encoder,
|
336 |
+
torch.load(text_path),
|
337 |
+
target_replace_module=["CLIPAttention"],
|
338 |
+
r=r,
|
339 |
+
)
|
340 |
+
else:
|
341 |
+
|
342 |
+
monkeypatch_replace_lora(
|
343 |
+
pipe.text_encoder,
|
344 |
+
torch.load(text_path),
|
345 |
+
target_replace_module=["CLIPAttention"],
|
346 |
+
r=r,
|
347 |
+
)
|
348 |
+
if patch_ti:
|
349 |
+
token = load_learned_embed_in_clip(
|
350 |
+
ti_path,
|
351 |
+
pipe.text_encoder,
|
352 |
+
pipe.tokenizer,
|
353 |
+
token,
|
354 |
+
idempotent=idempotent_token,
|
355 |
+
)
|
lora_diffusion/to_ckpt_v2.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05
|
2 |
+
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
|
3 |
+
# *Only* converts the UNet, VAE, and Text Encoder.
|
4 |
+
# Does not convert optimizer state or any other thing.
|
5 |
+
# Written by jachiam
|
6 |
+
import argparse
|
7 |
+
import os.path as osp
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
# =================#
|
13 |
+
# UNet Conversion #
|
14 |
+
# =================#
|
15 |
+
|
16 |
+
unet_conversion_map = [
|
17 |
+
# (stable-diffusion, HF Diffusers)
|
18 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
19 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
20 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
21 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
22 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
23 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
24 |
+
("out.0.weight", "conv_norm_out.weight"),
|
25 |
+
("out.0.bias", "conv_norm_out.bias"),
|
26 |
+
("out.2.weight", "conv_out.weight"),
|
27 |
+
("out.2.bias", "conv_out.bias"),
|
28 |
+
]
|
29 |
+
|
30 |
+
unet_conversion_map_resnet = [
|
31 |
+
# (stable-diffusion, HF Diffusers)
|
32 |
+
("in_layers.0", "norm1"),
|
33 |
+
("in_layers.2", "conv1"),
|
34 |
+
("out_layers.0", "norm2"),
|
35 |
+
("out_layers.3", "conv2"),
|
36 |
+
("emb_layers.1", "time_emb_proj"),
|
37 |
+
("skip_connection", "conv_shortcut"),
|
38 |
+
]
|
39 |
+
|
40 |
+
unet_conversion_map_layer = []
|
41 |
+
# hardcoded number of downblocks and resnets/attentions...
|
42 |
+
# would need smarter logic for other networks.
|
43 |
+
for i in range(4):
|
44 |
+
# loop over downblocks/upblocks
|
45 |
+
|
46 |
+
for j in range(2):
|
47 |
+
# loop over resnets/attentions for downblocks
|
48 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
49 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
50 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
51 |
+
|
52 |
+
if i < 3:
|
53 |
+
# no attention layers in down_blocks.3
|
54 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
55 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
56 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
57 |
+
|
58 |
+
for j in range(3):
|
59 |
+
# loop over resnets/attentions for upblocks
|
60 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
61 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
62 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
63 |
+
|
64 |
+
if i > 0:
|
65 |
+
# no attention layers in up_blocks.0
|
66 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
67 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
68 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
69 |
+
|
70 |
+
if i < 3:
|
71 |
+
# no downsample in down_blocks.3
|
72 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
73 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
74 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
75 |
+
|
76 |
+
# no upsample in up_blocks.3
|
77 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
78 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
79 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
80 |
+
|
81 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
82 |
+
sd_mid_atn_prefix = "middle_block.1."
|
83 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
84 |
+
|
85 |
+
for j in range(2):
|
86 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
87 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
88 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
89 |
+
|
90 |
+
|
91 |
+
def convert_unet_state_dict(unet_state_dict):
|
92 |
+
# buyer beware: this is a *brittle* function,
|
93 |
+
# and correct output requires that all of these pieces interact in
|
94 |
+
# the exact order in which I have arranged them.
|
95 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
96 |
+
for sd_name, hf_name in unet_conversion_map:
|
97 |
+
mapping[hf_name] = sd_name
|
98 |
+
for k, v in mapping.items():
|
99 |
+
if "resnets" in k:
|
100 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
101 |
+
v = v.replace(hf_part, sd_part)
|
102 |
+
mapping[k] = v
|
103 |
+
for k, v in mapping.items():
|
104 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
105 |
+
v = v.replace(hf_part, sd_part)
|
106 |
+
mapping[k] = v
|
107 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
108 |
+
return new_state_dict
|
109 |
+
|
110 |
+
|
111 |
+
# ================#
|
112 |
+
# VAE Conversion #
|
113 |
+
# ================#
|
114 |
+
|
115 |
+
vae_conversion_map = [
|
116 |
+
# (stable-diffusion, HF Diffusers)
|
117 |
+
("nin_shortcut", "conv_shortcut"),
|
118 |
+
("norm_out", "conv_norm_out"),
|
119 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
120 |
+
]
|
121 |
+
|
122 |
+
for i in range(4):
|
123 |
+
# down_blocks have two resnets
|
124 |
+
for j in range(2):
|
125 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
126 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
127 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
128 |
+
|
129 |
+
if i < 3:
|
130 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
131 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
132 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
133 |
+
|
134 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
135 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
136 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
137 |
+
|
138 |
+
# up_blocks have three resnets
|
139 |
+
# also, up blocks in hf are numbered in reverse from sd
|
140 |
+
for j in range(3):
|
141 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
142 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
143 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
144 |
+
|
145 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
146 |
+
for i in range(2):
|
147 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
148 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
149 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
150 |
+
|
151 |
+
|
152 |
+
vae_conversion_map_attn = [
|
153 |
+
# (stable-diffusion, HF Diffusers)
|
154 |
+
("norm.", "group_norm."),
|
155 |
+
("q.", "query."),
|
156 |
+
("k.", "key."),
|
157 |
+
("v.", "value."),
|
158 |
+
("proj_out.", "proj_attn."),
|
159 |
+
]
|
160 |
+
|
161 |
+
|
162 |
+
def reshape_weight_for_sd(w):
|
163 |
+
# convert HF linear weights to SD conv2d weights
|
164 |
+
return w.reshape(*w.shape, 1, 1)
|
165 |
+
|
166 |
+
|
167 |
+
def convert_vae_state_dict(vae_state_dict):
|
168 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
169 |
+
for k, v in mapping.items():
|
170 |
+
for sd_part, hf_part in vae_conversion_map:
|
171 |
+
v = v.replace(hf_part, sd_part)
|
172 |
+
mapping[k] = v
|
173 |
+
for k, v in mapping.items():
|
174 |
+
if "attentions" in k:
|
175 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
176 |
+
v = v.replace(hf_part, sd_part)
|
177 |
+
mapping[k] = v
|
178 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
179 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
180 |
+
for k, v in new_state_dict.items():
|
181 |
+
for weight_name in weights_to_convert:
|
182 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
183 |
+
print(f"Reshaping {k} for SD format")
|
184 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
185 |
+
return new_state_dict
|
186 |
+
|
187 |
+
|
188 |
+
# =========================#
|
189 |
+
# Text Encoder Conversion #
|
190 |
+
# =========================#
|
191 |
+
# pretty much a no-op
|
192 |
+
|
193 |
+
|
194 |
+
def convert_text_enc_state_dict(text_enc_dict):
|
195 |
+
return text_enc_dict
|
196 |
+
|
197 |
+
|
198 |
+
def convert_to_ckpt(model_path, checkpoint_path, as_half):
|
199 |
+
|
200 |
+
assert model_path is not None, "Must provide a model path!"
|
201 |
+
|
202 |
+
assert checkpoint_path is not None, "Must provide a checkpoint path!"
|
203 |
+
|
204 |
+
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
205 |
+
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
206 |
+
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
207 |
+
|
208 |
+
# Convert the UNet model
|
209 |
+
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
210 |
+
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
211 |
+
unet_state_dict = {
|
212 |
+
"model.diffusion_model." + k: v for k, v in unet_state_dict.items()
|
213 |
+
}
|
214 |
+
|
215 |
+
# Convert the VAE model
|
216 |
+
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
217 |
+
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
218 |
+
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
219 |
+
|
220 |
+
# Convert the text encoder model
|
221 |
+
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
222 |
+
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
223 |
+
text_enc_dict = {
|
224 |
+
"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()
|
225 |
+
}
|
226 |
+
|
227 |
+
# Put together new checkpoint
|
228 |
+
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
229 |
+
if as_half:
|
230 |
+
state_dict = {k: v.half() for k, v in state_dict.items()}
|
231 |
+
state_dict = {"state_dict": state_dict}
|
232 |
+
torch.save(state_dict, checkpoint_path)
|
lora_disney.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:72f687f810b86bb8cc64d2ece59886e2e96d29e3f57f97340ee147d168b8a5fe
|
3 |
+
size 3397249
|
lora_illust.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7f6acb0bc0cd5f96299be7839f89f58727e2666e58861e55866ea02125c97aba
|
3 |
+
size 3397249
|
lora_playgroundai_wt.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f4b62bd14a24b58dd46a07e348b41b9c97cccb62859ecd28e7b855cea4b845e4
|
3 |
+
size 3398857
|
lora_pop.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:18a1565852a08cfcff63e90670286c9427e3958f57de9b84e3f8b2c9a3a14b6c
|
3 |
+
size 3397249
|
lora_weight.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:33c55ddef387cec18e5990ccae0809c8c842c244892b9502fd7e65be5301266f
|
3 |
+
size 3398857
|
output_example/dummy.txt
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
git+https://github.com/huggingface/accelerate
|
3 |
+
git+https://github.com/huggingface/diffusers
|
4 |
+
git+https://github.com/cloneofsimo/lora.git
|
5 |
+
scipy
|
6 |
+
ftfy
|
7 |
+
torchvision
|
run_lora_db.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#https://github.com/huggingface/diffusers/tree/main/examples/dreambooth
|
2 |
+
export MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
|
3 |
+
export INSTANCE_DIR="./data_example"
|
4 |
+
export OUTPUT_DIR="./output_example"
|
5 |
+
|
6 |
+
accelerate launch train_lora_dreambooth.py \
|
7 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
8 |
+
--instance_data_dir=$INSTANCE_DIR \
|
9 |
+
--output_dir=$OUTPUT_DIR \
|
10 |
+
--instance_prompt="style of sks" \
|
11 |
+
--resolution=512 \
|
12 |
+
--train_batch_size=1 \
|
13 |
+
--gradient_accumulation_steps=1 \
|
14 |
+
--learning_rate=1e-4 \
|
15 |
+
--lr_scheduler="constant" \
|
16 |
+
--lr_warmup_steps=0 \
|
17 |
+
--max_train_steps=30000
|
scripts/make_alpha_gifs.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
scripts/run_inference.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
setup.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import pkg_resources
|
4 |
+
from setuptools import find_packages, setup
|
5 |
+
|
6 |
+
setup(
|
7 |
+
name="lora_diffusion",
|
8 |
+
py_modules=["lora_diffusion"],
|
9 |
+
version="0.0.1",
|
10 |
+
description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.",
|
11 |
+
author="Simo Ryu",
|
12 |
+
packages=find_packages(),
|
13 |
+
entry_points={
|
14 |
+
"console_scripts": [
|
15 |
+
"lora_add = lora_diffusion.cli_lora_add:main",
|
16 |
+
],
|
17 |
+
},
|
18 |
+
install_requires=[
|
19 |
+
str(r)
|
20 |
+
for r in pkg_resources.parse_requirements(
|
21 |
+
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
22 |
+
)
|
23 |
+
],
|
24 |
+
include_package_data=True,
|
25 |
+
)
|
simba1.jpg
ADDED
![]() |
simba2.jpg
ADDED
![]() |
simba3.jpg
ADDED
![]() |
simba4.jpg
ADDED
![]() |
train_lora_dreambooth.py
ADDED
@@ -0,0 +1,956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Bootstrapped from:
|
2 |
+
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import hashlib
|
6 |
+
import itertools
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import inspect
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torch.utils.checkpoint
|
16 |
+
|
17 |
+
|
18 |
+
from accelerate import Accelerator
|
19 |
+
from accelerate.logging import get_logger
|
20 |
+
from accelerate.utils import set_seed
|
21 |
+
from diffusers import (
|
22 |
+
AutoencoderKL,
|
23 |
+
DDPMScheduler,
|
24 |
+
StableDiffusionPipeline,
|
25 |
+
UNet2DConditionModel,
|
26 |
+
)
|
27 |
+
from diffusers.optimization import get_scheduler
|
28 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
29 |
+
|
30 |
+
from tqdm.auto import tqdm
|
31 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
32 |
+
|
33 |
+
from lora_diffusion import (
|
34 |
+
inject_trainable_lora,
|
35 |
+
save_lora_weight,
|
36 |
+
extract_lora_ups_down,
|
37 |
+
)
|
38 |
+
|
39 |
+
from torch.utils.data import Dataset
|
40 |
+
from PIL import Image
|
41 |
+
from torchvision import transforms
|
42 |
+
|
43 |
+
from pathlib import Path
|
44 |
+
|
45 |
+
import random
|
46 |
+
import re
|
47 |
+
|
48 |
+
|
49 |
+
class DreamBoothDataset(Dataset):
|
50 |
+
"""
|
51 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
52 |
+
It pre-processes the images and the tokenizes prompts.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
instance_data_root,
|
58 |
+
instance_prompt,
|
59 |
+
tokenizer,
|
60 |
+
class_data_root=None,
|
61 |
+
class_prompt=None,
|
62 |
+
size=512,
|
63 |
+
center_crop=False,
|
64 |
+
color_jitter=False,
|
65 |
+
):
|
66 |
+
self.size = size
|
67 |
+
self.center_crop = center_crop
|
68 |
+
self.tokenizer = tokenizer
|
69 |
+
|
70 |
+
self.instance_data_root = Path(instance_data_root)
|
71 |
+
if not self.instance_data_root.exists():
|
72 |
+
raise ValueError("Instance images root doesn't exists.")
|
73 |
+
|
74 |
+
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
75 |
+
self.num_instance_images = len(self.instance_images_path)
|
76 |
+
self.instance_prompt = instance_prompt
|
77 |
+
self._length = self.num_instance_images
|
78 |
+
|
79 |
+
if class_data_root is not None:
|
80 |
+
self.class_data_root = Path(class_data_root)
|
81 |
+
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
82 |
+
self.class_images_path = list(self.class_data_root.iterdir())
|
83 |
+
self.num_class_images = len(self.class_images_path)
|
84 |
+
self._length = max(self.num_class_images, self.num_instance_images)
|
85 |
+
self.class_prompt = class_prompt
|
86 |
+
else:
|
87 |
+
self.class_data_root = None
|
88 |
+
|
89 |
+
self.image_transforms = transforms.Compose(
|
90 |
+
[
|
91 |
+
transforms.Resize(
|
92 |
+
size, interpolation=transforms.InterpolationMode.BILINEAR
|
93 |
+
),
|
94 |
+
transforms.CenterCrop(size)
|
95 |
+
if center_crop
|
96 |
+
else transforms.RandomCrop(size),
|
97 |
+
transforms.ColorJitter(0.2, 0.1)
|
98 |
+
if color_jitter
|
99 |
+
else transforms.Lambda(lambda x: x),
|
100 |
+
transforms.ToTensor(),
|
101 |
+
transforms.Normalize([0.5], [0.5]),
|
102 |
+
]
|
103 |
+
)
|
104 |
+
|
105 |
+
def __len__(self):
|
106 |
+
return self._length
|
107 |
+
|
108 |
+
def __getitem__(self, index):
|
109 |
+
example = {}
|
110 |
+
instance_image = Image.open(
|
111 |
+
self.instance_images_path[index % self.num_instance_images]
|
112 |
+
)
|
113 |
+
if not instance_image.mode == "RGB":
|
114 |
+
instance_image = instance_image.convert("RGB")
|
115 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
116 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
117 |
+
self.instance_prompt,
|
118 |
+
padding="do_not_pad",
|
119 |
+
truncation=True,
|
120 |
+
max_length=self.tokenizer.model_max_length,
|
121 |
+
).input_ids
|
122 |
+
|
123 |
+
if self.class_data_root:
|
124 |
+
class_image = Image.open(
|
125 |
+
self.class_images_path[index % self.num_class_images]
|
126 |
+
)
|
127 |
+
if not class_image.mode == "RGB":
|
128 |
+
class_image = class_image.convert("RGB")
|
129 |
+
example["class_images"] = self.image_transforms(class_image)
|
130 |
+
example["class_prompt_ids"] = self.tokenizer(
|
131 |
+
self.class_prompt,
|
132 |
+
padding="do_not_pad",
|
133 |
+
truncation=True,
|
134 |
+
max_length=self.tokenizer.model_max_length,
|
135 |
+
).input_ids
|
136 |
+
|
137 |
+
return example
|
138 |
+
|
139 |
+
|
140 |
+
class PromptDataset(Dataset):
|
141 |
+
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
142 |
+
|
143 |
+
def __init__(self, prompt, num_samples):
|
144 |
+
self.prompt = prompt
|
145 |
+
self.num_samples = num_samples
|
146 |
+
|
147 |
+
def __len__(self):
|
148 |
+
return self.num_samples
|
149 |
+
|
150 |
+
def __getitem__(self, index):
|
151 |
+
example = {}
|
152 |
+
example["prompt"] = self.prompt
|
153 |
+
example["index"] = index
|
154 |
+
return example
|
155 |
+
|
156 |
+
|
157 |
+
logger = get_logger(__name__)
|
158 |
+
|
159 |
+
|
160 |
+
def parse_args(input_args=None):
|
161 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
162 |
+
parser.add_argument(
|
163 |
+
"--pretrained_model_name_or_path",
|
164 |
+
type=str,
|
165 |
+
default=None,
|
166 |
+
required=True,
|
167 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
168 |
+
)
|
169 |
+
parser.add_argument(
|
170 |
+
"--pretrained_vae_name_or_path",
|
171 |
+
type=str,
|
172 |
+
default=None,
|
173 |
+
help="Path to pretrained vae or vae identifier from huggingface.co/models.",
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--revision",
|
177 |
+
type=str,
|
178 |
+
default=None,
|
179 |
+
required=False,
|
180 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
181 |
+
)
|
182 |
+
parser.add_argument(
|
183 |
+
"--tokenizer_name",
|
184 |
+
type=str,
|
185 |
+
default=None,
|
186 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
187 |
+
)
|
188 |
+
parser.add_argument(
|
189 |
+
"--instance_data_dir",
|
190 |
+
type=str,
|
191 |
+
default=None,
|
192 |
+
required=True,
|
193 |
+
help="A folder containing the training data of instance images.",
|
194 |
+
)
|
195 |
+
parser.add_argument(
|
196 |
+
"--class_data_dir",
|
197 |
+
type=str,
|
198 |
+
default=None,
|
199 |
+
required=False,
|
200 |
+
help="A folder containing the training data of class images.",
|
201 |
+
)
|
202 |
+
parser.add_argument(
|
203 |
+
"--instance_prompt",
|
204 |
+
type=str,
|
205 |
+
default=None,
|
206 |
+
required=True,
|
207 |
+
help="The prompt with identifier specifying the instance",
|
208 |
+
)
|
209 |
+
parser.add_argument(
|
210 |
+
"--class_prompt",
|
211 |
+
type=str,
|
212 |
+
default=None,
|
213 |
+
help="The prompt to specify images in the same class as provided instance images.",
|
214 |
+
)
|
215 |
+
parser.add_argument(
|
216 |
+
"--with_prior_preservation",
|
217 |
+
default=False,
|
218 |
+
action="store_true",
|
219 |
+
help="Flag to add prior preservation loss.",
|
220 |
+
)
|
221 |
+
parser.add_argument(
|
222 |
+
"--prior_loss_weight",
|
223 |
+
type=float,
|
224 |
+
default=1.0,
|
225 |
+
help="The weight of prior preservation loss.",
|
226 |
+
)
|
227 |
+
parser.add_argument(
|
228 |
+
"--num_class_images",
|
229 |
+
type=int,
|
230 |
+
default=100,
|
231 |
+
help=(
|
232 |
+
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
|
233 |
+
" sampled with class_prompt."
|
234 |
+
),
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--output_dir",
|
238 |
+
type=str,
|
239 |
+
default="text-inversion-model",
|
240 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--seed", type=int, default=None, help="A seed for reproducible training."
|
244 |
+
)
|
245 |
+
parser.add_argument(
|
246 |
+
"--resolution",
|
247 |
+
type=int,
|
248 |
+
default=512,
|
249 |
+
help=(
|
250 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
251 |
+
" resolution"
|
252 |
+
),
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--center_crop",
|
256 |
+
action="store_true",
|
257 |
+
help="Whether to center crop images before resizing to resolution",
|
258 |
+
)
|
259 |
+
parser.add_argument(
|
260 |
+
"--color_jitter",
|
261 |
+
action="store_true",
|
262 |
+
help="Whether to apply color jitter to images",
|
263 |
+
)
|
264 |
+
parser.add_argument(
|
265 |
+
"--train_text_encoder",
|
266 |
+
action="store_true",
|
267 |
+
help="Whether to train the text encoder",
|
268 |
+
)
|
269 |
+
parser.add_argument(
|
270 |
+
"--train_batch_size",
|
271 |
+
type=int,
|
272 |
+
default=4,
|
273 |
+
help="Batch size (per device) for the training dataloader.",
|
274 |
+
)
|
275 |
+
parser.add_argument(
|
276 |
+
"--sample_batch_size",
|
277 |
+
type=int,
|
278 |
+
default=4,
|
279 |
+
help="Batch size (per device) for sampling images.",
|
280 |
+
)
|
281 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
282 |
+
parser.add_argument(
|
283 |
+
"--max_train_steps",
|
284 |
+
type=int,
|
285 |
+
default=None,
|
286 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
287 |
+
)
|
288 |
+
parser.add_argument(
|
289 |
+
"--save_steps",
|
290 |
+
type=int,
|
291 |
+
default=500,
|
292 |
+
help="Save checkpoint every X updates steps.",
|
293 |
+
)
|
294 |
+
parser.add_argument(
|
295 |
+
"--gradient_accumulation_steps",
|
296 |
+
type=int,
|
297 |
+
default=1,
|
298 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
299 |
+
)
|
300 |
+
parser.add_argument(
|
301 |
+
"--gradient_checkpointing",
|
302 |
+
action="store_true",
|
303 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
304 |
+
)
|
305 |
+
parser.add_argument(
|
306 |
+
"--lora_rank",
|
307 |
+
type=int,
|
308 |
+
default=4,
|
309 |
+
help="Rank of LoRA approximation.",
|
310 |
+
)
|
311 |
+
parser.add_argument(
|
312 |
+
"--learning_rate",
|
313 |
+
type=float,
|
314 |
+
default=None,
|
315 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
316 |
+
)
|
317 |
+
parser.add_argument(
|
318 |
+
"--learning_rate_text",
|
319 |
+
type=float,
|
320 |
+
default=5e-6,
|
321 |
+
help="Initial learning rate for text encoder (after the potential warmup period) to use.",
|
322 |
+
)
|
323 |
+
parser.add_argument(
|
324 |
+
"--scale_lr",
|
325 |
+
action="store_true",
|
326 |
+
default=False,
|
327 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
328 |
+
)
|
329 |
+
parser.add_argument(
|
330 |
+
"--lr_scheduler",
|
331 |
+
type=str,
|
332 |
+
default="constant",
|
333 |
+
help=(
|
334 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
335 |
+
' "constant", "constant_with_warmup"]'
|
336 |
+
),
|
337 |
+
)
|
338 |
+
parser.add_argument(
|
339 |
+
"--lr_warmup_steps",
|
340 |
+
type=int,
|
341 |
+
default=500,
|
342 |
+
help="Number of steps for the warmup in the lr scheduler.",
|
343 |
+
)
|
344 |
+
parser.add_argument(
|
345 |
+
"--use_8bit_adam",
|
346 |
+
action="store_true",
|
347 |
+
help="Whether or not to use 8-bit Adam from bitsandbytes.",
|
348 |
+
)
|
349 |
+
parser.add_argument(
|
350 |
+
"--adam_beta1",
|
351 |
+
type=float,
|
352 |
+
default=0.9,
|
353 |
+
help="The beta1 parameter for the Adam optimizer.",
|
354 |
+
)
|
355 |
+
parser.add_argument(
|
356 |
+
"--adam_beta2",
|
357 |
+
type=float,
|
358 |
+
default=0.999,
|
359 |
+
help="The beta2 parameter for the Adam optimizer.",
|
360 |
+
)
|
361 |
+
parser.add_argument(
|
362 |
+
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
|
363 |
+
)
|
364 |
+
parser.add_argument(
|
365 |
+
"--adam_epsilon",
|
366 |
+
type=float,
|
367 |
+
default=1e-08,
|
368 |
+
help="Epsilon value for the Adam optimizer",
|
369 |
+
)
|
370 |
+
parser.add_argument(
|
371 |
+
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
|
372 |
+
)
|
373 |
+
parser.add_argument(
|
374 |
+
"--push_to_hub",
|
375 |
+
action="store_true",
|
376 |
+
help="Whether or not to push the model to the Hub.",
|
377 |
+
)
|
378 |
+
parser.add_argument(
|
379 |
+
"--hub_token",
|
380 |
+
type=str,
|
381 |
+
default=None,
|
382 |
+
help="The token to use to push to the Model Hub.",
|
383 |
+
)
|
384 |
+
parser.add_argument(
|
385 |
+
"--logging_dir",
|
386 |
+
type=str,
|
387 |
+
default="logs",
|
388 |
+
help=(
|
389 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
390 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
391 |
+
),
|
392 |
+
)
|
393 |
+
parser.add_argument(
|
394 |
+
"--mixed_precision",
|
395 |
+
type=str,
|
396 |
+
default=None,
|
397 |
+
choices=["no", "fp16", "bf16"],
|
398 |
+
help=(
|
399 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
400 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
401 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
402 |
+
),
|
403 |
+
)
|
404 |
+
parser.add_argument(
|
405 |
+
"--local_rank",
|
406 |
+
type=int,
|
407 |
+
default=-1,
|
408 |
+
help="For distributed training: local_rank",
|
409 |
+
)
|
410 |
+
parser.add_argument(
|
411 |
+
"--resume_unet",
|
412 |
+
type=str,
|
413 |
+
default=None,
|
414 |
+
help=("File path for unet lora to resume training."),
|
415 |
+
)
|
416 |
+
parser.add_argument(
|
417 |
+
"--resume_text_encoder",
|
418 |
+
type=str,
|
419 |
+
default=None,
|
420 |
+
help=("File path for text encoder lora to resume training."),
|
421 |
+
)
|
422 |
+
|
423 |
+
if input_args is not None:
|
424 |
+
args = parser.parse_args(input_args)
|
425 |
+
else:
|
426 |
+
args = parser.parse_args()
|
427 |
+
|
428 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
429 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
430 |
+
args.local_rank = env_local_rank
|
431 |
+
|
432 |
+
if args.with_prior_preservation:
|
433 |
+
if args.class_data_dir is None:
|
434 |
+
raise ValueError("You must specify a data directory for class images.")
|
435 |
+
if args.class_prompt is None:
|
436 |
+
raise ValueError("You must specify prompt for class images.")
|
437 |
+
else:
|
438 |
+
if args.class_data_dir is not None:
|
439 |
+
logger.warning(
|
440 |
+
"You need not use --class_data_dir without --with_prior_preservation."
|
441 |
+
)
|
442 |
+
if args.class_prompt is not None:
|
443 |
+
logger.warning(
|
444 |
+
"You need not use --class_prompt without --with_prior_preservation."
|
445 |
+
)
|
446 |
+
|
447 |
+
return args
|
448 |
+
|
449 |
+
|
450 |
+
def main(args):
|
451 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
452 |
+
|
453 |
+
accelerator = Accelerator(
|
454 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
455 |
+
mixed_precision=args.mixed_precision,
|
456 |
+
log_with="tensorboard",
|
457 |
+
logging_dir=logging_dir,
|
458 |
+
)
|
459 |
+
|
460 |
+
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
461 |
+
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
462 |
+
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
463 |
+
if (
|
464 |
+
args.train_text_encoder
|
465 |
+
and args.gradient_accumulation_steps > 1
|
466 |
+
and accelerator.num_processes > 1
|
467 |
+
):
|
468 |
+
raise ValueError(
|
469 |
+
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
470 |
+
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
471 |
+
)
|
472 |
+
|
473 |
+
if args.seed is not None:
|
474 |
+
set_seed(args.seed)
|
475 |
+
|
476 |
+
if args.with_prior_preservation:
|
477 |
+
class_images_dir = Path(args.class_data_dir)
|
478 |
+
if not class_images_dir.exists():
|
479 |
+
class_images_dir.mkdir(parents=True)
|
480 |
+
cur_class_images = len(list(class_images_dir.iterdir()))
|
481 |
+
|
482 |
+
if cur_class_images < args.num_class_images:
|
483 |
+
torch_dtype = (
|
484 |
+
torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
485 |
+
)
|
486 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
487 |
+
args.pretrained_model_name_or_path,
|
488 |
+
torch_dtype=torch_dtype,
|
489 |
+
safety_checker=None,
|
490 |
+
revision=args.revision,
|
491 |
+
)
|
492 |
+
pipeline.set_progress_bar_config(disable=True)
|
493 |
+
|
494 |
+
num_new_images = args.num_class_images - cur_class_images
|
495 |
+
logger.info(f"Number of class images to sample: {num_new_images}.")
|
496 |
+
|
497 |
+
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
498 |
+
sample_dataloader = torch.utils.data.DataLoader(
|
499 |
+
sample_dataset, batch_size=args.sample_batch_size
|
500 |
+
)
|
501 |
+
|
502 |
+
sample_dataloader = accelerator.prepare(sample_dataloader)
|
503 |
+
pipeline.to(accelerator.device)
|
504 |
+
|
505 |
+
for example in tqdm(
|
506 |
+
sample_dataloader,
|
507 |
+
desc="Generating class images",
|
508 |
+
disable=not accelerator.is_local_main_process,
|
509 |
+
):
|
510 |
+
images = pipeline(example["prompt"]).images
|
511 |
+
|
512 |
+
for i, image in enumerate(images):
|
513 |
+
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
514 |
+
image_filename = (
|
515 |
+
class_images_dir
|
516 |
+
/ f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
517 |
+
)
|
518 |
+
image.save(image_filename)
|
519 |
+
|
520 |
+
del pipeline
|
521 |
+
if torch.cuda.is_available():
|
522 |
+
torch.cuda.empty_cache()
|
523 |
+
|
524 |
+
# Handle the repository creation
|
525 |
+
if accelerator.is_main_process:
|
526 |
+
|
527 |
+
if args.output_dir is not None:
|
528 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
529 |
+
|
530 |
+
# Load the tokenizer
|
531 |
+
if args.tokenizer_name:
|
532 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
533 |
+
args.tokenizer_name,
|
534 |
+
revision=args.revision,
|
535 |
+
)
|
536 |
+
elif args.pretrained_model_name_or_path:
|
537 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
538 |
+
args.pretrained_model_name_or_path,
|
539 |
+
subfolder="tokenizer",
|
540 |
+
revision=args.revision,
|
541 |
+
)
|
542 |
+
|
543 |
+
# Load models and create wrapper for stable diffusion
|
544 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
545 |
+
args.pretrained_model_name_or_path,
|
546 |
+
subfolder="text_encoder",
|
547 |
+
revision=args.revision,
|
548 |
+
)
|
549 |
+
vae = AutoencoderKL.from_pretrained(
|
550 |
+
args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
|
551 |
+
subfolder=None if args.pretrained_vae_name_or_path else "vae",
|
552 |
+
revision=None if args.pretrained_vae_name_or_path else args.revision,
|
553 |
+
)
|
554 |
+
unet = UNet2DConditionModel.from_pretrained(
|
555 |
+
args.pretrained_model_name_or_path,
|
556 |
+
subfolder="unet",
|
557 |
+
revision=args.revision,
|
558 |
+
)
|
559 |
+
unet.requires_grad_(False)
|
560 |
+
unet_lora_params, _ = inject_trainable_lora(
|
561 |
+
unet, r=args.lora_rank, loras=args.resume_unet
|
562 |
+
)
|
563 |
+
|
564 |
+
for _up, _down in extract_lora_ups_down(unet):
|
565 |
+
print("Before training: Unet First Layer lora up", _up.weight.data)
|
566 |
+
print("Before training: Unet First Layer lora down", _down.weight.data)
|
567 |
+
break
|
568 |
+
|
569 |
+
vae.requires_grad_(False)
|
570 |
+
text_encoder.requires_grad_(False)
|
571 |
+
|
572 |
+
if args.train_text_encoder:
|
573 |
+
text_encoder_lora_params, _ = inject_trainable_lora(
|
574 |
+
text_encoder,
|
575 |
+
target_replace_module=["CLIPAttention"],
|
576 |
+
r=args.lora_rank,
|
577 |
+
)
|
578 |
+
for _up, _down in extract_lora_ups_down(
|
579 |
+
text_encoder, target_replace_module=["CLIPAttention"]
|
580 |
+
):
|
581 |
+
print("Before training: text encoder First Layer lora up", _up.weight.data)
|
582 |
+
print(
|
583 |
+
"Before training: text encoder First Layer lora down", _down.weight.data
|
584 |
+
)
|
585 |
+
break
|
586 |
+
|
587 |
+
if args.gradient_checkpointing:
|
588 |
+
unet.enable_gradient_checkpointing()
|
589 |
+
if args.train_text_encoder:
|
590 |
+
text_encoder.gradient_checkpointing_enable()
|
591 |
+
|
592 |
+
if args.scale_lr:
|
593 |
+
args.learning_rate = (
|
594 |
+
args.learning_rate
|
595 |
+
* args.gradient_accumulation_steps
|
596 |
+
* args.train_batch_size
|
597 |
+
* accelerator.num_processes
|
598 |
+
)
|
599 |
+
|
600 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
601 |
+
if args.use_8bit_adam:
|
602 |
+
try:
|
603 |
+
import bitsandbytes as bnb
|
604 |
+
except ImportError:
|
605 |
+
raise ImportError(
|
606 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
607 |
+
)
|
608 |
+
|
609 |
+
optimizer_class = bnb.optim.AdamW8bit
|
610 |
+
else:
|
611 |
+
optimizer_class = torch.optim.AdamW
|
612 |
+
|
613 |
+
text_lr = (
|
614 |
+
args.learning_rate
|
615 |
+
if args.learning_rate_text is None
|
616 |
+
else args.learning_rate_text
|
617 |
+
)
|
618 |
+
|
619 |
+
params_to_optimize = (
|
620 |
+
[
|
621 |
+
{"params": itertools.chain(*unet_lora_params), "lr": args.learning_rate},
|
622 |
+
{
|
623 |
+
"params": itertools.chain(*text_encoder_lora_params),
|
624 |
+
"lr": text_lr,
|
625 |
+
},
|
626 |
+
]
|
627 |
+
if args.train_text_encoder
|
628 |
+
else itertools.chain(*unet_lora_params)
|
629 |
+
)
|
630 |
+
optimizer = optimizer_class(
|
631 |
+
params_to_optimize,
|
632 |
+
lr=args.learning_rate,
|
633 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
634 |
+
weight_decay=args.adam_weight_decay,
|
635 |
+
eps=args.adam_epsilon,
|
636 |
+
)
|
637 |
+
|
638 |
+
noise_scheduler = DDPMScheduler.from_config(
|
639 |
+
args.pretrained_model_name_or_path, subfolder="scheduler"
|
640 |
+
)
|
641 |
+
|
642 |
+
train_dataset = DreamBoothDataset(
|
643 |
+
instance_data_root=args.instance_data_dir,
|
644 |
+
instance_prompt=args.instance_prompt,
|
645 |
+
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
646 |
+
class_prompt=args.class_prompt,
|
647 |
+
tokenizer=tokenizer,
|
648 |
+
size=args.resolution,
|
649 |
+
center_crop=args.center_crop,
|
650 |
+
color_jitter=args.color_jitter,
|
651 |
+
)
|
652 |
+
|
653 |
+
def collate_fn(examples):
|
654 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
655 |
+
pixel_values = [example["instance_images"] for example in examples]
|
656 |
+
|
657 |
+
# Concat class and instance examples for prior preservation.
|
658 |
+
# We do this to avoid doing two forward passes.
|
659 |
+
if args.with_prior_preservation:
|
660 |
+
input_ids += [example["class_prompt_ids"] for example in examples]
|
661 |
+
pixel_values += [example["class_images"] for example in examples]
|
662 |
+
|
663 |
+
pixel_values = torch.stack(pixel_values)
|
664 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
665 |
+
|
666 |
+
input_ids = tokenizer.pad(
|
667 |
+
{"input_ids": input_ids},
|
668 |
+
padding="max_length",
|
669 |
+
max_length=tokenizer.model_max_length,
|
670 |
+
return_tensors="pt",
|
671 |
+
).input_ids
|
672 |
+
|
673 |
+
batch = {
|
674 |
+
"input_ids": input_ids,
|
675 |
+
"pixel_values": pixel_values,
|
676 |
+
}
|
677 |
+
return batch
|
678 |
+
|
679 |
+
train_dataloader = torch.utils.data.DataLoader(
|
680 |
+
train_dataset,
|
681 |
+
batch_size=args.train_batch_size,
|
682 |
+
shuffle=True,
|
683 |
+
collate_fn=collate_fn,
|
684 |
+
num_workers=1,
|
685 |
+
)
|
686 |
+
|
687 |
+
# Scheduler and math around the number of training steps.
|
688 |
+
overrode_max_train_steps = False
|
689 |
+
num_update_steps_per_epoch = math.ceil(
|
690 |
+
len(train_dataloader) / args.gradient_accumulation_steps
|
691 |
+
)
|
692 |
+
if args.max_train_steps is None:
|
693 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
694 |
+
overrode_max_train_steps = True
|
695 |
+
|
696 |
+
lr_scheduler = get_scheduler(
|
697 |
+
args.lr_scheduler,
|
698 |
+
optimizer=optimizer,
|
699 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
700 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
701 |
+
)
|
702 |
+
|
703 |
+
if args.train_text_encoder:
|
704 |
+
(
|
705 |
+
unet,
|
706 |
+
text_encoder,
|
707 |
+
optimizer,
|
708 |
+
train_dataloader,
|
709 |
+
lr_scheduler,
|
710 |
+
) = accelerator.prepare(
|
711 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
712 |
+
)
|
713 |
+
else:
|
714 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
715 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
716 |
+
)
|
717 |
+
|
718 |
+
weight_dtype = torch.float32
|
719 |
+
if accelerator.mixed_precision == "fp16":
|
720 |
+
weight_dtype = torch.float16
|
721 |
+
elif accelerator.mixed_precision == "bf16":
|
722 |
+
weight_dtype = torch.bfloat16
|
723 |
+
|
724 |
+
# Move text_encode and vae to gpu.
|
725 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
726 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
727 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
728 |
+
if not args.train_text_encoder:
|
729 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
730 |
+
|
731 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
732 |
+
num_update_steps_per_epoch = math.ceil(
|
733 |
+
len(train_dataloader) / args.gradient_accumulation_steps
|
734 |
+
)
|
735 |
+
if overrode_max_train_steps:
|
736 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
737 |
+
# Afterwards we recalculate our number of training epochs
|
738 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
739 |
+
|
740 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
741 |
+
# The trackers initializes automatically on the main process.
|
742 |
+
if accelerator.is_main_process:
|
743 |
+
accelerator.init_trackers("dreambooth", config=vars(args))
|
744 |
+
|
745 |
+
# Train!
|
746 |
+
total_batch_size = (
|
747 |
+
args.train_batch_size
|
748 |
+
* accelerator.num_processes
|
749 |
+
* args.gradient_accumulation_steps
|
750 |
+
)
|
751 |
+
|
752 |
+
print("***** Running training *****")
|
753 |
+
print(f" Num examples = {len(train_dataset)}")
|
754 |
+
print(f" Num batches each epoch = {len(train_dataloader)}")
|
755 |
+
print(f" Num Epochs = {args.num_train_epochs}")
|
756 |
+
print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
757 |
+
print(
|
758 |
+
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
759 |
+
)
|
760 |
+
print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
761 |
+
print(f" Total optimization steps = {args.max_train_steps}")
|
762 |
+
# Only show the progress bar once on each machine.
|
763 |
+
progress_bar = tqdm(
|
764 |
+
range(args.max_train_steps), disable=not accelerator.is_local_main_process
|
765 |
+
)
|
766 |
+
progress_bar.set_description("Steps")
|
767 |
+
global_step = 0
|
768 |
+
last_save = 0
|
769 |
+
|
770 |
+
for epoch in range(args.num_train_epochs):
|
771 |
+
unet.train()
|
772 |
+
if args.train_text_encoder:
|
773 |
+
text_encoder.train()
|
774 |
+
|
775 |
+
for step, batch in enumerate(train_dataloader):
|
776 |
+
# Convert images to latent space
|
777 |
+
latents = vae.encode(
|
778 |
+
batch["pixel_values"].to(dtype=weight_dtype)
|
779 |
+
).latent_dist.sample()
|
780 |
+
latents = latents * 0.18215
|
781 |
+
|
782 |
+
# Sample noise that we'll add to the latents
|
783 |
+
noise = torch.randn_like(latents)
|
784 |
+
bsz = latents.shape[0]
|
785 |
+
# Sample a random timestep for each image
|
786 |
+
timesteps = torch.randint(
|
787 |
+
0,
|
788 |
+
noise_scheduler.config.num_train_timesteps,
|
789 |
+
(bsz,),
|
790 |
+
device=latents.device,
|
791 |
+
)
|
792 |
+
timesteps = timesteps.long()
|
793 |
+
|
794 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
795 |
+
# (this is the forward diffusion process)
|
796 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
797 |
+
|
798 |
+
# Get the text embedding for conditioning
|
799 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
800 |
+
|
801 |
+
# Predict the noise residual
|
802 |
+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
803 |
+
|
804 |
+
# Get the target for loss depending on the prediction type
|
805 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
806 |
+
target = noise
|
807 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
808 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
809 |
+
else:
|
810 |
+
raise ValueError(
|
811 |
+
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
812 |
+
)
|
813 |
+
|
814 |
+
if args.with_prior_preservation:
|
815 |
+
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
816 |
+
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
817 |
+
target, target_prior = torch.chunk(target, 2, dim=0)
|
818 |
+
|
819 |
+
# Compute instance loss
|
820 |
+
loss = (
|
821 |
+
F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
822 |
+
.mean([1, 2, 3])
|
823 |
+
.mean()
|
824 |
+
)
|
825 |
+
|
826 |
+
# Compute prior loss
|
827 |
+
prior_loss = F.mse_loss(
|
828 |
+
model_pred_prior.float(), target_prior.float(), reduction="mean"
|
829 |
+
)
|
830 |
+
|
831 |
+
# Add the prior loss to the instance loss.
|
832 |
+
loss = loss + args.prior_loss_weight * prior_loss
|
833 |
+
else:
|
834 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
835 |
+
|
836 |
+
accelerator.backward(loss)
|
837 |
+
if accelerator.sync_gradients:
|
838 |
+
params_to_clip = (
|
839 |
+
itertools.chain(unet.parameters(), text_encoder.parameters())
|
840 |
+
if args.train_text_encoder
|
841 |
+
else unet.parameters()
|
842 |
+
)
|
843 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
844 |
+
optimizer.step()
|
845 |
+
lr_scheduler.step()
|
846 |
+
progress_bar.update(1)
|
847 |
+
optimizer.zero_grad()
|
848 |
+
|
849 |
+
global_step += 1
|
850 |
+
|
851 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
852 |
+
if accelerator.sync_gradients:
|
853 |
+
if args.save_steps and global_step - last_save >= args.save_steps:
|
854 |
+
if accelerator.is_main_process:
|
855 |
+
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
|
856 |
+
# it, the models will be unwrapped, and when they are then used for further training,
|
857 |
+
# we will crash. pass this, but only to newer versions of accelerate. fixes
|
858 |
+
# https://github.com/huggingface/diffusers/issues/1566
|
859 |
+
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
|
860 |
+
inspect.signature(
|
861 |
+
accelerator.unwrap_model
|
862 |
+
).parameters.keys()
|
863 |
+
)
|
864 |
+
extra_args = (
|
865 |
+
{"keep_fp32_wrapper": True}
|
866 |
+
if accepts_keep_fp32_wrapper
|
867 |
+
else {}
|
868 |
+
)
|
869 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
870 |
+
args.pretrained_model_name_or_path,
|
871 |
+
unet=accelerator.unwrap_model(unet, **extra_args),
|
872 |
+
text_encoder=accelerator.unwrap_model(
|
873 |
+
text_encoder, **extra_args
|
874 |
+
),
|
875 |
+
revision=args.revision,
|
876 |
+
)
|
877 |
+
|
878 |
+
filename_unet = (
|
879 |
+
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
|
880 |
+
)
|
881 |
+
filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
|
882 |
+
print(f"save weights {filename_unet}, {filename_text_encoder}")
|
883 |
+
save_lora_weight(pipeline.unet, filename_unet)
|
884 |
+
if args.train_text_encoder:
|
885 |
+
save_lora_weight(
|
886 |
+
pipeline.text_encoder,
|
887 |
+
filename_text_encoder,
|
888 |
+
target_replace_module=["CLIPAttention"],
|
889 |
+
)
|
890 |
+
|
891 |
+
for _up, _down in extract_lora_ups_down(pipeline.unet):
|
892 |
+
print(
|
893 |
+
"First Unet Layer's Up Weight is now : ",
|
894 |
+
_up.weight.data,
|
895 |
+
)
|
896 |
+
print(
|
897 |
+
"First Unet Layer's Down Weight is now : ",
|
898 |
+
_down.weight.data,
|
899 |
+
)
|
900 |
+
break
|
901 |
+
if args.train_text_encoder:
|
902 |
+
for _up, _down in extract_lora_ups_down(
|
903 |
+
pipeline.text_encoder,
|
904 |
+
target_replace_module=["CLIPAttention"],
|
905 |
+
):
|
906 |
+
print(
|
907 |
+
"First Text Encoder Layer's Up Weight is now : ",
|
908 |
+
_up.weight.data,
|
909 |
+
)
|
910 |
+
print(
|
911 |
+
"First Text Encoder Layer's Down Weight is now : ",
|
912 |
+
_down.weight.data,
|
913 |
+
)
|
914 |
+
break
|
915 |
+
|
916 |
+
last_save = global_step
|
917 |
+
|
918 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
919 |
+
progress_bar.set_postfix(**logs)
|
920 |
+
accelerator.log(logs, step=global_step)
|
921 |
+
|
922 |
+
if global_step >= args.max_train_steps:
|
923 |
+
break
|
924 |
+
|
925 |
+
accelerator.wait_for_everyone()
|
926 |
+
|
927 |
+
# Create the pipeline using using the trained modules and save it.
|
928 |
+
if accelerator.is_main_process:
|
929 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
930 |
+
args.pretrained_model_name_or_path,
|
931 |
+
unet=accelerator.unwrap_model(unet),
|
932 |
+
text_encoder=accelerator.unwrap_model(text_encoder),
|
933 |
+
revision=args.revision,
|
934 |
+
)
|
935 |
+
|
936 |
+
print("\n\nLora TRAINING DONE!\n\n")
|
937 |
+
|
938 |
+
save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
|
939 |
+
if args.train_text_encoder:
|
940 |
+
save_lora_weight(
|
941 |
+
pipeline.text_encoder,
|
942 |
+
args.output_dir + "/lora_weight.text_encoder.pt",
|
943 |
+
target_replace_module=["CLIPAttention"],
|
944 |
+
)
|
945 |
+
|
946 |
+
if args.push_to_hub:
|
947 |
+
repo.push_to_hub(
|
948 |
+
commit_message="End of training", blocking=False, auto_lfs_prune=True
|
949 |
+
)
|
950 |
+
|
951 |
+
accelerator.end_training()
|
952 |
+
|
953 |
+
|
954 |
+
if __name__ == "__main__":
|
955 |
+
args = parse_args()
|
956 |
+
main(args)
|
train_lora_dreambooth1.py
ADDED
@@ -0,0 +1,964 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Bootstrapped from:
|
2 |
+
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import hashlib
|
6 |
+
import itertools
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torch.utils.checkpoint
|
15 |
+
|
16 |
+
|
17 |
+
from accelerate import Accelerator
|
18 |
+
from accelerate.logging import get_logger
|
19 |
+
from accelerate.utils import set_seed
|
20 |
+
from diffusers import (
|
21 |
+
AutoencoderKL,
|
22 |
+
DDPMScheduler,
|
23 |
+
StableDiffusionPipeline,
|
24 |
+
UNet2DConditionModel,
|
25 |
+
)
|
26 |
+
from diffusers.optimization import get_scheduler
|
27 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
28 |
+
|
29 |
+
from tqdm.auto import tqdm
|
30 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
31 |
+
|
32 |
+
from lora_diffusion import (
|
33 |
+
inject_trainable_lora,
|
34 |
+
save_lora_weight,
|
35 |
+
extract_lora_ups_down,
|
36 |
+
)
|
37 |
+
|
38 |
+
from torch.utils.data import Dataset
|
39 |
+
from PIL import Image
|
40 |
+
from torchvision import transforms
|
41 |
+
|
42 |
+
from pathlib import Path
|
43 |
+
|
44 |
+
import random
|
45 |
+
import re
|
46 |
+
|
47 |
+
|
48 |
+
class DreamBoothDataset(Dataset):
|
49 |
+
"""
|
50 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
51 |
+
It pre-processes the images and the tokenizes prompts.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
instance_data_root,
|
57 |
+
instance_prompt,
|
58 |
+
tokenizer,
|
59 |
+
class_data_root=None,
|
60 |
+
class_prompt=None,
|
61 |
+
size=512,
|
62 |
+
center_crop=False,
|
63 |
+
):
|
64 |
+
self.size = size
|
65 |
+
self.center_crop = center_crop
|
66 |
+
self.tokenizer = tokenizer
|
67 |
+
|
68 |
+
self.instance_data_root = Path(instance_data_root)
|
69 |
+
if not self.instance_data_root.exists():
|
70 |
+
raise ValueError("Instance images root doesn't exists.")
|
71 |
+
|
72 |
+
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
73 |
+
self.num_instance_images = len(self.instance_images_path)
|
74 |
+
self.instance_prompt = instance_prompt
|
75 |
+
self._length = self.num_instance_images
|
76 |
+
|
77 |
+
if class_data_root is not None:
|
78 |
+
self.class_data_root = Path(class_data_root)
|
79 |
+
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
80 |
+
self.class_images_path = list(self.class_data_root.iterdir())
|
81 |
+
self.num_class_images = len(self.class_images_path)
|
82 |
+
self._length = max(self.num_class_images, self.num_instance_images)
|
83 |
+
self.class_prompt = class_prompt
|
84 |
+
else:
|
85 |
+
self.class_data_root = None
|
86 |
+
|
87 |
+
self.image_transforms = transforms.Compose(
|
88 |
+
[
|
89 |
+
transforms.Resize(
|
90 |
+
size, interpolation=transforms.InterpolationMode.BILINEAR
|
91 |
+
),
|
92 |
+
transforms.CenterCrop(size)
|
93 |
+
if center_crop
|
94 |
+
else transforms.RandomCrop(size),
|
95 |
+
transforms.ToTensor(),
|
96 |
+
transforms.Normalize([0.5], [0.5]),
|
97 |
+
]
|
98 |
+
)
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return self._length
|
102 |
+
|
103 |
+
def __getitem__(self, index):
|
104 |
+
example = {}
|
105 |
+
instance_image = Image.open(
|
106 |
+
self.instance_images_path[index % self.num_instance_images]
|
107 |
+
)
|
108 |
+
if not instance_image.mode == "RGB":
|
109 |
+
instance_image = instance_image.convert("RGB")
|
110 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
111 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
112 |
+
self.instance_prompt,
|
113 |
+
padding="do_not_pad",
|
114 |
+
truncation=True,
|
115 |
+
max_length=self.tokenizer.model_max_length,
|
116 |
+
).input_ids
|
117 |
+
|
118 |
+
if self.class_data_root:
|
119 |
+
class_image = Image.open(
|
120 |
+
self.class_images_path[index % self.num_class_images]
|
121 |
+
)
|
122 |
+
if not class_image.mode == "RGB":
|
123 |
+
class_image = class_image.convert("RGB")
|
124 |
+
example["class_images"] = self.image_transforms(class_image)
|
125 |
+
example["class_prompt_ids"] = self.tokenizer(
|
126 |
+
self.class_prompt,
|
127 |
+
padding="do_not_pad",
|
128 |
+
truncation=True,
|
129 |
+
max_length=self.tokenizer.model_max_length,
|
130 |
+
).input_ids
|
131 |
+
|
132 |
+
return example
|
133 |
+
|
134 |
+
|
135 |
+
class DreamBoothLabled(Dataset):
|
136 |
+
"""
|
137 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
138 |
+
It pre-processes the images and the tokenizes prompts.
|
139 |
+
"""
|
140 |
+
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
instance_data_root,
|
144 |
+
instance_prompt,
|
145 |
+
tokenizer,
|
146 |
+
class_data_root=None,
|
147 |
+
class_prompt=None,
|
148 |
+
size=512,
|
149 |
+
center_crop=False,
|
150 |
+
):
|
151 |
+
self.size = size
|
152 |
+
self.center_crop = center_crop
|
153 |
+
self.tokenizer = tokenizer
|
154 |
+
|
155 |
+
self.instance_data_root = Path(instance_data_root)
|
156 |
+
if not self.instance_data_root.exists():
|
157 |
+
raise ValueError("Instance images root doesn't exists.")
|
158 |
+
|
159 |
+
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
160 |
+
self.num_instance_images = len(self.instance_images_path)
|
161 |
+
self.instance_prompt = instance_prompt
|
162 |
+
self._length = self.num_instance_images
|
163 |
+
|
164 |
+
if class_data_root is not None:
|
165 |
+
self.class_data_root = Path(class_data_root)
|
166 |
+
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
167 |
+
self.class_images_path = list(self.class_data_root.iterdir())
|
168 |
+
self.num_class_images = len(self.class_images_path)
|
169 |
+
self._length = max(self.num_class_images, self.num_instance_images)
|
170 |
+
self.class_prompt = class_prompt
|
171 |
+
else:
|
172 |
+
self.class_data_root = None
|
173 |
+
|
174 |
+
self.image_transforms = transforms.Compose(
|
175 |
+
[
|
176 |
+
transforms.Resize(
|
177 |
+
size, interpolation=transforms.InterpolationMode.BILINEAR
|
178 |
+
),
|
179 |
+
transforms.CenterCrop(size)
|
180 |
+
if center_crop
|
181 |
+
else transforms.RandomCrop(size),
|
182 |
+
transforms.ToTensor(),
|
183 |
+
transforms.Normalize([0.5], [0.5]),
|
184 |
+
]
|
185 |
+
)
|
186 |
+
|
187 |
+
def __len__(self):
|
188 |
+
return self._length
|
189 |
+
|
190 |
+
def __getitem__(self, index):
|
191 |
+
example = {}
|
192 |
+
instance_image = Image.open(
|
193 |
+
self.instance_images_path[index % self.num_instance_images]
|
194 |
+
)
|
195 |
+
|
196 |
+
instance_prompt = (
|
197 |
+
str(self.instance_images_path[index % self.num_instance_images])
|
198 |
+
.split("/")[-1]
|
199 |
+
.split(".")[0]
|
200 |
+
.replace("-", " ")
|
201 |
+
)
|
202 |
+
# remove numbers in prompt
|
203 |
+
instance_prompt = re.sub(r"\d+", "", instance_prompt)
|
204 |
+
# print(instance_prompt)
|
205 |
+
|
206 |
+
_svg = random.choice(["svg", "flat color", "vector illustration", "sks"])
|
207 |
+
instance_prompt = f"{instance_prompt}, style of {_svg}"
|
208 |
+
|
209 |
+
if not instance_image.mode == "RGB":
|
210 |
+
instance_image = instance_image.convert("RGB")
|
211 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
212 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
213 |
+
instance_prompt,
|
214 |
+
padding="do_not_pad",
|
215 |
+
truncation=True,
|
216 |
+
max_length=self.tokenizer.model_max_length,
|
217 |
+
).input_ids
|
218 |
+
|
219 |
+
if self.class_data_root:
|
220 |
+
class_image = Image.open(
|
221 |
+
self.class_images_path[index % self.num_class_images]
|
222 |
+
)
|
223 |
+
if not class_image.mode == "RGB":
|
224 |
+
class_image = class_image.convert("RGB")
|
225 |
+
example["class_images"] = self.image_transforms(class_image)
|
226 |
+
example["class_prompt_ids"] = self.tokenizer(
|
227 |
+
self.class_prompt,
|
228 |
+
padding="do_not_pad",
|
229 |
+
truncation=True,
|
230 |
+
max_length=self.tokenizer.model_max_length,
|
231 |
+
).input_ids
|
232 |
+
|
233 |
+
return example
|
234 |
+
|
235 |
+
|
236 |
+
class PromptDataset(Dataset):
|
237 |
+
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
238 |
+
|
239 |
+
def __init__(self, prompt, num_samples):
|
240 |
+
self.prompt = prompt
|
241 |
+
self.num_samples = num_samples
|
242 |
+
|
243 |
+
def __len__(self):
|
244 |
+
return self.num_samples
|
245 |
+
|
246 |
+
def __getitem__(self, index):
|
247 |
+
example = {}
|
248 |
+
example["prompt"] = self.prompt
|
249 |
+
example["index"] = index
|
250 |
+
return example
|
251 |
+
|
252 |
+
|
253 |
+
logger = get_logger(__name__)
|
254 |
+
|
255 |
+
|
256 |
+
def parse_args(input_args=None):
|
257 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
258 |
+
parser.add_argument(
|
259 |
+
"--pretrained_model_name_or_path",
|
260 |
+
type=str,
|
261 |
+
default=None,
|
262 |
+
required=True,
|
263 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
264 |
+
)
|
265 |
+
parser.add_argument(
|
266 |
+
"--revision",
|
267 |
+
type=str,
|
268 |
+
default=None,
|
269 |
+
required=False,
|
270 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
271 |
+
)
|
272 |
+
parser.add_argument(
|
273 |
+
"--tokenizer_name",
|
274 |
+
type=str,
|
275 |
+
default=None,
|
276 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
277 |
+
)
|
278 |
+
parser.add_argument(
|
279 |
+
"--instance_data_dir",
|
280 |
+
type=str,
|
281 |
+
default=None,
|
282 |
+
required=True,
|
283 |
+
help="A folder containing the training data of instance images.",
|
284 |
+
)
|
285 |
+
parser.add_argument(
|
286 |
+
"--class_data_dir",
|
287 |
+
type=str,
|
288 |
+
default=None,
|
289 |
+
required=False,
|
290 |
+
help="A folder containing the training data of class images.",
|
291 |
+
)
|
292 |
+
parser.add_argument(
|
293 |
+
"--instance_prompt",
|
294 |
+
type=str,
|
295 |
+
default=None,
|
296 |
+
required=True,
|
297 |
+
help="The prompt with identifier specifying the instance",
|
298 |
+
)
|
299 |
+
parser.add_argument(
|
300 |
+
"--class_prompt",
|
301 |
+
type=str,
|
302 |
+
default=None,
|
303 |
+
help="The prompt to specify images in the same class as provided instance images.",
|
304 |
+
)
|
305 |
+
parser.add_argument(
|
306 |
+
"--with_prior_preservation",
|
307 |
+
default=False,
|
308 |
+
action="store_true",
|
309 |
+
help="Flag to add prior preservation loss.",
|
310 |
+
)
|
311 |
+
parser.add_argument(
|
312 |
+
"--prior_loss_weight",
|
313 |
+
type=float,
|
314 |
+
default=1.0,
|
315 |
+
help="The weight of prior preservation loss.",
|
316 |
+
)
|
317 |
+
parser.add_argument(
|
318 |
+
"--num_class_images",
|
319 |
+
type=int,
|
320 |
+
default=100,
|
321 |
+
help=(
|
322 |
+
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
|
323 |
+
" sampled with class_prompt."
|
324 |
+
),
|
325 |
+
)
|
326 |
+
parser.add_argument(
|
327 |
+
"--output_dir",
|
328 |
+
type=str,
|
329 |
+
default="text-inversion-model",
|
330 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
331 |
+
)
|
332 |
+
parser.add_argument(
|
333 |
+
"--seed", type=int, default=None, help="A seed for reproducible training."
|
334 |
+
)
|
335 |
+
parser.add_argument(
|
336 |
+
"--resolution",
|
337 |
+
type=int,
|
338 |
+
default=512,
|
339 |
+
help=(
|
340 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
341 |
+
" resolution"
|
342 |
+
),
|
343 |
+
)
|
344 |
+
parser.add_argument(
|
345 |
+
"--center_crop",
|
346 |
+
action="store_true",
|
347 |
+
help="Whether to center crop images before resizing to resolution",
|
348 |
+
)
|
349 |
+
parser.add_argument(
|
350 |
+
"--train_text_encoder",
|
351 |
+
action="store_true",
|
352 |
+
help="Whether to train the text encoder",
|
353 |
+
)
|
354 |
+
parser.add_argument(
|
355 |
+
"--train_batch_size",
|
356 |
+
type=int,
|
357 |
+
default=4,
|
358 |
+
help="Batch size (per device) for the training dataloader.",
|
359 |
+
)
|
360 |
+
parser.add_argument(
|
361 |
+
"--sample_batch_size",
|
362 |
+
type=int,
|
363 |
+
default=4,
|
364 |
+
help="Batch size (per device) for sampling images.",
|
365 |
+
)
|
366 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
367 |
+
parser.add_argument(
|
368 |
+
"--max_train_steps",
|
369 |
+
type=int,
|
370 |
+
default=None,
|
371 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
372 |
+
)
|
373 |
+
parser.add_argument(
|
374 |
+
"--save_steps",
|
375 |
+
type=int,
|
376 |
+
default=500,
|
377 |
+
help="Save checkpoint every X updates steps.",
|
378 |
+
)
|
379 |
+
parser.add_argument(
|
380 |
+
"--gradient_accumulation_steps",
|
381 |
+
type=int,
|
382 |
+
default=1,
|
383 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
384 |
+
)
|
385 |
+
parser.add_argument(
|
386 |
+
"--gradient_checkpointing",
|
387 |
+
action="store_true",
|
388 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
389 |
+
)
|
390 |
+
parser.add_argument(
|
391 |
+
"--learning_rate",
|
392 |
+
type=float,
|
393 |
+
default=5e-6,
|
394 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
395 |
+
)
|
396 |
+
parser.add_argument(
|
397 |
+
"--scale_lr",
|
398 |
+
action="store_true",
|
399 |
+
default=False,
|
400 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
401 |
+
)
|
402 |
+
parser.add_argument(
|
403 |
+
"--lr_scheduler",
|
404 |
+
type=str,
|
405 |
+
default="constant",
|
406 |
+
help=(
|
407 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
408 |
+
' "constant", "constant_with_warmup"]'
|
409 |
+
),
|
410 |
+
)
|
411 |
+
parser.add_argument(
|
412 |
+
"--lr_warmup_steps",
|
413 |
+
type=int,
|
414 |
+
default=500,
|
415 |
+
help="Number of steps for the warmup in the lr scheduler.",
|
416 |
+
)
|
417 |
+
parser.add_argument(
|
418 |
+
"--use_8bit_adam",
|
419 |
+
action="store_true",
|
420 |
+
help="Whether or not to use 8-bit Adam from bitsandbytes.",
|
421 |
+
)
|
422 |
+
parser.add_argument(
|
423 |
+
"--adam_beta1",
|
424 |
+
type=float,
|
425 |
+
default=0.9,
|
426 |
+
help="The beta1 parameter for the Adam optimizer.",
|
427 |
+
)
|
428 |
+
parser.add_argument(
|
429 |
+
"--adam_beta2",
|
430 |
+
type=float,
|
431 |
+
default=0.999,
|
432 |
+
help="The beta2 parameter for the Adam optimizer.",
|
433 |
+
)
|
434 |
+
parser.add_argument(
|
435 |
+
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
|
436 |
+
)
|
437 |
+
parser.add_argument(
|
438 |
+
"--adam_epsilon",
|
439 |
+
type=float,
|
440 |
+
default=1e-08,
|
441 |
+
help="Epsilon value for the Adam optimizer",
|
442 |
+
)
|
443 |
+
parser.add_argument(
|
444 |
+
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
|
445 |
+
)
|
446 |
+
parser.add_argument(
|
447 |
+
"--push_to_hub",
|
448 |
+
action="store_true",
|
449 |
+
help="Whether or not to push the model to the Hub.",
|
450 |
+
)
|
451 |
+
parser.add_argument(
|
452 |
+
"--hub_token",
|
453 |
+
type=str,
|
454 |
+
default=None,
|
455 |
+
help="The token to use to push to the Model Hub.",
|
456 |
+
)
|
457 |
+
parser.add_argument(
|
458 |
+
"--hub_model_id",
|
459 |
+
type=str,
|
460 |
+
default=None,
|
461 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
462 |
+
)
|
463 |
+
parser.add_argument(
|
464 |
+
"--logging_dir",
|
465 |
+
type=str,
|
466 |
+
default="logs",
|
467 |
+
help=(
|
468 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
469 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
470 |
+
),
|
471 |
+
)
|
472 |
+
parser.add_argument(
|
473 |
+
"--mixed_precision",
|
474 |
+
type=str,
|
475 |
+
default=None,
|
476 |
+
choices=["no", "fp16", "bf16"],
|
477 |
+
help=(
|
478 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
479 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
480 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
481 |
+
),
|
482 |
+
)
|
483 |
+
parser.add_argument(
|
484 |
+
"--local_rank",
|
485 |
+
type=int,
|
486 |
+
default=-1,
|
487 |
+
help="For distributed training: local_rank",
|
488 |
+
)
|
489 |
+
|
490 |
+
if input_args is not None:
|
491 |
+
args = parser.parse_args(input_args)
|
492 |
+
else:
|
493 |
+
args = parser.parse_args()
|
494 |
+
|
495 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
496 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
497 |
+
args.local_rank = env_local_rank
|
498 |
+
|
499 |
+
if args.with_prior_preservation:
|
500 |
+
if args.class_data_dir is None:
|
501 |
+
raise ValueError("You must specify a data directory for class images.")
|
502 |
+
if args.class_prompt is None:
|
503 |
+
raise ValueError("You must specify prompt for class images.")
|
504 |
+
else:
|
505 |
+
if args.class_data_dir is not None:
|
506 |
+
logger.warning(
|
507 |
+
"You need not use --class_data_dir without --with_prior_preservation."
|
508 |
+
)
|
509 |
+
if args.class_prompt is not None:
|
510 |
+
logger.warning(
|
511 |
+
"You need not use --class_prompt without --with_prior_preservation."
|
512 |
+
)
|
513 |
+
|
514 |
+
return args
|
515 |
+
|
516 |
+
|
517 |
+
def get_full_repo_name(
|
518 |
+
model_id: str, organization: Optional[str] = None, token: Optional[str] = None
|
519 |
+
):
|
520 |
+
if token is None:
|
521 |
+
token = HfFolder.get_token()
|
522 |
+
if organization is None:
|
523 |
+
username = whoami(token)["name"]
|
524 |
+
return f"{username}/{model_id}"
|
525 |
+
else:
|
526 |
+
return f"{organization}/{model_id}"
|
527 |
+
|
528 |
+
|
529 |
+
def main(args):
|
530 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
531 |
+
|
532 |
+
accelerator = Accelerator(
|
533 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
534 |
+
mixed_precision=args.mixed_precision,
|
535 |
+
log_with="tensorboard",
|
536 |
+
logging_dir=logging_dir,
|
537 |
+
)
|
538 |
+
|
539 |
+
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
540 |
+
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
541 |
+
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
542 |
+
if (
|
543 |
+
args.train_text_encoder
|
544 |
+
and args.gradient_accumulation_steps > 1
|
545 |
+
and accelerator.num_processes > 1
|
546 |
+
):
|
547 |
+
raise ValueError(
|
548 |
+
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
549 |
+
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
550 |
+
)
|
551 |
+
|
552 |
+
if args.seed is not None:
|
553 |
+
set_seed(args.seed)
|
554 |
+
|
555 |
+
if args.with_prior_preservation:
|
556 |
+
class_images_dir = Path(args.class_data_dir)
|
557 |
+
if not class_images_dir.exists():
|
558 |
+
class_images_dir.mkdir(parents=True)
|
559 |
+
cur_class_images = len(list(class_images_dir.iterdir()))
|
560 |
+
|
561 |
+
if cur_class_images < args.num_class_images:
|
562 |
+
torch_dtype = (
|
563 |
+
torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
564 |
+
)
|
565 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
566 |
+
args.pretrained_model_name_or_path,
|
567 |
+
torch_dtype=torch_dtype,
|
568 |
+
safety_checker=None,
|
569 |
+
revision=args.revision,
|
570 |
+
)
|
571 |
+
pipeline.set_progress_bar_config(disable=True)
|
572 |
+
|
573 |
+
num_new_images = args.num_class_images - cur_class_images
|
574 |
+
logger.info(f"Number of class images to sample: {num_new_images}.")
|
575 |
+
|
576 |
+
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
577 |
+
sample_dataloader = torch.utils.data.DataLoader(
|
578 |
+
sample_dataset, batch_size=args.sample_batch_size
|
579 |
+
)
|
580 |
+
|
581 |
+
sample_dataloader = accelerator.prepare(sample_dataloader)
|
582 |
+
pipeline.to(accelerator.device)
|
583 |
+
|
584 |
+
for example in tqdm(
|
585 |
+
sample_dataloader,
|
586 |
+
desc="Generating class images",
|
587 |
+
disable=not accelerator.is_local_main_process,
|
588 |
+
):
|
589 |
+
images = pipeline(example["prompt"]).images
|
590 |
+
|
591 |
+
for i, image in enumerate(images):
|
592 |
+
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
593 |
+
image_filename = (
|
594 |
+
class_images_dir
|
595 |
+
/ f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
596 |
+
)
|
597 |
+
image.save(image_filename)
|
598 |
+
|
599 |
+
del pipeline
|
600 |
+
if torch.cuda.is_available():
|
601 |
+
torch.cuda.empty_cache()
|
602 |
+
|
603 |
+
# Handle the repository creation
|
604 |
+
if accelerator.is_main_process:
|
605 |
+
if args.push_to_hub:
|
606 |
+
if args.hub_model_id is None:
|
607 |
+
repo_name = get_full_repo_name(
|
608 |
+
Path(args.output_dir).name, token=args.hub_token
|
609 |
+
)
|
610 |
+
else:
|
611 |
+
repo_name = args.hub_model_id
|
612 |
+
repo = Repository(args.output_dir, clone_from=repo_name)
|
613 |
+
|
614 |
+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
615 |
+
if "step_*" not in gitignore:
|
616 |
+
gitignore.write("step_*\n")
|
617 |
+
if "epoch_*" not in gitignore:
|
618 |
+
gitignore.write("epoch_*\n")
|
619 |
+
elif args.output_dir is not None:
|
620 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
621 |
+
|
622 |
+
# Load the tokenizer
|
623 |
+
if args.tokenizer_name:
|
624 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
625 |
+
args.tokenizer_name,
|
626 |
+
revision=args.revision,
|
627 |
+
)
|
628 |
+
elif args.pretrained_model_name_or_path:
|
629 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
630 |
+
args.pretrained_model_name_or_path,
|
631 |
+
subfolder="tokenizer",
|
632 |
+
revision=args.revision,
|
633 |
+
)
|
634 |
+
|
635 |
+
# Load models and create wrapper for stable diffusion
|
636 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
637 |
+
args.pretrained_model_name_or_path,
|
638 |
+
subfolder="text_encoder",
|
639 |
+
revision=args.revision,
|
640 |
+
)
|
641 |
+
vae = AutoencoderKL.from_pretrained(
|
642 |
+
args.pretrained_model_name_or_path,
|
643 |
+
subfolder="vae",
|
644 |
+
revision=args.revision,
|
645 |
+
)
|
646 |
+
unet = UNet2DConditionModel.from_pretrained(
|
647 |
+
args.pretrained_model_name_or_path,
|
648 |
+
subfolder="unet",
|
649 |
+
revision=args.revision,
|
650 |
+
)
|
651 |
+
unet.requires_grad_(False)
|
652 |
+
unet_lora_params, train_names = inject_trainable_lora(unet)
|
653 |
+
|
654 |
+
for _up, _down in extract_lora_ups_down(unet):
|
655 |
+
print(_up.weight)
|
656 |
+
print(_down.weight)
|
657 |
+
break
|
658 |
+
|
659 |
+
vae.requires_grad_(False)
|
660 |
+
if not args.train_text_encoder:
|
661 |
+
text_encoder.requires_grad_(False)
|
662 |
+
|
663 |
+
if args.gradient_checkpointing:
|
664 |
+
unet.enable_gradient_checkpointing()
|
665 |
+
if args.train_text_encoder:
|
666 |
+
text_encoder.gradient_checkpointing_enable()
|
667 |
+
|
668 |
+
if args.scale_lr:
|
669 |
+
args.learning_rate = (
|
670 |
+
args.learning_rate
|
671 |
+
* args.gradient_accumulation_steps
|
672 |
+
* args.train_batch_size
|
673 |
+
* accelerator.num_processes
|
674 |
+
)
|
675 |
+
|
676 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
677 |
+
if args.use_8bit_adam:
|
678 |
+
try:
|
679 |
+
import bitsandbytes as bnb
|
680 |
+
except ImportError:
|
681 |
+
raise ImportError(
|
682 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
683 |
+
)
|
684 |
+
|
685 |
+
optimizer_class = bnb.optim.AdamW8bit
|
686 |
+
else:
|
687 |
+
optimizer_class = torch.optim.AdamW
|
688 |
+
|
689 |
+
params_to_optimize = (
|
690 |
+
itertools.chain(*unet_lora_params, text_encoder.parameters())
|
691 |
+
if args.train_text_encoder
|
692 |
+
else itertools.chain(*unet_lora_params)
|
693 |
+
)
|
694 |
+
optimizer = optimizer_class(
|
695 |
+
params_to_optimize,
|
696 |
+
lr=args.learning_rate,
|
697 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
698 |
+
weight_decay=args.adam_weight_decay,
|
699 |
+
eps=args.adam_epsilon,
|
700 |
+
)
|
701 |
+
|
702 |
+
noise_scheduler = DDPMScheduler.from_config(
|
703 |
+
args.pretrained_model_name_or_path, subfolder="scheduler"
|
704 |
+
)
|
705 |
+
|
706 |
+
train_dataset = DreamBoothDataset(
|
707 |
+
instance_data_root=args.instance_data_dir,
|
708 |
+
instance_prompt=args.instance_prompt,
|
709 |
+
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
710 |
+
class_prompt=args.class_prompt,
|
711 |
+
tokenizer=tokenizer,
|
712 |
+
size=args.resolution,
|
713 |
+
center_crop=args.center_crop,
|
714 |
+
)
|
715 |
+
|
716 |
+
def collate_fn(examples):
|
717 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
718 |
+
pixel_values = [example["instance_images"] for example in examples]
|
719 |
+
|
720 |
+
# Concat class and instance examples for prior preservation.
|
721 |
+
# We do this to avoid doing two forward passes.
|
722 |
+
if args.with_prior_preservation:
|
723 |
+
input_ids += [example["class_prompt_ids"] for example in examples]
|
724 |
+
pixel_values += [example["class_images"] for example in examples]
|
725 |
+
|
726 |
+
pixel_values = torch.stack(pixel_values)
|
727 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
728 |
+
|
729 |
+
input_ids = tokenizer.pad(
|
730 |
+
{"input_ids": input_ids},
|
731 |
+
padding="max_length",
|
732 |
+
max_length=tokenizer.model_max_length,
|
733 |
+
return_tensors="pt",
|
734 |
+
).input_ids
|
735 |
+
|
736 |
+
batch = {
|
737 |
+
"input_ids": input_ids,
|
738 |
+
"pixel_values": pixel_values,
|
739 |
+
}
|
740 |
+
return batch
|
741 |
+
|
742 |
+
train_dataloader = torch.utils.data.DataLoader(
|
743 |
+
train_dataset,
|
744 |
+
batch_size=args.train_batch_size,
|
745 |
+
shuffle=True,
|
746 |
+
collate_fn=collate_fn,
|
747 |
+
num_workers=1,
|
748 |
+
)
|
749 |
+
|
750 |
+
# Scheduler and math around the number of training steps.
|
751 |
+
overrode_max_train_steps = False
|
752 |
+
num_update_steps_per_epoch = math.ceil(
|
753 |
+
len(train_dataloader) / args.gradient_accumulation_steps
|
754 |
+
)
|
755 |
+
if args.max_train_steps is None:
|
756 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
757 |
+
overrode_max_train_steps = True
|
758 |
+
|
759 |
+
lr_scheduler = get_scheduler(
|
760 |
+
args.lr_scheduler,
|
761 |
+
optimizer=optimizer,
|
762 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
763 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
764 |
+
)
|
765 |
+
|
766 |
+
if args.train_text_encoder:
|
767 |
+
(
|
768 |
+
unet,
|
769 |
+
text_encoder,
|
770 |
+
optimizer,
|
771 |
+
train_dataloader,
|
772 |
+
lr_scheduler,
|
773 |
+
) = accelerator.prepare(
|
774 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
775 |
+
)
|
776 |
+
else:
|
777 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
778 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
779 |
+
)
|
780 |
+
|
781 |
+
weight_dtype = torch.float32
|
782 |
+
if accelerator.mixed_precision == "fp16":
|
783 |
+
weight_dtype = torch.float16
|
784 |
+
elif accelerator.mixed_precision == "bf16":
|
785 |
+
weight_dtype = torch.bfloat16
|
786 |
+
|
787 |
+
# Move text_encode and vae to gpu.
|
788 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
789 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
790 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
791 |
+
if not args.train_text_encoder:
|
792 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
793 |
+
|
794 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
795 |
+
num_update_steps_per_epoch = math.ceil(
|
796 |
+
len(train_dataloader) / args.gradient_accumulation_steps
|
797 |
+
)
|
798 |
+
if overrode_max_train_steps:
|
799 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
800 |
+
# Afterwards we recalculate our number of training epochs
|
801 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
802 |
+
|
803 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
804 |
+
# The trackers initializes automatically on the main process.
|
805 |
+
if accelerator.is_main_process:
|
806 |
+
accelerator.init_trackers("dreambooth", config=vars(args))
|
807 |
+
|
808 |
+
# Train!
|
809 |
+
total_batch_size = (
|
810 |
+
args.train_batch_size
|
811 |
+
* accelerator.num_processes
|
812 |
+
* args.gradient_accumulation_steps
|
813 |
+
)
|
814 |
+
|
815 |
+
print("***** Running training *****")
|
816 |
+
print(f" Num examples = {len(train_dataset)}")
|
817 |
+
print(f" Num batches each epoch = {len(train_dataloader)}")
|
818 |
+
print(f" Num Epochs = {args.num_train_epochs}")
|
819 |
+
print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
820 |
+
print(
|
821 |
+
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
822 |
+
)
|
823 |
+
print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
824 |
+
print(f" Total optimization steps = {args.max_train_steps}")
|
825 |
+
# Only show the progress bar once on each machine.
|
826 |
+
progress_bar = tqdm(
|
827 |
+
range(args.max_train_steps), disable=not accelerator.is_local_main_process
|
828 |
+
)
|
829 |
+
progress_bar.set_description("Steps")
|
830 |
+
global_step = 0
|
831 |
+
|
832 |
+
for epoch in range(args.num_train_epochs):
|
833 |
+
unet.train()
|
834 |
+
if args.train_text_encoder:
|
835 |
+
text_encoder.train()
|
836 |
+
for step, batch in enumerate(train_dataloader):
|
837 |
+
|
838 |
+
# Convert images to latent space
|
839 |
+
latents = vae.encode(
|
840 |
+
batch["pixel_values"].to(dtype=weight_dtype)
|
841 |
+
).latent_dist.sample()
|
842 |
+
latents = latents * 0.18215
|
843 |
+
|
844 |
+
# Sample noise that we'll add to the latents
|
845 |
+
noise = torch.randn_like(latents)
|
846 |
+
bsz = latents.shape[0]
|
847 |
+
# Sample a random timestep for each image
|
848 |
+
timesteps = torch.randint(
|
849 |
+
0,
|
850 |
+
noise_scheduler.config.num_train_timesteps,
|
851 |
+
(bsz,),
|
852 |
+
device=latents.device,
|
853 |
+
)
|
854 |
+
timesteps = timesteps.long()
|
855 |
+
|
856 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
857 |
+
# (this is the forward diffusion process)
|
858 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
859 |
+
|
860 |
+
# Get the text embedding for conditioning
|
861 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
862 |
+
|
863 |
+
# Predict the noise residual
|
864 |
+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
865 |
+
|
866 |
+
# Get the target for loss depending on the prediction type
|
867 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
868 |
+
target = noise
|
869 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
870 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
871 |
+
else:
|
872 |
+
raise ValueError(
|
873 |
+
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
874 |
+
)
|
875 |
+
|
876 |
+
if args.with_prior_preservation:
|
877 |
+
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
878 |
+
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
879 |
+
target, target_prior = torch.chunk(target, 2, dim=0)
|
880 |
+
|
881 |
+
# Compute instance loss
|
882 |
+
loss = (
|
883 |
+
F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
884 |
+
.mean([1, 2, 3])
|
885 |
+
.mean()
|
886 |
+
)
|
887 |
+
|
888 |
+
# Compute prior loss
|
889 |
+
prior_loss = F.mse_loss(
|
890 |
+
model_pred_prior.float(), target_prior.float(), reduction="mean"
|
891 |
+
)
|
892 |
+
|
893 |
+
# Add the prior loss to the instance loss.
|
894 |
+
loss = loss + args.prior_loss_weight * prior_loss
|
895 |
+
else:
|
896 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
897 |
+
|
898 |
+
accelerator.backward(loss)
|
899 |
+
if accelerator.sync_gradients:
|
900 |
+
params_to_clip = (
|
901 |
+
itertools.chain(unet.parameters(), text_encoder.parameters())
|
902 |
+
if args.train_text_encoder
|
903 |
+
else unet.parameters()
|
904 |
+
)
|
905 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
906 |
+
optimizer.step()
|
907 |
+
lr_scheduler.step()
|
908 |
+
progress_bar.update(1)
|
909 |
+
optimizer.zero_grad()
|
910 |
+
|
911 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
912 |
+
if accelerator.sync_gradients:
|
913 |
+
|
914 |
+
global_step += 1
|
915 |
+
|
916 |
+
if global_step % args.save_steps == 0:
|
917 |
+
if accelerator.is_main_process:
|
918 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
919 |
+
args.pretrained_model_name_or_path,
|
920 |
+
unet=accelerator.unwrap_model(unet),
|
921 |
+
text_encoder=accelerator.unwrap_model(text_encoder),
|
922 |
+
revision=args.revision,
|
923 |
+
)
|
924 |
+
|
925 |
+
save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
|
926 |
+
|
927 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
928 |
+
progress_bar.set_postfix(**logs)
|
929 |
+
accelerator.log(logs, step=global_step)
|
930 |
+
|
931 |
+
if global_step >= args.max_train_steps:
|
932 |
+
break
|
933 |
+
|
934 |
+
accelerator.wait_for_everyone()
|
935 |
+
|
936 |
+
# Create the pipeline using using the trained modules and save it.
|
937 |
+
if accelerator.is_main_process:
|
938 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
939 |
+
args.pretrained_model_name_or_path,
|
940 |
+
unet=accelerator.unwrap_model(unet),
|
941 |
+
text_encoder=accelerator.unwrap_model(text_encoder),
|
942 |
+
revision=args.revision,
|
943 |
+
)
|
944 |
+
|
945 |
+
print("\n\nLora TRAINING DONE!\n\n")
|
946 |
+
|
947 |
+
save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
|
948 |
+
|
949 |
+
for _up, _down in extract_lora_ups_down(pipeline.unet):
|
950 |
+
print("First Layer's Up Weight is now : ", _up.weight)
|
951 |
+
print("First Layer's Down Weight is now : ", _down.weight)
|
952 |
+
break
|
953 |
+
|
954 |
+
if args.push_to_hub:
|
955 |
+
repo.push_to_hub(
|
956 |
+
commit_message="End of training", blocking=False, auto_lfs_prune=True
|
957 |
+
)
|
958 |
+
|
959 |
+
accelerator.end_training()
|
960 |
+
|
961 |
+
|
962 |
+
if __name__ == "__main__":
|
963 |
+
args = parse_args()
|
964 |
+
main(args)
|