shadow ysharma HF Staff commited on
Commit
ac93824
·
0 Parent(s):

Duplicate from lora-library/Low-rank-Adaptation

Browse files

Co-authored-by: yuvraj sharma <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ contents/alpha_scale.gif filter=lfs diff=lfs merge=lfs -text
36
+ contents/alpha_scale.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ data_*
2
+ output_*
3
+ __pycache__
4
+ *.pyc
5
+ __test*
6
+ merged_lora*
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LORA-Low-rank-Adaptation
3
+ emoji: 📚
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.12
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: lora-library/Low-rank-Adaptation
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
README1.md ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Low-rank Adaptation for Fast Text-to-Image Diffusion Fine-tuning
2
+
3
+ <!-- #region -->
4
+ <p align="center">
5
+ <img src="contents/alpha_scale.gif">
6
+ </p>
7
+ <!-- #endregion -->
8
+
9
+ > Using LORA to fine tune on illustration dataset : $W = W_0 + \alpha \Delta W$, where $\alpha$ is the merging ratio. Above gif is scaling alpha from 0 to 1. Setting alpha to 0 is same as using the original model, and setting alpha to 1 is same as using the fully fine-tuned model.
10
+
11
+ <!-- #region -->
12
+ <p align="center">
13
+ <img src="contents/disney_lora.jpg">
14
+ </p>
15
+ <!-- #endregion -->
16
+
17
+ > "style of sks, baby lion", with disney-style LORA model.
18
+
19
+ <!-- #region -->
20
+ <p align="center">
21
+ <img src="contents/pop_art.jpg">
22
+ </p>
23
+ <!-- #endregion -->
24
+
25
+ > "style of sks, superman", with pop-art style LORA model.
26
+
27
+ ## Main Features
28
+
29
+ - Fine-tune Stable diffusion models twice as faster than dreambooth method, by Low-rank Adaptation
30
+ - Get insanely small end result, easy to share and download.
31
+ - Easy to use, compatible with diffusers
32
+ - Sometimes even better performance than full fine-tuning (but left as future work for extensive comparisons)
33
+ - Merge checkpoints by merging LORA
34
+
35
+ # Lengthy Introduction
36
+
37
+ Thanks to the generous work of Stability AI and Huggingface, so many people have enjoyed fine-tuning stable diffusion models to fit their needs and generate higher fidelity images. **However, the fine-tuning process is very slow, and it is not easy to find a good balance between the number of steps and the quality of the results.**
38
+
39
+ Also, the final results (fully fined-tuned model) is very large. Some people instead works with textual-inversion as an alternative for this. But clearly this is suboptimal: textual inversion only creates a small word-embedding, and the final image is not as good as a fully fine-tuned model.
40
+
41
+ Well, what's the alternative? In the domain of LLM, researchers have developed Efficient fine-tuning methods. LORA, especially, tackles the very problem the community currently has: end users with Open-sourced stable-diffusion model want to try various other fine-tuned model that is created by the community, but the model is too large to download and use. LORA instead attempts to fine-tune the "residual" of the model instead of the entire model: i.e., train the $\Delta W$ instead of $W$.
42
+
43
+ $$
44
+ W' = W + \Delta W
45
+ $$
46
+
47
+ Where we can further decompose $\Delta W$ into low-rank matrices : $\Delta W = A B^T $, where $A, \in \mathbb{R}^{n \times d}, B \in \mathbb{R}^{m \times d}, d << n$.
48
+ This is the key idea of LORA. We can then fine-tune $A$ and $B$ instead of $W$. In the end, you get an insanely small model as $A$ and $B$ are much smaller than $W$.
49
+
50
+ Also, not all of the parameters need tuning: they found that often, $Q, K, V, O$ (i.e., attention layer) of the transformer model is enough to tune. (This is also the reason why the end result is so small). This repo will follow the same idea.
51
+
52
+ Enough of the lengthy introduction, let's get to the code.
53
+
54
+ # Installation
55
+
56
+ ```bash
57
+ pip install git+https://github.com/cloneofsimo/lora.git
58
+ ```
59
+
60
+ # Getting Started
61
+
62
+ ## Fine-tuning Stable diffusion with LORA.
63
+
64
+ Basic usage is as follows: prepare sets of $A, B$ matrices in an unet model, and fine-tune them.
65
+
66
+ ```python
67
+ from lora_diffusion import inject_trainable_lora, extract_lora_up_downs
68
+
69
+ ...
70
+
71
+ unet = UNet2DConditionModel.from_pretrained(
72
+ pretrained_model_name_or_path,
73
+ subfolder="unet",
74
+ )
75
+ unet.requires_grad_(False)
76
+ unet_lora_params, train_names = inject_trainable_lora(unet) # This will
77
+ # turn off all of the gradients of unet, except for the trainable LORA params.
78
+ optimizer = optim.Adam(
79
+ itertools.chain(*unet_lora_params, text_encoder.parameters()), lr=1e-4
80
+ )
81
+ ```
82
+
83
+ An example of this can be found in `train_lora_dreambooth.py`. Run this example with
84
+
85
+ ```bash
86
+ run_lora_db.sh
87
+ ```
88
+
89
+ ## Loading, merging, and interpolating trained LORAs.
90
+
91
+ We've seen that people have been merging different checkpoints with different ratios, and this seems to be very useful to the community. LORA is extremely easy to merge.
92
+
93
+ By the nature of LORA, one can interpolate between different fine-tuned models by adding different $A, B$ matrices.
94
+
95
+ Currently, LORA cli has two options : merge unet with LORA, or merge LORA with LORA.
96
+
97
+ ### Merging unet with LORA
98
+
99
+ ```bash
100
+ $ lora_add --path_1 PATH_TO_DIFFUSER_FORMAT_MODEL --path_2 PATH_TO_LORA.PT --mode upl --alpha 1.0 --output_path OUTPUT_PATH
101
+ ```
102
+
103
+ `path_1` can be both local path or huggingface model name. When adding LORA to unet, alpha is the constant as below:
104
+
105
+ $$
106
+ W' = W + \alpha \Delta W
107
+ $$
108
+
109
+ So, set alpha to 1.0 to fully add LORA. If the LORA seems to have too much effect (i.e., overfitted), set alpha to lower value. If the LORA seems to have too little effect, set alpha to higher than 1.0. You can tune these values to your needs.
110
+
111
+ **Example**
112
+
113
+ ```bash
114
+ $ lora_add --path_1 stabilityai/stable-diffusion-2-base --path_2 lora_illust.pt --mode upl --alpha 1.0 --output_path merged_model
115
+ ```
116
+
117
+ ### Merging LORA with LORA
118
+
119
+ ```bash
120
+ $ lora_add --path_1 PATH_TO_LORA.PT --path_2 PATH_TO_LORA.PT --mode lpl --alpha 0.5 --output_path OUTPUT_PATH.PT
121
+ ```
122
+
123
+ alpha is the ratio of the first model to the second model. i.e.,
124
+
125
+ $$
126
+ \Delta W = (\alpha A_1 + (1 - \alpha) A_2) (B_1 + (1 - \alpha) B_2)^T
127
+ $$
128
+
129
+ Set alpha to 0.5 to get the average of the two models. Set alpha close to 1.0 to get more effect of the first model, and set alpha close to 0.0 to get more effect of the second model.
130
+
131
+ **Example**
132
+
133
+ ```bash
134
+ $ lora_add --path_1 lora_illust.pt --path_2 lora_pop.pt --alpha 0.3 --output_path lora_merged.pt
135
+ ```
136
+
137
+ ### Making Inference with trained LORA
138
+
139
+ Checkout `scripts/run_inference.ipynb` for an example of how to make inference with LORA.
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline
2
+ from lora_diffusion import monkeypatch_lora, tune_lora_scale
3
+ import torch
4
+ import os, shutil
5
+ import gradio as gr
6
+ import subprocess
7
+
8
+ MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
9
+ INSTANCE_DIR="./data_example"
10
+ OUTPUT_DIR="./output_example"
11
+
12
+ model_id = "stabilityai/stable-diffusion-2-1-base"
13
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
14
+ #prompt = "style of sks, baby lion"
15
+ torch.manual_seed(1)
16
+ #image = pipe(prompt, num_inference_steps=50, guidance_scale= 7).images[0] #no need
17
+ #image # nice. diffusers are cool. #no need
18
+ #finetuned_lora_weights = "./lora_weight.pt"
19
+
20
+ #global var
21
+ counter = 0
22
+
23
+ #Getting Lora fine-tuned weights
24
+ def monkeypatching(unet_alpha,texten_alpha, in_prompt, wts): #, prompt, pipe): finetuned_lora_weights
25
+ print("****** inside monkeypatching *******")
26
+ print(f"in_prompt is - {str(in_prompt)}")
27
+ global counter
28
+ #if model == 'Text-encoder':
29
+ unet_wt = wts[-2]
30
+ #else:
31
+ texten_wt = wts[-1]
32
+ print(f"UNET weight is = {unet_wt}, Text-encoder weight is = {texten_wt}")
33
+ if counter == 0 :
34
+ #if wt == "./lora_playgroundai_wt.pt" :
35
+ monkeypatch_lora(pipe.unet, torch.load(unet_wt)) #finetuned_lora_weights
36
+ monkeypatch_lora(pipe.text_encoder, torch.load(texten_wt), target_replace_module=["CLIPAttention"]) #text-encoder #"./lora/lora_kiriko.text_encoder.pt"
37
+ #tune_lora_scale(pipe.unet, alpha) #1.00)
38
+ tune_lora_scale(pipe.unet, unet_alpha)
39
+ tune_lora_scale(pipe.text_encoder, texten_alpha)
40
+ counter +=1
41
+ #else:
42
+ #monkeypatch_lora(pipe.unet, torch.load("./output_example/lora_weight.pt")) #finetuned_lora_weights
43
+ #tune_lora_scale(pipe.unet, alpha) #1.00)
44
+ #counter +=1
45
+ else :
46
+ tune_lora_scale(pipe.unet, unet_alpha)
47
+ tune_lora_scale(pipe.text_encoder, texten_alpha)
48
+ #tune_lora_scale(pipe.unet, alpha) #1.00)
49
+ prompt = str(in_prompt) #"style of hclu, " + str(in_prompt) #"baby lion"
50
+ image = pipe(prompt, num_inference_steps=50, guidance_scale=7).images[0]
51
+ image.save("./illust_lora.jpg") #"./contents/illust_lora.jpg")
52
+ return image
53
+
54
+ #{in_prompt} line68 --pl ignore
55
+ def accelerate_train_lora(steps, images, in_prompt):
56
+ print("*********** inside accelerate_train_lora ***********")
57
+ print(f"images are -- {images}")
58
+ # path can be retrieved by file_obj.name and original filename can be retrieved with file_obj.orig_name
59
+ for file in images:
60
+ print(f"file passed -- {file.name}")
61
+ os.makedirs(INSTANCE_DIR, exist_ok=True)
62
+ shutil.copy( file.name, INSTANCE_DIR) #/{file.orig_name}
63
+ #subprocess.Popen(f'accelerate launch {"./train_lora_dreambooth.py"} \
64
+ os.system( f"accelerate launch {'./train_lora_dreambooth.py'} --pretrained_model_name_or_path={MODEL_NAME} --instance_data_dir={INSTANCE_DIR} --output_dir={OUTPUT_DIR} --instance_prompt='{in_prompt}' --train_text_encoder --resolution=512 --train_batch_size=1 --gradient_accumulation_steps=1 --learning_rate='1e-4' --learning_rate_text='5e-5' --color_jitter --lr_scheduler='constant' --lr_warmup_steps=0 --max_train_steps={int(steps)}") #10000
65
+ print("*********** completing accelerate_train_lora ***********")
66
+ print(f"files in output_dir -- {os.listdir(OUTPUT_DIR)}")
67
+ #lora_trained_weights = "./output_example/lora_weight.pt"
68
+ files = os.listdir(OUTPUT_DIR)
69
+ file_list = []
70
+ for file in files: #os.listdir(OUTPUT_DIR):
71
+ if file.endswith(".pt"):
72
+ print("weight files are -- ",os.path.join(f"{OUTPUT_DIR}", file))
73
+ file_list.append(os.path.join(f"{OUTPUT_DIR}", file))
74
+ return file_list #files[1:]
75
+ #return f"{OUTPUT_DIR}/*.pt"
76
+
77
+ with gr.Blocks() as demo:
78
+ gr.Markdown("""<h1><center><b>LORA</b> (Low-rank Adaptation) for Faster Text-to-Image Diffusion Fine-tuning (UNET+CLIP)</center></h1>
79
+ """)
80
+ gr.HTML("<p>You can skip the queue by duplicating this space and upgrading to GPU in settings: <a style='display:inline-block' href='https://huggingface.co/spaces/ysharma/Low-rank-Adaptation?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14' alt='Duplicate Space'></a></p>")
81
+ #gr.Markdown("""<b>NEW!!</b> : I have fine-tuned the SD model for 15,000 steps using 100 PlaygroundAI images and LORA. You can load this trained model using the example component. Load the weight and start using the Space with the Inference button. Feel free to toggle the Alpha value.""")
82
+ gr.Markdown(
83
+ """**Main Features**<br>- Fine-tune Stable diffusion models twice as faster as Dreambooth method by Low-rank Adaptation.<br>- Get insanely small end results, easy to share and download.<br>- Easy to use, compatible with diffusers.<br>- Sometimes even better performance than full fine-tuning<br><br>Please refer to the GitHub repo this Space is based on, here - <a href = "https://github.com/cloneofsimo/lora">LORA</a>. You can also refer to this tweet by AK to quote/retweet/like here on <a href="https://twitter.com/_akhaliq/status/1601120767009513472">Twitter</a>.This Gradio Space is an attempt to explore this novel LORA approach to fine-tune Stable diffusion models, using the power and flexibility of Gradio! The higher number of steps results in longer training time and better fine-tuned SD models.<br><br><b>To use this Space well:</b><br>- First, upload your set of images (suggested number of images is between 4-9), enter the prompt, enter the number of fine-tuning steps (suggested value is between 2000-4000), and then press the 'Train LORA model' button. This will produce your fine-tuned model weights.<br>- Modify the previous prompt by adding to it (suffix), set the alpha value using the Slider (nearer to 1 implies overfitting to the uploaded images), and then press the 'Inference' button. This will produce an image by the newly fine-tuned UNET and Text-Encoder LORA models.<br><b>Bonus:</b>You can download your fine-tuned model weights from the Gradio file component. The smaller size of LORA models (around 3-4 MB files) is the main highlight of this 'Low-rank Adaptation' approach of fine-tuning.""")
84
+
85
+ with gr.Row():
86
+ in_images = gr.File(label="Upload images to fine-tune for LORA", file_count="multiple")
87
+ with gr.Column():
88
+ b1 = gr.Button(value="Train LORA model")
89
+ in_prompt = gr.Textbox(label="Enter a prompt for fine-tuned LORA model", visible=True)
90
+ b2 = gr.Button(value="Inference using LORA model")
91
+
92
+ with gr.Row():
93
+ out_image = gr.Image(label="Image generated by LORA model")
94
+ with gr.Column():
95
+ with gr.Accordion("Advance settings for Training and Inference", open=False):
96
+ gr.Markdown("Advance settings for a number of Training Steps and Alpha. Set alpha to 1.0 to fully add LORA. If the LORA seems to have too much effect (i.e., overfitting), set alpha to a lower value. If the LORA seems to have too little effect, set the alpha higher. You can tune these two values to your needs.")
97
+ in_steps = gr.Number(label="Enter the number of training steps", value = 2000)
98
+ in_alpha_unet = gr.Slider(0.1,1.0, step=0.01, label="Set UNET Alpha level", value=0.5)
99
+ in_alpha_texten = gr.Slider(0.1,1.0, step=0.01, label="Set Text-Encoder Alpha level", value=0.5)
100
+ #in_model = gr.Radio(["Text-encoder", "Unet"], label="Select the fine-tuned model for inference", value="Text-encoder", type="value")
101
+ out_file = gr.File(label="Lora trained model weights", file_count='multiple' )
102
+
103
+ #gr.Examples(
104
+ # examples=[[0.65, 0.6, "lion", ["./lora_playgroundai_wt.pt","./lora_playgroundai_wt.pt"], ],],
105
+ # inputs=[in_alpha_unet, in_alpha_texten, in_prompt, out_file ],
106
+ # outputs=out_image,
107
+ # fn=monkeypatching,
108
+ # cache_examples=True,)
109
+ #gr.Examples(
110
+ # examples=[[2500, ['./simba1.jpg', './simba2.jpg', './simba3.jpg', './simba4.jpg'], "baby lion in disney style"]],
111
+ # inputs=[in_steps, in_images, in_prompt],
112
+ # outputs=out_file,
113
+ # fn=accelerate_train_lora,
114
+ # cache_examples=False,
115
+ # run_on_click=False)
116
+
117
+ b1.click(fn = accelerate_train_lora, inputs=[in_steps, in_images, in_prompt] , outputs=out_file)
118
+ b2.click(fn = monkeypatching, inputs=[in_alpha_unet, in_alpha_texten, in_prompt, out_file,], outputs=out_image)
119
+
120
+ demo.queue(concurrency_count=3)
121
+ demo.launch(debug=True, show_error=True,)
contents/alpha_scale.gif ADDED

