Spaces:
Build error
Build error
Nupur Kumari
commited on
Commit
·
dbc579c
1
Parent(s):
f4d0eb6
update
Browse files- app.py +86 -25
- inference.py +1 -0
- trainer.py +54 -24
app.py
CHANGED
|
@@ -25,8 +25,7 @@ It is recommended to upgrade to GPU in Settings after duplicating this space to
|
|
| 25 |
DETAILDESCRIPTION='''
|
| 26 |
Custom Diffusion allows you to fine-tune text-to-image diffusion models, such as Stable Diffusion, given a few images of a new concept (~4-20).
|
| 27 |
We fine-tune only a subset of model parameters, namely key and value projection matrices, in the cross-attention layers and the modifier token used to represent the object.
|
| 28 |
-
This also reduces the extra storage for each additional concept to 75MB.
|
| 29 |
-
Our method further allows you to use a combination of concepts. Demo for multiple concepts will be added soon.
|
| 30 |
<center>
|
| 31 |
<img src="https://huggingface.co/spaces/nupurkmr9/custom-diffusion/resolve/main/method.jpg" width="600" align="center" >
|
| 32 |
</center>
|
|
@@ -81,27 +80,82 @@ def create_training_demo(trainer: Trainer,
|
|
| 81 |
|
| 82 |
with gr.Row():
|
| 83 |
with gr.Box():
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
gr.Markdown('''
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
with gr.Box():
|
| 104 |
gr.Markdown('Training Parameters')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
num_training_steps = gr.Number(
|
| 106 |
label='Number of Training Steps', value=1000, precision=0)
|
| 107 |
learning_rate = gr.Number(label='Learning Rate', value=0.00001)
|
|
@@ -115,6 +169,10 @@ def create_training_demo(trainer: Trainer,
|
|
| 115 |
label='Number of Gradient Accumulation',
|
| 116 |
value=1,
|
| 117 |
precision=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
gen_images = gr.Checkbox(label='Generated images as regularization',
|
| 119 |
value=False)
|
| 120 |
gr.Markdown('''
|
|
@@ -122,6 +180,7 @@ def create_training_demo(trainer: Trainer,
|
|
| 122 |
- Our results in the paper are trained with batch-size 4 (8 including class regularization samples).
|
| 123 |
- Enable gradient checkpointing for lower memory requirements (~14GB) at the expense of slower backward pass.
|
| 124 |
- Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
|
|
|
|
| 125 |
''')
|
| 126 |
|
| 127 |
run_button = gr.Button('Start Training')
|
|
@@ -141,9 +200,6 @@ def create_training_demo(trainer: Trainer,
|
|
| 141 |
inputs=[
|
| 142 |
base_model,
|
| 143 |
resolution,
|
| 144 |
-
concept_images,
|
| 145 |
-
concept_prompt,
|
| 146 |
-
class_prompt,
|
| 147 |
num_training_steps,
|
| 148 |
learning_rate,
|
| 149 |
train_text_encoder,
|
|
@@ -152,8 +208,13 @@ def create_training_demo(trainer: Trainer,
|
|
| 152 |
batch_size,
|
| 153 |
use_8bit_adam,
|
| 154 |
gradient_checkpointing,
|
| 155 |
-
gen_images
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
outputs=[
|
| 158 |
training_status,
|
| 159 |
output_files,
|
|
|
|
| 25 |
DETAILDESCRIPTION='''
|
| 26 |
Custom Diffusion allows you to fine-tune text-to-image diffusion models, such as Stable Diffusion, given a few images of a new concept (~4-20).
|
| 27 |
We fine-tune only a subset of model parameters, namely key and value projection matrices, in the cross-attention layers and the modifier token used to represent the object.
|
| 28 |
+
This also reduces the extra storage for each additional concept to 75MB. Our method also allows you to use a combination of concepts. There's still limitations on which compositions work. For more analysis please refer to our [website](https://www.cs.cmu.edu/~custom-diffusion/).
|
|
|
|
| 29 |
<center>
|
| 30 |
<img src="https://huggingface.co/spaces/nupurkmr9/custom-diffusion/resolve/main/method.jpg" width="600" align="center" >
|
| 31 |
</center>
|
|
|
|
| 80 |
|
| 81 |
with gr.Row():
|
| 82 |
with gr.Box():
|
| 83 |
+
concept_images_collection = []
|
| 84 |
+
concept_prompt_collection = []
|
| 85 |
+
class_prompt_collection = []
|
| 86 |
+
buttons_collection = []
|
| 87 |
+
delete_collection = []
|
| 88 |
+
is_visible = []
|
| 89 |
+
maximum_concepts = 3
|
| 90 |
+
row = [None] * maximum_concepts
|
| 91 |
+
for x in range(maximum_concepts):
|
| 92 |
+
ordinal = lambda n: "%d%s" % (n, "tsnrhtdd"[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4])
|
| 93 |
+
ordinal_concept = ["<new1> cat", "<new2> wooden pot", "<new3> chair"]
|
| 94 |
+
if(x == 0):
|
| 95 |
+
visible = True
|
| 96 |
+
is_visible.append(gr.State(value=True))
|
| 97 |
+
else:
|
| 98 |
+
visible = False
|
| 99 |
+
is_visible.append(gr.State(value=False))
|
| 100 |
+
|
| 101 |
+
concept_images_collection.append(gr.Files(label=f'''Upload the images for your {ordinal(x+1) if (x>0) else ""} concept''', visible=visible))
|
| 102 |
+
with gr.Column(visible=visible) as row[x]:
|
| 103 |
+
concept_prompt_collection.append(
|
| 104 |
+
gr.Textbox(label=f'''{ordinal(x+1) if (x>0) else ""} concept prompt ''', max_lines=1,
|
| 105 |
+
placeholder=f'''Example: "photo of a {ordinal_concept[x]}"''' )
|
| 106 |
+
)
|
| 107 |
+
class_prompt_collection.append(
|
| 108 |
+
gr.Textbox(label=f'''{ordinal(x+1) if (x>0) else ""} class prompt ''',
|
| 109 |
+
max_lines=1, placeholder=f'''Example: "{ordinal_concept[x][7:]}"''')
|
| 110 |
+
)
|
| 111 |
+
with gr.Row():
|
| 112 |
+
if(x < maximum_concepts-1):
|
| 113 |
+
buttons_collection.append(gr.Button(value=f"Add {ordinal(x+2)} concept", visible=visible))
|
| 114 |
+
if(x > 0):
|
| 115 |
+
delete_collection.append(gr.Button(value=f"Delete {ordinal(x+1)} concept"))
|
| 116 |
+
|
| 117 |
+
counter_add = 1
|
| 118 |
+
for button in buttons_collection:
|
| 119 |
+
if(counter_add < len(buttons_collection)):
|
| 120 |
+
button.click(lambda:
|
| 121 |
+
[gr.update(visible=True),gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), True, None],
|
| 122 |
+
None,
|
| 123 |
+
[row[counter_add], concept_images_collection[counter_add], buttons_collection[counter_add-1], buttons_collection[counter_add], is_visible[counter_add], concept_images_collection[counter_add]], queue=False)
|
| 124 |
+
else:
|
| 125 |
+
button.click(lambda:
|
| 126 |
+
[gr.update(visible=True),gr.update(visible=True), gr.update(visible=False), True],
|
| 127 |
+
None,
|
| 128 |
+
[row[counter_add], concept_images_collection[counter_add], buttons_collection[counter_add-1], is_visible[counter_add]], queue=False)
|
| 129 |
+
counter_add += 1
|
| 130 |
+
|
| 131 |
+
counter_delete = 1
|
| 132 |
+
for delete_button in delete_collection:
|
| 133 |
+
if(counter_delete < len(delete_collection)+1):
|
| 134 |
+
if counter_delete == 1:
|
| 135 |
+
delete_button.click(lambda:
|
| 136 |
+
[gr.update(visible=False, value=None),gr.update(visible=False), gr.update(visible=True), gr.update(visible=False),False],
|
| 137 |
+
None,
|
| 138 |
+
[concept_images_collection[counter_delete], row[counter_delete], buttons_collection[counter_delete-1], buttons_collection[counter_delete], is_visible[counter_delete]], queue=False)
|
| 139 |
+
else:
|
| 140 |
+
delete_button.click(lambda:
|
| 141 |
+
[gr.update(visible=False, value=None),gr.update(visible=False), gr.update(visible=True), False],
|
| 142 |
+
None,
|
| 143 |
+
[concept_images_collection[counter_delete], row[counter_delete], buttons_collection[counter_delete-1], is_visible[counter_delete]], queue=False)
|
| 144 |
+
counter_delete += 1
|
| 145 |
gr.Markdown('''
|
| 146 |
+
- We use "\<new1\>" modifier_token in front of the concept, e.g., "\<new1\> cat". For multiple concepts use "\<new2\>", "\<new3\>" etc. Increase the number of steps with more concepts.
|
| 147 |
+
- For a new concept an e.g. concept prompt is "photo of a \<new1\> cat" and "cat" for class prompt.
|
| 148 |
+
- For a style concept, use "painting in the style of \<new1\> art" for concept prompt and "art" for class prompt.
|
| 149 |
+
- Class prompt should be the object category.
|
| 150 |
+
- If "Train Text Encoder", disable "modifier token" and use any unique text to describe the concept e.g. "ktn cat".
|
| 151 |
+
''')
|
| 152 |
with gr.Box():
|
| 153 |
gr.Markdown('Training Parameters')
|
| 154 |
+
with gr.Row():
|
| 155 |
+
modifier_token = gr.Checkbox(label='modifier token',
|
| 156 |
+
value=True)
|
| 157 |
+
train_text_encoder = gr.Checkbox(label='Train Text Encoder',
|
| 158 |
+
value=False)
|
| 159 |
num_training_steps = gr.Number(
|
| 160 |
label='Number of Training Steps', value=1000, precision=0)
|
| 161 |
learning_rate = gr.Number(label='Learning Rate', value=0.00001)
|
|
|
|
| 169 |
label='Number of Gradient Accumulation',
|
| 170 |
value=1,
|
| 171 |
precision=0)
|
| 172 |
+
num_reg_images = gr.Number(
|
| 173 |
+
label='Number of Class Concept images',
|
| 174 |
+
value=200,
|
| 175 |
+
precision=0)
|
| 176 |
gen_images = gr.Checkbox(label='Generated images as regularization',
|
| 177 |
value=False)
|
| 178 |
gr.Markdown('''
|
|
|
|
| 180 |
- Our results in the paper are trained with batch-size 4 (8 including class regularization samples).
|
| 181 |
- Enable gradient checkpointing for lower memory requirements (~14GB) at the expense of slower backward pass.
|
| 182 |
- Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
|
| 183 |
+
- We retrieve real images for class concept using clip_retireval library which can take some time.
|
| 184 |
''')
|
| 185 |
|
| 186 |
run_button = gr.Button('Start Training')
|
|
|
|
| 200 |
inputs=[
|
| 201 |
base_model,
|
| 202 |
resolution,
|
|
|
|
|
|
|
|
|
|
| 203 |
num_training_steps,
|
| 204 |
learning_rate,
|
| 205 |
train_text_encoder,
|
|
|
|
| 208 |
batch_size,
|
| 209 |
use_8bit_adam,
|
| 210 |
gradient_checkpointing,
|
| 211 |
+
gen_images,
|
| 212 |
+
num_reg_images,
|
| 213 |
+
] +
|
| 214 |
+
concept_images_collection +
|
| 215 |
+
concept_prompt_collection +
|
| 216 |
+
class_prompt_collection
|
| 217 |
+
,
|
| 218 |
outputs=[
|
| 219 |
training_status,
|
| 220 |
output_files,
|
inference.py
CHANGED
|
@@ -75,6 +75,7 @@ class InferencePipeline:
|
|
| 75 |
height=resolution, width=resolution,
|
| 76 |
eta = eta,
|
| 77 |
generator=generator) # type: ignore
|
|
|
|
| 78 |
out = out.images
|
| 79 |
out = PIL.Image.fromarray(np.hstack([np.array(x) for x in out]))
|
| 80 |
return out
|
|
|
|
| 75 |
height=resolution, width=resolution,
|
| 76 |
eta = eta,
|
| 77 |
generator=generator) # type: ignore
|
| 78 |
+
torch.cuda.empty_cache()
|
| 79 |
out = out.images
|
| 80 |
out = PIL.Image.fromarray(np.hstack([np.array(x) for x in out]))
|
| 81 |
return out
|
trainer.py
CHANGED
|
@@ -9,6 +9,7 @@ import subprocess
|
|
| 9 |
import gradio as gr
|
| 10 |
import PIL.Image
|
| 11 |
import torch
|
|
|
|
| 12 |
|
| 13 |
os.environ['PYTHONPATH'] = f'custom-diffusion:{os.getenv("PYTHONPATH", "")}'
|
| 14 |
|
|
@@ -45,23 +46,41 @@ class Trainer:
|
|
| 45 |
def cleanup_dirs(self) -> None:
|
| 46 |
shutil.rmtree(self.output_dir, ignore_errors=True)
|
| 47 |
|
| 48 |
-
def prepare_dataset(self,
|
| 49 |
self.instance_data_dir.mkdir(parents=True)
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
def run(
|
| 59 |
self,
|
| 60 |
base_model: str,
|
| 61 |
resolution_s: str,
|
| 62 |
-
concept_images: list | None,
|
| 63 |
-
concept_prompt: str,
|
| 64 |
-
class_prompt: str,
|
| 65 |
n_steps: int,
|
| 66 |
learning_rate: float,
|
| 67 |
train_text_encoder: bool,
|
|
@@ -71,32 +90,40 @@ class Trainer:
|
|
| 71 |
use_8bit_adam: bool,
|
| 72 |
gradient_checkpointing: bool,
|
| 73 |
gen_images: bool,
|
|
|
|
|
|
|
| 74 |
) -> tuple[dict, list[pathlib.Path]]:
|
| 75 |
if not torch.cuda.is_available():
|
| 76 |
raise gr.Error('CUDA is not available.')
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
if self.is_running:
|
| 79 |
return gr.update(value=self.is_running_message), []
|
| 80 |
|
| 81 |
-
if
|
| 82 |
raise gr.Error('You need to upload images.')
|
| 83 |
-
if not
|
| 84 |
raise gr.Error('The concept prompt is missing.')
|
| 85 |
|
| 86 |
resolution = int(resolution_s)
|
| 87 |
|
| 88 |
self.cleanup_dirs()
|
| 89 |
-
self.prepare_dataset(
|
| 90 |
-
|
| 91 |
command = f'''
|
| 92 |
accelerate launch custom-diffusion/src/diffuser_training.py \
|
| 93 |
--pretrained_model_name_or_path={base_model} \
|
| 94 |
-
--instance_data_dir={self.instance_data_dir} \
|
| 95 |
--output_dir={self.output_dir} \
|
| 96 |
-
--
|
| 97 |
-
--class_data_dir={self.class_data_dir} \
|
| 98 |
--with_prior_preservation --prior_loss_weight=1.0 \
|
| 99 |
-
--class_prompt="{class_prompt}" \
|
| 100 |
--resolution={resolution} \
|
| 101 |
--train_batch_size={batch_size} \
|
| 102 |
--gradient_accumulation_steps={gradient_accumulation} \
|
|
@@ -104,11 +131,14 @@ class Trainer:
|
|
| 104 |
--lr_scheduler="constant" \
|
| 105 |
--lr_warmup_steps=0 \
|
| 106 |
--max_train_steps={n_steps} \
|
| 107 |
-
--num_class_images=
|
| 108 |
-
--
|
|
|
|
| 109 |
'''
|
| 110 |
if modifier_token:
|
| 111 |
-
|
|
|
|
|
|
|
| 112 |
if not gen_images:
|
| 113 |
command += ' --real_prior'
|
| 114 |
if use_8bit_adam:
|
|
@@ -117,7 +147,7 @@ class Trainer:
|
|
| 117 |
command += f' --train_text_encoder'
|
| 118 |
if gradient_checkpointing:
|
| 119 |
command += f' --gradient_checkpointing'
|
| 120 |
-
|
| 121 |
with open(self.output_dir / 'train.sh', 'w') as f:
|
| 122 |
command_s = ' '.join(command.split())
|
| 123 |
f.write(command_s)
|
|
|
|
| 9 |
import gradio as gr
|
| 10 |
import PIL.Image
|
| 11 |
import torch
|
| 12 |
+
import json
|
| 13 |
|
| 14 |
os.environ['PYTHONPATH'] = f'custom-diffusion:{os.getenv("PYTHONPATH", "")}'
|
| 15 |
|
|
|
|
| 46 |
def cleanup_dirs(self) -> None:
|
| 47 |
shutil.rmtree(self.output_dir, ignore_errors=True)
|
| 48 |
|
| 49 |
+
def prepare_dataset(self, concept_images_collection: list, concept_prompt_collection: list, class_prompt_collection: list, resolution: int) -> None:
|
| 50 |
self.instance_data_dir.mkdir(parents=True)
|
| 51 |
+
concepts_list = []
|
| 52 |
+
|
| 53 |
+
for i in range(len(concept_images_collection)):
|
| 54 |
+
concept_dir = self.instance_data_dir / f'{i}'
|
| 55 |
+
class_dir = self.class_data_dir / f'{i}'
|
| 56 |
+
concept_dir.mkdir(parents=True)
|
| 57 |
+
concept_images = concept_images_collection[i]
|
| 58 |
+
|
| 59 |
+
concepts_list.append(
|
| 60 |
+
{
|
| 61 |
+
"instance_prompt": concept_prompt_collection[i],
|
| 62 |
+
"class_prompt": class_prompt_collection[i],
|
| 63 |
+
"instance_data_dir": f'{concept_dir}',
|
| 64 |
+
"class_data_dir": f'{class_dir}'
|
| 65 |
+
}
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
for i, temp_path in enumerate(concept_images):
|
| 69 |
+
image = PIL.Image.open(temp_path.name)
|
| 70 |
+
image = pad_image(image)
|
| 71 |
+
# image = image.resize((resolution, resolution))
|
| 72 |
+
image = image.convert('RGB')
|
| 73 |
+
out_path = concept_dir / f'{i:03d}.jpg'
|
| 74 |
+
image.save(out_path, format='JPEG', quality=100)
|
| 75 |
+
|
| 76 |
+
print(concepts_list)
|
| 77 |
+
json.dump(concepts_list, open( f'{self.output_dir}/temp.json' , 'w') )
|
| 78 |
+
|
| 79 |
+
|
| 80 |
def run(
|
| 81 |
self,
|
| 82 |
base_model: str,
|
| 83 |
resolution_s: str,
|
|
|
|
|
|
|
|
|
|
| 84 |
n_steps: int,
|
| 85 |
learning_rate: float,
|
| 86 |
train_text_encoder: bool,
|
|
|
|
| 90 |
use_8bit_adam: bool,
|
| 91 |
gradient_checkpointing: bool,
|
| 92 |
gen_images: bool,
|
| 93 |
+
num_reg_images: int,
|
| 94 |
+
*inputs,
|
| 95 |
) -> tuple[dict, list[pathlib.Path]]:
|
| 96 |
if not torch.cuda.is_available():
|
| 97 |
raise gr.Error('CUDA is not available.')
|
| 98 |
|
| 99 |
+
num_concept = 0
|
| 100 |
+
for i in range(len(inputs) // 3):
|
| 101 |
+
if inputs[i] != None:
|
| 102 |
+
num_concept +=1
|
| 103 |
+
|
| 104 |
+
print(num_concept, inputs)
|
| 105 |
+
concept_images_collection = inputs[: num_concept]
|
| 106 |
+
concept_prompt_collection = inputs[3: 3 + num_concept]
|
| 107 |
+
class_prompt_collection = inputs[6: 6+num_concept]
|
| 108 |
if self.is_running:
|
| 109 |
return gr.update(value=self.is_running_message), []
|
| 110 |
|
| 111 |
+
if concept_images_collection is None:
|
| 112 |
raise gr.Error('You need to upload images.')
|
| 113 |
+
if not concept_prompt_collection:
|
| 114 |
raise gr.Error('The concept prompt is missing.')
|
| 115 |
|
| 116 |
resolution = int(resolution_s)
|
| 117 |
|
| 118 |
self.cleanup_dirs()
|
| 119 |
+
self.prepare_dataset(concept_images_collection, concept_prompt_collection, class_prompt_collection, resolution)
|
| 120 |
+
torch.cuda.empty_cache()
|
| 121 |
command = f'''
|
| 122 |
accelerate launch custom-diffusion/src/diffuser_training.py \
|
| 123 |
--pretrained_model_name_or_path={base_model} \
|
|
|
|
| 124 |
--output_dir={self.output_dir} \
|
| 125 |
+
--concepts_list={f'{self.output_dir}/temp.json'} \
|
|
|
|
| 126 |
--with_prior_preservation --prior_loss_weight=1.0 \
|
|
|
|
| 127 |
--resolution={resolution} \
|
| 128 |
--train_batch_size={batch_size} \
|
| 129 |
--gradient_accumulation_steps={gradient_accumulation} \
|
|
|
|
| 131 |
--lr_scheduler="constant" \
|
| 132 |
--lr_warmup_steps=0 \
|
| 133 |
--max_train_steps={n_steps} \
|
| 134 |
+
--num_class_images={num_reg_images} \
|
| 135 |
+
--initializer_token="ktn+pll+ucd" \
|
| 136 |
+
--scale_lr --hflip
|
| 137 |
'''
|
| 138 |
if modifier_token:
|
| 139 |
+
tokens = '+'.join([f'<new{i+1}>' for i in range(num_concept)])
|
| 140 |
+
command += f' --modifier_token {tokens}'
|
| 141 |
+
|
| 142 |
if not gen_images:
|
| 143 |
command += ' --real_prior'
|
| 144 |
if use_8bit_adam:
|
|
|
|
| 147 |
command += f' --train_text_encoder'
|
| 148 |
if gradient_checkpointing:
|
| 149 |
command += f' --gradient_checkpointing'
|
| 150 |
+
|
| 151 |
with open(self.output_dir / 'train.sh', 'w') as f:
|
| 152 |
command_s = ' '.join(command.split())
|
| 153 |
f.write(command_s)
|