Git LFS Details

  • SHA256: 43e9966f27a2b9823956545970d3b2ed5b2f376a1dab5d653f21a977a919e164
  • Pointer size: 132 Bytes
  • Size of remote file: 5.23 MB
contents/alpha_scale.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ad74f5f69d99bfcbeee1d4d2b3900ac1ca7ff83fba5ddf8269ffed8a56c9c6e
3
+ size 5247140
contents/disney_lora.jpg ADDED
contents/pop_art.jpg ADDED
lora_diffusion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .lora import *
lora_diffusion/cli_lora_add.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Union, Dict
2
+ import os
3
+ import shutil
4
+ import fire
5
+ from diffusers import StableDiffusionPipeline
6
+
7
+ import torch
8
+ from .lora import tune_lora_scale, weight_apply_lora
9
+ from .to_ckpt_v2 import convert_to_ckpt
10
+
11
+
12
+ def _text_lora_path(path: str) -> str:
13
+ assert path.endswith(".pt"), "Only .pt files are supported"
14
+ return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
15
+
16
+
17
+ def add(
18
+ path_1: str,
19
+ path_2: str,
20
+ output_path: str,
21
+ alpha: float = 0.5,
22
+ mode: Literal[
23
+ "lpl",
24
+ "upl",
25
+ "upl-ckpt-v2",
26
+ ] = "lpl",
27
+ with_text_lora: bool = False,
28
+ ):
29
+ print("Lora Add, mode " + mode)
30
+ if mode == "lpl":
31
+ for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + (
32
+ [(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")]
33
+ if with_text_lora
34
+ else []
35
+ ):
36
+ print("Loading", _path_1, _path_2)
37
+ out_list = []
38
+ if opt == "text_encoder":
39
+ if not os.path.exists(_path_1):
40
+ print(f"No text encoder found in {_path_1}, skipping...")
41
+ continue
42
+ if not os.path.exists(_path_2):
43
+ print(f"No text encoder found in {_path_1}, skipping...")
44
+ continue
45
+
46
+ l1 = torch.load(_path_1)
47
+ l2 = torch.load(_path_2)
48
+
49
+ l1pairs = zip(l1[::2], l1[1::2])
50
+ l2pairs = zip(l2[::2], l2[1::2])
51
+
52
+ for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
53
+ # print("Merging", x1.shape, y1.shape, x2.shape, y2.shape)
54
+ x1.data = alpha * x1.data + (1 - alpha) * x2.data
55
+ y1.data = alpha * y1.data + (1 - alpha) * y2.data
56
+
57
+ out_list.append(x1)
58
+ out_list.append(y1)
59
+
60
+ if opt == "unet":
61
+
62
+ print("Saving merged UNET to", output_path)
63
+ torch.save(out_list, output_path)
64
+
65
+ elif opt == "text_encoder":
66
+ print("Saving merged text encoder to", _text_lora_path(output_path))
67
+ torch.save(
68
+ out_list,
69
+ _text_lora_path(output_path),
70
+ )
71
+
72
+ elif mode == "upl":
73
+
74
+ loaded_pipeline = StableDiffusionPipeline.from_pretrained(
75
+ path_1,
76
+ ).to("cpu")
77
+
78
+ weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
79
+ if with_text_lora:
80
+
81
+ weight_apply_lora(
82
+ loaded_pipeline.text_encoder,
83
+ torch.load(_text_lora_path(path_2)),
84
+ alpha=alpha,
85
+ target_replace_module=["CLIPAttention"],
86
+ )
87
+
88
+ loaded_pipeline.save_pretrained(output_path)
89
+
90
+ elif mode == "upl-ckpt-v2":
91
+
92
+ loaded_pipeline = StableDiffusionPipeline.from_pretrained(
93
+ path_1,
94
+ ).to("cpu")
95
+
96
+ weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
97
+ if with_text_lora:
98
+ weight_apply_lora(
99
+ loaded_pipeline.text_encoder,
100
+ torch.load(_text_lora_path(path_2)),
101
+ alpha=alpha,
102
+ target_replace_module=["CLIPAttention"],
103
+ )
104
+
105
+ _tmp_output = output_path + ".tmp"
106
+
107
+ loaded_pipeline.save_pretrained(_tmp_output)
108
+ convert_to_ckpt(_tmp_output, output_path, as_half=True)
109
+ # remove the tmp_output folder
110
+ shutil.rmtree(_tmp_output)
111
+
112
+ else:
113
+ print("Unknown mode", mode)
114
+ raise ValueError(f"Unknown mode {mode}")
115
+
116
+
117
+ def main():
118
+ fire.Fire(add)
lora_diffusion/lora.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable, Dict, List, Optional, Tuple
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ import torch.nn as nn
10
+
11
+
12
+ class LoraInjectedLinear(nn.Module):
13
+ def __init__(self, in_features, out_features, bias=False, r=4):
14
+ super().__init__()
15
+
16
+ if r > min(in_features, out_features):
17
+ raise ValueError(
18
+ f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
19
+ )
20
+
21
+ self.linear = nn.Linear(in_features, out_features, bias)
22
+ self.lora_down = nn.Linear(in_features, r, bias=False)
23
+ self.lora_up = nn.Linear(r, out_features, bias=False)
24
+ self.scale = 1.0
25
+
26
+ nn.init.normal_(self.lora_down.weight, std=1 / r**2)
27
+ nn.init.zeros_(self.lora_up.weight)
28
+
29
+ def forward(self, input):
30
+ return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale
31
+
32
+
33
+ def inject_trainable_lora(
34
+ model: nn.Module,
35
+ target_replace_module: List[str] = ["CrossAttention", "Attention"],
36
+ r: int = 4,
37
+ loras=None, # path to lora .pt
38
+ ):
39
+ """
40
+ inject lora into model, and returns lora parameter groups.
41
+ """
42
+
43
+ require_grad_params = []
44
+ names = []
45
+
46
+ if loras != None:
47
+ loras = torch.load(loras)
48
+
49
+ for _module in model.modules():
50
+ if _module.__class__.__name__ in target_replace_module:
51
+
52
+ for name, _child_module in _module.named_modules():
53
+ if _child_module.__class__.__name__ == "Linear":
54
+
55
+ weight = _child_module.weight
56
+ bias = _child_module.bias
57
+ _tmp = LoraInjectedLinear(
58
+ _child_module.in_features,
59
+ _child_module.out_features,
60
+ _child_module.bias is not None,
61
+ r,
62
+ )
63
+ _tmp.linear.weight = weight
64
+ if bias is not None:
65
+ _tmp.linear.bias = bias
66
+
67
+ # switch the module
68
+ _module._modules[name] = _tmp
69
+
70
+ require_grad_params.append(
71
+ _module._modules[name].lora_up.parameters()
72
+ )
73
+ require_grad_params.append(
74
+ _module._modules[name].lora_down.parameters()
75
+ )
76
+
77
+ if loras != None:
78
+ _module._modules[name].lora_up.weight = loras.pop(0)
79
+ _module._modules[name].lora_down.weight = loras.pop(0)
80
+
81
+ _module._modules[name].lora_up.weight.requires_grad = True
82
+ _module._modules[name].lora_down.weight.requires_grad = True
83
+ names.append(name)
84
+ return require_grad_params, names
85
+
86
+
87
+ def extract_lora_ups_down(model, target_replace_module=["CrossAttention", "Attention"]):
88
+
89
+ loras = []
90
+
91
+ for _module in model.modules():
92
+ if _module.__class__.__name__ in target_replace_module:
93
+ for _child_module in _module.modules():
94
+ if _child_module.__class__.__name__ == "LoraInjectedLinear":
95
+ loras.append((_child_module.lora_up, _child_module.lora_down))
96
+ if len(loras) == 0:
97
+ raise ValueError("No lora injected.")
98
+ return loras
99
+
100
+
101
+ def save_lora_weight(
102
+ model, path="./lora.pt", target_replace_module=["CrossAttention", "Attention"]
103
+ ):
104
+ weights = []
105
+ for _up, _down in extract_lora_ups_down(
106
+ model, target_replace_module=target_replace_module
107
+ ):
108
+ weights.append(_up.weight)
109
+ weights.append(_down.weight)
110
+
111
+ torch.save(weights, path)
112
+
113
+
114
+ def save_lora_as_json(model, path="./lora.json"):
115
+ weights = []
116
+ for _up, _down in extract_lora_ups_down(model):
117
+ weights.append(_up.weight.detach().cpu().numpy().tolist())
118
+ weights.append(_down.weight.detach().cpu().numpy().tolist())
119
+
120
+ import json
121
+
122
+ with open(path, "w") as f:
123
+ json.dump(weights, f)
124
+
125
+
126
+ def weight_apply_lora(
127
+ model, loras, target_replace_module=["CrossAttention", "Attention"], alpha=1.0
128
+ ):
129
+
130
+ for _module in model.modules():
131
+ if _module.__class__.__name__ in target_replace_module:
132
+ for _child_module in _module.modules():
133
+ if _child_module.__class__.__name__ == "Linear":
134
+
135
+ weight = _child_module.weight
136
+
137
+ up_weight = loras.pop(0).detach().to(weight.device)
138
+ down_weight = loras.pop(0).detach().to(weight.device)
139
+
140
+ # W <- W + U * D
141
+ weight = weight + alpha * (up_weight @ down_weight).type(
142
+ weight.dtype
143
+ )
144
+ _child_module.weight = nn.Parameter(weight)
145
+
146
+
147
+ def monkeypatch_lora(
148
+ model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
149
+ ):
150
+ for _module in model.modules():
151
+ if _module.__class__.__name__ in target_replace_module:
152
+ for name, _child_module in _module.named_modules():
153
+ if _child_module.__class__.__name__ == "Linear":
154
+
155
+ weight = _child_module.weight
156
+ bias = _child_module.bias
157
+ _tmp = LoraInjectedLinear(
158
+ _child_module.in_features,
159
+ _child_module.out_features,
160
+ _child_module.bias is not None,
161
+ r=r,
162
+ )
163
+ _tmp.linear.weight = weight
164
+
165
+ if bias is not None:
166
+ _tmp.linear.bias = bias
167
+
168
+ # switch the module
169
+ _module._modules[name] = _tmp
170
+
171
+ up_weight = loras.pop(0)
172
+ down_weight = loras.pop(0)
173
+
174
+ _module._modules[name].lora_up.weight = nn.Parameter(
175
+ up_weight.type(weight.dtype)
176
+ )
177
+ _module._modules[name].lora_down.weight = nn.Parameter(
178
+ down_weight.type(weight.dtype)
179
+ )
180
+
181
+ _module._modules[name].to(weight.device)
182
+
183
+
184
+ def monkeypatch_replace_lora(
185
+ model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
186
+ ):
187
+ for _module in model.modules():
188
+ if _module.__class__.__name__ in target_replace_module:
189
+ for name, _child_module in _module.named_modules():
190
+ if _child_module.__class__.__name__ == "LoraInjectedLinear":
191
+
192
+ weight = _child_module.linear.weight
193
+ bias = _child_module.linear.bias
194
+ _tmp = LoraInjectedLinear(
195
+ _child_module.linear.in_features,
196
+ _child_module.linear.out_features,
197
+ _child_module.linear.bias is not None,
198
+ r=r,
199
+ )
200
+ _tmp.linear.weight = weight
201
+
202
+ if bias is not None:
203
+ _tmp.linear.bias = bias
204
+
205
+ # switch the module
206
+ _module._modules[name] = _tmp
207
+
208
+ up_weight = loras.pop(0)
209
+ down_weight = loras.pop(0)
210
+
211
+ _module._modules[name].lora_up.weight = nn.Parameter(
212
+ up_weight.type(weight.dtype)
213
+ )
214
+ _module._modules[name].lora_down.weight = nn.Parameter(
215
+ down_weight.type(weight.dtype)
216
+ )
217
+
218
+ _module._modules[name].to(weight.device)
219
+
220
+
221
+ def monkeypatch_add_lora(
222
+ model,
223
+ loras,
224
+ target_replace_module=["CrossAttention", "Attention"],
225
+ alpha: float = 1.0,
226
+ beta: float = 1.0,
227
+ ):
228
+ for _module in model.modules():
229
+ if _module.__class__.__name__ in target_replace_module:
230
+ for name, _child_module in _module.named_modules():
231
+ if _child_module.__class__.__name__ == "LoraInjectedLinear":
232
+
233
+ weight = _child_module.linear.weight
234
+
235
+ up_weight = loras.pop(0)
236
+ down_weight = loras.pop(0)
237
+
238
+ _module._modules[name].lora_up.weight = nn.Parameter(
239
+ up_weight.type(weight.dtype).to(weight.device) * alpha
240
+ + _module._modules[name].lora_up.weight.to(weight.device) * beta
241
+ )
242
+ _module._modules[name].lora_down.weight = nn.Parameter(
243
+ down_weight.type(weight.dtype).to(weight.device) * alpha
244
+ + _module._modules[name].lora_down.weight.to(weight.device)
245
+ * beta
246
+ )
247
+
248
+ _module._modules[name].to(weight.device)
249
+
250
+
251
+ def tune_lora_scale(model, alpha: float = 1.0):
252
+ for _module in model.modules():
253
+ if _module.__class__.__name__ == "LoraInjectedLinear":
254
+ _module.scale = alpha
255
+
256
+
257
+ def _text_lora_path(path: str) -> str:
258
+ assert path.endswith(".pt"), "Only .pt files are supported"
259
+ return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
260
+
261
+
262
+ def _ti_lora_path(path: str) -> str:
263
+ assert path.endswith(".pt"), "Only .pt files are supported"
264
+ return ".".join(path.split(".")[:-1] + ["ti", "pt"])
265
+
266
+
267
+ def load_learned_embed_in_clip(
268
+ learned_embeds_path, text_encoder, tokenizer, token=None, idempotent=False
269
+ ):
270
+ loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
271
+
272
+ # separate token and the embeds
273
+ trained_token = list(loaded_learned_embeds.keys())[0]
274
+ embeds = loaded_learned_embeds[trained_token]
275
+
276
+ # cast to dtype of text_encoder
277
+ dtype = text_encoder.get_input_embeddings().weight.dtype
278
+
279
+ # add the token in tokenizer
280
+ token = token if token is not None else trained_token
281
+ num_added_tokens = tokenizer.add_tokens(token)
282
+ i = 1
283
+ if num_added_tokens == 0 and idempotent:
284
+ return token
285
+
286
+ while num_added_tokens == 0:
287
+ print(f"The tokenizer already contains the token {token}.")
288
+ token = f"{token[:-1]}-{i}>"
289
+ print(f"Attempting to add the token {token}.")
290
+ num_added_tokens = tokenizer.add_tokens(token)
291
+ i += 1
292
+
293
+ # resize the token embeddings
294
+ text_encoder.resize_token_embeddings(len(tokenizer))
295
+
296
+ # get the id for the token and assign the embeds
297
+ token_id = tokenizer.convert_tokens_to_ids(token)
298
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
299
+ return token
300
+
301
+
302
+ def patch_pipe(
303
+ pipe,
304
+ unet_path,
305
+ token,
306
+ alpha: float = 1.0,
307
+ r: int = 4,
308
+ patch_text=False,
309
+ patch_ti=False,
310
+ idempotent_token=True,
311
+ ):
312
+
313
+ ti_path = _ti_lora_path(unet_path)
314
+ text_path = _text_lora_path(unet_path)
315
+
316
+ unet_has_lora = False
317
+ text_encoder_has_lora = False
318
+
319
+ for _module in pipe.unet.modules():
320
+ if _module.__class__.__name__ == "LoraInjectedLinear":
321
+ unet_has_lora = True
322
+
323
+ for _module in pipe.text_encoder.modules():
324
+ if _module.__class__.__name__ == "LoraInjectedLinear":
325
+ text_encoder_has_lora = True
326
+
327
+ if not unet_has_lora:
328
+ monkeypatch_lora(pipe.unet, torch.load(unet_path), r=r)
329
+ else:
330
+ monkeypatch_replace_lora(pipe.unet, torch.load(unet_path), r=r)
331
+
332
+ if patch_text:
333
+ if not text_encoder_has_lora:
334
+ monkeypatch_lora(
335
+ pipe.text_encoder,
336
+ torch.load(text_path),
337
+ target_replace_module=["CLIPAttention"],
338
+ r=r,
339
+ )
340
+ else:
341
+
342
+ monkeypatch_replace_lora(
343
+ pipe.text_encoder,
344
+ torch.load(text_path),
345
+ target_replace_module=["CLIPAttention"],
346
+ r=r,
347
+ )
348
+ if patch_ti:
349
+ token = load_learned_embed_in_clip(
350
+ ti_path,
351
+ pipe.text_encoder,
352
+ pipe.tokenizer,
353
+ token,
354
+ idempotent=idempotent_token,
355
+ )
lora_diffusion/to_ckpt_v2.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05
2
+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
3
+ # *Only* converts the UNet, VAE, and Text Encoder.
4
+ # Does not convert optimizer state or any other thing.
5
+ # Written by jachiam
6
+ import argparse
7
+ import os.path as osp
8
+
9
+ import torch
10
+
11
+
12
+ # =================#
13
+ # UNet Conversion #
14
+ # =================#
15
+
16
+ unet_conversion_map = [
17
+ # (stable-diffusion, HF Diffusers)
18
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
19
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
20
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
21
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
22
+ ("input_blocks.0.0.weight", "conv_in.weight"),
23
+ ("input_blocks.0.0.bias", "conv_in.bias"),
24
+ ("out.0.weight", "conv_norm_out.weight"),
25
+ ("out.0.bias", "conv_norm_out.bias"),
26
+ ("out.2.weight", "conv_out.weight"),
27
+ ("out.2.bias", "conv_out.bias"),
28
+ ]
29
+
30
+ unet_conversion_map_resnet = [
31
+ # (stable-diffusion, HF Diffusers)
32
+ ("in_layers.0", "norm1"),
33
+ ("in_layers.2", "conv1"),
34
+ ("out_layers.0", "norm2"),
35
+ ("out_layers.3", "conv2"),
36
+ ("emb_layers.1", "time_emb_proj"),
37
+ ("skip_connection", "conv_shortcut"),
38
+ ]
39
+
40
+ unet_conversion_map_layer = []
41
+ # hardcoded number of downblocks and resnets/attentions...
42
+ # would need smarter logic for other networks.
43
+ for i in range(4):
44
+ # loop over downblocks/upblocks
45
+
46
+ for j in range(2):
47
+ # loop over resnets/attentions for downblocks
48
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
49
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
50
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
51
+
52
+ if i < 3:
53
+ # no attention layers in down_blocks.3
54
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
55
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
56
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
57
+
58
+ for j in range(3):
59
+ # loop over resnets/attentions for upblocks
60
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
61
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
62
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
63
+
64
+ if i > 0:
65
+ # no attention layers in up_blocks.0
66
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
67
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
68
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
69
+
70
+ if i < 3:
71
+ # no downsample in down_blocks.3
72
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
73
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
74
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
75
+
76
+ # no upsample in up_blocks.3
77
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
78
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
79
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
80
+
81
+ hf_mid_atn_prefix = "mid_block.attentions.0."
82
+ sd_mid_atn_prefix = "middle_block.1."
83
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
84
+
85
+ for j in range(2):
86
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
87
+ sd_mid_res_prefix = f"middle_block.{2*j}."
88
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
89
+
90
+
91
+ def convert_unet_state_dict(unet_state_dict):
92
+ # buyer beware: this is a *brittle* function,
93
+ # and correct output requires that all of these pieces interact in
94
+ # the exact order in which I have arranged them.
95
+ mapping = {k: k for k in unet_state_dict.keys()}
96
+ for sd_name, hf_name in unet_conversion_map:
97
+ mapping[hf_name] = sd_name
98
+ for k, v in mapping.items():
99
+ if "resnets" in k:
100
+ for sd_part, hf_part in unet_conversion_map_resnet:
101
+ v = v.replace(hf_part, sd_part)
102
+ mapping[k] = v
103
+ for k, v in mapping.items():
104
+ for sd_part, hf_part in unet_conversion_map_layer:
105
+ v = v.replace(hf_part, sd_part)
106
+ mapping[k] = v
107
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
108
+ return new_state_dict
109
+
110
+
111
+ # ================#
112
+ # VAE Conversion #
113
+ # ================#
114
+
115
+ vae_conversion_map = [
116
+ # (stable-diffusion, HF Diffusers)
117
+ ("nin_shortcut", "conv_shortcut"),
118
+ ("norm_out", "conv_norm_out"),
119
+ ("mid.attn_1.", "mid_block.attentions.0."),
120
+ ]
121
+
122
+ for i in range(4):
123
+ # down_blocks have two resnets
124
+ for j in range(2):
125
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
126
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
127
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
128
+
129
+ if i < 3:
130
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
131
+ sd_downsample_prefix = f"down.{i}.downsample."
132
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
133
+
134
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
135
+ sd_upsample_prefix = f"up.{3-i}.upsample."
136
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
137
+
138
+ # up_blocks have three resnets
139
+ # also, up blocks in hf are numbered in reverse from sd
140
+ for j in range(3):
141
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
142
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
143
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
144
+
145
+ # this part accounts for mid blocks in both the encoder and the decoder
146
+ for i in range(2):
147
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
148
+ sd_mid_res_prefix = f"mid.block_{i+1}."
149
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
150
+
151
+
152
+ vae_conversion_map_attn = [
153
+ # (stable-diffusion, HF Diffusers)
154
+ ("norm.", "group_norm."),
155
+ ("q.", "query."),
156
+ ("k.", "key."),
157
+ ("v.", "value."),
158
+ ("proj_out.", "proj_attn."),
159
+ ]
160
+
161
+
162
+ def reshape_weight_for_sd(w):
163
+ # convert HF linear weights to SD conv2d weights
164
+ return w.reshape(*w.shape, 1, 1)
165
+
166
+
167
+ def convert_vae_state_dict(vae_state_dict):
168
+ mapping = {k: k for k in vae_state_dict.keys()}
169
+ for k, v in mapping.items():
170
+ for sd_part, hf_part in vae_conversion_map:
171
+ v = v.replace(hf_part, sd_part)
172
+ mapping[k] = v
173
+ for k, v in mapping.items():
174
+ if "attentions" in k:
175
+ for sd_part, hf_part in vae_conversion_map_attn:
176
+ v = v.replace(hf_part, sd_part)
177
+ mapping[k] = v
178
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
179
+ weights_to_convert = ["q", "k", "v", "proj_out"]
180
+ for k, v in new_state_dict.items():
181
+ for weight_name in weights_to_convert:
182
+ if f"mid.attn_1.{weight_name}.weight" in k:
183
+ print(f"Reshaping {k} for SD format")
184
+ new_state_dict[k] = reshape_weight_for_sd(v)
185
+ return new_state_dict
186
+
187
+
188
+ # =========================#
189
+ # Text Encoder Conversion #
190
+ # =========================#
191
+ # pretty much a no-op
192
+
193
+
194
+ def convert_text_enc_state_dict(text_enc_dict):
195
+ return text_enc_dict
196
+
197
+
198
+ def convert_to_ckpt(model_path, checkpoint_path, as_half):
199
+
200
+ assert model_path is not None, "Must provide a model path!"
201
+
202
+ assert checkpoint_path is not None, "Must provide a checkpoint path!"
203
+
204
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
205
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
206
+ text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
207
+
208
+ # Convert the UNet model
209
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
210
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
211
+ unet_state_dict = {
212
+ "model.diffusion_model." + k: v for k, v in unet_state_dict.items()
213
+ }
214
+
215
+ # Convert the VAE model
216
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
217
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
218
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
219
+
220
+ # Convert the text encoder model
221
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
222
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
223
+ text_enc_dict = {
224
+ "cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()
225
+ }
226
+
227
+ # Put together new checkpoint
228
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
229
+ if as_half:
230
+ state_dict = {k: v.half() for k, v in state_dict.items()}
231
+ state_dict = {"state_dict": state_dict}
232
+ torch.save(state_dict, checkpoint_path)
lora_disney.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72f687f810b86bb8cc64d2ece59886e2e96d29e3f57f97340ee147d168b8a5fe
3
+ size 3397249
lora_illust.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f6acb0bc0cd5f96299be7839f89f58727e2666e58861e55866ea02125c97aba
3
+ size 3397249
lora_playgroundai_wt.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4b62bd14a24b58dd46a07e348b41b9c97cccb62859ecd28e7b855cea4b845e4
3
+ size 3398857
lora_pop.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18a1565852a08cfcff63e90670286c9427e3958f57de9b84e3f8b2c9a3a14b6c
3
+ size 3397249
lora_weight.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33c55ddef387cec18e5990ccae0809c8c842c244892b9502fd7e65be5301266f
3
+ size 3398857
output_example/dummy.txt ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ git+https://github.com/huggingface/accelerate
3
+ git+https://github.com/huggingface/diffusers
4
+ git+https://github.com/cloneofsimo/lora.git
5
+ scipy
6
+ ftfy
7
+ torchvision
run_lora_db.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #https://github.com/huggingface/diffusers/tree/main/examples/dreambooth
2
+ export MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
3
+ export INSTANCE_DIR="./data_example"
4
+ export OUTPUT_DIR="./output_example"
5
+
6
+ accelerate launch train_lora_dreambooth.py \
7
+ --pretrained_model_name_or_path=$MODEL_NAME \
8
+ --instance_data_dir=$INSTANCE_DIR \
9
+ --output_dir=$OUTPUT_DIR \
10
+ --instance_prompt="style of sks" \
11
+ --resolution=512 \
12
+ --train_batch_size=1 \
13
+ --gradient_accumulation_steps=1 \
14
+ --learning_rate=1e-4 \
15
+ --lr_scheduler="constant" \
16
+ --lr_warmup_steps=0 \
17
+ --max_train_steps=30000
scripts/make_alpha_gifs.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
scripts/run_inference.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
setup.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pkg_resources
4
+ from setuptools import find_packages, setup
5
+
6
+ setup(
7
+ name="lora_diffusion",
8
+ py_modules=["lora_diffusion"],
9
+ version="0.0.1",
10
+ description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.",
11
+ author="Simo Ryu",
12
+ packages=find_packages(),
13
+ entry_points={
14
+ "console_scripts": [
15
+ "lora_add = lora_diffusion.cli_lora_add:main",
16
+ ],
17
+ },
18
+ install_requires=[
19
+ str(r)
20
+ for r in pkg_resources.parse_requirements(
21
+ open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
22
+ )
23
+ ],
24
+ include_package_data=True,
25
+ )
simba1.jpg ADDED
simba2.jpg ADDED
simba3.jpg ADDED
simba4.jpg ADDED
train_lora_dreambooth.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Bootstrapped from:
2
+ # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
3
+
4
+ import argparse
5
+ import hashlib
6
+ import itertools
7
+ import math
8
+ import os
9
+ import inspect
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint
16
+
17
+
18
+ from accelerate import Accelerator
19
+ from accelerate.logging import get_logger
20
+ from accelerate.utils import set_seed
21
+ from diffusers import (
22
+ AutoencoderKL,
23
+ DDPMScheduler,
24
+ StableDiffusionPipeline,
25
+ UNet2DConditionModel,
26
+ )
27
+ from diffusers.optimization import get_scheduler
28
+ from huggingface_hub import HfFolder, Repository, whoami
29
+
30
+ from tqdm.auto import tqdm
31
+ from transformers import CLIPTextModel, CLIPTokenizer
32
+
33
+ from lora_diffusion import (
34
+ inject_trainable_lora,
35
+ save_lora_weight,
36
+ extract_lora_ups_down,
37
+ )
38
+
39
+ from torch.utils.data import Dataset
40
+ from PIL import Image
41
+ from torchvision import transforms
42
+
43
+ from pathlib import Path
44
+
45
+ import random
46
+ import re
47
+
48
+
49
+ class DreamBoothDataset(Dataset):
50
+ """
51
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
52
+ It pre-processes the images and the tokenizes prompts.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ instance_data_root,
58
+ instance_prompt,
59
+ tokenizer,
60
+ class_data_root=None,
61
+ class_prompt=None,
62
+ size=512,
63
+ center_crop=False,
64
+ color_jitter=False,
65
+ ):
66
+ self.size = size
67
+ self.center_crop = center_crop
68
+ self.tokenizer = tokenizer
69
+
70
+ self.instance_data_root = Path(instance_data_root)
71
+ if not self.instance_data_root.exists():
72
+ raise ValueError("Instance images root doesn't exists.")
73
+
74
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
75
+ self.num_instance_images = len(self.instance_images_path)
76
+ self.instance_prompt = instance_prompt
77
+ self._length = self.num_instance_images
78
+
79
+ if class_data_root is not None:
80
+ self.class_data_root = Path(class_data_root)
81
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
82
+ self.class_images_path = list(self.class_data_root.iterdir())
83
+ self.num_class_images = len(self.class_images_path)
84
+ self._length = max(self.num_class_images, self.num_instance_images)
85
+ self.class_prompt = class_prompt
86
+ else:
87
+ self.class_data_root = None
88
+
89
+ self.image_transforms = transforms.Compose(
90
+ [
91
+ transforms.Resize(
92
+ size, interpolation=transforms.InterpolationMode.BILINEAR
93
+ ),
94
+ transforms.CenterCrop(size)
95
+ if center_crop
96
+ else transforms.RandomCrop(size),
97
+ transforms.ColorJitter(0.2, 0.1)
98
+ if color_jitter
99
+ else transforms.Lambda(lambda x: x),
100
+ transforms.ToTensor(),
101
+ transforms.Normalize([0.5], [0.5]),
102
+ ]
103
+ )
104
+
105
+ def __len__(self):
106
+ return self._length
107
+
108
+ def __getitem__(self, index):
109
+ example = {}
110
+ instance_image = Image.open(
111
+ self.instance_images_path[index % self.num_instance_images]
112
+ )
113
+ if not instance_image.mode == "RGB":
114
+ instance_image = instance_image.convert("RGB")
115
+ example["instance_images"] = self.image_transforms(instance_image)
116
+ example["instance_prompt_ids"] = self.tokenizer(
117
+ self.instance_prompt,
118
+ padding="do_not_pad",
119
+ truncation=True,
120
+ max_length=self.tokenizer.model_max_length,
121
+ ).input_ids
122
+
123
+ if self.class_data_root:
124
+ class_image = Image.open(
125
+ self.class_images_path[index % self.num_class_images]
126
+ )
127
+ if not class_image.mode == "RGB":
128
+ class_image = class_image.convert("RGB")
129
+ example["class_images"] = self.image_transforms(class_image)
130
+ example["class_prompt_ids"] = self.tokenizer(
131
+ self.class_prompt,
132
+ padding="do_not_pad",
133
+ truncation=True,
134
+ max_length=self.tokenizer.model_max_length,
135
+ ).input_ids
136
+
137
+ return example
138
+
139
+
140
+ class PromptDataset(Dataset):
141
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
142
+
143
+ def __init__(self, prompt, num_samples):
144
+ self.prompt = prompt
145
+ self.num_samples = num_samples
146
+
147
+ def __len__(self):
148
+ return self.num_samples
149
+
150
+ def __getitem__(self, index):
151
+ example = {}
152
+ example["prompt"] = self.prompt
153
+ example["index"] = index
154
+ return example
155
+
156
+
157
+ logger = get_logger(__name__)
158
+
159
+
160
+ def parse_args(input_args=None):
161
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
162
+ parser.add_argument(
163
+ "--pretrained_model_name_or_path",
164
+ type=str,
165
+ default=None,
166
+ required=True,
167
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
168
+ )
169
+ parser.add_argument(
170
+ "--pretrained_vae_name_or_path",
171
+ type=str,
172
+ default=None,
173
+ help="Path to pretrained vae or vae identifier from huggingface.co/models.",
174
+ )
175
+ parser.add_argument(
176
+ "--revision",
177
+ type=str,
178
+ default=None,
179
+ required=False,
180
+ help="Revision of pretrained model identifier from huggingface.co/models.",
181
+ )
182
+ parser.add_argument(
183
+ "--tokenizer_name",
184
+ type=str,
185
+ default=None,
186
+ help="Pretrained tokenizer name or path if not the same as model_name",
187
+ )
188
+ parser.add_argument(
189
+ "--instance_data_dir",
190
+ type=str,
191
+ default=None,
192
+ required=True,
193
+ help="A folder containing the training data of instance images.",
194
+ )
195
+ parser.add_argument(
196
+ "--class_data_dir",
197
+ type=str,
198
+ default=None,
199
+ required=False,
200
+ help="A folder containing the training data of class images.",
201
+ )
202
+ parser.add_argument(
203
+ "--instance_prompt",
204
+ type=str,
205
+ default=None,
206
+ required=True,
207
+ help="The prompt with identifier specifying the instance",
208
+ )
209
+ parser.add_argument(
210
+ "--class_prompt",
211
+ type=str,
212
+ default=None,
213
+ help="The prompt to specify images in the same class as provided instance images.",
214
+ )
215
+ parser.add_argument(
216
+ "--with_prior_preservation",
217
+ default=False,
218
+ action="store_true",
219
+ help="Flag to add prior preservation loss.",
220
+ )
221
+ parser.add_argument(
222
+ "--prior_loss_weight",
223
+ type=float,
224
+ default=1.0,
225
+ help="The weight of prior preservation loss.",
226
+ )
227
+ parser.add_argument(
228
+ "--num_class_images",
229
+ type=int,
230
+ default=100,
231
+ help=(
232
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
233
+ " sampled with class_prompt."
234
+ ),
235
+ )
236
+ parser.add_argument(
237
+ "--output_dir",
238
+ type=str,
239
+ default="text-inversion-model",
240
+ help="The output directory where the model predictions and checkpoints will be written.",
241
+ )
242
+ parser.add_argument(
243
+ "--seed", type=int, default=None, help="A seed for reproducible training."
244
+ )
245
+ parser.add_argument(
246
+ "--resolution",
247
+ type=int,
248
+ default=512,
249
+ help=(
250
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
251
+ " resolution"
252
+ ),
253
+ )
254
+ parser.add_argument(
255
+ "--center_crop",
256
+ action="store_true",
257
+ help="Whether to center crop images before resizing to resolution",
258
+ )
259
+ parser.add_argument(
260
+ "--color_jitter",
261
+ action="store_true",
262
+ help="Whether to apply color jitter to images",
263
+ )
264
+ parser.add_argument(
265
+ "--train_text_encoder",
266
+ action="store_true",
267
+ help="Whether to train the text encoder",
268
+ )
269
+ parser.add_argument(
270
+ "--train_batch_size",
271
+ type=int,
272
+ default=4,
273
+ help="Batch size (per device) for the training dataloader.",
274
+ )
275
+ parser.add_argument(
276
+ "--sample_batch_size",
277
+ type=int,
278
+ default=4,
279
+ help="Batch size (per device) for sampling images.",
280
+ )
281
+ parser.add_argument("--num_train_epochs", type=int, default=1)
282
+ parser.add_argument(
283
+ "--max_train_steps",
284
+ type=int,
285
+ default=None,
286
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
287
+ )
288
+ parser.add_argument(
289
+ "--save_steps",
290
+ type=int,
291
+ default=500,
292
+ help="Save checkpoint every X updates steps.",
293
+ )
294
+ parser.add_argument(
295
+ "--gradient_accumulation_steps",
296
+ type=int,
297
+ default=1,
298
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
299
+ )
300
+ parser.add_argument(
301
+ "--gradient_checkpointing",
302
+ action="store_true",
303
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
304
+ )
305
+ parser.add_argument(
306
+ "--lora_rank",
307
+ type=int,
308
+ default=4,
309
+ help="Rank of LoRA approximation.",
310
+ )
311
+ parser.add_argument(
312
+ "--learning_rate",
313
+ type=float,
314
+ default=None,
315
+ help="Initial learning rate (after the potential warmup period) to use.",
316
+ )
317
+ parser.add_argument(
318
+ "--learning_rate_text",
319
+ type=float,
320
+ default=5e-6,
321
+ help="Initial learning rate for text encoder (after the potential warmup period) to use.",
322
+ )
323
+ parser.add_argument(
324
+ "--scale_lr",
325
+ action="store_true",
326
+ default=False,
327
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
328
+ )
329
+ parser.add_argument(
330
+ "--lr_scheduler",
331
+ type=str,
332
+ default="constant",
333
+ help=(
334
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
335
+ ' "constant", "constant_with_warmup"]'
336
+ ),
337
+ )
338
+ parser.add_argument(
339
+ "--lr_warmup_steps",
340
+ type=int,
341
+ default=500,
342
+ help="Number of steps for the warmup in the lr scheduler.",
343
+ )
344
+ parser.add_argument(
345
+ "--use_8bit_adam",
346
+ action="store_true",
347
+ help="Whether or not to use 8-bit Adam from bitsandbytes.",
348
+ )
349
+ parser.add_argument(
350
+ "--adam_beta1",
351
+ type=float,
352
+ default=0.9,
353
+ help="The beta1 parameter for the Adam optimizer.",
354
+ )
355
+ parser.add_argument(
356
+ "--adam_beta2",
357
+ type=float,
358
+ default=0.999,
359
+ help="The beta2 parameter for the Adam optimizer.",
360
+ )
361
+ parser.add_argument(
362
+ "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
363
+ )
364
+ parser.add_argument(
365
+ "--adam_epsilon",
366
+ type=float,
367
+ default=1e-08,
368
+ help="Epsilon value for the Adam optimizer",
369
+ )
370
+ parser.add_argument(
371
+ "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
372
+ )
373
+ parser.add_argument(
374
+ "--push_to_hub",
375
+ action="store_true",
376
+ help="Whether or not to push the model to the Hub.",
377
+ )
378
+ parser.add_argument(
379
+ "--hub_token",
380
+ type=str,
381
+ default=None,
382
+ help="The token to use to push to the Model Hub.",
383
+ )
384
+ parser.add_argument(
385
+ "--logging_dir",
386
+ type=str,
387
+ default="logs",
388
+ help=(
389
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
390
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
391
+ ),
392
+ )
393
+ parser.add_argument(
394
+ "--mixed_precision",
395
+ type=str,
396
+ default=None,
397
+ choices=["no", "fp16", "bf16"],
398
+ help=(
399
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
400
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
401
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
402
+ ),
403
+ )
404
+ parser.add_argument(
405
+ "--local_rank",
406
+ type=int,
407
+ default=-1,
408
+ help="For distributed training: local_rank",
409
+ )
410
+ parser.add_argument(
411
+ "--resume_unet",
412
+ type=str,
413
+ default=None,
414
+ help=("File path for unet lora to resume training."),
415
+ )
416
+ parser.add_argument(
417
+ "--resume_text_encoder",
418
+ type=str,
419
+ default=None,
420
+ help=("File path for text encoder lora to resume training."),
421
+ )
422
+
423
+ if input_args is not None:
424
+ args = parser.parse_args(input_args)
425
+ else:
426
+ args = parser.parse_args()
427
+
428
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
429
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
430
+ args.local_rank = env_local_rank
431
+
432
+ if args.with_prior_preservation:
433
+ if args.class_data_dir is None:
434
+ raise ValueError("You must specify a data directory for class images.")
435
+ if args.class_prompt is None:
436
+ raise ValueError("You must specify prompt for class images.")
437
+ else:
438
+ if args.class_data_dir is not None:
439
+ logger.warning(
440
+ "You need not use --class_data_dir without --with_prior_preservation."
441
+ )
442
+ if args.class_prompt is not None:
443
+ logger.warning(
444
+ "You need not use --class_prompt without --with_prior_preservation."
445
+ )
446
+
447
+ return args
448
+
449
+
450
+ def main(args):
451
+ logging_dir = Path(args.output_dir, args.logging_dir)
452
+
453
+ accelerator = Accelerator(
454
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
455
+ mixed_precision=args.mixed_precision,
456
+ log_with="tensorboard",
457
+ logging_dir=logging_dir,
458
+ )
459
+
460
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
461
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
462
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
463
+ if (
464
+ args.train_text_encoder
465
+ and args.gradient_accumulation_steps > 1
466
+ and accelerator.num_processes > 1
467
+ ):
468
+ raise ValueError(
469
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
470
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
471
+ )
472
+
473
+ if args.seed is not None:
474
+ set_seed(args.seed)
475
+
476
+ if args.with_prior_preservation:
477
+ class_images_dir = Path(args.class_data_dir)
478
+ if not class_images_dir.exists():
479
+ class_images_dir.mkdir(parents=True)
480
+ cur_class_images = len(list(class_images_dir.iterdir()))
481
+
482
+ if cur_class_images < args.num_class_images:
483
+ torch_dtype = (
484
+ torch.float16 if accelerator.device.type == "cuda" else torch.float32
485
+ )
486
+ pipeline = StableDiffusionPipeline.from_pretrained(
487
+ args.pretrained_model_name_or_path,
488
+ torch_dtype=torch_dtype,
489
+ safety_checker=None,
490
+ revision=args.revision,
491
+ )
492
+ pipeline.set_progress_bar_config(disable=True)
493
+
494
+ num_new_images = args.num_class_images - cur_class_images
495
+ logger.info(f"Number of class images to sample: {num_new_images}.")
496
+
497
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
498
+ sample_dataloader = torch.utils.data.DataLoader(
499
+ sample_dataset, batch_size=args.sample_batch_size
500
+ )
501
+
502
+ sample_dataloader = accelerator.prepare(sample_dataloader)
503
+ pipeline.to(accelerator.device)
504
+
505
+ for example in tqdm(
506
+ sample_dataloader,
507
+ desc="Generating class images",
508
+ disable=not accelerator.is_local_main_process,
509
+ ):
510
+ images = pipeline(example["prompt"]).images
511
+
512
+ for i, image in enumerate(images):
513
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
514
+ image_filename = (
515
+ class_images_dir
516
+ / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
517
+ )
518
+ image.save(image_filename)
519
+
520
+ del pipeline
521
+ if torch.cuda.is_available():
522
+ torch.cuda.empty_cache()
523
+
524
+ # Handle the repository creation
525
+ if accelerator.is_main_process:
526
+
527
+ if args.output_dir is not None:
528
+ os.makedirs(args.output_dir, exist_ok=True)
529
+
530
+ # Load the tokenizer
531
+ if args.tokenizer_name:
532
+ tokenizer = CLIPTokenizer.from_pretrained(
533
+ args.tokenizer_name,
534
+ revision=args.revision,
535
+ )
536
+ elif args.pretrained_model_name_or_path:
537
+ tokenizer = CLIPTokenizer.from_pretrained(
538
+ args.pretrained_model_name_or_path,
539
+ subfolder="tokenizer",
540
+ revision=args.revision,
541
+ )
542
+
543
+ # Load models and create wrapper for stable diffusion
544
+ text_encoder = CLIPTextModel.from_pretrained(
545
+ args.pretrained_model_name_or_path,
546
+ subfolder="text_encoder",
547
+ revision=args.revision,
548
+ )
549
+ vae = AutoencoderKL.from_pretrained(
550
+ args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
551
+ subfolder=None if args.pretrained_vae_name_or_path else "vae",
552
+ revision=None if args.pretrained_vae_name_or_path else args.revision,
553
+ )
554
+ unet = UNet2DConditionModel.from_pretrained(
555
+ args.pretrained_model_name_or_path,
556
+ subfolder="unet",
557
+ revision=args.revision,
558
+ )
559
+ unet.requires_grad_(False)
560
+ unet_lora_params, _ = inject_trainable_lora(
561
+ unet, r=args.lora_rank, loras=args.resume_unet
562
+ )
563
+
564
+ for _up, _down in extract_lora_ups_down(unet):
565
+ print("Before training: Unet First Layer lora up", _up.weight.data)
566
+ print("Before training: Unet First Layer lora down", _down.weight.data)
567
+ break
568
+
569
+ vae.requires_grad_(False)
570
+ text_encoder.requires_grad_(False)
571
+
572
+ if args.train_text_encoder:
573
+ text_encoder_lora_params, _ = inject_trainable_lora(
574
+ text_encoder,
575
+ target_replace_module=["CLIPAttention"],
576
+ r=args.lora_rank,
577
+ )
578
+ for _up, _down in extract_lora_ups_down(
579
+ text_encoder, target_replace_module=["CLIPAttention"]
580
+ ):
581
+ print("Before training: text encoder First Layer lora up", _up.weight.data)
582
+ print(
583
+ "Before training: text encoder First Layer lora down", _down.weight.data
584
+ )
585
+ break
586
+
587
+ if args.gradient_checkpointing:
588
+ unet.enable_gradient_checkpointing()
589
+ if args.train_text_encoder:
590
+ text_encoder.gradient_checkpointing_enable()
591
+
592
+ if args.scale_lr:
593
+ args.learning_rate = (
594
+ args.learning_rate
595
+ * args.gradient_accumulation_steps
596
+ * args.train_batch_size
597
+ * accelerator.num_processes
598
+ )
599
+
600
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
601
+ if args.use_8bit_adam:
602
+ try:
603
+ import bitsandbytes as bnb
604
+ except ImportError:
605
+ raise ImportError(
606
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
607
+ )
608
+
609
+ optimizer_class = bnb.optim.AdamW8bit
610
+ else:
611
+ optimizer_class = torch.optim.AdamW
612
+
613
+ text_lr = (
614
+ args.learning_rate
615
+ if args.learning_rate_text is None
616
+ else args.learning_rate_text
617
+ )
618
+
619
+ params_to_optimize = (
620
+ [
621
+ {"params": itertools.chain(*unet_lora_params), "lr": args.learning_rate},
622
+ {
623
+ "params": itertools.chain(*text_encoder_lora_params),
624
+ "lr": text_lr,
625
+ },
626
+ ]
627
+ if args.train_text_encoder
628
+ else itertools.chain(*unet_lora_params)
629
+ )
630
+ optimizer = optimizer_class(
631
+ params_to_optimize,
632
+ lr=args.learning_rate,
633
+ betas=(args.adam_beta1, args.adam_beta2),
634
+ weight_decay=args.adam_weight_decay,
635
+ eps=args.adam_epsilon,
636
+ )
637
+
638
+ noise_scheduler = DDPMScheduler.from_config(
639
+ args.pretrained_model_name_or_path, subfolder="scheduler"
640
+ )
641
+
642
+ train_dataset = DreamBoothDataset(
643
+ instance_data_root=args.instance_data_dir,
644
+ instance_prompt=args.instance_prompt,
645
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
646
+ class_prompt=args.class_prompt,
647
+ tokenizer=tokenizer,
648
+ size=args.resolution,
649
+ center_crop=args.center_crop,
650
+ color_jitter=args.color_jitter,
651
+ )
652
+
653
+ def collate_fn(examples):
654
+ input_ids = [example["instance_prompt_ids"] for example in examples]
655
+ pixel_values = [example["instance_images"] for example in examples]
656
+
657
+ # Concat class and instance examples for prior preservation.
658
+ # We do this to avoid doing two forward passes.
659
+ if args.with_prior_preservation:
660
+ input_ids += [example["class_prompt_ids"] for example in examples]
661
+ pixel_values += [example["class_images"] for example in examples]
662
+
663
+ pixel_values = torch.stack(pixel_values)
664
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
665
+
666
+ input_ids = tokenizer.pad(
667
+ {"input_ids": input_ids},
668
+ padding="max_length",
669
+ max_length=tokenizer.model_max_length,
670
+ return_tensors="pt",
671
+ ).input_ids
672
+
673
+ batch = {
674
+ "input_ids": input_ids,
675
+ "pixel_values": pixel_values,
676
+ }
677
+ return batch
678
+
679
+ train_dataloader = torch.utils.data.DataLoader(
680
+ train_dataset,
681
+ batch_size=args.train_batch_size,
682
+ shuffle=True,
683
+ collate_fn=collate_fn,
684
+ num_workers=1,
685
+ )
686
+
687
+ # Scheduler and math around the number of training steps.
688
+ overrode_max_train_steps = False
689
+ num_update_steps_per_epoch = math.ceil(
690
+ len(train_dataloader) / args.gradient_accumulation_steps
691
+ )
692
+ if args.max_train_steps is None:
693
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
694
+ overrode_max_train_steps = True
695
+
696
+ lr_scheduler = get_scheduler(
697
+ args.lr_scheduler,
698
+ optimizer=optimizer,
699
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
700
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
701
+ )
702
+
703
+ if args.train_text_encoder:
704
+ (
705
+ unet,
706
+ text_encoder,
707
+ optimizer,
708
+ train_dataloader,
709
+ lr_scheduler,
710
+ ) = accelerator.prepare(
711
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
712
+ )
713
+ else:
714
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
715
+ unet, optimizer, train_dataloader, lr_scheduler
716
+ )
717
+
718
+ weight_dtype = torch.float32
719
+ if accelerator.mixed_precision == "fp16":
720
+ weight_dtype = torch.float16
721
+ elif accelerator.mixed_precision == "bf16":
722
+ weight_dtype = torch.bfloat16
723
+
724
+ # Move text_encode and vae to gpu.
725
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
726
+ # as these models are only used for inference, keeping weights in full precision is not required.
727
+ vae.to(accelerator.device, dtype=weight_dtype)
728
+ if not args.train_text_encoder:
729
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
730
+
731
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
732
+ num_update_steps_per_epoch = math.ceil(
733
+ len(train_dataloader) / args.gradient_accumulation_steps
734
+ )
735
+ if overrode_max_train_steps:
736
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
737
+ # Afterwards we recalculate our number of training epochs
738
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
739
+
740
+ # We need to initialize the trackers we use, and also store our configuration.
741
+ # The trackers initializes automatically on the main process.
742
+ if accelerator.is_main_process:
743
+ accelerator.init_trackers("dreambooth", config=vars(args))
744
+
745
+ # Train!
746
+ total_batch_size = (
747
+ args.train_batch_size
748
+ * accelerator.num_processes
749
+ * args.gradient_accumulation_steps
750
+ )
751
+
752
+ print("***** Running training *****")
753
+ print(f" Num examples = {len(train_dataset)}")
754
+ print(f" Num batches each epoch = {len(train_dataloader)}")
755
+ print(f" Num Epochs = {args.num_train_epochs}")
756
+ print(f" Instantaneous batch size per device = {args.train_batch_size}")
757
+ print(
758
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
759
+ )
760
+ print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
761
+ print(f" Total optimization steps = {args.max_train_steps}")
762
+ # Only show the progress bar once on each machine.
763
+ progress_bar = tqdm(
764
+ range(args.max_train_steps), disable=not accelerator.is_local_main_process
765
+ )
766
+ progress_bar.set_description("Steps")
767
+ global_step = 0
768
+ last_save = 0
769
+
770
+ for epoch in range(args.num_train_epochs):
771
+ unet.train()
772
+ if args.train_text_encoder:
773
+ text_encoder.train()
774
+
775
+ for step, batch in enumerate(train_dataloader):
776
+ # Convert images to latent space
777
+ latents = vae.encode(
778
+ batch["pixel_values"].to(dtype=weight_dtype)
779
+ ).latent_dist.sample()
780
+ latents = latents * 0.18215
781
+
782
+ # Sample noise that we'll add to the latents
783
+ noise = torch.randn_like(latents)
784
+ bsz = latents.shape[0]
785
+ # Sample a random timestep for each image
786
+ timesteps = torch.randint(
787
+ 0,
788
+ noise_scheduler.config.num_train_timesteps,
789
+ (bsz,),
790
+ device=latents.device,
791
+ )
792
+ timesteps = timesteps.long()
793
+
794
+ # Add noise to the latents according to the noise magnitude at each timestep
795
+ # (this is the forward diffusion process)
796
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
797
+
798
+ # Get the text embedding for conditioning
799
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
800
+
801
+ # Predict the noise residual
802
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
803
+
804
+ # Get the target for loss depending on the prediction type
805
+ if noise_scheduler.config.prediction_type == "epsilon":
806
+ target = noise
807
+ elif noise_scheduler.config.prediction_type == "v_prediction":
808
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
809
+ else:
810
+ raise ValueError(
811
+ f"Unknown prediction type {noise_scheduler.config.prediction_type}"
812
+ )
813
+
814
+ if args.with_prior_preservation:
815
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
816
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
817
+ target, target_prior = torch.chunk(target, 2, dim=0)
818
+
819
+ # Compute instance loss
820
+ loss = (
821
+ F.mse_loss(model_pred.float(), target.float(), reduction="none")
822
+ .mean([1, 2, 3])
823
+ .mean()
824
+ )
825
+
826
+ # Compute prior loss
827
+ prior_loss = F.mse_loss(
828
+ model_pred_prior.float(), target_prior.float(), reduction="mean"
829
+ )
830
+
831
+ # Add the prior loss to the instance loss.
832
+ loss = loss + args.prior_loss_weight * prior_loss
833
+ else:
834
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
835
+
836
+ accelerator.backward(loss)
837
+ if accelerator.sync_gradients:
838
+ params_to_clip = (
839
+ itertools.chain(unet.parameters(), text_encoder.parameters())
840
+ if args.train_text_encoder
841
+ else unet.parameters()
842
+ )
843
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
844
+ optimizer.step()
845
+ lr_scheduler.step()
846
+ progress_bar.update(1)
847
+ optimizer.zero_grad()
848
+
849
+ global_step += 1
850
+
851
+ # Checks if the accelerator has performed an optimization step behind the scenes
852
+ if accelerator.sync_gradients:
853
+ if args.save_steps and global_step - last_save >= args.save_steps:
854
+ if accelerator.is_main_process:
855
+ # newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
856
+ # it, the models will be unwrapped, and when they are then used for further training,
857
+ # we will crash. pass this, but only to newer versions of accelerate. fixes
858
+ # https://github.com/huggingface/diffusers/issues/1566
859
+ accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
860
+ inspect.signature(
861
+ accelerator.unwrap_model
862
+ ).parameters.keys()
863
+ )
864
+ extra_args = (
865
+ {"keep_fp32_wrapper": True}
866
+ if accepts_keep_fp32_wrapper
867
+ else {}
868
+ )
869
+ pipeline = StableDiffusionPipeline.from_pretrained(
870
+ args.pretrained_model_name_or_path,
871
+ unet=accelerator.unwrap_model(unet, **extra_args),
872
+ text_encoder=accelerator.unwrap_model(
873
+ text_encoder, **extra_args
874
+ ),
875
+ revision=args.revision,
876
+ )
877
+
878
+ filename_unet = (
879
+ f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
880
+ )
881
+ filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
882
+ print(f"save weights {filename_unet}, {filename_text_encoder}")
883
+ save_lora_weight(pipeline.unet, filename_unet)
884
+ if args.train_text_encoder:
885
+ save_lora_weight(
886
+ pipeline.text_encoder,
887
+ filename_text_encoder,
888
+ target_replace_module=["CLIPAttention"],
889
+ )
890
+
891
+ for _up, _down in extract_lora_ups_down(pipeline.unet):
892
+ print(
893
+ "First Unet Layer's Up Weight is now : ",
894
+ _up.weight.data,
895
+ )
896
+ print(
897
+ "First Unet Layer's Down Weight is now : ",
898
+ _down.weight.data,
899
+ )
900
+ break
901
+ if args.train_text_encoder:
902
+ for _up, _down in extract_lora_ups_down(
903
+ pipeline.text_encoder,
904
+ target_replace_module=["CLIPAttention"],
905
+ ):
906
+ print(
907
+ "First Text Encoder Layer's Up Weight is now : ",
908
+ _up.weight.data,
909
+ )
910
+ print(
911
+ "First Text Encoder Layer's Down Weight is now : ",
912
+ _down.weight.data,
913
+ )
914
+ break
915
+
916
+ last_save = global_step
917
+
918
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
919
+ progress_bar.set_postfix(**logs)
920
+ accelerator.log(logs, step=global_step)
921
+
922
+ if global_step >= args.max_train_steps:
923
+ break
924
+
925
+ accelerator.wait_for_everyone()
926
+
927
+ # Create the pipeline using using the trained modules and save it.
928
+ if accelerator.is_main_process:
929
+ pipeline = StableDiffusionPipeline.from_pretrained(
930
+ args.pretrained_model_name_or_path,
931
+ unet=accelerator.unwrap_model(unet),
932
+ text_encoder=accelerator.unwrap_model(text_encoder),
933
+ revision=args.revision,
934
+ )
935
+
936
+ print("\n\nLora TRAINING DONE!\n\n")
937
+
938
+ save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
939
+ if args.train_text_encoder:
940
+ save_lora_weight(
941
+ pipeline.text_encoder,
942
+ args.output_dir + "/lora_weight.text_encoder.pt",
943
+ target_replace_module=["CLIPAttention"],
944
+ )
945
+
946
+ if args.push_to_hub:
947
+ repo.push_to_hub(
948
+ commit_message="End of training", blocking=False, auto_lfs_prune=True
949
+ )
950
+
951
+ accelerator.end_training()
952
+
953
+
954
+ if __name__ == "__main__":
955
+ args = parse_args()
956
+ main(args)
train_lora_dreambooth1.py ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Bootstrapped from:
2
+ # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
3
+
4
+ import argparse
5
+ import hashlib
6
+ import itertools
7
+ import math
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint
15
+
16
+
17
+ from accelerate import Accelerator
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import set_seed
20
+ from diffusers import (
21
+ AutoencoderKL,
22
+ DDPMScheduler,
23
+ StableDiffusionPipeline,
24
+ UNet2DConditionModel,
25
+ )
26
+ from diffusers.optimization import get_scheduler
27
+ from huggingface_hub import HfFolder, Repository, whoami
28
+
29
+ from tqdm.auto import tqdm
30
+ from transformers import CLIPTextModel, CLIPTokenizer
31
+
32
+ from lora_diffusion import (
33
+ inject_trainable_lora,
34
+ save_lora_weight,
35
+ extract_lora_ups_down,
36
+ )
37
+
38
+ from torch.utils.data import Dataset
39
+ from PIL import Image
40
+ from torchvision import transforms
41
+
42
+ from pathlib import Path
43
+
44
+ import random
45
+ import re
46
+
47
+
48
+ class DreamBoothDataset(Dataset):
49
+ """
50
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
51
+ It pre-processes the images and the tokenizes prompts.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ instance_data_root,
57
+ instance_prompt,
58
+ tokenizer,
59
+ class_data_root=None,
60
+ class_prompt=None,
61
+ size=512,
62
+ center_crop=False,
63
+ ):
64
+ self.size = size
65
+ self.center_crop = center_crop
66
+ self.tokenizer = tokenizer
67
+
68
+ self.instance_data_root = Path(instance_data_root)
69
+ if not self.instance_data_root.exists():
70
+ raise ValueError("Instance images root doesn't exists.")
71
+
72
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
73
+ self.num_instance_images = len(self.instance_images_path)
74
+ self.instance_prompt = instance_prompt
75
+ self._length = self.num_instance_images
76
+
77
+ if class_data_root is not None:
78
+ self.class_data_root = Path(class_data_root)
79
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
80
+ self.class_images_path = list(self.class_data_root.iterdir())
81
+ self.num_class_images = len(self.class_images_path)
82
+ self._length = max(self.num_class_images, self.num_instance_images)
83
+ self.class_prompt = class_prompt
84
+ else:
85
+ self.class_data_root = None
86
+
87
+ self.image_transforms = transforms.Compose(
88
+ [
89
+ transforms.Resize(
90
+ size, interpolation=transforms.InterpolationMode.BILINEAR
91
+ ),
92
+ transforms.CenterCrop(size)
93
+ if center_crop
94
+ else transforms.RandomCrop(size),
95
+ transforms.ToTensor(),
96
+ transforms.Normalize([0.5], [0.5]),
97
+ ]
98
+ )
99
+
100
+ def __len__(self):
101
+ return self._length
102
+
103
+ def __getitem__(self, index):
104
+ example = {}
105
+ instance_image = Image.open(
106
+ self.instance_images_path[index % self.num_instance_images]
107
+ )
108
+ if not instance_image.mode == "RGB":
109
+ instance_image = instance_image.convert("RGB")
110
+ example["instance_images"] = self.image_transforms(instance_image)
111
+ example["instance_prompt_ids"] = self.tokenizer(
112
+ self.instance_prompt,
113
+ padding="do_not_pad",
114
+ truncation=True,
115
+ max_length=self.tokenizer.model_max_length,
116
+ ).input_ids
117
+
118
+ if self.class_data_root:
119
+ class_image = Image.open(
120
+ self.class_images_path[index % self.num_class_images]
121
+ )
122
+ if not class_image.mode == "RGB":
123
+ class_image = class_image.convert("RGB")
124
+ example["class_images"] = self.image_transforms(class_image)
125
+ example["class_prompt_ids"] = self.tokenizer(
126
+ self.class_prompt,
127
+ padding="do_not_pad",
128
+ truncation=True,
129
+ max_length=self.tokenizer.model_max_length,
130
+ ).input_ids
131
+
132
+ return example
133
+
134
+
135
+ class DreamBoothLabled(Dataset):
136
+ """
137
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
138
+ It pre-processes the images and the tokenizes prompts.
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ instance_data_root,
144
+ instance_prompt,
145
+ tokenizer,
146
+ class_data_root=None,
147
+ class_prompt=None,
148
+ size=512,
149
+ center_crop=False,
150
+ ):
151
+ self.size = size
152
+ self.center_crop = center_crop
153
+ self.tokenizer = tokenizer
154
+
155
+ self.instance_data_root = Path(instance_data_root)
156
+ if not self.instance_data_root.exists():
157
+ raise ValueError("Instance images root doesn't exists.")
158
+
159
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
160
+ self.num_instance_images = len(self.instance_images_path)
161
+ self.instance_prompt = instance_prompt
162
+ self._length = self.num_instance_images
163
+
164
+ if class_data_root is not None:
165
+ self.class_data_root = Path(class_data_root)
166
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
167
+ self.class_images_path = list(self.class_data_root.iterdir())
168
+ self.num_class_images = len(self.class_images_path)
169
+ self._length = max(self.num_class_images, self.num_instance_images)
170
+ self.class_prompt = class_prompt
171
+ else:
172
+ self.class_data_root = None
173
+
174
+ self.image_transforms = transforms.Compose(
175
+ [
176
+ transforms.Resize(
177
+ size, interpolation=transforms.InterpolationMode.BILINEAR
178
+ ),
179
+ transforms.CenterCrop(size)
180
+ if center_crop
181
+ else transforms.RandomCrop(size),
182
+ transforms.ToTensor(),
183
+ transforms.Normalize([0.5], [0.5]),
184
+ ]
185
+ )
186
+
187
+ def __len__(self):
188
+ return self._length
189
+
190
+ def __getitem__(self, index):
191
+ example = {}
192
+ instance_image = Image.open(
193
+ self.instance_images_path[index % self.num_instance_images]
194
+ )
195
+
196
+ instance_prompt = (
197
+ str(self.instance_images_path[index % self.num_instance_images])
198
+ .split("/")[-1]
199
+ .split(".")[0]
200
+ .replace("-", " ")
201
+ )
202
+ # remove numbers in prompt
203
+ instance_prompt = re.sub(r"\d+", "", instance_prompt)
204
+ # print(instance_prompt)
205
+
206
+ _svg = random.choice(["svg", "flat color", "vector illustration", "sks"])
207
+ instance_prompt = f"{instance_prompt}, style of {_svg}"
208
+
209
+ if not instance_image.mode == "RGB":
210
+ instance_image = instance_image.convert("RGB")
211
+ example["instance_images"] = self.image_transforms(instance_image)
212
+ example["instance_prompt_ids"] = self.tokenizer(
213
+ instance_prompt,
214
+ padding="do_not_pad",
215
+ truncation=True,
216
+ max_length=self.tokenizer.model_max_length,
217
+ ).input_ids
218
+
219
+ if self.class_data_root:
220
+ class_image = Image.open(
221
+ self.class_images_path[index % self.num_class_images]
222
+ )
223
+ if not class_image.mode == "RGB":
224
+ class_image = class_image.convert("RGB")
225
+ example["class_images"] = self.image_transforms(class_image)
226
+ example["class_prompt_ids"] = self.tokenizer(
227
+ self.class_prompt,
228
+ padding="do_not_pad",
229
+ truncation=True,
230
+ max_length=self.tokenizer.model_max_length,
231
+ ).input_ids
232
+
233
+ return example
234
+
235
+
236
+ class PromptDataset(Dataset):
237
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
238
+
239
+ def __init__(self, prompt, num_samples):
240
+ self.prompt = prompt
241
+ self.num_samples = num_samples
242
+
243
+ def __len__(self):
244
+ return self.num_samples
245
+
246
+ def __getitem__(self, index):
247
+ example = {}
248
+ example["prompt"] = self.prompt
249
+ example["index"] = index
250
+ return example
251
+
252
+
253
+ logger = get_logger(__name__)
254
+
255
+
256
+ def parse_args(input_args=None):
257
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
258
+ parser.add_argument(
259
+ "--pretrained_model_name_or_path",
260
+ type=str,
261
+ default=None,
262
+ required=True,
263
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
264
+ )
265
+ parser.add_argument(
266
+ "--revision",
267
+ type=str,
268
+ default=None,
269
+ required=False,
270
+ help="Revision of pretrained model identifier from huggingface.co/models.",
271
+ )
272
+ parser.add_argument(
273
+ "--tokenizer_name",
274
+ type=str,
275
+ default=None,
276
+ help="Pretrained tokenizer name or path if not the same as model_name",
277
+ )
278
+ parser.add_argument(
279
+ "--instance_data_dir",
280
+ type=str,
281
+ default=None,
282
+ required=True,
283
+ help="A folder containing the training data of instance images.",
284
+ )
285
+ parser.add_argument(
286
+ "--class_data_dir",
287
+ type=str,
288
+ default=None,
289
+ required=False,
290
+ help="A folder containing the training data of class images.",
291
+ )
292
+ parser.add_argument(
293
+ "--instance_prompt",
294
+ type=str,
295
+ default=None,
296
+ required=True,
297
+ help="The prompt with identifier specifying the instance",
298
+ )
299
+ parser.add_argument(
300
+ "--class_prompt",
301
+ type=str,
302
+ default=None,
303
+ help="The prompt to specify images in the same class as provided instance images.",
304
+ )
305
+ parser.add_argument(
306
+ "--with_prior_preservation",
307
+ default=False,
308
+ action="store_true",
309
+ help="Flag to add prior preservation loss.",
310
+ )
311
+ parser.add_argument(
312
+ "--prior_loss_weight",
313
+ type=float,
314
+ default=1.0,
315
+ help="The weight of prior preservation loss.",
316
+ )
317
+ parser.add_argument(
318
+ "--num_class_images",
319
+ type=int,
320
+ default=100,
321
+ help=(
322
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
323
+ " sampled with class_prompt."
324
+ ),
325
+ )
326
+ parser.add_argument(
327
+ "--output_dir",
328
+ type=str,
329
+ default="text-inversion-model",
330
+ help="The output directory where the model predictions and checkpoints will be written.",
331
+ )
332
+ parser.add_argument(
333
+ "--seed", type=int, default=None, help="A seed for reproducible training."
334
+ )
335
+ parser.add_argument(
336
+ "--resolution",
337
+ type=int,
338
+ default=512,
339
+ help=(
340
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
341
+ " resolution"
342
+ ),
343
+ )
344
+ parser.add_argument(
345
+ "--center_crop",
346
+ action="store_true",
347
+ help="Whether to center crop images before resizing to resolution",
348
+ )
349
+ parser.add_argument(
350
+ "--train_text_encoder",
351
+ action="store_true",
352
+ help="Whether to train the text encoder",
353
+ )
354
+ parser.add_argument(
355
+ "--train_batch_size",
356
+ type=int,
357
+ default=4,
358
+ help="Batch size (per device) for the training dataloader.",
359
+ )
360
+ parser.add_argument(
361
+ "--sample_batch_size",
362
+ type=int,
363
+ default=4,
364
+ help="Batch size (per device) for sampling images.",
365
+ )
366
+ parser.add_argument("--num_train_epochs", type=int, default=1)
367
+ parser.add_argument(
368
+ "--max_train_steps",
369
+ type=int,
370
+ default=None,
371
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
372
+ )
373
+ parser.add_argument(
374
+ "--save_steps",
375
+ type=int,
376
+ default=500,
377
+ help="Save checkpoint every X updates steps.",
378
+ )
379
+ parser.add_argument(
380
+ "--gradient_accumulation_steps",
381
+ type=int,
382
+ default=1,
383
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
384
+ )
385
+ parser.add_argument(
386
+ "--gradient_checkpointing",
387
+ action="store_true",
388
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
389
+ )
390
+ parser.add_argument(
391
+ "--learning_rate",
392
+ type=float,
393
+ default=5e-6,
394
+ help="Initial learning rate (after the potential warmup period) to use.",
395
+ )
396
+ parser.add_argument(
397
+ "--scale_lr",
398
+ action="store_true",
399
+ default=False,
400
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
401
+ )
402
+ parser.add_argument(
403
+ "--lr_scheduler",
404
+ type=str,
405
+ default="constant",
406
+ help=(
407
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
408
+ ' "constant", "constant_with_warmup"]'
409
+ ),
410
+ )
411
+ parser.add_argument(
412
+ "--lr_warmup_steps",
413
+ type=int,
414
+ default=500,
415
+ help="Number of steps for the warmup in the lr scheduler.",
416
+ )
417
+ parser.add_argument(
418
+ "--use_8bit_adam",
419
+ action="store_true",
420
+ help="Whether or not to use 8-bit Adam from bitsandbytes.",
421
+ )
422
+ parser.add_argument(
423
+ "--adam_beta1",
424
+ type=float,
425
+ default=0.9,
426
+ help="The beta1 parameter for the Adam optimizer.",
427
+ )
428
+ parser.add_argument(
429
+ "--adam_beta2",
430
+ type=float,
431
+ default=0.999,
432
+ help="The beta2 parameter for the Adam optimizer.",
433
+ )
434
+ parser.add_argument(
435
+ "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
436
+ )
437
+ parser.add_argument(
438
+ "--adam_epsilon",
439
+ type=float,
440
+ default=1e-08,
441
+ help="Epsilon value for the Adam optimizer",
442
+ )
443
+ parser.add_argument(
444
+ "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
445
+ )
446
+ parser.add_argument(
447
+ "--push_to_hub",
448
+ action="store_true",
449
+ help="Whether or not to push the model to the Hub.",
450
+ )
451
+ parser.add_argument(
452
+ "--hub_token",
453
+ type=str,
454
+ default=None,
455
+ help="The token to use to push to the Model Hub.",
456
+ )
457
+ parser.add_argument(
458
+ "--hub_model_id",
459
+ type=str,
460
+ default=None,
461
+ help="The name of the repository to keep in sync with the local `output_dir`.",
462
+ )
463
+ parser.add_argument(
464
+ "--logging_dir",
465
+ type=str,
466
+ default="logs",
467
+ help=(
468
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
469
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
470
+ ),
471
+ )
472
+ parser.add_argument(
473
+ "--mixed_precision",
474
+ type=str,
475
+ default=None,
476
+ choices=["no", "fp16", "bf16"],
477
+ help=(
478
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
479
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
480
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
481
+ ),
482
+ )
483
+ parser.add_argument(
484
+ "--local_rank",
485
+ type=int,
486
+ default=-1,
487
+ help="For distributed training: local_rank",
488
+ )
489
+
490
+ if input_args is not None:
491
+ args = parser.parse_args(input_args)
492
+ else:
493
+ args = parser.parse_args()
494
+
495
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
496
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
497
+ args.local_rank = env_local_rank
498
+
499
+ if args.with_prior_preservation:
500
+ if args.class_data_dir is None:
501
+ raise ValueError("You must specify a data directory for class images.")
502
+ if args.class_prompt is None:
503
+ raise ValueError("You must specify prompt for class images.")
504
+ else:
505
+ if args.class_data_dir is not None:
506
+ logger.warning(
507
+ "You need not use --class_data_dir without --with_prior_preservation."
508
+ )
509
+ if args.class_prompt is not None:
510
+ logger.warning(
511
+ "You need not use --class_prompt without --with_prior_preservation."
512
+ )
513
+
514
+ return args
515
+
516
+
517
+ def get_full_repo_name(
518
+ model_id: str, organization: Optional[str] = None, token: Optional[str] = None
519
+ ):
520
+ if token is None:
521
+ token = HfFolder.get_token()
522
+ if organization is None:
523
+ username = whoami(token)["name"]
524
+ return f"{username}/{model_id}"
525
+ else:
526
+ return f"{organization}/{model_id}"
527
+
528
+
529
+ def main(args):
530
+ logging_dir = Path(args.output_dir, args.logging_dir)
531
+
532
+ accelerator = Accelerator(
533
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
534
+ mixed_precision=args.mixed_precision,
535
+ log_with="tensorboard",
536
+ logging_dir=logging_dir,
537
+ )
538
+
539
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
540
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
541
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
542
+ if (
543
+ args.train_text_encoder
544
+ and args.gradient_accumulation_steps > 1
545
+ and accelerator.num_processes > 1
546
+ ):
547
+ raise ValueError(
548
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
549
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
550
+ )
551
+
552
+ if args.seed is not None:
553
+ set_seed(args.seed)
554
+
555
+ if args.with_prior_preservation:
556
+ class_images_dir = Path(args.class_data_dir)
557
+ if not class_images_dir.exists():
558
+ class_images_dir.mkdir(parents=True)
559
+ cur_class_images = len(list(class_images_dir.iterdir()))
560
+
561
+ if cur_class_images < args.num_class_images:
562
+ torch_dtype = (
563
+ torch.float16 if accelerator.device.type == "cuda" else torch.float32
564
+ )
565
+ pipeline = StableDiffusionPipeline.from_pretrained(
566
+ args.pretrained_model_name_or_path,
567
+ torch_dtype=torch_dtype,
568
+ safety_checker=None,
569
+ revision=args.revision,
570
+ )
571
+ pipeline.set_progress_bar_config(disable=True)
572
+
573
+ num_new_images = args.num_class_images - cur_class_images
574
+ logger.info(f"Number of class images to sample: {num_new_images}.")
575
+
576
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
577
+ sample_dataloader = torch.utils.data.DataLoader(
578
+ sample_dataset, batch_size=args.sample_batch_size
579
+ )
580
+
581
+ sample_dataloader = accelerator.prepare(sample_dataloader)
582
+ pipeline.to(accelerator.device)
583
+
584
+ for example in tqdm(
585
+ sample_dataloader,
586
+ desc="Generating class images",
587
+ disable=not accelerator.is_local_main_process,
588
+ ):
589
+ images = pipeline(example["prompt"]).images
590
+
591
+ for i, image in enumerate(images):
592
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
593
+ image_filename = (
594
+ class_images_dir
595
+ / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
596
+ )
597
+ image.save(image_filename)
598
+
599
+ del pipeline
600
+ if torch.cuda.is_available():
601
+ torch.cuda.empty_cache()
602
+
603
+ # Handle the repository creation
604
+ if accelerator.is_main_process:
605
+ if args.push_to_hub:
606
+ if args.hub_model_id is None:
607
+ repo_name = get_full_repo_name(
608
+ Path(args.output_dir).name, token=args.hub_token
609
+ )
610
+ else:
611
+ repo_name = args.hub_model_id
612
+ repo = Repository(args.output_dir, clone_from=repo_name)
613
+
614
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
615
+ if "step_*" not in gitignore:
616
+ gitignore.write("step_*\n")
617
+ if "epoch_*" not in gitignore:
618
+ gitignore.write("epoch_*\n")
619
+ elif args.output_dir is not None:
620
+ os.makedirs(args.output_dir, exist_ok=True)
621
+
622
+ # Load the tokenizer
623
+ if args.tokenizer_name:
624
+ tokenizer = CLIPTokenizer.from_pretrained(
625
+ args.tokenizer_name,
626
+ revision=args.revision,
627
+ )
628
+ elif args.pretrained_model_name_or_path:
629
+ tokenizer = CLIPTokenizer.from_pretrained(
630
+ args.pretrained_model_name_or_path,
631
+ subfolder="tokenizer",
632
+ revision=args.revision,
633
+ )
634
+
635
+ # Load models and create wrapper for stable diffusion
636
+ text_encoder = CLIPTextModel.from_pretrained(
637
+ args.pretrained_model_name_or_path,
638
+ subfolder="text_encoder",
639
+ revision=args.revision,
640
+ )
641
+ vae = AutoencoderKL.from_pretrained(
642
+ args.pretrained_model_name_or_path,
643
+ subfolder="vae",
644
+ revision=args.revision,
645
+ )
646
+ unet = UNet2DConditionModel.from_pretrained(
647
+ args.pretrained_model_name_or_path,
648
+ subfolder="unet",
649
+ revision=args.revision,
650
+ )
651
+ unet.requires_grad_(False)
652
+ unet_lora_params, train_names = inject_trainable_lora(unet)
653
+
654
+ for _up, _down in extract_lora_ups_down(unet):
655
+ print(_up.weight)
656
+ print(_down.weight)
657
+ break
658
+
659
+ vae.requires_grad_(False)
660
+ if not args.train_text_encoder:
661
+ text_encoder.requires_grad_(False)
662
+
663
+ if args.gradient_checkpointing:
664
+ unet.enable_gradient_checkpointing()
665
+ if args.train_text_encoder:
666
+ text_encoder.gradient_checkpointing_enable()
667
+
668
+ if args.scale_lr:
669
+ args.learning_rate = (
670
+ args.learning_rate
671
+ * args.gradient_accumulation_steps
672
+ * args.train_batch_size
673
+ * accelerator.num_processes
674
+ )
675
+
676
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
677
+ if args.use_8bit_adam:
678
+ try:
679
+ import bitsandbytes as bnb
680
+ except ImportError:
681
+ raise ImportError(
682
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
683
+ )
684
+
685
+ optimizer_class = bnb.optim.AdamW8bit
686
+ else:
687
+ optimizer_class = torch.optim.AdamW
688
+
689
+ params_to_optimize = (
690
+ itertools.chain(*unet_lora_params, text_encoder.parameters())
691
+ if args.train_text_encoder
692
+ else itertools.chain(*unet_lora_params)
693
+ )
694
+ optimizer = optimizer_class(
695
+ params_to_optimize,
696
+ lr=args.learning_rate,
697
+ betas=(args.adam_beta1, args.adam_beta2),
698
+ weight_decay=args.adam_weight_decay,
699
+ eps=args.adam_epsilon,
700
+ )
701
+
702
+ noise_scheduler = DDPMScheduler.from_config(
703
+ args.pretrained_model_name_or_path, subfolder="scheduler"
704
+ )
705
+
706
+ train_dataset = DreamBoothDataset(
707
+ instance_data_root=args.instance_data_dir,
708
+ instance_prompt=args.instance_prompt,
709
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
710
+ class_prompt=args.class_prompt,
711
+ tokenizer=tokenizer,
712
+ size=args.resolution,
713
+ center_crop=args.center_crop,
714
+ )
715
+
716
+ def collate_fn(examples):
717
+ input_ids = [example["instance_prompt_ids"] for example in examples]
718
+ pixel_values = [example["instance_images"] for example in examples]
719
+
720
+ # Concat class and instance examples for prior preservation.
721
+ # We do this to avoid doing two forward passes.
722
+ if args.with_prior_preservation:
723
+ input_ids += [example["class_prompt_ids"] for example in examples]
724
+ pixel_values += [example["class_images"] for example in examples]
725
+
726
+ pixel_values = torch.stack(pixel_values)
727
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
728
+
729
+ input_ids = tokenizer.pad(
730
+ {"input_ids": input_ids},
731
+ padding="max_length",
732
+ max_length=tokenizer.model_max_length,
733
+ return_tensors="pt",
734
+ ).input_ids
735
+
736
+ batch = {
737
+ "input_ids": input_ids,
738
+ "pixel_values": pixel_values,
739
+ }
740
+ return batch
741
+
742
+ train_dataloader = torch.utils.data.DataLoader(
743
+ train_dataset,
744
+ batch_size=args.train_batch_size,
745
+ shuffle=True,
746
+ collate_fn=collate_fn,
747
+ num_workers=1,
748
+ )
749
+
750
+ # Scheduler and math around the number of training steps.
751
+ overrode_max_train_steps = False
752
+ num_update_steps_per_epoch = math.ceil(
753
+ len(train_dataloader) / args.gradient_accumulation_steps
754
+ )
755
+ if args.max_train_steps is None:
756
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
757
+ overrode_max_train_steps = True
758
+
759
+ lr_scheduler = get_scheduler(
760
+ args.lr_scheduler,
761
+ optimizer=optimizer,
762
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
763
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
764
+ )
765
+
766
+ if args.train_text_encoder:
767
+ (
768
+ unet,
769
+ text_encoder,
770
+ optimizer,
771
+ train_dataloader,
772
+ lr_scheduler,
773
+ ) = accelerator.prepare(
774
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
775
+ )
776
+ else:
777
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
778
+ unet, optimizer, train_dataloader, lr_scheduler
779
+ )
780
+
781
+ weight_dtype = torch.float32
782
+ if accelerator.mixed_precision == "fp16":
783
+ weight_dtype = torch.float16
784
+ elif accelerator.mixed_precision == "bf16":
785
+ weight_dtype = torch.bfloat16
786
+
787
+ # Move text_encode and vae to gpu.
788
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
789
+ # as these models are only used for inference, keeping weights in full precision is not required.
790
+ vae.to(accelerator.device, dtype=weight_dtype)
791
+ if not args.train_text_encoder:
792
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
793
+
794
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
795
+ num_update_steps_per_epoch = math.ceil(
796
+ len(train_dataloader) / args.gradient_accumulation_steps
797
+ )
798
+ if overrode_max_train_steps:
799
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
800
+ # Afterwards we recalculate our number of training epochs
801
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
802
+
803
+ # We need to initialize the trackers we use, and also store our configuration.
804
+ # The trackers initializes automatically on the main process.
805
+ if accelerator.is_main_process:
806
+ accelerator.init_trackers("dreambooth", config=vars(args))
807
+
808
+ # Train!
809
+ total_batch_size = (
810
+ args.train_batch_size
811
+ * accelerator.num_processes
812
+ * args.gradient_accumulation_steps
813
+ )
814
+
815
+ print("***** Running training *****")
816
+ print(f" Num examples = {len(train_dataset)}")
817
+ print(f" Num batches each epoch = {len(train_dataloader)}")
818
+ print(f" Num Epochs = {args.num_train_epochs}")
819
+ print(f" Instantaneous batch size per device = {args.train_batch_size}")
820
+ print(
821
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
822
+ )
823
+ print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
824
+ print(f" Total optimization steps = {args.max_train_steps}")
825
+ # Only show the progress bar once on each machine.
826
+ progress_bar = tqdm(
827
+ range(args.max_train_steps), disable=not accelerator.is_local_main_process
828
+ )
829
+ progress_bar.set_description("Steps")
830
+ global_step = 0
831
+
832
+ for epoch in range(args.num_train_epochs):
833
+ unet.train()
834
+ if args.train_text_encoder:
835
+ text_encoder.train()
836
+ for step, batch in enumerate(train_dataloader):
837
+
838
+ # Convert images to latent space
839
+ latents = vae.encode(
840
+ batch["pixel_values"].to(dtype=weight_dtype)
841
+ ).latent_dist.sample()
842
+ latents = latents * 0.18215
843
+
844
+ # Sample noise that we'll add to the latents
845
+ noise = torch.randn_like(latents)
846
+ bsz = latents.shape[0]
847
+ # Sample a random timestep for each image
848
+ timesteps = torch.randint(
849
+ 0,
850
+ noise_scheduler.config.num_train_timesteps,
851
+ (bsz,),
852
+ device=latents.device,
853
+ )
854
+ timesteps = timesteps.long()
855
+
856
+ # Add noise to the latents according to the noise magnitude at each timestep
857
+ # (this is the forward diffusion process)
858
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
859
+
860
+ # Get the text embedding for conditioning
861
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
862
+
863
+ # Predict the noise residual
864
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
865
+
866
+ # Get the target for loss depending on the prediction type
867
+ if noise_scheduler.config.prediction_type == "epsilon":
868
+ target = noise
869
+ elif noise_scheduler.config.prediction_type == "v_prediction":
870
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
871
+ else:
872
+ raise ValueError(
873
+ f"Unknown prediction type {noise_scheduler.config.prediction_type}"
874
+ )
875
+
876
+ if args.with_prior_preservation:
877
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
878
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
879
+ target, target_prior = torch.chunk(target, 2, dim=0)
880
+
881
+ # Compute instance loss
882
+ loss = (
883
+ F.mse_loss(model_pred.float(), target.float(), reduction="none")
884
+ .mean([1, 2, 3])
885
+ .mean()
886
+ )
887
+
888
+ # Compute prior loss
889
+ prior_loss = F.mse_loss(
890
+ model_pred_prior.float(), target_prior.float(), reduction="mean"
891
+ )
892
+
893
+ # Add the prior loss to the instance loss.
894
+ loss = loss + args.prior_loss_weight * prior_loss
895
+ else:
896
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
897
+
898
+ accelerator.backward(loss)
899
+ if accelerator.sync_gradients:
900
+ params_to_clip = (
901
+ itertools.chain(unet.parameters(), text_encoder.parameters())
902
+ if args.train_text_encoder
903
+ else unet.parameters()
904
+ )
905
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
906
+ optimizer.step()
907
+ lr_scheduler.step()
908
+ progress_bar.update(1)
909
+ optimizer.zero_grad()
910
+
911
+ # Checks if the accelerator has performed an optimization step behind the scenes
912
+ if accelerator.sync_gradients:
913
+
914
+ global_step += 1
915
+
916
+ if global_step % args.save_steps == 0:
917
+ if accelerator.is_main_process:
918
+ pipeline = StableDiffusionPipeline.from_pretrained(
919
+ args.pretrained_model_name_or_path,
920
+ unet=accelerator.unwrap_model(unet),
921
+ text_encoder=accelerator.unwrap_model(text_encoder),
922
+ revision=args.revision,
923
+ )
924
+
925
+ save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
926
+
927
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
928
+ progress_bar.set_postfix(**logs)
929
+ accelerator.log(logs, step=global_step)
930
+
931
+ if global_step >= args.max_train_steps:
932
+ break
933
+
934
+ accelerator.wait_for_everyone()
935
+
936
+ # Create the pipeline using using the trained modules and save it.
937
+ if accelerator.is_main_process:
938
+ pipeline = StableDiffusionPipeline.from_pretrained(
939
+ args.pretrained_model_name_or_path,
940
+ unet=accelerator.unwrap_model(unet),
941
+ text_encoder=accelerator.unwrap_model(text_encoder),
942
+ revision=args.revision,
943
+ )
944
+
945
+ print("\n\nLora TRAINING DONE!\n\n")
946
+
947
+ save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
948
+
949
+ for _up, _down in extract_lora_ups_down(pipeline.unet):
950
+ print("First Layer's Up Weight is now : ", _up.weight)
951
+ print("First Layer's Down Weight is now : ", _down.weight)
952
+ break
953
+
954
+ if args.push_to_hub:
955
+ repo.push_to_hub(
956
+ commit_message="End of training", blocking=False, auto_lfs_prune=True
957
+ )
958
+
959
+ accelerator.end_training()
960
+
961
+
962
+ if __name__ == "__main__":
963
+ args = parse_args()
964
+ main(args)