afeng commited on
Commit
d807efd
·
0 Parent(s):
.gitignore ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ example1_512/
2
+ example1_1024/
3
+ example1_example2_512/
4
+ example1_example2_1024/
5
+ example1/
6
+ old/
7
+
8
+ out_active.png
9
+ out_mask.png
10
+ out_soft.png
11
+
12
+ # Byte-compiled / optimized / DLL files
13
+ __pycache__/
14
+ *.py[cod]
15
+ *$py.class
16
+
17
+ # C extensions
18
+ *.so
19
+
20
+ # Distribution / packaging
21
+ .Python
22
+ build/
23
+ develop-eggs/
24
+ dist/
25
+ downloads/
26
+ eggs/
27
+ .eggs/
28
+ lib/
29
+ lib64/
30
+ parts/
31
+ sdist/
32
+ var/
33
+ wheels/
34
+ share/python-wheels/
35
+ *.egg-info/
36
+ .installed.cfg
37
+ *.egg
38
+ MANIFEST
39
+
40
+ # PyInstaller
41
+ # Usually these files are written by a python script from a template
42
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
43
+ *.manifest
44
+ *.spec
45
+
46
+ # Installer logs
47
+ pip-log.txt
48
+ pip-delete-this-directory.txt
49
+
50
+ # Unit test / coverage reports
51
+ htmlcov/
52
+ .tox/
53
+ .nox/
54
+ .coverage
55
+ .coverage.*
56
+ .cache
57
+ nosetests.xml
58
+ coverage.xml
59
+ *.cover
60
+ *.py,cover
61
+ .hypothesis/
62
+ .pytest_cache/
63
+ cover/
64
+
65
+ # Translations
66
+ *.mo
67
+ *.pot
68
+
69
+ # Django stuff:
70
+ *.log
71
+ local_settings.py
72
+ db.sqlite3
73
+ db.sqlite3-journal
74
+
75
+ # Flask stuff:
76
+ instance/
77
+ .webassets-cache
78
+
79
+ # Scrapy stuff:
80
+ .scrapy
81
+
82
+ # Sphinx documentation
83
+ docs/_build/
84
+
85
+ # PyBuilder
86
+ .pybuilder/
87
+ target/
88
+
89
+ # Jupyter Notebook
90
+ .ipynb_checkpoints
91
+
92
+ # IPython
93
+ profile_default/
94
+ ipython_config.py
95
+
96
+ # pyenv
97
+ # For a library or package, you might want to ignore these files since the code is
98
+ # intended to run in multiple environments; otherwise, check them in:
99
+ # .python-version
100
+
101
+ # pipenv
102
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
104
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
105
+ # install all needed dependencies.
106
+ #Pipfile.lock
107
+
108
+ # poetry
109
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
110
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
111
+ # commonly ignored for libraries.
112
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
113
+ #poetry.lock
114
+
115
+ # pdm
116
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
117
+ #pdm.lock
118
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
119
+ # in version control.
120
+ # https://pdm.fming.dev/#use-with-ide
121
+ .pdm.toml
122
+
123
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
124
+ __pypackages__/
125
+
126
+ # Celery stuff
127
+ celerybeat-schedule
128
+ celerybeat.pid
129
+
130
+ # SageMath parsed files
131
+ *.sage.py
132
+
133
+ # Environments
134
+ .env
135
+ .venv
136
+ env/
137
+ venv/
138
+ ENV/
139
+ env.bak/
140
+ venv.bak/
141
+
142
+ # Spyder project settings
143
+ .spyderproject
144
+ .spyproject
145
+
146
+ # Rope project settings
147
+ .ropeproject
148
+
149
+ # mkdocs documentation
150
+ /site
151
+
152
+ # mypy
153
+ .mypy_cache/
154
+ .dmypy.json
155
+ dmypy.json
156
+
157
+ # Pyre type checker
158
+ .pyre/
159
+
160
+ # pytype static type analyzer
161
+ .pytype/
162
+
163
+ # Cython debug symbols
164
+ cython_debug/
165
+
166
+ # PyCharm
167
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
168
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
169
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
170
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
171
+ #.idea/
README.md ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>An item is Worth a Prompt: Versatile Image Editing with Disentangled Control</h1>
3
+
4
+
5
+
6
+ <a href='https://arxiv.org/abs/2403.04880'><img src='https://img.shields.io/badge/Technique-Report-red'></a>
7
+
8
+
9
+ </div>
10
+ D-Edit is a versatile image editing framework based on diffusion models, supporting text, image, mask-based editing.
11
+
12
+ <!-- <img src='assets/applications.png'> -->
13
+ ## Release
14
+ - [2024/03/12] 🔥 Code uploaded.
15
+
16
+
17
+
18
+
19
+ ## 🔥 Examples
20
+
21
+ <p align="center">
22
+ <img alt="text" src="assets/demo1.gif" width="45%">
23
+ &nbsp; &nbsp; &nbsp; &nbsp;
24
+ <img alt="image" src="assets/demo2.gif" width="45%">
25
+ </p>
26
+
27
+ 1. **Text-Guided Editing**:Allows users to select an object within an image and replace or refine it based on a text description.
28
+ - Key features:
29
+ - Generates more realistic details and smoother transitions than alternative methods
30
+ - Focuses edits specifically on the targeted object
31
+ - Preserves unrelated parts of the image
32
+
33
+ 2. **Image-Guided Editing**: Enables users to choose an object from a reference image and transplant it into another image while preserving its identity.
34
+ - Key features:
35
+ - Ensures seamless integration of the object into the new context
36
+ - Adapts the object's appearance to match the target image's style
37
+ - Works effectively even when the object's appearance differs significantly between reference and target images
38
+
39
+
40
+
41
+ <p align="center">
42
+ <img alt="mask" src="assets/demo3.gif" width="45%">
43
+ &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;
44
+ <img alt="remove" src="assets/demo4.gif" width="45%">
45
+ </p>
46
+
47
+
48
+
49
+ 3. **Mask-Based Editing**: Involves manipulating objects by directly editing their masks.
50
+ - Key features:
51
+ - Allows for operations like moving, reshaping, resizing, and refining objects
52
+ - Fills in new details according to the object's associated prompt
53
+ - Produces natural-looking results that maintain consistency with the overall image
54
+
55
+ 4. **Item Removal**: Enables users to remove objects from images by deleting the mask-object associations.
56
+ - Key features:
57
+ - Intelligently fills in the empty space left by removed objects
58
+ - Ensures a coherent final image
59
+ - Maintains the integrity of the surrounding image elements
60
+
61
+ ## 🔧 Dependencies and Installation
62
+ - Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
63
+ - [PyTorch >= 2.1.0](https://pytorch.org/)
64
+ ```bash
65
+ conda create --name dedit python=3.10
66
+ conda activate dedit
67
+ pip install -U pip
68
+
69
+ # Install requirements
70
+ pip install -r requirements.txt
71
+ ```
72
+
73
+
74
+ ## 💻 Run
75
+
76
+ ### 1. Segmentation
77
+ Put the image (of any resolution) to be edited into the folder with a specified name, and rename the image as "img.png" or "img.jpg".
78
+ Then run the segmentation model
79
+ ```
80
+ sh ./scripts/run_segment.sh
81
+ ```
82
+ Alternatively, run [GroundedSAM](https://github.com/IDEA-Research/Grounded-Segment-Anything) to detect with text prompt
83
+ ```
84
+ sh ./scripts/run_segmentSAM.sh
85
+ ```
86
+
87
+ Optionally, if segmentation is not good, refine masks with GUI by locally running the mask editing web:
88
+ ```
89
+ python ui_edit_mask.py
90
+ ```
91
+ For image-based editing, repeat this step for both reference and target images.
92
+
93
+ ### 2. Model Finetuning
94
+ Finetune UNet cross-attention layer of diffusion models by running
95
+ ```
96
+ sh ./scripts/sdxl/run_ft_sdxl_1024.sh
97
+ ```
98
+ or finetune full UNet with lora
99
+ ```
100
+ sh ./scripts/sdxl/run_ft_sdxl_1024_fulllora.sh
101
+ ```
102
+ If image-based editing is needed, finetune the model with both reference and target images using
103
+
104
+ ```
105
+ sh ./scripts/sdxl/run_ft_sdxl_1024_fulllora_2imgs.sh
106
+ ```
107
+
108
+ ### 3. Edit \!
109
+ #### 3.1 Reconstruction
110
+ To see if the original image can be constructed
111
+ ```
112
+ sh ./scripts/sdxl/run_recon.sh
113
+ ```
114
+ #### 3.1 Text-based
115
+ Replace the target item (tgt_index) with the item described by the text prompt (tgt_prompt)
116
+ ```
117
+ sh ./scripts/sdxl/run_text.sh
118
+ ```
119
+ #### 3.2 Image-based
120
+ Replace the target item (tgt_index) in the target image (tgt_name) with the item (src_index) in the reference image
121
+ ```
122
+ sh ./scripts/sdxl/run_image.sh
123
+ ```
124
+ #### 3.3 Mask-based
125
+ For target items (tgt_indices_list), resize it (resize_list), move it (delta_x, delta_y) or reshape it by manually editing the mask shape (using UI).
126
+
127
+ The resulting new masks (processed by a simple algorithm) can be visualized in './example1/move_resize/seg_move_resize.png', if it is not reasonable, edit using the UI.
128
+
129
+ ```
130
+ sh ./scripts/sdxl/run_move_resize.sh
131
+ ```
132
+ #### 3.4 Remove
133
+ Remove the target item (tgt_index), the remaining region will be reassigned to the nearby regions with a simple algorithm.
134
+ The resulting new masks (processed by a simple algorithm) can be visualized in './example1/remove/seg_removed.png', if it is not reasonable, edit using the UI.
135
+
136
+ ```
137
+ sh ./scripts/sdxl/run_move_resize.sh
138
+ ```
139
+
140
+ #### 3.4 General editing parameters
141
+ - We partition the image into three regions as shown below. Regions with the hard mask are frozen, regions with the active mask are generated with diffusion model, and regions with soft mask keep the original content in the first "strength*N" sampling steps.
142
+ <p align="center">
143
+ <img src="assets/mask_def.png" height=200>
144
+ </p>
145
+
146
+ - During editing, if you use an edited segmentation that is different from finetuning, add --load_edited_mask; For mask-based and remove, if you edit the masks automatically processed by the algorithm as mentioned, add --load_edited_processed_mask.
147
+
148
+ ### Cite
149
+ If you find D-Edit useful for your research and applications, please cite us using this BibTeX:
150
+
151
+ ```bibtex
152
+ @article{feng2024dedit,
153
+ title={An item is Worth a Prompt: Versatile Image Editing with Disentangled Control},
154
+ author={Aosong Feng, Weikang Qiu, Jinbin Bai, Kaicheng Zhou, Zhen Dong, Xiao Zhang, Rex Ying, and Leandros Tassiulas},
155
+ journal={arXiv preprint arXiv:2403.04880},
156
+ year={2024}
157
+ }
158
+ ```
app.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import copy
4
+ from PIL import Image
5
+ import matplotlib
6
+ import numpy as np
7
+ import gradio as gr
8
+ from utils import load_mask, load_mask_edit
9
+ from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean
10
+ from pathlib import Path
11
+ import subprocess
12
+ from PIL import Image
13
+
14
+ LENGTH=512 #length of the square area displaying/editing images
15
+ TRANSPARENCY = 150 # transparency of the mask in display
16
+
17
+ def add_mask(mask_np_list_updated, mask_label_list):
18
+ mask_new = np.zeros_like(mask_np_list_updated[0])
19
+ mask_np_list_updated.append(mask_new)
20
+ mask_label_list.append("new")
21
+ return mask_np_list_updated, mask_label_list
22
+
23
+ def create_segmentation(mask_np_list):
24
+ viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list))
25
+ segmentation = 0
26
+ for i, m in enumerate(mask_np_list):
27
+ color = matplotlib.colors.to_rgb(viridis(i))
28
+ color_mat = np.ones_like(m)
29
+ color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2)
30
+ color_mat = color_mat * m[:,:,np.newaxis]
31
+ segmentation += color_mat
32
+ segmentation = Image.fromarray(np.uint8(segmentation*255))
33
+ return segmentation
34
+
35
+ def load_mask_ui(input_folder,load_edit = False):
36
+ if not load_edit:
37
+ mask_list, mask_label_list = load_mask(input_folder)
38
+ else:
39
+ mask_list, mask_label_list = load_mask_edit(input_folder)
40
+
41
+ mask_np_list = []
42
+ for m in mask_list:
43
+ mask_np_list. append( m.cpu().numpy())
44
+
45
+ return mask_np_list, mask_label_list
46
+
47
+ def load_image_ui(input_folder, load_edit):
48
+ try:
49
+ for img_path in Path(input_folder).iterdir():
50
+ if img_path.name in ["img.png", "img_1024.png", "img_512.png"]:
51
+ image = Image.open(img_path)
52
+ mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit)
53
+ image = image.convert('RGB')
54
+ segmentation = create_segmentation(mask_np_list)
55
+ return image, segmentation, mask_np_list, mask_label_list, image
56
+ except:
57
+ print("Image folder invalid: The folder should contain image.png")
58
+ return None, None, None, None, None
59
+
60
+ def run_segmentation(input_folder):
61
+ subprocess.run(["python", "segment.py" , "--name={}".format(input_folder)])
62
+ return
63
+
64
+
65
+
66
+ def run_edit_text(
67
+ input_folder,
68
+ num_tokens,
69
+ num_sampling_steps,
70
+ strength,
71
+ edge_thickness,
72
+ tgt_prompt,
73
+ tgt_idx,
74
+ guidance_scale
75
+ ):
76
+ subprocess.run(["python",
77
+ "main.py" ,
78
+ "--text",
79
+ "--name={}".format(input_folder),
80
+ "--dpm={}".format("sd"),
81
+ "--resolution={}".format(512),
82
+ "--load_trained",
83
+ "--num_tokens={}".format(num_tokens),
84
+ "--seed={}".format(2024),
85
+ "--guidance_scale={}".format(guidance_scale),
86
+ "--num_sampling_step={}".format(num_sampling_steps),
87
+ "--strength={}".format(strength),
88
+ "--edge_thickness={}".format(edge_thickness),
89
+ "--num_imgs={}".format(2),
90
+ "--tgt_prompt={}".format(tgt_prompt) ,
91
+ "--tgt_index={}".format(tgt_idx)
92
+ ])
93
+
94
+ return Image.open(os.path.join(input_folder, "text", "out_text_0.png"))
95
+
96
+
97
+ def run_optimization(
98
+ input_folder,
99
+ num_tokens,
100
+ embedding_learning_rate,
101
+ max_emb_train_steps,
102
+ diffusion_model_learning_rate,
103
+ max_diffusion_train_steps,
104
+ train_batch_size,
105
+ gradient_accumulation_steps
106
+ ):
107
+ subprocess.run(["python",
108
+ "main.py" ,
109
+ "--name={}".format(input_folder),
110
+ "--dpm={}".format("sd"),
111
+ "--resolution={}".format(512),
112
+ "--num_tokens={}".format(num_tokens),
113
+ "--embedding_learning_rate={}".format(embedding_learning_rate),
114
+ "--diffusion_model_learning_rate={}".format(diffusion_model_learning_rate),
115
+ "--max_emb_train_steps={}".format(max_emb_train_steps),
116
+ "--max_diffusion_train_steps={}".format(max_diffusion_train_steps),
117
+ "--train_batch_size={}".format(train_batch_size),
118
+ "--gradient_accumulation_steps={}".format(gradient_accumulation_steps)
119
+
120
+ ])
121
+ return
122
+
123
+
124
+ def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
125
+ backimg_solid_np = np.array(backimg)
126
+ bimg = backimg.copy()
127
+ fimg = foreimg.copy()
128
+ fimg.putalpha(transparency)
129
+ bimg.paste(fimg, (0,0), fimg)
130
+
131
+ bimg_np = np.array(bimg)
132
+ mask_np = mask_np[:,:,np.newaxis]
133
+ try:
134
+ new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np
135
+ return Image.fromarray(new_img_np)
136
+ except:
137
+ import pdb; pdb.set_trace()
138
+
139
+ def show_segmentation(image, segmentation, flag):
140
+ if flag is False:
141
+ flag = True
142
+ mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8)
143
+ image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY)
144
+ return image_edit, flag
145
+ else:
146
+ flag = False
147
+ return image,flag
148
+
149
+ def edit_mask_add(canvas, image, idx, mask_np_list):
150
+ mask_sel = mask_np_list[idx]
151
+ mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.)
152
+ mask_np_list_updated = []
153
+ for midx, m in enumerate(mask_np_list):
154
+ if midx == idx:
155
+ mask_np_list_updated.append(mask_union(mask_sel, mask_new))
156
+ else:
157
+ mask_np_list_updated.append(m)
158
+
159
+ priority_list = [0 for _ in range(len(mask_np_list_updated))]
160
+ priority_list[idx] = 1
161
+ mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list)
162
+ mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8)
163
+ segmentation = create_segmentation(mask_np_list_updated)
164
+ image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY)
165
+ return mask_np_list_updated, image_edit
166
+
167
+ def slider_release(index, image, mask_np_list_updated, mask_label_list):
168
+ if index > len(mask_np_list_updated):
169
+ return image, "out of range"
170
+ else:
171
+ mask_np = mask_np_list_updated[index]
172
+ mask_label = mask_label_list[index]
173
+ segmentation = create_segmentation(mask_np_list_updated)
174
+ new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY)
175
+ return new_image, mask_label
176
+
177
+ def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder):
178
+ try:
179
+ assert np.all(sum(mask_np_list_updated)==1)
180
+ except:
181
+ print("please check mask")
182
+ # plt.imsave( "out_mask.png", mask_list_edit[0])
183
+ import pdb; pdb.set_trace()
184
+
185
+ for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
186
+ # np.save(os.path.join(input_folder, "maskEDIT{}_{}.npy".format(midx, mask_label)),mask )
187
+ np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask )
188
+ savepath = os.path.join(input_folder, "seg_current.png")
189
+ visualize_mask_list_clean(mask_np_list_updated, savepath)
190
+
191
+ def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder):
192
+ try:
193
+ assert np.all(sum(mask_np_list_updated)==1)
194
+ except:
195
+ print("please check mask")
196
+ # plt.imsave( "out_mask.png", mask_list_edit[0])
197
+ import pdb; pdb.set_trace()
198
+ for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
199
+ np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask)
200
+ savepath = os.path.join(input_folder, "seg_edited.png")
201
+ visualize_mask_list_clean(mask_np_list_updated, savepath)
202
+
203
+ with gr.Blocks() as demo:
204
+ image = gr.State() # store mask
205
+ image_loaded = gr.State()
206
+ segmentation = gr.State()
207
+
208
+ mask_np_list = gr.State([])
209
+ mask_label_list = gr.State([])
210
+ mask_np_list_updated = gr.State([])
211
+ true = gr.State(True)
212
+ false = gr.State(False)
213
+
214
+
215
+ with gr.Row():
216
+ gr.Markdown("""# D-Edit""")
217
+
218
+ with gr.Tab(label="1 Edit mask"):
219
+ with gr.Row():
220
+ with gr.Column():
221
+ canvas = gr.Image(value = None, type="numpy", tool="sketch", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
222
+ input_folder = gr.Textbox(value="example1", label="input folder", interactive= True, )
223
+
224
+ segment_button = gr.Button("1.1 Run segmentation")
225
+ segment_button.click(run_segmentation,
226
+ [input_folder] ,
227
+ [] )
228
+
229
+
230
+ text_button = gr.Button("1.2 Load original masks")
231
+ text_button.click(load_image_ui,
232
+ [input_folder, false] ,
233
+ [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
234
+
235
+ load_edit_button = gr.Button("1.2 Load edited masks")
236
+ load_edit_button.click(load_image_ui,
237
+ [input_folder, true] ,
238
+ [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
239
+
240
+ show_segment = gr.Checkbox(label = "Show Segmentation")
241
+
242
+ flag = gr.State(False)
243
+ show_segment.select(show_segmentation,
244
+ [image_loaded, segmentation, flag],
245
+ [canvas, flag])
246
+
247
+ mask_np_list_updated = copy.deepcopy(mask_np_list)
248
+
249
+ with gr.Column():
250
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
251
+ slider = gr.Slider(0, 20, step=1, interactive=True)
252
+ label = gr.Textbox()
253
+ slider.release(slider_release,
254
+ inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list],
255
+ outputs= [canvas, label]
256
+ )
257
+ add_button = gr.Button("Add")
258
+ add_button.click( edit_mask_add,
259
+ [canvas, image_loaded, slider, mask_np_list_updated] ,
260
+ [mask_np_list_updated, canvas]
261
+ )
262
+
263
+ save_button2 = gr.Button("Set and Save as edited masks")
264
+ save_button2.click( save_as_edit_mask,
265
+ [mask_np_list_updated, mask_label_list, input_folder] ,
266
+ [] )
267
+
268
+ save_button = gr.Button("Set and Save as original masks")
269
+ save_button.click( save_as_orig_mask,
270
+ [mask_np_list_updated, mask_label_list, input_folder] ,
271
+ [] )
272
+
273
+ back_button = gr.Button("Back to current seg")
274
+ back_button.click( load_mask_ui,
275
+ [input_folder] ,
276
+ [ mask_np_list_updated,mask_label_list] )
277
+
278
+ add_mask_button = gr.Button("Add new empty mask")
279
+ add_mask_button.click(add_mask,
280
+ [mask_np_list_updated, mask_label_list] ,
281
+ [mask_np_list_updated, mask_label_list] )
282
+
283
+ with gr.Tab(label="2 Optimization"):
284
+ with gr.Row():
285
+ with gr.Column():
286
+ canvas_opt = gr.Image(value = canvas.value, type="pil", tool="sketch", label="Loaded Image", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
287
+
288
+ with gr.Column():
289
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
290
+ num_tokens = gr.Textbox(value="5", label="num tokens to represent each object", interactive= True)
291
+ embedding_learning_rate = gr.Textbox(value="1e-4", label="Embedding optimization: Learning rate", interactive= True )
292
+ max_emb_train_steps = gr.Textbox(value="500", label="embedding optimization: Training steps", interactive= True )
293
+
294
+ diffusion_model_learning_rate = gr.Textbox(value="5e-5", label="UNet Optimization: Learning rate", interactive= True )
295
+ max_diffusion_train_steps = gr.Textbox(value="500", label="UNet Optimization: Learning rate: Training steps", interactive= True )
296
+
297
+ train_batch_size = gr.Textbox(value="5", label="Batch size", interactive= True )
298
+ gradient_accumulation_steps=gr.Textbox(value="5", label="Gradient accumulation", interactive= True )
299
+
300
+ add_button = gr.Button("Run optimization")
301
+ add_button.click(run_optimization,
302
+ inputs = [
303
+ input_folder,
304
+ num_tokens,
305
+ embedding_learning_rate,
306
+ max_emb_train_steps,
307
+ diffusion_model_learning_rate,
308
+ max_diffusion_train_steps,
309
+ train_batch_size,gradient_accumulation_steps
310
+ ],
311
+ outputs = []
312
+ )
313
+
314
+
315
+ with gr.Tab(label="3 Editing"):
316
+ with gr.Tab(label="3.1 Text-based editing"):
317
+ canvas_text_edit = gr.State() # store mask
318
+ with gr.Row():
319
+ with gr.Column():
320
+ canvas_text_edit = gr.Image(value = None, label="Editing results", show_label=True, height=LENGTH, width=LENGTH)
321
+ # canvas_text_edit = gr.Gallery(label = "Edited results")
322
+
323
+ with gr.Column():
324
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""")
325
+
326
+ tgt_prompt = gr.Textbox(value="Dog", label="Editing: Text prompt", interactive= True )
327
+ tgt_idx = gr.Textbox(value="0", label="Editing: Object index", interactive= True )
328
+ guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True )
329
+ num_sampling_steps = gr.Textbox(value="50", label="Editing: Sampling steps", interactive= True )
330
+ edge_thickness = gr.Textbox(value="10", label="Editing: Edge thickness", interactive= True )
331
+ strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
332
+
333
+ add_button = gr.Button("Run Editing")
334
+ add_button.click(run_edit_text,
335
+ inputs = [
336
+ input_folder,
337
+ num_tokens,
338
+ num_sampling_steps,
339
+ strength,
340
+ edge_thickness,
341
+ tgt_prompt,
342
+ tgt_idx,
343
+ guidance_scale
344
+ ],
345
+ outputs = [canvas_text_edit]
346
+ )
347
+
348
+
349
+ demo.queue().launch(share=True, debug=True)
assets/demo1.gif ADDED
assets/demo2.gif ADDED
assets/demo3.gif ADDED
assets/demo4.gif ADDED
assets/mask_def.png ADDED
controller.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import math
4
+ import xformers
5
+
6
+ class DummyController:
7
+ def __call__(self, *args):
8
+ return args[0]
9
+ def __init__(self):
10
+ self.num_att_layers = 0
11
+
12
+ class GroupedCAController:
13
+ def __init__(self, mask_list = None):
14
+ self.mask_list = mask_list
15
+ if self.mask_list is None:
16
+ self.is_decom = False
17
+ else:
18
+ self.is_decom = True
19
+
20
+ def mask_img_to_mask_vec(self, mask, length):
21
+ mask_vec = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), (length, length)).squeeze()
22
+ mask_vec = mask_vec.flatten()
23
+ return mask_vec
24
+
25
+ def ca_forward_decom(self, q, k_list, v_list, scale, place_in_unet):
26
+ # attn [Bh, N, d ]
27
+ # [8, 4096, 77]
28
+ # q [Bh, N, d] [8, 4096, 40] [8, 1024, 80] [8, 256,160] [8, 64, 160]
29
+ # k [Bh, P, d] [8, 77 , 40] [8, 77, 80] [8, 77, 160] [8, 77, 160]
30
+ # v [Bh, P, d] [8, 77 , 40] [8, 77, 80] [8, 77, 160] [8, 77, 160]
31
+ N = q.shape[1]
32
+ mask_vec_list = []
33
+ for mask in self.mask_list:
34
+ mask_vec = self.mask_img_to_mask_vec(mask, int(math.sqrt(N))) # [1,N,1]
35
+ mask_vec = mask_vec.unsqueeze(0).unsqueeze(-1)
36
+ mask_vec_list.append(mask_vec)
37
+ out = 0
38
+ for mask_vec, k, v in zip(mask_vec_list, k_list, v_list):
39
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * scale # [8, 4096, 20]
40
+ attn = sim.softmax(dim=-1) # [Bh,N,P] [8,4096,20]
41
+ attn = attn.masked_fill(mask_vec==0, 0)
42
+ masked_out = torch.einsum("b i j, b j d -> b i d", attn, v) # [Bh,N,d] [8,4096,320/h]
43
+ # mask_vec_inf = torch.where(mask_vec>0, 0, torch.finfo(k.dtype).min)
44
+ # masked_out1 = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask_vec_inf, op=None, scale=scale)
45
+ out += masked_out
46
+ return out
47
+
48
+ def reshape_heads_to_batch_dim(self):
49
+ def func(tensor):
50
+ batch_size, seq_len, dim = tensor.shape
51
+ head_size = self.num_heads
52
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
53
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
54
+ return func
55
+
56
+ def reshape_batch_dim_to_heads(self):
57
+ def func(tensor):
58
+ batch_size, seq_len, dim = tensor.shape
59
+ head_size = self.num_heads
60
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
61
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
62
+ return func
63
+
64
+ def register_attention_disentangled_control(unet, controller):
65
+ def ca_forward(self, place_in_unet):
66
+ to_out = self.to_out
67
+ if type(to_out) is torch.nn.modules.container.ModuleList:
68
+ to_out = self.to_out[0]
69
+ else:
70
+ to_out = self.to_out
71
+ def forward(x, encoder_hidden_states =None, attention_mask=None):
72
+ if isinstance(controller, DummyController): # SA CA full
73
+ q = self.to_q(x)
74
+ is_cross = encoder_hidden_states is not None
75
+ encoder_hidden_states = encoder_hidden_states if is_cross else x
76
+ k = self.to_k(encoder_hidden_states)
77
+ v = self.to_v(encoder_hidden_states)
78
+ q = self.head_to_batch_dim(q)
79
+ k = self.head_to_batch_dim(k)
80
+ v = self.head_to_batch_dim(v)
81
+
82
+ # sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
83
+ # attn = sim.softmax(dim=-1)
84
+ # attn = controller(attn, is_cross, place_in_unet)
85
+ # out = torch.einsum("b i j, b j d -> b i d", attn, v)
86
+ out = xformers.ops.memory_efficient_attention(
87
+ q, k, v, attn_bias=None, op=None, scale=self.scale
88
+ )
89
+ out = self.batch_to_head_dim(out)
90
+ else: # decom: CA+SA
91
+ is_cross = encoder_hidden_states is not None
92
+ assert is_cross is not None
93
+ encoder_hidden_states_list = encoder_hidden_states if is_cross else x
94
+ q = self.to_q(x)
95
+ q = self.head_to_batch_dim(q) # [Bh, 4096, 320/h ] h: 8
96
+ if is_cross: #CA
97
+ k_list = []
98
+ v_list = []
99
+ assert type(encoder_hidden_states_list) is list
100
+ for encoder_hidden_states in encoder_hidden_states_list:
101
+ k = self.to_k(encoder_hidden_states)
102
+ k = self.head_to_batch_dim(k) # [Bh, 77, 320/h ]
103
+ k_list.append(k)
104
+ v = self.to_v(encoder_hidden_states)
105
+ v = self.head_to_batch_dim(v) # [Bh, 77, 320/h ]
106
+ v_list.append(v)
107
+ out = controller.ca_forward_decom(q, k_list, v_list, self.scale, place_in_unet) # [Bh,N,d]
108
+ out = self.batch_to_head_dim(out)
109
+ else: # SA
110
+ exit("decomposing SA!")
111
+ k = self.to_k(x)
112
+ v = self.to_v(x)
113
+ k = self.head_to_batch_dim(k) # [Bh, 77, 320/h ]
114
+ v = self.head_to_batch_dim(v) # [Bh, 77, 320/h ]
115
+ import pdb; pdb.set_trace()
116
+ if k.shape[1] <= 1024 ** 2:
117
+ out = controller.sa_forward(q, k, v, self.scale, place_in_unet) # [Bh,N,d]
118
+ else:
119
+ print("warining")
120
+ out = controller.sa_forward_decom(q, k, v, self.scale, place_in_unet) # [Bh,N,d]
121
+ # sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
122
+ # attn = sim.softmax(dim=-1) # [8,4096,4096] [Bh,N,N]
123
+ # out = torch.einsum("b i j, b j d -> b i d", attn, v) # [Bh,N,d] [8,4096,320/h]
124
+
125
+ out = self.batch_to_head_dim(out) # [B, H, N, D]
126
+
127
+ return to_out(out)
128
+
129
+ return forward
130
+
131
+ if controller is None:
132
+ controller = DummyController()
133
+
134
+ def register_recr(net_, count, place_in_unet):
135
+ if net_.__class__.__name__ == 'Attention' and net_.to_k.in_features == unet.ca_dim:
136
+ net_.forward = ca_forward(net_, place_in_unet)
137
+ return count + 1
138
+ elif hasattr(net_, 'children'):
139
+ for net__ in net_.children():
140
+ count = register_recr(net__, count, place_in_unet)
141
+ return count
142
+
143
+ cross_att_count = 0
144
+ sub_nets = unet.named_children()
145
+
146
+ for net in sub_nets:
147
+ if "down" in net[0]:
148
+ down_count = register_recr(net[1], 0, "down")#6
149
+ cross_att_count += down_count
150
+ elif "up" in net[0]:
151
+ up_count = register_recr(net[1], 0, "up") #9
152
+ cross_att_count += up_count
153
+ elif "mid" in net[0]:
154
+ mid_count = register_recr(net[1], 0, "mid") #1
155
+ cross_att_count += mid_count
156
+ controller.num_att_layers = cross_att_count
example2/img.png ADDED
main.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import argparse
5
+ from peft import LoraConfig
6
+ from pipeline_dedit_sdxl import DEditSDXLPipeline
7
+ from pipeline_dedit_sd import DEditSDPipeline
8
+ from utils import load_image, load_mask, load_mask_edit
9
+ from utils_mask import process_mask_move_torch, process_mask_remove_torch, mask_union_torch, mask_substract_torch, create_outer_edge_mask_torch
10
+ from utils_mask import check_mask_overlap_torch, check_cover_all_torch, visualize_mask_list, get_mask_difference_torch, save_mask_list_to_npys
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--name", type=str,required=True, default=None)
14
+ parser.add_argument("--name_2", type=str,required=False, default=None)
15
+ parser.add_argument("--dpm", type=str,required=True, default="sd")
16
+ parser.add_argument("--resolution", type=int, default=1024)
17
+ parser.add_argument("--seed", type=int, default=42)
18
+ parser.add_argument("--embedding_learning_rate", type=float, default=1e-4)
19
+ parser.add_argument("--max_emb_train_steps", type=int, default=200)
20
+ parser.add_argument("--diffusion_model_learning_rate", type=float, default=5e-5)
21
+ parser.add_argument("--max_diffusion_train_steps", type=int, default=200)
22
+ parser.add_argument("--train_batch_size", type=int, default=1)
23
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
24
+ parser.add_argument("--num_tokens", type=int, default=1)
25
+
26
+
27
+ parser.add_argument("--load_trained", default=False, action="store_true" )
28
+ parser.add_argument("--num_sampling_steps", type=int, default=50)
29
+ parser.add_argument("--guidance_scale", type=float, default = 3 )
30
+ parser.add_argument("--strength", type=float, default=0.8)
31
+
32
+ parser.add_argument("--train_full_lora", default=False, action="store_true" )
33
+ parser.add_argument("--lora_rank", type=int, default=4)
34
+ parser.add_argument("--lora_alpha", type=int, default=4)
35
+
36
+ parser.add_argument("--prompt_auxin_list", nargs="+", type=str, default = None)
37
+ parser.add_argument("--prompt_auxin_idx_list", nargs="+", type=int, default = None)
38
+
39
+ # general editing configs
40
+ parser.add_argument("--load_edited_mask", default=False, action="store_true")
41
+ parser.add_argument("--load_edited_processed_mask", default=False, action="store_true")
42
+ parser.add_argument("--edge_thickness", type=int, default=20)
43
+ parser.add_argument("--num_imgs", type=int, default = 1 )
44
+ parser.add_argument('--active_mask_list', nargs="+", type=int)
45
+ parser.add_argument("--tgt_index", type=int, default=None)
46
+
47
+ # recon
48
+ parser.add_argument("--recon", default=False, action="store_true" )
49
+ parser.add_argument("--recon_an_item", default=False, action="store_true" )
50
+ parser.add_argument("--recon_prompt", type=str, default=None)
51
+
52
+ # text-based editing
53
+ parser.add_argument("--text", default=False, action="store_true")
54
+ parser.add_argument("--tgt_prompt", type=str, default=None)
55
+
56
+ # image-based editing
57
+ parser.add_argument("--image", default=False, action="store_true" )
58
+ parser.add_argument("--src_index", type=int, default=None)
59
+ parser.add_argument("--tgt_name", type=str, default=None)
60
+
61
+ # mask-based move
62
+ parser.add_argument("--move_resize", default=False, action="store_true" )
63
+ parser.add_argument('--tgt_indices_list', nargs="+", type=int)
64
+ parser.add_argument("--delta_x_list", nargs="+", type=int)
65
+ parser.add_argument("--delta_y_list", nargs="+", type=int)
66
+ parser.add_argument("--priority_list", nargs="+", type=int)
67
+ parser.add_argument("--force_mask_remain", type=int, default=None)
68
+ parser.add_argument("--resize_list", nargs="+", type=float)
69
+
70
+ # remove
71
+ parser.add_argument("--remove", default=False, action="store_true" )
72
+ parser.add_argument("--load_edited_removemask", default=False, action="store_true")
73
+
74
+ args = parser.parse_args()
75
+
76
+ torch.cuda.manual_seed_all(args.seed)
77
+ torch.manual_seed(args.seed)
78
+ base_input_folder = "."
79
+ base_output_folder = "."
80
+
81
+ input_folder = os.path.join(base_input_folder, args.name)
82
+
83
+
84
+ mask_list, mask_label_list = load_mask(input_folder)
85
+ assert mask_list[0].shape[0] == args.resolution, "Segmentation should be done on size {}".format(args.resolution)
86
+ try:
87
+ image_gt = load_image(os.path.join(input_folder, "img_{}.png".format(args.resolution) ), size = args.resolution)
88
+ except:
89
+ image_gt = load_image(os.path.join(input_folder, "img_{}.jpg".format(args.resolution) ), size = args.resolution)
90
+
91
+ if args.image:
92
+ input_folder_2 = os.path.join(base_input_folder, args.name_2)
93
+ mask_list_2, mask_label_list_2 = load_mask(input_folder_2)
94
+ assert mask_list_2[0].shape[0] == args.resolution, "Segmentation should be done on size {}".format(args.resolution)
95
+ try:
96
+ image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.png".format(args.resolution) ), size = args.resolution)
97
+ except:
98
+ image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.jpg".format(args.resolution) ), size = args.resolution)
99
+ output_dir = os.path.join(base_output_folder, args.name + "_" + args.name_2)
100
+ os.makedirs(output_dir, exist_ok = True)
101
+ else:
102
+ output_dir = os.path.join(base_output_folder, args.name)
103
+ os.makedirs(output_dir, exist_ok = True)
104
+
105
+ if args.dpm == "sd":
106
+ if args.image:
107
+ pipe = DEditSDPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = args.resolution, num_tokens = args.num_tokens)
108
+ else:
109
+ pipe = DEditSDPipeline(mask_list, mask_label_list, resolution = args.resolution, num_tokens = args.num_tokens)
110
+
111
+ elif args.dpm == "sdxl":
112
+ if args.image:
113
+ pipe = DEditSDXLPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = args.resolution, num_tokens = args.num_tokens)
114
+ else:
115
+ pipe = DEditSDXLPipeline(mask_list, mask_label_list, resolution = args.resolution, num_tokens = args.num_tokens)
116
+
117
+ else:
118
+ raise NotImplementedError
119
+
120
+ set_string_list = pipe.set_string_list
121
+ if args.prompt_auxin_list is not None:
122
+ for auxin_idx, auxin_prompt in zip(args.prompt_auxin_idx_list, args.prompt_auxin_list):
123
+ set_string_list[auxin_idx] = auxin_prompt.replace("*", set_string_list[auxin_idx] )
124
+ print(set_string_list)
125
+
126
+ if args.image:
127
+ set_string_list_2 = pipe.set_string_list_2
128
+ print(set_string_list_2)
129
+
130
+ if args.load_trained:
131
+ unet_save_path = os.path.join(output_dir, "unet.pt")
132
+ unet_state_dict = torch.load(unet_save_path)
133
+ text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
134
+ text_encoder1_state_dict = torch.load(text_encoder1_save_path)
135
+ if args.dpm == "sdxl":
136
+ text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
137
+ text_encoder2_state_dict = torch.load(text_encoder2_save_path)
138
+
139
+ if 'lora' in ''.join(unet_state_dict.keys()):
140
+ unet_lora_config = LoraConfig(
141
+ r=args.lora_rank,
142
+ lora_alpha=args.lora_alpha,
143
+ init_lora_weights="gaussian",
144
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
145
+ )
146
+ pipe.unet.add_adapter(unet_lora_config)
147
+
148
+ pipe.unet.load_state_dict(unet_state_dict)
149
+ pipe.text_encoder.load_state_dict(text_encoder1_state_dict)
150
+ if args.dpm == "sdxl":
151
+ pipe.text_encoder_2.load_state_dict(text_encoder2_state_dict)
152
+ else:
153
+ if args.image:
154
+ pipe.mask_list = [m.cuda() for m in pipe.mask_list]
155
+ pipe.mask_list_2 = [m.cuda() for m in pipe.mask_list_2]
156
+ pipe.train_emb_2imgs(
157
+ image_gt,
158
+ image_gt_2,
159
+ set_string_list,
160
+ set_string_list_2,
161
+ gradient_accumulation_steps = args.gradient_accumulation_steps,
162
+ embedding_learning_rate = args.embedding_learning_rate,
163
+ max_emb_train_steps = args.max_emb_train_steps,
164
+ train_batch_size = args.train_batch_size,
165
+ )
166
+
167
+ pipe.train_model_2imgs(
168
+ image_gt,
169
+ image_gt_2,
170
+ set_string_list,
171
+ set_string_list_2,
172
+ gradient_accumulation_steps = args.gradient_accumulation_steps,
173
+ max_diffusion_train_steps = args.max_diffusion_train_steps,
174
+ diffusion_model_learning_rate = args.diffusion_model_learning_rate ,
175
+ train_batch_size =args.train_batch_size,
176
+ train_full_lora = args.train_full_lora,
177
+ lora_rank = args.lora_rank,
178
+ lora_alpha = args.lora_alpha
179
+ )
180
+
181
+ else:
182
+ pipe.mask_list = [m.cuda() for m in pipe.mask_list]
183
+ pipe.train_emb(
184
+ image_gt,
185
+ set_string_list,
186
+ gradient_accumulation_steps = args.gradient_accumulation_steps,
187
+ embedding_learning_rate = args.embedding_learning_rate,
188
+ max_emb_train_steps = args.max_emb_train_steps,
189
+ train_batch_size = args.train_batch_size,
190
+ )
191
+
192
+ pipe.train_model(
193
+ image_gt,
194
+ set_string_list,
195
+ gradient_accumulation_steps = args.gradient_accumulation_steps,
196
+ max_diffusion_train_steps = args.max_diffusion_train_steps,
197
+ diffusion_model_learning_rate = args.diffusion_model_learning_rate ,
198
+ train_batch_size = args.train_batch_size,
199
+ train_full_lora = args.train_full_lora,
200
+ lora_rank = args.lora_rank,
201
+ lora_alpha = args.lora_alpha
202
+ )
203
+
204
+
205
+ unet_save_path = os.path.join(output_dir, "unet.pt")
206
+ torch.save(pipe.unet.state_dict(),unet_save_path )
207
+ text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
208
+ torch.save(pipe.text_encoder.state_dict(), text_encoder1_save_path)
209
+ if args.dpm == "sdxl":
210
+ text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
211
+ torch.save(pipe.text_encoder_2.state_dict(), text_encoder2_save_path )
212
+
213
+
214
+ if args.recon:
215
+ output_dir = os.path.join(output_dir, "recon")
216
+ os.makedirs(output_dir, exist_ok = True)
217
+ if args.recon_an_item:
218
+ mask_list = [torch.from_numpy(np.ones_like(mask_list[0].numpy()))]
219
+ tgt_string = set_string_list[args.tgt_index]
220
+ tgt_string = args.recon_prompt.replace("*", tgt_string)
221
+ set_string_list = [tgt_string]
222
+ print(set_string_list)
223
+ save_path = os.path.join(output_dir, "out_recon.png")
224
+ x_np = pipe.inference_with_mask(
225
+ save_path,
226
+ guidance_scale = args.guidance_scale,
227
+ num_sampling_steps = args.num_sampling_steps,
228
+ seed = args.seed,
229
+ num_imgs = args.num_imgs,
230
+ set_string_list = set_string_list,
231
+ mask_list = mask_list
232
+ )
233
+
234
+ if args.text:
235
+ print("Text-guided editing ")
236
+ output_dir = os.path.join(output_dir, "text")
237
+ os.makedirs(output_dir, exist_ok = True)
238
+ save_path = os.path.join(output_dir, "out_text.png")
239
+ set_string_list[args.tgt_index] = args.tgt_prompt
240
+ mask_active = torch.zeros_like(mask_list[0])
241
+ mask_active = mask_union_torch(mask_active, mask_list[args.tgt_index])
242
+
243
+ if args.active_mask_list is not None:
244
+ for midx in args.active_mask_list:
245
+ mask_active = mask_union_torch(mask_active, mask_list[midx])
246
+
247
+ if args.load_edited_mask:
248
+ mask_list_edited, mask_label_list_edited = load_mask_edit(input_folder)
249
+ mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
250
+ mask_active = mask_union_torch(mask_active, mask_diff)
251
+ mask_list = mask_list_edited
252
+ save_path = os.path.join(output_dir, "out_textEdited.png")
253
+
254
+ mask_hard = mask_substract_torch(torch.ones_like(mask_list[0]), mask_active)
255
+ mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = args.edge_thickness)
256
+ mask_hard = mask_substract_torch(mask_hard, mask_soft)
257
+
258
+ pipe.inference_with_mask(
259
+ save_path,
260
+ orig_image = image_gt,
261
+ set_string_list = set_string_list,
262
+ guidance_scale = args.guidance_scale,
263
+ strength = args.strength,
264
+ num_imgs = args.num_imgs,
265
+ mask_hard= mask_hard,
266
+ mask_soft = mask_soft,
267
+ mask_list = mask_list,
268
+ seed = args.seed,
269
+ num_sampling_steps = args.num_sampling_steps
270
+ )
271
+
272
+ if args.remove:
273
+ output_dir = os.path.join(output_dir, "remove")
274
+ save_path = os.path.join(output_dir, "out_remove.png")
275
+ os.makedirs(output_dir, exist_ok = True)
276
+ mask_active = torch.zeros_like(mask_list[0])
277
+
278
+ if args.load_edited_mask:
279
+ mask_list_edited, _ = load_mask_edit(input_folder)
280
+ mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
281
+ mask_active = mask_union_torch(mask_active, mask_diff)
282
+ mask_list = mask_list_edited
283
+
284
+ if args.load_edited_processed_mask:
285
+ # manually edit or draw masks after removing one index, then load
286
+ mask_list_processed, _ = load_mask_edit(output_dir)
287
+ mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
288
+ else:
289
+ # generate masks after removing one index, using nearest neighbor algorithm
290
+ mask_list_processed, mask_remain = process_mask_remove_torch(mask_list, args.tgt_index)
291
+ save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
292
+ visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_removed.png"))
293
+ check_cover_all_torch(*mask_list_processed)
294
+ mask_active = mask_union_torch(mask_active, mask_remain)
295
+
296
+ if args.active_mask_list is not None:
297
+ for midx in args.active_mask_list:
298
+ mask_active = mask_union_torch(mask_active, mask_list[midx])
299
+
300
+ mask_hard = 1 - mask_active
301
+ mask_soft = create_outer_edge_mask_torch(mask_remain, edge_thickness = args.edge_thickness)
302
+ mask_hard = mask_substract_torch(mask_hard, mask_soft)
303
+
304
+ pipe.inference_with_mask(
305
+ save_path,
306
+ orig_image = image_gt,
307
+ guidance_scale = args.guidance_scale,
308
+ strength = args.strength,
309
+ num_imgs = args.num_imgs,
310
+ mask_hard= mask_hard,
311
+ mask_soft = mask_soft,
312
+ mask_list = mask_list_processed,
313
+ seed = args.seed,
314
+ num_sampling_steps = args.num_sampling_steps
315
+ )
316
+
317
+ if args.image:
318
+ output_dir = os.path.join(output_dir, "image")
319
+ save_path = os.path.join(output_dir, "out_image.png")
320
+ os.makedirs(output_dir, exist_ok = True)
321
+ mask_active = torch.zeros_like(mask_list[0])
322
+
323
+ if None not in (args.tgt_name, args.src_index, args.tgt_index):
324
+ if args.tgt_name == args.name:
325
+ set_string_list_tgt = set_string_list
326
+ set_string_list_src = set_string_list_2
327
+ image_tgt = image_gt
328
+ if args.load_edited_mask:
329
+ mask_list_edited, _ = load_mask_edit(input_folder)
330
+ mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
331
+ mask_active = mask_union_torch(mask_active, mask_diff)
332
+ mask_list = mask_list_edited
333
+ save_path = os.path.join(output_dir, "out_imageEdited.png")
334
+ mask_list_tgt = mask_list
335
+
336
+ elif args.tgt_name == args.name_2:
337
+ set_string_list_tgt = set_string_list_2
338
+ set_string_list_src = set_string_list
339
+ image_tgt = image_gt_2
340
+ if args.load_edited_mask:
341
+ mask_list_2_edited, _ = load_mask_edit(input_folder_2)
342
+ mask_diff = get_mask_difference_torch(mask_list_2_edited, mask_list_2)
343
+ mask_active = mask_union_torch(mask_active, mask_diff)
344
+ mask_list_2 = mask_list_2_edited
345
+ save_path = os.path.join(output_dir, "out_imageEdited.png")
346
+ mask_list_tgt = mask_list_2
347
+ else:
348
+ exit("tgt_name should be either name or name_2")
349
+
350
+ set_string_list_tgt[args.tgt_index] = set_string_list_src[args.src_index]
351
+
352
+ mask_active = mask_list_tgt[args.tgt_index]
353
+ mask_frozen = (1-mask_active.float()).to(mask_active.device)
354
+ mask_soft = create_outer_edge_mask_torch(mask_active.cpu(), edge_thickness = args.edge_thickness)
355
+ mask_hard = mask_substract_torch(mask_frozen.cpu(), mask_soft.cpu())
356
+
357
+ mask_list_tgt = [m.cuda() for m in mask_list_tgt]
358
+
359
+ pipe.inference_with_mask(
360
+ save_path,
361
+ set_string_list = set_string_list_tgt,
362
+ mask_list = mask_list_tgt,
363
+ guidance_scale = args.guidance_scale,
364
+ num_sampling_steps = args.num_sampling_steps,
365
+ mask_hard = mask_hard.cuda(),
366
+ mask_soft = mask_soft.cuda(),
367
+ num_imgs = args.num_imgs,
368
+ orig_image = image_tgt,
369
+ strength = args.strength,
370
+ )
371
+
372
+ if args.move_resize:
373
+ output_dir = os.path.join(output_dir, "move_resize")
374
+ os.makedirs(output_dir, exist_ok = True)
375
+ save_path = os.path.join(output_dir, "out_moveresize.png")
376
+ mask_active = torch.zeros_like(mask_list[0])
377
+
378
+ if args.load_edited_mask:
379
+ mask_list_edited, _ = load_mask_edit(input_folder)
380
+ mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
381
+ mask_active = mask_union_torch(mask_active, mask_diff)
382
+ mask_list = mask_list_edited
383
+ # save_path = os.path.join(output_dir, "out_moveresizeEdited.png")
384
+
385
+ if args.load_edited_processed_mask:
386
+ mask_list_processed, _ = load_mask_edit(output_dir)
387
+ mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
388
+ else:
389
+ mask_list_processed, mask_remain = process_mask_move_torch(
390
+ mask_list,
391
+ args.tgt_indices_list,
392
+ args.delta_x_list,
393
+ args.delta_y_list, args.priority_list,
394
+ force_mask_remain = args.force_mask_remain,
395
+ resize_list = args.resize_list
396
+ )
397
+ save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
398
+ visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_move_resize.png"))
399
+ active_idxs = args.tgt_indices_list
400
+
401
+ mask_active = mask_union_torch(mask_active, *[m for midx, m in enumerate(mask_list_processed) if midx in active_idxs])
402
+ mask_active = mask_union_torch(mask_remain, mask_active)
403
+ if args.active_mask_list is not None:
404
+ for midx in args.active_mask_list:
405
+ mask_active = mask_union_torch(mask_active, mask_list_processed[midx])
406
+
407
+ mask_frozen =(1 - mask_active.float())
408
+ mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = args.edge_thickness)
409
+ mask_hard = mask_substract_torch(mask_frozen, mask_soft)
410
+
411
+ check_mask_overlap_torch(mask_hard, mask_soft)
412
+
413
+ pipe.inference_with_mask(
414
+ save_path,
415
+ strength = args.strength,
416
+ orig_image = image_gt,
417
+ guidance_scale = args.guidance_scale,
418
+ num_sampling_steps = args.num_sampling_steps,
419
+ num_imgs = args.num_imgs,
420
+ mask_hard= mask_hard,
421
+ mask_soft = mask_soft,
422
+ mask_list = mask_list_processed,
423
+ seed = args.seed
424
+ )
pipeline_dedit_sd.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils import import_model_class_from_model_name_or_path
3
+ from transformers import AutoTokenizer
4
+ from diffusers import (
5
+ AutoencoderKL,
6
+ DDPMScheduler,
7
+ DDIMScheduler,
8
+ UNet2DConditionModel,
9
+ )
10
+ from accelerate import Accelerator
11
+ from tqdm.auto import tqdm
12
+ from utils import sd_prepare_input_decom, save_images
13
+ import torch.nn.functional as F
14
+ import itertools
15
+ from peft import LoraConfig
16
+ from controller import GroupedCAController, register_attention_disentangled_control, DummyController
17
+ from utils import image2latent, latent2image
18
+ import matplotlib.pyplot as plt
19
+ from utils_mask import check_mask_overlap_torch
20
+
21
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
22
+
23
+ class DEditSDPipeline:
24
+ def __init__(
25
+ self,
26
+ mask_list,
27
+ mask_label_list,
28
+ mask_list_2 = None,
29
+ mask_label_list_2 = None,
30
+ resolution = 1024,
31
+ num_tokens = 1
32
+ ):
33
+ super().__init__()
34
+ model_id = "./stable-diffusion-v1-5"
35
+ self.model_id = model_id
36
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", use_fast=False)
37
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(model_id, subfolder = "text_encoder")
38
+ self.text_encoder = text_encoder_cls_one.from_pretrained(model_id, subfolder="text_encoder" ).to(device)
39
+
40
+ self.unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
41
+ self.unet.ca_dim = 768
42
+ self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
43
+ self.scheduler = DDPMScheduler.from_pretrained(model_id , subfolder="scheduler")
44
+ self.scheduler = DDIMScheduler(
45
+ beta_start=0.00085,
46
+ beta_end=0.012,
47
+ beta_schedule="scaled_linear",
48
+ clip_sample=False,
49
+ set_alpha_to_one=True,
50
+ rescale_betas_zero_snr = False,
51
+ )
52
+ self.mixed_precision = "fp16"
53
+ self.resolution = resolution
54
+ self.num_tokens = num_tokens
55
+
56
+ self.mask_list = mask_list
57
+ self.mask_label_list = mask_label_list
58
+ notation_token_list = [phrase.split(" ")[-1] for phrase in mask_label_list]
59
+ placeholder_token_list = ["#"+word+"{}".format(widx) for widx, word in enumerate(notation_token_list)]
60
+ self.set_string_list, placeholder_token_ids = self.add_tokens(placeholder_token_list)
61
+ self.min_added_id = min(placeholder_token_ids)
62
+ self.max_added_id = max(placeholder_token_ids)
63
+
64
+ if mask_list_2 is not None:
65
+ self.mask_list_2 = mask_list_2
66
+ self.mask_label_list_2 = mask_label_list_2
67
+ notation_token_list_2 = [phrase.split(" ")[-1] for phrase in mask_label_list_2]
68
+
69
+ placeholder_token_list_2 = ["$"+word+"{}".format(widx) for widx, word in enumerate(notation_token_list_2)]
70
+ self.set_string_list_2, placeholder_token_ids_2 = self.add_tokens(placeholder_token_list_2)
71
+ self.max_added_id = max(placeholder_token_ids_2)
72
+
73
+ def add_tokens_text_encoder_random_init(self, placeholder_token, num_tokens=1):
74
+ # Add the placeholder token in tokenizer
75
+ placeholder_tokens = [placeholder_token]
76
+ # add dummy tokens for multi-vector
77
+ additional_tokens = []
78
+ for i in range(1, num_tokens):
79
+ additional_tokens.append(f"{placeholder_token}_{i}")
80
+ placeholder_tokens += additional_tokens
81
+ num_added_tokens = self.tokenizer.add_tokens(placeholder_tokens) # 49408
82
+
83
+ if num_added_tokens != num_tokens:
84
+ raise ValueError(
85
+ f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
86
+ " `placeholder_token` that is not already in the tokenizer."
87
+ )
88
+ placeholder_token_ids = self.tokenizer.convert_tokens_to_ids(placeholder_tokens)
89
+
90
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
91
+ token_embeds = self.text_encoder.get_input_embeddings().weight.data
92
+ std, mean = torch.std_mean(token_embeds)
93
+ with torch.no_grad():
94
+ for token_id in placeholder_token_ids:
95
+ token_embeds[token_id] = torch.randn_like(token_embeds[token_id])*std + mean
96
+
97
+ set_string = " ".join(self.tokenizer.convert_ids_to_tokens(placeholder_token_ids))
98
+
99
+ return set_string, placeholder_token_ids
100
+
101
+ def add_tokens(self, placeholder_token_list):
102
+ set_string_list = []
103
+ placeholder_token_ids_list = []
104
+ for str_idx in range(len(placeholder_token_list)):
105
+ placeholder_token = placeholder_token_list[str_idx]
106
+ set_string, placeholder_token_ids = self.add_tokens_text_encoder_random_init(placeholder_token, num_tokens=self.num_tokens)
107
+ set_string_list.append(set_string)
108
+ placeholder_token_ids_list.append(placeholder_token_ids)
109
+ placeholder_token_ids = list(itertools.chain(*placeholder_token_ids_list))
110
+ return set_string_list, placeholder_token_ids
111
+
112
+ def train_emb(
113
+ self,
114
+ image_gt,
115
+ set_string_list,
116
+ gradient_accumulation_steps = 5,
117
+ embedding_learning_rate = 1e-4,
118
+ max_emb_train_steps = 100,
119
+ train_batch_size = 1,
120
+ ):
121
+ decom_controller = GroupedCAController(mask_list = self.mask_list)
122
+ register_attention_disentangled_control(self.unet, decom_controller)
123
+
124
+ accelerator = Accelerator(mixed_precision=self.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps)
125
+ self.vae.requires_grad_(False)
126
+ self.unet.requires_grad_(False)
127
+
128
+ self.text_encoder.requires_grad_(True)
129
+
130
+ self.text_encoder.text_model.encoder.requires_grad_(False)
131
+ self.text_encoder.text_model.final_layer_norm.requires_grad_(False)
132
+ self.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
133
+
134
+ weight_dtype = torch.float32
135
+ if accelerator.mixed_precision == "fp16":
136
+ weight_dtype = torch.float16
137
+ elif accelerator.mixed_precision == "bf16":
138
+ weight_dtype = torch.bfloat16
139
+
140
+ self.unet.to(device, dtype=weight_dtype)
141
+ self.vae.to(device, dtype=weight_dtype)
142
+
143
+ trainable_embmat_list_1 = [param for param in self.text_encoder.get_input_embeddings().parameters()]
144
+ optimizer = torch.optim.AdamW(trainable_embmat_list_1, lr=embedding_learning_rate)
145
+
146
+ self.text_encoder, optimizer = accelerator.prepare(self.text_encoder, optimizer)
147
+
148
+ orig_embeds_params_1 = accelerator.unwrap_model(self.text_encoder).get_input_embeddings().weight.data.clone()
149
+
150
+ self.text_encoder.train()
151
+
152
+ effective_emb_train_steps = max_emb_train_steps//gradient_accumulation_steps
153
+
154
+ if accelerator.is_main_process:
155
+ accelerator.init_trackers("DEdit EmbSteps", config={
156
+ "embedding_learning_rate": embedding_learning_rate,
157
+ "text_embedding_optimization_steps": effective_emb_train_steps,
158
+ })
159
+ global_step = 0
160
+ noise_scheduler = self.scheduler
161
+ progress_bar = tqdm(range(0, effective_emb_train_steps), initial = global_step, desc="EmbSteps")
162
+ latents0 = image2latent(image_gt, vae = self.vae, dtype = weight_dtype)
163
+ latents0 = latents0.repeat(train_batch_size, 1, 1, 1)
164
+
165
+ for _ in range(max_emb_train_steps):
166
+ with accelerator.accumulate(self.text_encoder):
167
+ latents = latents0.clone().detach()
168
+ noise = torch.randn_like(latents)
169
+ bsz = latents.shape[0]
170
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
171
+ timesteps = timesteps.long()
172
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
173
+ encoder_hidden_states_list = sd_prepare_input_decom(
174
+ set_string_list,
175
+ self.tokenizer,
176
+ self.text_encoder,
177
+ length = 40,
178
+ bsz = train_batch_size,
179
+ weight_dtype = weight_dtype
180
+ )
181
+
182
+ model_pred = self.unet(
183
+ noisy_latents,
184
+ timesteps,
185
+ encoder_hidden_states = encoder_hidden_states_list,
186
+ ).sample
187
+
188
+ loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
189
+ accelerator.backward(loss)
190
+ optimizer.step()
191
+ optimizer.zero_grad()
192
+
193
+ index_no_updates = torch.ones((len(self.tokenizer),), dtype=torch.bool)
194
+ index_no_updates[self.min_added_id : self.max_added_id + 1] = False
195
+ with torch.no_grad():
196
+ accelerator.unwrap_model(self.text_encoder).get_input_embeddings().weight[
197
+ index_no_updates] = orig_embeds_params_1[index_no_updates]
198
+
199
+ logs = {"loss": loss.detach().item(), "lr": embedding_learning_rate}
200
+ progress_bar.set_postfix(**logs)
201
+ accelerator.log(logs, step=global_step)
202
+ if accelerator.sync_gradients:
203
+ progress_bar.update(1)
204
+ global_step += 1
205
+
206
+ if global_step >= max_emb_train_steps:
207
+ break
208
+
209
+ accelerator.wait_for_everyone()
210
+ accelerator.end_training()
211
+ self.text_encoder = accelerator.unwrap_model(self.text_encoder).to(dtype = weight_dtype)
212
+
213
+ def train_model(
214
+ self,
215
+ image_gt,
216
+ set_string_list,
217
+ gradient_accumulation_steps = 5,
218
+ max_diffusion_train_steps = 100,
219
+ diffusion_model_learning_rate = 1e-5,
220
+ train_batch_size = 1,
221
+ train_full_lora = False,
222
+ lora_rank = 4,
223
+ lora_alpha = 4
224
+ ):
225
+ self.unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet").to(device)
226
+ self.unet.ca_dim = 768
227
+ decom_controller = GroupedCAController(mask_list = self.mask_list)
228
+ register_attention_disentangled_control(self.unet, decom_controller)
229
+
230
+ mixed_precision = "fp16"
231
+ accelerator = Accelerator(gradient_accumulation_steps = gradient_accumulation_steps, mixed_precision = mixed_precision)
232
+
233
+ weight_dtype = torch.float32
234
+ if accelerator.mixed_precision == "fp16":
235
+ weight_dtype = torch.float16
236
+ elif accelerator.mixed_precision == "bf16":
237
+ weight_dtype = torch.bfloat16
238
+
239
+ self.vae.requires_grad_(False)
240
+ self.vae.to(device, dtype=weight_dtype)
241
+
242
+ self.unet.requires_grad_(False)
243
+ self.unet.train()
244
+
245
+ self.text_encoder.requires_grad_(False)
246
+
247
+ if not train_full_lora:
248
+ trainable_params_list = []
249
+ for _, module in self.unet.named_modules():
250
+ module_name = type(module).__name__
251
+ if module_name == "Attention":
252
+ if module.to_k.in_features == self.unet.ca_dim: # this is cross attention:
253
+ module.to_k.weight.requires_grad = True
254
+ trainable_params_list.append(module.to_k.weight)
255
+ if module.to_k.bias is not None:
256
+ module.to_k.bias.requires_grad = True
257
+ trainable_params_list.append(module.to_k.bias)
258
+ module.to_v.weight.requires_grad = True
259
+ trainable_params_list.append(module.to_v.weight)
260
+ if module.to_v.bias is not None:
261
+ module.to_v.bias.requires_grad = True
262
+ trainable_params_list.append(module.to_v.bias)
263
+ module.to_q.weight.requires_grad = True
264
+ trainable_params_list.append(module.to_q.weight)
265
+ if module.to_q.bias is not None:
266
+ module.to_q.bias.requires_grad = True
267
+ trainable_params_list.append(module.to_q.bias)
268
+ else:
269
+ unet_lora_config = LoraConfig(
270
+ r=lora_rank,
271
+ lora_alpha=lora_alpha,
272
+ init_lora_weights="gaussian",
273
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
274
+ )
275
+ self.unet.add_adapter(unet_lora_config)
276
+ print("training full parameters using lora!")
277
+ trainable_params_list = list(filter(lambda p: p.requires_grad, self.unet.parameters()))
278
+
279
+ self.text_encoder.to(device, dtype=weight_dtype)
280
+
281
+ optimizer = torch.optim.AdamW(trainable_params_list, lr=diffusion_model_learning_rate)
282
+ self.unet, optimizer = accelerator.prepare(self.unet, optimizer)
283
+ psum2 = sum(p.numel() for p in trainable_params_list)
284
+
285
+ effective_diffusion_train_steps = max_diffusion_train_steps // gradient_accumulation_steps
286
+ if accelerator.is_main_process:
287
+ accelerator.init_trackers("textual_inversion", config={
288
+ "diffusion_model_learning_rate": diffusion_model_learning_rate,
289
+ "diffusion_model_optimization_steps": effective_diffusion_train_steps,
290
+ })
291
+
292
+ global_step = 0
293
+ progress_bar = tqdm( range(0, effective_diffusion_train_steps),initial=global_step, desc="ModelSteps")
294
+
295
+ noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0" , subfolder="scheduler")
296
+
297
+ latents0 = image2latent(image_gt, vae = self.vae, dtype=weight_dtype)
298
+ latents0 = latents0.repeat(train_batch_size, 1, 1, 1)
299
+
300
+ with torch.no_grad():
301
+ encoder_hidden_states_list = sd_prepare_input_decom(
302
+ set_string_list,
303
+ self.tokenizer,
304
+ self.text_encoder,
305
+ length = 40,
306
+ bsz = train_batch_size,
307
+ weight_dtype = weight_dtype
308
+ )
309
+
310
+ for _ in range(max_diffusion_train_steps):
311
+ with accelerator.accumulate(self.unet):
312
+ latents = latents0.clone().detach()
313
+ noise = torch.randn_like(latents)
314
+ bsz = latents.shape[0]
315
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
316
+ timesteps = timesteps.long()
317
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
318
+ model_pred = self.unet(
319
+ noisy_latents,
320
+ timesteps,
321
+ encoder_hidden_states=encoder_hidden_states_list,
322
+ ).sample
323
+ loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
324
+ accelerator.backward(loss)
325
+ optimizer.step()
326
+ optimizer.zero_grad()
327
+
328
+ logs = {"loss": loss.detach().item(), "lr": diffusion_model_learning_rate}
329
+ progress_bar.set_postfix(**logs)
330
+ accelerator.log(logs, step=global_step)
331
+ if accelerator.sync_gradients:
332
+ progress_bar.update(1)
333
+ global_step += 1
334
+ if global_step >=max_diffusion_train_steps:
335
+ break
336
+ accelerator.wait_for_everyone()
337
+ accelerator.end_training()
338
+ self.unet = accelerator.unwrap_model(self.unet).to(dtype = weight_dtype)
339
+
340
+ def train_emb_2imgs(
341
+ self,
342
+ image_gt_1,
343
+ image_gt_2,
344
+ set_string_list_1,
345
+ set_string_list_2,
346
+ gradient_accumulation_steps = 5,
347
+ embedding_learning_rate = 1e-4,
348
+ max_emb_train_steps = 100,
349
+ train_batch_size = 1,
350
+ ):
351
+ decom_controller_1 = GroupedCAController(mask_list = self.mask_list)
352
+ decom_controller_2 = GroupedCAController(mask_list = self.mask_list_2)
353
+ accelerator = Accelerator(mixed_precision=self.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps)
354
+ self.vae.requires_grad_(False)
355
+ self.unet.requires_grad_(False)
356
+
357
+ self.text_encoder.requires_grad_(True)
358
+
359
+ self.text_encoder.text_model.encoder.requires_grad_(False)
360
+ self.text_encoder.text_model.final_layer_norm.requires_grad_(False)
361
+ self.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
362
+
363
+
364
+ weight_dtype = torch.float32
365
+ if accelerator.mixed_precision == "fp16":
366
+ weight_dtype = torch.float16
367
+ elif accelerator.mixed_precision == "bf16":
368
+ weight_dtype = torch.bfloat16
369
+
370
+ self.unet.to(device, dtype=weight_dtype)
371
+ self.vae.to(device, dtype=weight_dtype)
372
+
373
+
374
+ trainable_embmat_list_1 = [param for param in self.text_encoder.get_input_embeddings().parameters()]
375
+
376
+ optimizer = torch.optim.AdamW(trainable_embmat_list_1, lr=embedding_learning_rate)
377
+ self.text_encoder, optimizer= accelerator.prepare(self.text_encoder, optimizer) ###
378
+ orig_embeds_params_1 = accelerator.unwrap_model(self.text_encoder) .get_input_embeddings().weight.data.clone()
379
+
380
+ self.text_encoder.train()
381
+
382
+ effective_emb_train_steps = max_emb_train_steps//gradient_accumulation_steps
383
+
384
+ if accelerator.is_main_process:
385
+ accelerator.init_trackers("EmbFt", config={
386
+ "embedding_learning_rate": embedding_learning_rate,
387
+ "text_embedding_optimization_steps": effective_emb_train_steps,
388
+ })
389
+
390
+ global_step = 0
391
+
392
+ noise_scheduler = DDPMScheduler.from_pretrained(self.model_id , subfolder="scheduler")
393
+ progress_bar = tqdm(range(0, effective_emb_train_steps),initial=global_step,desc="EmbSteps")
394
+ latents0_1 = image2latent(image_gt_1, vae = self.vae, dtype=weight_dtype)
395
+ latents0_1 = latents0_1.repeat(train_batch_size,1,1,1)
396
+
397
+ latents0_2 = image2latent(image_gt_2, vae = self.vae, dtype=weight_dtype)
398
+ latents0_2 = latents0_2.repeat(train_batch_size,1,1,1)
399
+
400
+ for step in range(max_emb_train_steps):
401
+ with accelerator.accumulate(self.text_encoder):
402
+ latents_1 = latents0_1.clone().detach()
403
+ noise_1 = torch.randn_like(latents_1)
404
+
405
+ latents_2 = latents0_2.clone().detach()
406
+ noise_2 = torch.randn_like(latents_2)
407
+
408
+ bsz = latents_1.shape[0]
409
+
410
+ timesteps_1 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_1.device)
411
+ timesteps_1 = timesteps_1.long()
412
+ noisy_latents_1 = noise_scheduler.add_noise(latents_1, noise_1, timesteps_1)
413
+
414
+ timesteps_2 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_2.device)
415
+ timesteps_2 = timesteps_2.long()
416
+ noisy_latents_2 = noise_scheduler.add_noise(latents_2, noise_2, timesteps_2)
417
+
418
+ register_attention_disentangled_control(self.unet, decom_controller_1)
419
+ encoder_hidden_states_list_1 = sd_prepare_input_decom(
420
+ set_string_list_1,
421
+ self.tokenizer,
422
+ self.text_encoder,
423
+ length = 40,
424
+ bsz = train_batch_size,
425
+ weight_dtype = weight_dtype
426
+ )
427
+
428
+ model_pred_1 = self.unet(
429
+ noisy_latents_1,
430
+ timesteps_1,
431
+ encoder_hidden_states=encoder_hidden_states_list_1,
432
+ ).sample
433
+
434
+ register_attention_disentangled_control(self.unet, decom_controller_2)
435
+ # import pdb; pdb.set_trace()
436
+ encoder_hidden_states_list_2= sd_prepare_input_decom(
437
+ set_string_list_2,
438
+ self.tokenizer,
439
+ self.text_encoder,
440
+ length = 40,
441
+ bsz = train_batch_size,
442
+ weight_dtype = weight_dtype
443
+ )
444
+
445
+ model_pred_2 = self.unet(
446
+ noisy_latents_2,
447
+ timesteps_2,
448
+ encoder_hidden_states = encoder_hidden_states_list_2,
449
+ ).sample
450
+
451
+ loss_1 = F.mse_loss(model_pred_1.float(), noise_1.float(), reduction="mean") /2
452
+ loss_2 = F.mse_loss(model_pred_2.float(), noise_2.float(), reduction="mean") /2
453
+ loss = loss_1 + loss_2
454
+ accelerator.backward(loss)
455
+ optimizer.step()
456
+ optimizer.zero_grad()
457
+
458
+ index_no_updates = torch.ones((len(self.tokenizer),), dtype=torch.bool)
459
+ index_no_updates[self.min_added_id : self.max_added_id + 1] = False
460
+ with torch.no_grad():
461
+ accelerator.unwrap_model(self.text_encoder).get_input_embeddings().weight[
462
+ index_no_updates] = orig_embeds_params_1[index_no_updates]
463
+
464
+ logs = {"loss": loss.detach().item(), "lr": embedding_learning_rate}
465
+ progress_bar.set_postfix(**logs)
466
+ accelerator.log(logs, step=global_step)
467
+ if accelerator.sync_gradients:
468
+ progress_bar.update(1)
469
+ global_step += 1
470
+
471
+ if global_step >= max_emb_train_steps:
472
+ break
473
+ accelerator.wait_for_everyone()
474
+ accelerator.end_training()
475
+ self.text_encoder = accelerator.unwrap_model(self.text_encoder) .to(dtype = weight_dtype)
476
+
477
+ def train_model_2imgs(
478
+ self,
479
+ image_gt_1,
480
+ image_gt_2,
481
+ set_string_list_1,
482
+ set_string_list_2,
483
+ gradient_accumulation_steps = 5,
484
+ max_diffusion_train_steps = 100,
485
+ diffusion_model_learning_rate = 1e-5,
486
+ train_batch_size = 1,
487
+ train_full_lora = False,
488
+ lora_rank = 4,
489
+ lora_alpha = 4
490
+ ):
491
+ self.unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet").to(device)
492
+ self.unet.ca_dim = 768
493
+ decom_controller_1 = GroupedCAController(mask_list = self.mask_list)
494
+ decom_controller_2 = GroupedCAController(mask_list = self.mask_list_2)
495
+
496
+ mixed_precision = "fp16"
497
+ accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps,mixed_precision=mixed_precision)
498
+
499
+ weight_dtype = torch.float32
500
+ if accelerator.mixed_precision == "fp16":
501
+ weight_dtype = torch.float16
502
+ elif accelerator.mixed_precision == "bf16":
503
+ weight_dtype = torch.bfloat16
504
+
505
+
506
+ self.vae.requires_grad_(False)
507
+ self.vae.to(device, dtype=weight_dtype)
508
+ self.unet.requires_grad_(False)
509
+ self.unet.train()
510
+
511
+ self.text_encoder.requires_grad_(False)
512
+
513
+ if not train_full_lora:
514
+ trainable_params_list = []
515
+ for name, module in self.unet.named_modules():
516
+ module_name = type(module).__name__
517
+ if module_name == "Attention":
518
+ if module.to_k.in_features == self.unet.ca_dim: # this is cross attention:
519
+ module.to_k.weight.requires_grad = True
520
+ trainable_params_list.append(module.to_k.weight)
521
+ if module.to_k.bias is not None:
522
+ module.to_k.bias.requires_grad = True
523
+ trainable_params_list.append(module.to_k.bias)
524
+
525
+ module.to_v.weight.requires_grad = True
526
+ trainable_params_list.append(module.to_v.weight)
527
+ if module.to_v.bias is not None:
528
+ module.to_v.bias.requires_grad = True
529
+ trainable_params_list.append(module.to_v.bias)
530
+ module.to_q.weight.requires_grad = True
531
+ trainable_params_list.append(module.to_q.weight)
532
+ if module.to_q.bias is not None:
533
+ module.to_q.bias.requires_grad = True
534
+ trainable_params_list.append(module.to_q.bias)
535
+ else:
536
+ unet_lora_config = LoraConfig(
537
+ r = lora_rank,
538
+ lora_alpha = lora_alpha,
539
+ init_lora_weights="gaussian",
540
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
541
+ )
542
+ self.unet.add_adapter(unet_lora_config)
543
+ print("training full parameters using lora!")
544
+ trainable_params_list = list(filter(lambda p: p.requires_grad, self.unet.parameters()))
545
+
546
+ self.text_encoder.to(device, dtype=weight_dtype)
547
+ optimizer = torch.optim.AdamW(trainable_params_list, lr=diffusion_model_learning_rate)
548
+ self.unet, optimizer = accelerator.prepare(self.unet, optimizer)
549
+ psum2 = sum(p.numel() for p in trainable_params_list)
550
+
551
+ effective_diffusion_train_steps = max_diffusion_train_steps // gradient_accumulation_steps
552
+ if accelerator.is_main_process:
553
+ accelerator.init_trackers("ModelFt", config={
554
+ "diffusion_model_learning_rate": diffusion_model_learning_rate,
555
+ "diffusion_model_optimization_steps": effective_diffusion_train_steps,
556
+ })
557
+
558
+ global_step = 0
559
+ progress_bar = tqdm(range(0, effective_diffusion_train_steps),initial=global_step, desc="ModelSteps")
560
+ noise_scheduler = DDPMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
561
+
562
+ latents0_1 = image2latent(image_gt_1, vae = self.vae, dtype=weight_dtype)
563
+ latents0_1 = latents0_1.repeat(train_batch_size, 1, 1, 1)
564
+
565
+ latents0_2 = image2latent(image_gt_2, vae = self.vae, dtype=weight_dtype)
566
+ latents0_2 = latents0_2.repeat(train_batch_size,1, 1, 1)
567
+
568
+ with torch.no_grad():
569
+ encoder_hidden_states_list_1 = sd_prepare_input_decom(
570
+ set_string_list_1,
571
+ self.tokenizer,
572
+ self.text_encoder,
573
+ length = 40,
574
+ bsz = train_batch_size,
575
+ weight_dtype = weight_dtype
576
+ )
577
+ encoder_hidden_states_list_2 = sd_prepare_input_decom(
578
+ set_string_list_2,
579
+ self.tokenizer,
580
+ self.text_encoder,
581
+ length = 40,
582
+ bsz = train_batch_size,
583
+ weight_dtype = weight_dtype
584
+ )
585
+
586
+ for _ in range(max_diffusion_train_steps):
587
+ with accelerator.accumulate(self.unet):
588
+ latents_1 = latents0_1.clone().detach()
589
+ noise_1 = torch.randn_like(latents_1)
590
+ bsz = latents_1.shape[0]
591
+ timesteps_1 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_1.device)
592
+ timesteps_1 = timesteps_1.long()
593
+ noisy_latents_1 = noise_scheduler.add_noise(latents_1, noise_1, timesteps_1)
594
+
595
+ latents_2 = latents0_2.clone().detach()
596
+ noise_2 = torch.randn_like(latents_2)
597
+ bsz = latents_2.shape[0]
598
+ timesteps_2 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_2.device)
599
+ timesteps_2 = timesteps_2.long()
600
+ noisy_latents_2 = noise_scheduler.add_noise(latents_2, noise_2, timesteps_2)
601
+
602
+ register_attention_disentangled_control(self.unet, decom_controller_1)
603
+ model_pred_1 = self.unet(
604
+ noisy_latents_1,
605
+ timesteps_1,
606
+ encoder_hidden_states = encoder_hidden_states_list_1,
607
+ ).sample
608
+
609
+ register_attention_disentangled_control(self.unet, decom_controller_2)
610
+ model_pred_2 = self.unet(
611
+ noisy_latents_2,
612
+ timesteps_2,
613
+ encoder_hidden_states = encoder_hidden_states_list_2,
614
+ ).sample
615
+
616
+ loss_1 = F.mse_loss(model_pred_1.float(), noise_1.float(), reduction="mean")
617
+ loss_2 = F.mse_loss(model_pred_2.float(), noise_2.float(), reduction="mean")
618
+ loss = loss_1 + loss_2
619
+ accelerator.backward(loss)
620
+ optimizer.step()
621
+ optimizer.zero_grad()
622
+
623
+
624
+ logs = {"loss": loss.detach().item(), "lr": diffusion_model_learning_rate}
625
+ progress_bar.set_postfix(**logs)
626
+ accelerator.log(logs, step=global_step)
627
+ if accelerator.sync_gradients:
628
+ progress_bar.update(1)
629
+ global_step += 1
630
+
631
+ if global_step >=max_diffusion_train_steps:
632
+ break
633
+ accelerator.wait_for_everyone()
634
+ accelerator.end_training()
635
+ self.unet = accelerator.unwrap_model(self.unet).to(dtype = weight_dtype)
636
+
637
+ @torch.no_grad()
638
+ def backward_zT_to_z0_euler_decom(
639
+ self,
640
+ zT,
641
+ cond_emb_list,
642
+ uncond_emb=None,
643
+ guidance_scale = 1,
644
+ num_sampling_steps = 20,
645
+ cond_controller = None,
646
+ uncond_controller = None,
647
+ mask_hard = None,
648
+ mask_soft = None,
649
+ orig_image = None,
650
+ return_intermediate = False,
651
+ strength = 1
652
+ ):
653
+ latent_cur = zT
654
+ if uncond_emb is None:
655
+ uncond_emb = torch.zeros(zT.shape[0], 77, self.unet.ca_dim).to(dtype = zT.dtype, device = zT.device)
656
+
657
+ if mask_soft is not None:
658
+ init_latents_orig = image2latent(orig_image, self.vae, dtype=self.vae.dtype)
659
+ length = init_latents_orig.shape[-1]
660
+ noise = torch.randn_like(init_latents_orig)
661
+ mask_soft = torch.nn.functional.interpolate(mask_soft.float().unsqueeze(0).unsqueeze(0), (length, length)).to(self.vae.dtype) ###
662
+
663
+ if mask_hard is not None:
664
+ init_latents_orig = image2latent(orig_image, self.vae, dtype=self.vae.dtype)
665
+ length = init_latents_orig.shape[-1]
666
+ noise = torch.randn_like(init_latents_orig)
667
+ mask_hard = torch.nn.functional.interpolate(mask_hard.float().unsqueeze(0).unsqueeze(0), (length, length)).to(self.vae.dtype) ###
668
+
669
+ intermediate_list = [latent_cur.detach()]
670
+ for i in tqdm(range(num_sampling_steps)):
671
+ t = self.scheduler.timesteps[i]
672
+ latent_input = self.scheduler.scale_model_input(latent_cur, t)
673
+
674
+ register_attention_disentangled_control(self.unet, uncond_controller)
675
+ noise_pred_uncond = self.unet(
676
+ latent_input,
677
+ t,
678
+ encoder_hidden_states=uncond_emb,
679
+ ).sample
680
+
681
+ register_attention_disentangled_control(self.unet, cond_controller)
682
+ noise_pred_cond = self.unet(
683
+ latent_input,
684
+ t,
685
+ encoder_hidden_states=cond_emb_list,
686
+ ).sample
687
+
688
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
689
+ latent_cur = self.scheduler.step(noise_pred, t, latent_cur, generator = None, return_dict=False)[0]
690
+
691
+ if return_intermediate is True:
692
+ intermediate_list.append(latent_cur)
693
+
694
+ if mask_hard is not None and mask_soft is not None and i <= strength *num_sampling_steps:
695
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
696
+ mask = mask_soft.to(latent_cur.device, latent_cur.dtype) + mask_hard.to(latent_cur.device, latent_cur.dtype)
697
+ latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
698
+
699
+ elif mask_hard is not None and mask_soft is not None and i > strength *num_sampling_steps:
700
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
701
+ mask = mask_hard.to(latent_cur.device, latent_cur.dtype)
702
+ latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
703
+
704
+ elif mask_hard is None and mask_soft is not None and i <= strength *num_sampling_steps:
705
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
706
+ mask = mask_soft.to(latent_cur.device, latent_cur.dtype)
707
+ latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
708
+
709
+ elif mask_hard is None and mask_soft is not None and i > strength *num_sampling_steps:
710
+ pass
711
+
712
+ elif mask_hard is not None and mask_soft is None:
713
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
714
+ mask = mask_hard.to(latent_cur.dtype)
715
+ latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
716
+
717
+ else: # hard and soft are both none
718
+ pass
719
+
720
+ if return_intermediate is True:
721
+ return latent_cur, intermediate_list
722
+ else:
723
+ return latent_cur
724
+
725
+ @torch.no_grad()
726
+ def sampling(
727
+ self,
728
+ set_string_list,
729
+ cond_controller = None,
730
+ uncond_controller = None,
731
+ guidance_scale = 7,
732
+ num_sampling_steps = 20,
733
+ mask_hard = None,
734
+ mask_soft = None,
735
+ orig_image = None,
736
+ strength = 1.,
737
+ num_imgs = 1,
738
+ normal_token_id_list = [],
739
+ seed = 1
740
+ ):
741
+ weight_dtype = torch.float16
742
+ self.scheduler.set_timesteps(num_sampling_steps)
743
+ self.unet.to(device, dtype=weight_dtype)
744
+ self.vae.to(device, dtype=weight_dtype)
745
+ self.text_encoder.to(device, dtype=weight_dtype)
746
+
747
+ torch.manual_seed(seed)
748
+ torch.cuda.manual_seed(seed)
749
+
750
+ vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
751
+ zT = torch.randn(num_imgs, 4, self.resolution//vae_scale_factor,self.resolution//vae_scale_factor).to(device,dtype=weight_dtype)
752
+ zT = zT * self.scheduler.init_noise_sigma
753
+
754
+ cond_emb_list = sd_prepare_input_decom(
755
+ set_string_list,
756
+ self.tokenizer,
757
+ self.text_encoder,
758
+ length = 40,
759
+ bsz = num_imgs,
760
+ weight_dtype = weight_dtype,
761
+ normal_token_id_list = normal_token_id_list
762
+ )
763
+
764
+ z0 = self.backward_zT_to_z0_euler_decom(zT, cond_emb_list,
765
+ guidance_scale = guidance_scale, num_sampling_steps = num_sampling_steps,
766
+ cond_controller = cond_controller, uncond_controller = uncond_controller,
767
+ mask_hard = mask_hard, mask_soft = mask_soft, orig_image = orig_image, strength = strength
768
+ )
769
+ x0 = latent2image(z0, vae = self.vae)
770
+ return x0
771
+
772
+ @torch.no_grad()
773
+ def inference_with_mask(
774
+ self,
775
+ save_path,
776
+ guidance_scale = 3,
777
+ num_sampling_steps = 50,
778
+ strength = 1,
779
+ mask_soft = None,
780
+ mask_hard= None,
781
+ orig_image=None,
782
+ mask_list = None,
783
+ num_imgs = 1,
784
+ seed = 1,
785
+ set_string_list = None
786
+ ):
787
+ if mask_list is not None:
788
+ mask_list = [m.to(device) for m in mask_list]
789
+ else:
790
+ mask_list = self.mask_list
791
+ if set_string_list is not None:
792
+ self.set_string_list = set_string_list
793
+
794
+ if mask_hard is not None and mask_soft is not None:
795
+ check_mask_overlap_torch(mask_hard, mask_soft)
796
+ null_controller = DummyController()
797
+ decom_controller = GroupedCAController(mask_list = mask_list)
798
+
799
+ x0 = self.sampling(
800
+ self.set_string_list,
801
+ guidance_scale = guidance_scale,
802
+ num_sampling_steps = num_sampling_steps,
803
+ strength = strength,
804
+ cond_controller = decom_controller,
805
+ uncond_controller = null_controller,
806
+ mask_soft = mask_soft,
807
+ mask_hard = mask_hard,
808
+ orig_image = orig_image,
809
+ num_imgs = num_imgs,
810
+ seed = seed
811
+ )
812
+ save_images(x0, save_path)
813
+ return x0
pipeline_dedit_sdxl.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils import import_model_class_from_model_name_or_path
3
+ from transformers import AutoTokenizer
4
+ from diffusers import (
5
+ AutoencoderKL,
6
+ DDPMScheduler,
7
+ StableDiffusionXLPipeline,
8
+ UNet2DConditionModel,
9
+ )
10
+ from accelerate import Accelerator
11
+ from tqdm.auto import tqdm
12
+ from utils import sdxl_prepare_input_decom, save_images
13
+ import torch.nn.functional as F
14
+ import itertools
15
+ from peft import LoraConfig
16
+ from controller import GroupedCAController, register_attention_disentangled_control, DummyController
17
+ from utils import image2latent, latent2image
18
+ import matplotlib.pyplot as plt
19
+ from utils_mask import check_mask_overlap_torch
20
+
21
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
22
+ max_length = 40
23
+ class DEditSDXLPipeline:
24
+ def __init__(
25
+ self,
26
+ mask_list,
27
+ mask_label_list,
28
+ mask_list_2 = None,
29
+ mask_label_list_2 = None,
30
+ resolution = 1024,
31
+ num_tokens = 1
32
+ ):
33
+ super().__init__()
34
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
35
+ self.model_id = model_id
36
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", use_fast=False)
37
+ self.tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", use_fast=False)
38
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(model_id, subfolder = "text_encoder")
39
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(model_id, subfolder="text_encoder_2")
40
+ self.text_encoder = text_encoder_cls_one.from_pretrained(model_id, subfolder="text_encoder" ).to(device)
41
+ self.text_encoder_2 = text_encoder_cls_two.from_pretrained(model_id, subfolder="text_encoder_2").to(device)
42
+ self.unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet" )
43
+ self.unet.ca_dim = 2048
44
+ self.vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix")
45
+ self.scheduler = DDPMScheduler.from_pretrained(model_id , subfolder="scheduler")
46
+
47
+ self.mixed_precision = "fp16"
48
+ self.resolution = resolution
49
+ self.num_tokens = num_tokens
50
+
51
+ self.mask_list = mask_list
52
+ self.mask_label_list = mask_label_list
53
+ notation_token_list = [phrase.split(" ")[-1] for phrase in mask_label_list]
54
+ placeholder_token_list = ["#"+word+"{}".format(widx) for widx, word in enumerate(notation_token_list)]
55
+ self.set_string_list, placeholder_token_ids = self.add_tokens(placeholder_token_list)
56
+ self.min_added_id = min(placeholder_token_ids)
57
+ self.max_added_id = max(placeholder_token_ids)
58
+
59
+ if mask_list_2 is not None:
60
+ self.mask_list_2 = mask_list_2
61
+ self.mask_label_list_2 = mask_label_list_2
62
+ notation_token_list_2 = [phrase.split(" ")[-1] for phrase in mask_label_list_2]
63
+
64
+ placeholder_token_list_2 = ["$"+word+"{}".format(widx) for widx, word in enumerate(notation_token_list_2)]
65
+ self.set_string_list_2, placeholder_token_ids_2 = self.add_tokens(placeholder_token_list_2)
66
+ self.max_added_id = max(placeholder_token_ids_2)
67
+
68
+ def add_tokens_text_encoder_random_init(self, placeholder_token, num_tokens=1):
69
+ # Add the placeholder token in tokenizer
70
+ placeholder_tokens = [placeholder_token]
71
+ # add dummy tokens for multi-vector
72
+ additional_tokens = []
73
+ for i in range(1, num_tokens):
74
+ additional_tokens.append(f"{placeholder_token}_{i}")
75
+ placeholder_tokens += additional_tokens
76
+ num_added_tokens = self.tokenizer.add_tokens(placeholder_tokens) # 49408
77
+ num_added_tokens = self.tokenizer_2.add_tokens(placeholder_tokens) # 49408
78
+
79
+ if num_added_tokens != num_tokens:
80
+ raise ValueError(
81
+ f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
82
+ " `placeholder_token` that is not already in the tokenizer."
83
+ )
84
+ placeholder_token_ids = self.tokenizer.convert_tokens_to_ids(placeholder_tokens)
85
+ placeholder_token_ids_2 = self.tokenizer_2.convert_tokens_to_ids(placeholder_tokens)
86
+ assert placeholder_token_ids == placeholder_token_ids_2, "Two text encoders are expected to have same vocabs"
87
+
88
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
89
+ token_embeds = self.text_encoder.get_input_embeddings().weight.data
90
+ std, mean = torch.std_mean(token_embeds)
91
+ with torch.no_grad():
92
+ for token_id in placeholder_token_ids:
93
+ token_embeds[token_id] = torch.randn_like(token_embeds[token_id])*std + mean
94
+
95
+ self.text_encoder_2.resize_token_embeddings(len(self.tokenizer))
96
+ token_embeds = self.text_encoder_2.get_input_embeddings().weight.data
97
+ std, mean = torch.std_mean(token_embeds)
98
+ with torch.no_grad():
99
+ for token_id in placeholder_token_ids:
100
+ token_embeds[token_id] = torch.randn_like(token_embeds[token_id])*std + mean
101
+
102
+ set_string = " ".join(self.tokenizer.convert_ids_to_tokens(placeholder_token_ids))
103
+
104
+ return set_string, placeholder_token_ids
105
+
106
+ def add_tokens(self, placeholder_token_list):
107
+ set_string_list = []
108
+ placeholder_token_ids_list = []
109
+ for str_idx in range(len(placeholder_token_list)):
110
+ placeholder_token = placeholder_token_list[str_idx]
111
+ set_string, placeholder_token_ids = self.add_tokens_text_encoder_random_init(placeholder_token, num_tokens=self.num_tokens)
112
+ set_string_list.append(set_string)
113
+ placeholder_token_ids_list.append(placeholder_token_ids)
114
+ placeholder_token_ids = list(itertools.chain(*placeholder_token_ids_list))
115
+ return set_string_list, placeholder_token_ids
116
+
117
+ def train_emb(
118
+ self,
119
+ image_gt,
120
+ set_string_list,
121
+ gradient_accumulation_steps = 5,
122
+ embedding_learning_rate = 1e-4,
123
+ max_emb_train_steps = 100,
124
+ train_batch_size = 1,
125
+ train_full_lora = False
126
+ ):
127
+ decom_controller = GroupedCAController(mask_list = self.mask_list)
128
+ register_attention_disentangled_control(self.unet, decom_controller)
129
+
130
+ accelerator = Accelerator(mixed_precision=self.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps)
131
+ self.vae.requires_grad_(False)
132
+ self.unet.requires_grad_(False)
133
+
134
+ self.text_encoder.requires_grad_(True)
135
+ self.text_encoder_2.requires_grad_(True)
136
+
137
+ self.text_encoder.text_model.encoder.requires_grad_(False)
138
+ self.text_encoder.text_model.final_layer_norm.requires_grad_(False)
139
+ self.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
140
+
141
+ self.text_encoder_2.text_model.encoder.requires_grad_(False)
142
+ self.text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
143
+ self.text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
144
+
145
+ weight_dtype = torch.float32
146
+ if accelerator.mixed_precision == "fp16":
147
+ weight_dtype = torch.float16
148
+ elif accelerator.mixed_precision == "bf16":
149
+ weight_dtype = torch.bfloat16
150
+
151
+ self.unet.to(device, dtype=weight_dtype)
152
+ self.vae.to(device, dtype=weight_dtype)
153
+
154
+ trainable_embmat_list_1 = [param for param in self.text_encoder.get_input_embeddings().parameters()]
155
+ trainable_embmat_list_2 = [param for param in self.text_encoder_2.get_input_embeddings().parameters()]
156
+
157
+ optimizer = torch.optim.AdamW(trainable_embmat_list_1 + trainable_embmat_list_2, lr=embedding_learning_rate)
158
+
159
+ self.text_encoder, self.text_encoder_2, optimizer = accelerator.prepare(self.text_encoder, self.text_encoder_2, optimizer)
160
+
161
+ orig_embeds_params_1 = accelerator.unwrap_model(self.text_encoder) .get_input_embeddings().weight.data.clone()
162
+ orig_embeds_params_2 = accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight.data.clone()
163
+
164
+ self.text_encoder.train()
165
+ self.text_encoder_2.train()
166
+
167
+ effective_emb_train_steps = max_emb_train_steps//gradient_accumulation_steps
168
+
169
+ if accelerator.is_main_process:
170
+ accelerator.init_trackers("DEdit EmbSteps", config={
171
+ "embedding_learning_rate": embedding_learning_rate,
172
+ "text_embedding_optimization_steps": effective_emb_train_steps,
173
+ })
174
+ global_step = 0
175
+ noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0" , subfolder="scheduler")
176
+ progress_bar = tqdm(range(0, effective_emb_train_steps), initial = global_step, desc="EmbSteps")
177
+ latents0 = image2latent(image_gt, vae = self.vae, dtype=weight_dtype)
178
+ latents0 = latents0.repeat(train_batch_size, 1, 1, 1)
179
+
180
+ for _ in range(max_emb_train_steps):
181
+ with accelerator.accumulate(self.text_encoder, self.text_encoder_2):
182
+ latents = latents0.clone().detach()
183
+ noise = torch.randn_like(latents)
184
+ bsz = latents.shape[0]
185
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
186
+ timesteps = timesteps.long()
187
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
188
+ encoder_hidden_states_list, add_text_embeds, add_time_ids = sdxl_prepare_input_decom(
189
+ set_string_list,
190
+ self.tokenizer,
191
+ self.tokenizer_2,
192
+ self.text_encoder,
193
+ self.text_encoder_2,
194
+ length = max_length,
195
+ bsz = train_batch_size,
196
+ weight_dtype = weight_dtype
197
+ )
198
+
199
+ model_pred = self.unet(
200
+ noisy_latents,
201
+ timesteps,
202
+ encoder_hidden_states = encoder_hidden_states_list,
203
+ cross_attention_kwargs = None,
204
+ added_cond_kwargs={"text_embeds": add_text_embeds, "time_ids": add_time_ids},
205
+ return_dict=False
206
+ )[0]
207
+ loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
208
+ accelerator.backward(loss)
209
+ optimizer.step()
210
+ optimizer.zero_grad()
211
+
212
+ index_no_updates = torch.ones((len(self.tokenizer),), dtype=torch.bool)
213
+ index_no_updates[self.min_added_id : self.max_added_id + 1] = False
214
+ with torch.no_grad():
215
+ accelerator.unwrap_model(self.text_encoder).get_input_embeddings().weight[
216
+ index_no_updates] = orig_embeds_params_1[index_no_updates]
217
+
218
+ index_no_updates = torch.ones((len(self.tokenizer_2),), dtype=torch.bool)
219
+ index_no_updates[self.min_added_id : self.max_added_id + 1] = False
220
+ with torch.no_grad():
221
+ accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight[
222
+ index_no_updates] = orig_embeds_params_2[index_no_updates]
223
+
224
+ logs = {"loss": loss.detach().item(), "lr": embedding_learning_rate}
225
+ progress_bar.set_postfix(**logs)
226
+ accelerator.log(logs, step=global_step)
227
+ if accelerator.sync_gradients:
228
+ progress_bar.update(1)
229
+ global_step += 1
230
+
231
+ if global_step >= max_emb_train_steps:
232
+ break
233
+ accelerator.wait_for_everyone()
234
+ accelerator.end_training()
235
+ self.text_encoder = accelerator.unwrap_model(self.text_encoder).to(dtype = weight_dtype)
236
+ self.text_encoder_2 = accelerator.unwrap_model(self.text_encoder_2).to(dtype = weight_dtype)
237
+
238
+ def train_model(
239
+ self,
240
+ image_gt,
241
+ set_string_list,
242
+ gradient_accumulation_steps = 5,
243
+ max_diffusion_train_steps = 100,
244
+ diffusion_model_learning_rate = 1e-5,
245
+ train_batch_size = 1,
246
+ train_full_lora = False,
247
+ lora_rank = 4,
248
+ lora_alpha = 4
249
+ ):
250
+ self.unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet").to(device)
251
+ self.unet.ca_dim = 2048
252
+ decom_controller = GroupedCAController(mask_list = self.mask_list)
253
+ register_attention_disentangled_control(self.unet, decom_controller)
254
+
255
+ mixed_precision = "fp16"
256
+ accelerator = Accelerator(gradient_accumulation_steps = gradient_accumulation_steps, mixed_precision = mixed_precision)
257
+
258
+ weight_dtype = torch.float32
259
+ if accelerator.mixed_precision == "fp16":
260
+ weight_dtype = torch.float16
261
+ elif accelerator.mixed_precision == "bf16":
262
+ weight_dtype = torch.bfloat16
263
+
264
+ self.vae.requires_grad_(False)
265
+ self.vae.to(device, dtype=weight_dtype)
266
+
267
+ self.unet.requires_grad_(False)
268
+ self.unet.train()
269
+
270
+ self.text_encoder.requires_grad_(False)
271
+ self.text_encoder_2.requires_grad_(False)
272
+
273
+ if not train_full_lora:
274
+ trainable_params_list = []
275
+ for _, module in self.unet.named_modules():
276
+ module_name = type(module).__name__
277
+ if module_name == "Attention":
278
+ if module.to_k.in_features == 2048: # this is cross attention:
279
+ module.to_k.weight.requires_grad = True
280
+ trainable_params_list.append(module.to_k.weight)
281
+ if module.to_k.bias is not None:
282
+ module.to_k.bias.requires_grad = True
283
+ trainable_params_list.append(module.to_k.bias)
284
+ module.to_v.weight.requires_grad = True
285
+ trainable_params_list.append(module.to_v.weight)
286
+ if module.to_v.bias is not None:
287
+ module.to_v.bias.requires_grad = True
288
+ trainable_params_list.append(module.to_v.bias)
289
+ module.to_q.weight.requires_grad = True
290
+ trainable_params_list.append(module.to_q.weight)
291
+ if module.to_q.bias is not None:
292
+ module.to_q.bias.requires_grad = True
293
+ trainable_params_list.append(module.to_q.bias)
294
+ else:
295
+ unet_lora_config = LoraConfig(
296
+ r=lora_rank,
297
+ lora_alpha=lora_alpha,
298
+ init_lora_weights="gaussian",
299
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
300
+ )
301
+ self.unet.add_adapter(unet_lora_config)
302
+ print("training full parameters using lora!")
303
+ trainable_params_list = list(filter(lambda p: p.requires_grad, self.unet.parameters()))
304
+
305
+ self.text_encoder.to(device, dtype=weight_dtype)
306
+ self.text_encoder_2.to(device, dtype=weight_dtype)
307
+ optimizer = torch.optim.AdamW(trainable_params_list, lr=diffusion_model_learning_rate)
308
+ self.unet, optimizer = accelerator.prepare(self.unet, optimizer)
309
+ psum2 = sum(p.numel() for p in trainable_params_list)
310
+
311
+ effective_diffusion_train_steps = max_diffusion_train_steps // gradient_accumulation_steps
312
+ if accelerator.is_main_process:
313
+ accelerator.init_trackers("textual_inversion", config={
314
+ "diffusion_model_learning_rate": diffusion_model_learning_rate,
315
+ "diffusion_model_optimization_steps": effective_diffusion_train_steps,
316
+ })
317
+
318
+ global_step = 0
319
+ progress_bar = tqdm( range(0, effective_diffusion_train_steps),initial=global_step, desc="ModelSteps")
320
+
321
+ noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0" , subfolder="scheduler")
322
+
323
+ latents0 = image2latent(image_gt, vae = self.vae, dtype=weight_dtype)
324
+ latents0 = latents0.repeat(train_batch_size, 1, 1, 1)
325
+
326
+ with torch.no_grad():
327
+ encoder_hidden_states_list, add_text_embeds, add_time_ids = sdxl_prepare_input_decom(
328
+ set_string_list,
329
+ self.tokenizer,
330
+ self.tokenizer_2,
331
+ self.text_encoder,
332
+ self.text_encoder_2,
333
+ length = max_length,
334
+ bsz = train_batch_size,
335
+ weight_dtype = weight_dtype
336
+ )
337
+
338
+ for _ in range(max_diffusion_train_steps):
339
+ with accelerator.accumulate(self.unet):
340
+ latents = latents0.clone().detach()
341
+ noise = torch.randn_like(latents)
342
+ bsz = latents.shape[0]
343
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
344
+ timesteps = timesteps.long()
345
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
346
+ model_pred = self.unet(
347
+ noisy_latents,
348
+ timesteps,
349
+ encoder_hidden_states=encoder_hidden_states_list,
350
+ cross_attention_kwargs=None, return_dict=False,
351
+ added_cond_kwargs={"text_embeds": add_text_embeds, "time_ids": add_time_ids}
352
+ )[0]
353
+ loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
354
+ accelerator.backward(loss)
355
+ optimizer.step()
356
+ optimizer.zero_grad()
357
+
358
+ logs = {"loss": loss.detach().item(), "lr": diffusion_model_learning_rate}
359
+ progress_bar.set_postfix(**logs)
360
+ accelerator.log(logs, step=global_step)
361
+ if accelerator.sync_gradients:
362
+ progress_bar.update(1)
363
+ global_step += 1
364
+ if global_step >=max_diffusion_train_steps:
365
+ break
366
+ accelerator.wait_for_everyone()
367
+ accelerator.end_training()
368
+ self.unet = accelerator.unwrap_model(self.unet).to(dtype = weight_dtype)
369
+
370
+ def train_emb_2imgs(
371
+ self,
372
+ image_gt_1,
373
+ image_gt_2,
374
+ set_string_list_1,
375
+ set_string_list_2,
376
+ gradient_accumulation_steps = 5,
377
+ embedding_learning_rate = 1e-4,
378
+ max_emb_train_steps = 100,
379
+ train_batch_size = 1,
380
+ train_full_lora = False
381
+ ):
382
+ decom_controller_1 = GroupedCAController(mask_list = self.mask_list)
383
+ decom_controller_2 = GroupedCAController(mask_list = self.mask_list_2)
384
+ accelerator = Accelerator(mixed_precision=self.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps)
385
+ self.vae.requires_grad_(False)
386
+ self.unet.requires_grad_(False)
387
+
388
+ self.text_encoder.requires_grad_(True)
389
+ self.text_encoder_2.requires_grad_(True)
390
+
391
+ self.text_encoder.text_model.encoder.requires_grad_(False)
392
+ self.text_encoder.text_model.final_layer_norm.requires_grad_(False)
393
+ self.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
394
+
395
+ self.text_encoder_2.text_model.encoder.requires_grad_(False)
396
+ self.text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
397
+ self.text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
398
+
399
+ weight_dtype = torch.float32
400
+ if accelerator.mixed_precision == "fp16":
401
+ weight_dtype = torch.float16
402
+ elif accelerator.mixed_precision == "bf16":
403
+ weight_dtype = torch.bfloat16
404
+
405
+ self.unet.to(device, dtype=weight_dtype)
406
+ self.vae.to(device, dtype=weight_dtype)
407
+
408
+
409
+ trainable_embmat_list_1 = [param for param in self.text_encoder.get_input_embeddings().parameters()]
410
+ trainable_embmat_list_2 = [param for param in self.text_encoder_2.get_input_embeddings().parameters()]
411
+
412
+ optimizer = torch.optim.AdamW(trainable_embmat_list_1 + trainable_embmat_list_2, lr=embedding_learning_rate)
413
+ self.text_encoder, self.text_encoder_2, optimizer= accelerator.prepare(self.text_encoder, self.text_encoder_2, optimizer) ###
414
+ orig_embeds_params_1 = accelerator.unwrap_model(self.text_encoder) .get_input_embeddings().weight.data.clone()
415
+ orig_embeds_params_2 = accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight.data.clone()
416
+
417
+ self.text_encoder.train()
418
+ self.text_encoder_2.train()
419
+
420
+ effective_emb_train_steps = max_emb_train_steps//gradient_accumulation_steps
421
+
422
+ if accelerator.is_main_process:
423
+ accelerator.init_trackers("EmbFt", config={
424
+ "embedding_learning_rate": embedding_learning_rate,
425
+ "text_embedding_optimization_steps": effective_emb_train_steps,
426
+ })
427
+
428
+ global_step = 0
429
+
430
+ noise_scheduler = DDPMScheduler.from_pretrained(self.model_id , subfolder="scheduler")
431
+ progress_bar = tqdm(range(0, effective_emb_train_steps),initial=global_step,desc="EmbSteps")
432
+ latents0_1 = image2latent(image_gt_1, vae = self.vae, dtype=weight_dtype)
433
+ latents0_1 = latents0_1.repeat(train_batch_size,1,1,1)
434
+
435
+ latents0_2 = image2latent(image_gt_2, vae = self.vae, dtype=weight_dtype)
436
+ latents0_2 = latents0_2.repeat(train_batch_size,1,1,1)
437
+
438
+ for step in range(max_emb_train_steps):
439
+ with accelerator.accumulate(self.text_encoder, self.text_encoder_2):
440
+ latents_1 = latents0_1.clone().detach()
441
+ noise_1 = torch.randn_like(latents_1)
442
+
443
+ latents_2 = latents0_2.clone().detach()
444
+ noise_2 = torch.randn_like(latents_2)
445
+
446
+ bsz = latents_1.shape[0]
447
+
448
+ timesteps_1 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_1.device)
449
+ timesteps_1 = timesteps_1.long()
450
+ noisy_latents_1 = noise_scheduler.add_noise(latents_1, noise_1, timesteps_1)
451
+
452
+ timesteps_2 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_2.device)
453
+ timesteps_2 = timesteps_2.long()
454
+ noisy_latents_2 = noise_scheduler.add_noise(latents_2, noise_2, timesteps_2)
455
+
456
+ register_attention_disentangled_control(self.unet, decom_controller_1)
457
+ encoder_hidden_states_list_1, add_text_embeds_1, add_time_ids_1 = sdxl_prepare_input_decom(
458
+ set_string_list_1,
459
+ self.tokenizer,
460
+ self.tokenizer_2,
461
+ self.text_encoder,
462
+ self.text_encoder_2,
463
+ length = max_length,
464
+ bsz = train_batch_size,
465
+ weight_dtype = weight_dtype
466
+ )
467
+
468
+ model_pred_1 = self.unet(
469
+ noisy_latents_1,
470
+ timesteps_1,
471
+ encoder_hidden_states=encoder_hidden_states_list_1,
472
+ cross_attention_kwargs=None,
473
+ added_cond_kwargs={"text_embeds": add_text_embeds_1, "time_ids": add_time_ids_1},
474
+ return_dict=False
475
+ )[0]
476
+
477
+ register_attention_disentangled_control(self.unet, decom_controller_2)
478
+ # import pdb; pdb.set_trace()
479
+ encoder_hidden_states_list_2, add_text_embeds_2, add_time_ids_2 = sdxl_prepare_input_decom(
480
+ set_string_list_2,
481
+ self.tokenizer,
482
+ self.tokenizer_2,
483
+ self.text_encoder,
484
+ self.text_encoder_2,
485
+ length = max_length,
486
+ bsz = train_batch_size,
487
+ weight_dtype = weight_dtype
488
+ )
489
+
490
+ model_pred_2 = self.unet(
491
+ noisy_latents_2,
492
+ timesteps_2,
493
+ encoder_hidden_states = encoder_hidden_states_list_2,
494
+ cross_attention_kwargs=None,
495
+ added_cond_kwargs={"text_embeds": add_text_embeds_2, "time_ids": add_time_ids_2},
496
+ return_dict=False
497
+ )[0]
498
+
499
+ loss_1 = F.mse_loss(model_pred_1.float(), noise_1.float(), reduction="mean") /2
500
+ loss_2 = F.mse_loss(model_pred_2.float(), noise_2.float(), reduction="mean") /2
501
+ loss = loss_1 + loss_2
502
+ accelerator.backward(loss)
503
+ optimizer.step()
504
+ optimizer.zero_grad()
505
+
506
+ index_no_updates = torch.ones((len(self.tokenizer),), dtype=torch.bool)
507
+ index_no_updates[self.min_added_id : self.max_added_id + 1] = False
508
+ with torch.no_grad():
509
+ accelerator.unwrap_model(self.text_encoder).get_input_embeddings().weight[
510
+ index_no_updates] = orig_embeds_params_1[index_no_updates]
511
+ index_no_updates = torch.ones((len(self.tokenizer_2),), dtype=torch.bool)
512
+ index_no_updates[self.min_added_id : self.max_added_id + 1] = False
513
+ with torch.no_grad():
514
+ accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight[
515
+ index_no_updates] = orig_embeds_params_2[index_no_updates]
516
+
517
+ logs = {"loss": loss.detach().item(), "lr": embedding_learning_rate}
518
+ progress_bar.set_postfix(**logs)
519
+ accelerator.log(logs, step=global_step)
520
+ if accelerator.sync_gradients:
521
+ progress_bar.update(1)
522
+ global_step += 1
523
+
524
+ if global_step >= max_emb_train_steps:
525
+ break
526
+ accelerator.wait_for_everyone()
527
+ accelerator.end_training()
528
+ self.text_encoder = accelerator.unwrap_model(self.text_encoder) .to(dtype = weight_dtype)
529
+ self.text_encoder_2 = accelerator.unwrap_model(self.text_encoder_2).to(dtype = weight_dtype)
530
+
531
+ def train_model_2imgs(
532
+ self,
533
+ image_gt_1,
534
+ image_gt_2,
535
+ set_string_list_1,
536
+ set_string_list_2,
537
+ gradient_accumulation_steps = 5,
538
+ max_diffusion_train_steps = 100,
539
+ diffusion_model_learning_rate = 1e-5,
540
+ train_batch_size = 1,
541
+ train_full_lora = False,
542
+ lora_rank = 4,
543
+ lora_alpha = 4
544
+ ):
545
+ self.unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet").to(device)
546
+ self.unet.ca_dim = 2048
547
+ decom_controller_1 = GroupedCAController(mask_list = self.mask_list)
548
+ decom_controller_2 = GroupedCAController(mask_list = self.mask_list_2)
549
+
550
+ mixed_precision = "fp16"
551
+ accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps,mixed_precision=mixed_precision)
552
+
553
+ weight_dtype = torch.float32
554
+ if accelerator.mixed_precision == "fp16":
555
+ weight_dtype = torch.float16
556
+ elif accelerator.mixed_precision == "bf16":
557
+ weight_dtype = torch.bfloat16
558
+
559
+
560
+ self.vae.requires_grad_(False)
561
+ self.vae.to(device, dtype=weight_dtype)
562
+ self.unet.requires_grad_(False)
563
+ self.unet.train()
564
+
565
+ self.text_encoder.requires_grad_(False)
566
+ self.text_encoder_2.requires_grad_(False)
567
+ if not train_full_lora:
568
+ trainable_params_list = []
569
+ for name, module in self.unet.named_modules():
570
+ module_name = type(module).__name__
571
+ if module_name == "Attention":
572
+ if module.to_k.in_features == 2048: # this is cross attention:
573
+ module.to_k.weight.requires_grad = True
574
+ trainable_params_list.append(module.to_k.weight)
575
+ if module.to_k.bias is not None:
576
+ module.to_k.bias.requires_grad = True
577
+ trainable_params_list.append(module.to_k.bias)
578
+
579
+ module.to_v.weight.requires_grad = True
580
+ trainable_params_list.append(module.to_v.weight)
581
+ if module.to_v.bias is not None:
582
+ module.to_v.bias.requires_grad = True
583
+ trainable_params_list.append(module.to_v.bias)
584
+ module.to_q.weight.requires_grad = True
585
+ trainable_params_list.append(module.to_q.weight)
586
+ if module.to_q.bias is not None:
587
+ module.to_q.bias.requires_grad = True
588
+ trainable_params_list.append(module.to_q.bias)
589
+ else:
590
+ unet_lora_config = LoraConfig(
591
+ r = lora_rank,
592
+ lora_alpha = lora_alpha,
593
+ init_lora_weights="gaussian",
594
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
595
+ )
596
+ self.unet.add_adapter(unet_lora_config)
597
+ print("training full parameters using lora!")
598
+ trainable_params_list = list(filter(lambda p: p.requires_grad, self.unet.parameters()))
599
+
600
+ self.text_encoder.to(device, dtype=weight_dtype)
601
+ self.text_encoder_2.to(device, dtype=weight_dtype)
602
+ optimizer = torch.optim.AdamW(trainable_params_list, lr=diffusion_model_learning_rate)
603
+ self.unet, optimizer = accelerator.prepare(self.unet, optimizer)
604
+ psum2 = sum(p.numel() for p in trainable_params_list)
605
+
606
+ effective_diffusion_train_steps = max_diffusion_train_steps // gradient_accumulation_steps
607
+ if accelerator.is_main_process:
608
+ accelerator.init_trackers("ModelFt", config={
609
+ "diffusion_model_learning_rate": diffusion_model_learning_rate,
610
+ "diffusion_model_optimization_steps": effective_diffusion_train_steps,
611
+ })
612
+
613
+ global_step = 0
614
+ progress_bar = tqdm(range(0, effective_diffusion_train_steps),initial=global_step, desc="ModelSteps")
615
+ noise_scheduler = DDPMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
616
+
617
+ latents0_1 = image2latent(image_gt_1, vae = self.vae, dtype=weight_dtype)
618
+ latents0_1 = latents0_1.repeat(train_batch_size, 1, 1, 1)
619
+
620
+ latents0_2 = image2latent(image_gt_2, vae = self.vae, dtype=weight_dtype)
621
+ latents0_2 = latents0_2.repeat(train_batch_size,1, 1, 1)
622
+
623
+ with torch.no_grad():
624
+ encoder_hidden_states_list_1, add_text_embeds_1, add_time_ids_1 = sdxl_prepare_input_decom(
625
+ set_string_list_1,
626
+ self.tokenizer,
627
+ self.tokenizer_2,
628
+ self.text_encoder,
629
+ self.text_encoder_2,
630
+ length = max_length,
631
+ bsz = train_batch_size,
632
+ weight_dtype = weight_dtype
633
+ )
634
+ encoder_hidden_states_list_2, add_text_embeds_2, add_time_ids_2 = sdxl_prepare_input_decom(
635
+ set_string_list_2,
636
+ self.tokenizer,
637
+ self.tokenizer_2,
638
+ self.text_encoder,
639
+ self.text_encoder_2,
640
+ length = max_length,
641
+ bsz = train_batch_size,
642
+ weight_dtype = weight_dtype
643
+ )
644
+
645
+ for _ in range(max_diffusion_train_steps):
646
+ with accelerator.accumulate(self.unet):
647
+ latents_1 = latents0_1.clone().detach()
648
+ noise_1 = torch.randn_like(latents_1)
649
+ bsz = latents_1.shape[0]
650
+ timesteps_1 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_1.device)
651
+ timesteps_1 = timesteps_1.long()
652
+ noisy_latents_1 = noise_scheduler.add_noise(latents_1, noise_1, timesteps_1)
653
+
654
+ latents_2 = latents0_2.clone().detach()
655
+ noise_2 = torch.randn_like(latents_2)
656
+ bsz = latents_2.shape[0]
657
+ timesteps_2 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_2.device)
658
+ timesteps_2 = timesteps_2.long()
659
+ noisy_latents_2 = noise_scheduler.add_noise(latents_2, noise_2, timesteps_2)
660
+
661
+ register_attention_disentangled_control(self.unet, decom_controller_1)
662
+ model_pred_1 = self.unet(
663
+ noisy_latents_1,
664
+ timesteps_1,
665
+ encoder_hidden_states = encoder_hidden_states_list_1,
666
+ cross_attention_kwargs = None,
667
+ return_dict = False,
668
+ added_cond_kwargs = {"text_embeds": add_text_embeds_1, "time_ids": add_time_ids_1}
669
+ )[0]
670
+
671
+ register_attention_disentangled_control(self.unet, decom_controller_2)
672
+ model_pred_2 = self.unet(
673
+ noisy_latents_2,
674
+ timesteps_2,
675
+ encoder_hidden_states = encoder_hidden_states_list_2,
676
+ cross_attention_kwargs = None,
677
+ return_dict=False,
678
+ added_cond_kwargs={"text_embeds": add_text_embeds_2, "time_ids": add_time_ids_2}
679
+ )[0]
680
+
681
+ loss_1 = F.mse_loss(model_pred_1.float(), noise_1.float(), reduction="mean")
682
+ loss_2 = F.mse_loss(model_pred_2.float(), noise_2.float(), reduction="mean")
683
+ loss = loss_1 + loss_2
684
+ accelerator.backward(loss)
685
+ optimizer.step()
686
+ optimizer.zero_grad()
687
+
688
+
689
+ logs = {"loss": loss.detach().item(), "lr": diffusion_model_learning_rate}
690
+ progress_bar.set_postfix(**logs)
691
+ accelerator.log(logs, step=global_step)
692
+ if accelerator.sync_gradients:
693
+ progress_bar.update(1)
694
+ global_step += 1
695
+
696
+ if global_step >=max_diffusion_train_steps:
697
+ break
698
+ accelerator.wait_for_everyone()
699
+ accelerator.end_training()
700
+ self.unet = accelerator.unwrap_model(self.unet).to(dtype = weight_dtype)
701
+
702
+ @torch.no_grad()
703
+ def backward_zT_to_z0_euler_decom(
704
+ self,
705
+ zT,
706
+ cond_emb_list,
707
+ cond_add_text_embeds,
708
+ add_time_ids,
709
+ uncond_emb=None,
710
+ guidance_scale = 1,
711
+ num_sampling_steps = 20,
712
+ cond_controller = None,
713
+ uncond_controller = None,
714
+ mask_hard = None,
715
+ mask_soft = None,
716
+ orig_image = None,
717
+ return_intermediate = False,
718
+ strength = 1
719
+ ):
720
+ latent_cur = zT
721
+ if uncond_emb is None:
722
+ uncond_emb = torch.zeros(zT.shape[0], 77, 2048).to(dtype = zT.dtype, device = zT.device)
723
+ uncond_add_text_embeds = torch.zeros(1, 1280).to(dtype = zT.dtype, device = zT.device)
724
+ if mask_soft is not None:
725
+ init_latents_orig = image2latent(orig_image, self.vae, dtype=self.vae.dtype)
726
+ length = init_latents_orig.shape[-1]
727
+ noise = torch.randn_like(init_latents_orig)
728
+ mask_soft = torch.nn.functional.interpolate(mask_soft.float().unsqueeze(0).unsqueeze(0), (length, length)).to(self.vae.dtype) ###
729
+ if mask_hard is not None:
730
+ init_latents_orig = image2latent(orig_image, self.vae, dtype=self.vae.dtype)
731
+ length = init_latents_orig.shape[-1]
732
+ noise = torch.randn_like(init_latents_orig)
733
+ mask_hard = torch.nn.functional.interpolate(mask_hard.float().unsqueeze(0).unsqueeze(0), (length, length)).to(self.vae.dtype) ###
734
+
735
+ intermediate_list = [latent_cur.detach()]
736
+ for i in tqdm(range(num_sampling_steps)):
737
+ t = self.scheduler.timesteps[i]
738
+ latent_input = self.scheduler.scale_model_input(latent_cur, t)
739
+
740
+ register_attention_disentangled_control(self.unet, uncond_controller)
741
+ noise_pred_uncond = self.unet(latent_input, t,
742
+ encoder_hidden_states=uncond_emb,
743
+ added_cond_kwargs={"text_embeds": uncond_add_text_embeds, "time_ids": add_time_ids},
744
+ return_dict=False,)[0]
745
+
746
+ register_attention_disentangled_control(self.unet, cond_controller)
747
+ noise_pred_cond = self.unet(latent_input, t,
748
+ encoder_hidden_states=cond_emb_list,
749
+ added_cond_kwargs={"text_embeds": cond_add_text_embeds, "time_ids": add_time_ids},
750
+ return_dict=False,)[0]
751
+
752
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
753
+ latent_cur = self.scheduler.step(noise_pred, t, latent_cur, generator = None, return_dict=False)[0]
754
+ if return_intermediate is True:
755
+ intermediate_list.append(latent_cur)
756
+ if mask_hard is not None and mask_soft is not None and i <= strength *num_sampling_steps:
757
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
758
+ mask = mask_soft.to(latent_cur.device, latent_cur.dtype) + mask_hard.to(latent_cur.device, latent_cur.dtype)
759
+ latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
760
+
761
+ elif mask_hard is not None and mask_soft is not None and i > strength *num_sampling_steps:
762
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
763
+ mask = mask_hard.to(latent_cur.device, latent_cur.dtype)
764
+ latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
765
+
766
+ elif mask_hard is None and mask_soft is not None and i <= strength *num_sampling_steps:
767
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
768
+ mask = mask_soft.to(latent_cur.device, latent_cur.dtype)
769
+ latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
770
+
771
+ elif mask_hard is None and mask_soft is not None and i > strength *num_sampling_steps:
772
+ pass
773
+
774
+ elif mask_hard is not None and mask_soft is None:
775
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
776
+ mask = mask_hard.to(latent_cur.dtype)
777
+ latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
778
+
779
+ else: # hard and soft are both none
780
+ pass
781
+
782
+ if return_intermediate is True:
783
+ return latent_cur, intermediate_list
784
+ else:
785
+ return latent_cur
786
+
787
+ @torch.no_grad()
788
+ def sampling(
789
+ self,
790
+ set_string_list,
791
+ cond_controller = None,
792
+ uncond_controller = None,
793
+ guidance_scale = 7,
794
+ num_sampling_steps = 20,
795
+ mask_hard = None,
796
+ mask_soft = None,
797
+ orig_image = None,
798
+ strength = 1.,
799
+ num_imgs = 1,
800
+ normal_token_id_list = [],
801
+ seed = 1
802
+ ):
803
+ weight_dtype = torch.float16
804
+ self.scheduler.set_timesteps(num_sampling_steps)
805
+ self.unet.to(device, dtype=weight_dtype)
806
+ self.vae.to(device, dtype=weight_dtype)
807
+ self.text_encoder.to(device, dtype=weight_dtype)
808
+ self.text_encoder_2.to(device, dtype=weight_dtype)
809
+ torch.manual_seed(seed)
810
+ torch.cuda.manual_seed(seed)
811
+
812
+ vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
813
+ zT = torch.randn(num_imgs, 4, self.resolution//vae_scale_factor,self.resolution//vae_scale_factor).to(device,dtype=weight_dtype)
814
+ zT = zT * self.scheduler.init_noise_sigma
815
+
816
+ cond_emb_list, cond_add_text_embeds, add_time_ids = sdxl_prepare_input_decom(
817
+ set_string_list,
818
+ self.tokenizer,
819
+ self.tokenizer_2,
820
+ self.text_encoder,
821
+ self.text_encoder_2,
822
+ length = max_length,
823
+ bsz = num_imgs,
824
+ weight_dtype = weight_dtype,
825
+ normal_token_id_list = normal_token_id_list
826
+ )
827
+
828
+ z0 = self.backward_zT_to_z0_euler_decom(zT, cond_emb_list, cond_add_text_embeds, add_time_ids,
829
+ guidance_scale = guidance_scale, num_sampling_steps = num_sampling_steps,
830
+ cond_controller = cond_controller, uncond_controller = uncond_controller,
831
+ mask_hard = mask_hard, mask_soft = mask_soft, orig_image =orig_image, strength = strength
832
+ )
833
+ x0 = latent2image(z0, vae = self.vae)
834
+ return x0
835
+
836
+ @torch.no_grad()
837
+ def inference_with_mask(
838
+ self,
839
+ save_path,
840
+ guidance_scale = 3,
841
+ num_sampling_steps = 50,
842
+ strength = 1,
843
+ mask_soft = None,
844
+ mask_hard= None,
845
+ orig_image=None,
846
+ mask_list = None,
847
+ num_imgs = 1,
848
+ seed = 1,
849
+ set_string_list = None
850
+ ):
851
+ if mask_list is not None:
852
+ mask_list = [m.to(device) for m in mask_list]
853
+ else:
854
+ mask_list = self.mask_list
855
+ if set_string_list is not None:
856
+ self.set_string_list = set_string_list
857
+
858
+ if mask_hard is not None and mask_soft is not None:
859
+ check_mask_overlap_torch(mask_hard, mask_soft)
860
+ null_controller = DummyController()
861
+ decom_controller = GroupedCAController(mask_list = mask_list)
862
+ x0 = self.sampling(
863
+ self.set_string_list,
864
+ guidance_scale = guidance_scale,
865
+ num_sampling_steps = num_sampling_steps,
866
+ strength = strength,
867
+ cond_controller = decom_controller,
868
+ uncond_controller = null_controller,
869
+ mask_soft = mask_soft,
870
+ mask_hard = mask_hard,
871
+ orig_image = orig_image,
872
+ num_imgs = num_imgs,
873
+ seed = seed
874
+ )
875
+ save_images(x0, save_path)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.2.0
2
+ torchvision==0.17.0
3
+ transformers==4.37.2
4
+ accelerate==0.23.0
5
+ gradio==3.41.1
6
+ xformers==0.0.24
7
+ diffusers==0.26.3
8
+ scipy
9
+ tqdm
10
+ numpy
11
+ safetensors
12
+ peft
scripts/run_segment.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ # python segment.py --name=$IMAGE_NAME --size=512
3
+ python segment.py --name=$IMAGE_NAME --size=1024
scripts/run_segmentSAM.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ python segment_sam.py --name=$IMAGE_NAME --text_prompt="bag"
scripts/sd/run_ft_sd_512.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ python main.py --name=$IMAGE_NAME \
3
+ --dpm="sd" \
4
+ --resolution=512 \
5
+ --num_tokens=5 \
6
+ --embedding_learning_rate=1e-4 \
7
+ --diffusion_model_learning_rate=5e-5 \
8
+ --max_emb_train_steps=500 \
9
+ --max_diffusion_train_steps=500 \
10
+ --train_batch_size=5 \
11
+ --gradient_accumulation_steps=5
scripts/sd/run_ft_sd_512_2imgs.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ export IMAGE_NAME_2="example2"
3
+ python main.py --name=$IMAGE_NAME \
4
+ --dpm="sd" \
5
+ --resolution=512 \
6
+ --image \
7
+ --name_2=$IMAGE_NAME_2 \
8
+ --embedding_learning_rate=1e-4 \
9
+ --diffusion_model_learning_rate=5e-5 \
10
+ --max_emb_train_steps=500 \
11
+ --max_diffusion_train_steps=500 \
12
+ --train_batch_size=5 \
13
+ --gradient_accumulation_steps=5
14
+
scripts/sd/run_image.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ export IMAGE_NAME_2="example2"
3
+ python main.py --name=$IMAGE_NAME \
4
+ --name_2=$IMAGE_NAME_2 \
5
+ --dpm="sd" \
6
+ --resolution=512 \
7
+ --image \
8
+ --load_trained \
9
+ --guidance_scale=2 \
10
+ --num_imgs=2 \
11
+ --seed=2024 \
12
+ --strength=0.5 \
13
+ --edge_thickness=10 \
14
+ --src_index=1 --tgt_index=0 \
15
+ --tgt_name=$IMAGE_NAME
scripts/sd/run_move_resize.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ CUDA_VISIBLE_DEVICES=1 python main.py --name=$IMAGE_NAME \
3
+ --dpm="sd" \
4
+ --resolution=512 \
5
+ --num_tokens=5 \
6
+ --load_trained \
7
+ --move_resize \
8
+ --seed=2023 \
9
+ --num_sampling_step=50 \
10
+ --strength=0.6 \
11
+ --edge_thickness=10 \
12
+ --guidance_scale=2 \
13
+ --num_imgs=1 \
14
+ --tgt_indices_list 0 \
15
+ --active_mask_list 2 \
16
+ --delta_x 100 --delta_y 60 \
17
+ --resize_list 0.6 \
18
+ --priority_list 1
scripts/sd/run_recon.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ python main.py --name=$IMAGE_NAME \
3
+ --dpm="sd" \
4
+ --resolution=512 \
5
+ --num_tokens=5 \
6
+ --load_trained \
7
+ --recon \
8
+ --seed=2024 \
9
+ --guidance_scale=2 \
10
+ --num_sampling_step=20 \
11
+ --num_imgs=1 \
scripts/sd/run_remove.sh ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # export IMAGE_NAME="example1"
2
+ # CUDA_VISIBLE_DEVICES=1 python main.py --name=$IMAGE_NAME \
3
+ # --dpm="sd" \
4
+ # --resolution=512 \
5
+ # --num_tokens=5 \
6
+ # --load_trained \
7
+ # --load_edited_mask \
8
+ # --remove \
9
+ # --seed=2023 \
10
+ # --num_sampling_step=50 \
11
+ # --strength=0.7 \
12
+ # --edge_thickness=10 \
13
+ # --guidance_scale=2 \
14
+ # --num_imgs=1 \
15
+ # --tgt_index=0
16
+
17
+
18
+
19
+ # export IMAGE_NAME="example1"
20
+ # CUDA_VISIBLE_DEVICES=1 python main.py --name=$IMAGE_NAME \
21
+ # --dpm="sd" \
22
+ # --resolution=512 \
23
+ # --num_tokens=5 \
24
+ # --load_trained \
25
+ # --load_edited_processed_mask \
26
+ # --remove \
27
+ # --seed=2024 \
28
+ # --num_sampling_step=50 \
29
+ # --strength=0.5 \
30
+ # --edge_thickness=10 \
31
+ # --guidance_scale=7 \
32
+ # --num_imgs=1 \
33
+ # --tgt_index=2
34
+
35
+
36
+
37
+
38
+ export IMAGE_NAME="example1"
39
+ CUDA_VISIBLE_DEVICES=1 python main.py --name=$IMAGE_NAME \
40
+ --dpm="sd" \
41
+ --resolution=512 \
42
+ --num_tokens=5 \
43
+ --load_trained \
44
+ --load_edited_processed_mask \
45
+ --remove \
46
+ --seed=1 \
47
+ --num_sampling_step=50 \
48
+ --strength=0.6 \
49
+ --edge_thickness=10 \
50
+ --guidance_scale=7 \
51
+ --num_imgs=1 \
52
+ --tgt_index=2
53
+
54
+
55
+
scripts/sd/run_text.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # export IMAGE_NAME="example1"
2
+ # python main.py --name=$IMAGE_NAME \
3
+ # --dpm="sd" \
4
+ # --resolution=512 \
5
+ # --load_trained \
6
+ # --text \
7
+ # --num_tokens=5 \
8
+ # --seed=2024 \
9
+ # --guidance_scale=7 \
10
+ # --num_sampling_step=50 \
11
+ # --strength=0.7 \
12
+ # --edge_thickness=15 \
13
+ # --num_imgs=1 \
14
+ # --tgt_prompt="a red bag" \
15
+ # --tgt_index=0
16
+
17
+
18
+ export IMAGE_NAME="example1"
19
+ python main.py --name=$IMAGE_NAME \
20
+ --dpm="sd" \
21
+ --resolution=512 \
22
+ --load_trained \
23
+ --text \
24
+ --num_tokens=5 \
25
+ --seed=2024 \
26
+ --guidance_scale=6 \
27
+ --num_sampling_step=50 \
28
+ --strength=0.5 \
29
+ --edge_thickness=15 \
30
+ --num_imgs=2 \
31
+ --tgt_prompt="a black bag" \
32
+ --tgt_index=0
33
+
scripts/sdxl/run_ft_sdxl_1024.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ python main.py --name=$IMAGE_NAME \
3
+ --dpm="sdxl" \
4
+ --resolution=1024 \
5
+ --num_tokens=5 \
6
+ --embedding_learning_rate=1e-4 \
7
+ --diffusion_model_learning_rate=5e-5 \
8
+ --max_emb_train_steps=500 \
9
+ --max_diffusion_train_steps=500 \
10
+ --train_batch_size=2 \
11
+ --gradient_accumulation_steps=5
12
+
scripts/sdxl/run_ft_sdxl_1024_2imgs.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ export IMAGE_NAME_2="example2"
3
+ python main.py --name=$IMAGE_NAME \
4
+ --dpm="sdxl" \
5
+ --image \
6
+ --name_2=$IMAGE_NAME_2 \
7
+ --resolution=1024 \
8
+ --embedding_learning_rate=1e-4 \
9
+ --diffusion_model_learning_rate=5e-5 \
10
+ --max_emb_train_steps=500 \
11
+ --max_diffusion_train_steps=500 \
12
+ --train_batch_size=1 \
13
+ --gradient_accumulation_steps=5
14
+
scripts/sdxl/run_ft_sdxl_1024_auxin_todo.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ python main.py --name=$IMAGE_NAME \
3
+ --dpm="sdxl" \
4
+ --resolution=1024 \
5
+ --train_full_lora \
6
+ --embedding_learning_rate=1e-4 \
7
+ --diffusion_model_learning_rate=1e-3 \
8
+ --max_emb_train_steps=500 \
9
+ --max_diffusion_train_steps=500 \
10
+ --train_batch_size=2 \
11
+ --gradient_accumulation_steps=5 \
12
+ --prompt_auxin_idx_list 0 2 \
13
+ --prompt_auxin_list "a photo of * handbag" "a photo of * model"
14
+
15
+ export IMAGE_NAME="example1"
16
+ python main.py --name=$IMAGE_NAME \
17
+ --dpm="sdxl" \
18
+ --resolution=1024 \
19
+ --train_full_lora \
20
+ --load_trained \
21
+ --recon \
22
+ --seed=23 \
23
+ --guidance_scale=7 \
24
+ --num_sampling_step=20 \
25
+ --num_imgs=2 \
26
+ --prompt_auxin_idx_list 0 2 \
27
+ --prompt_auxin_list "a photo of * handbag" "a photo of * model"
28
+
29
+
scripts/sdxl/run_ft_sdxl_1024_fulllora.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ python main.py --name=$IMAGE_NAME \
3
+ --dpm="sdxl" \
4
+ --train_full_lora \
5
+ --resolution=1024 \
6
+ --embedding_learning_rate=1e-4 \
7
+ --diffusion_model_learning_rate=1e-4 \
8
+ --max_emb_train_steps=500 \
9
+ --max_diffusion_train_steps=500 \
10
+ --train_batch_size=2 \
11
+ --gradient_accumulation_steps=5
12
+
scripts/sdxl/run_ft_sdxl_1024_fulllora_2imgs.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ export IMAGE_NAME_2="example2"
3
+ # python main.py --name=$IMAGE_NAME \
4
+ # --dpm="sdxl" \
5
+ # --image \
6
+ # --name_2=$IMAGE_NAME_2 \
7
+ # --resolution=1024 \
8
+ # --embedding_learning_rate=1e-4 \
9
+ # --diffusion_model_learning_rate=5e-5 \
10
+ # --max_emb_train_steps=500 \
11
+ # --max_diffusion_train_steps=500 \
12
+ # --train_batch_size=1 \
13
+ # --gradient_accumulation_steps=5
14
+
15
+ python main.py --name=$IMAGE_NAME \
16
+ --dpm="sdxl" \
17
+ --image \
18
+ --train_full_lora \
19
+ --name_2=$IMAGE_NAME_2 \
20
+ --resolution=1024 \
21
+ --embedding_learning_rate=1e-4 \
22
+ --diffusion_model_learning_rate=5e-4 \
23
+ --max_emb_train_steps=500 \
24
+ --max_diffusion_train_steps=500 \
25
+ --train_batch_size=1 \
26
+ --gradient_accumulation_steps=5
27
+
28
+
29
+ # python main.py --load_trained \
30
+ # --dpm="sdxl" \
31
+ # --image \
32
+ # --name=$IMAGE_NAME \
33
+ # --name_2=$IMAGE_NAME_2 \
34
+ # --tgt_name=$IMAGE_NAME \
35
+ # --guidance_scale 2.5 \
36
+ # --edge_thickness 40 \
37
+ # --strength 0.5 \
38
+ # --seed 29 \
39
+ # --num_imgs 4 \
40
+ # --tgt_index=0 \
41
+ # --src_index=2
scripts/sdxl/run_image.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ export IMAGE_NAME_2="example2"
3
+ python main.py --name=$IMAGE_NAME \
4
+ --name_2=$IMAGE_NAME_2 \
5
+ --dpm="sdxl" \
6
+ --image \
7
+ --load_trained \
8
+ --resolution=1024 \
9
+ --guidance_scale=2.8 \
10
+ --num_imgs=2 \
11
+ --seed=2023 \
12
+ --strength=0.5 \
13
+ --edge_thickness=20 \
14
+ --src_index=2 --tgt_index=0 \
15
+ --tgt_name=$IMAGE_NAME
scripts/sdxl/run_image_w_edited_mask.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ export IMAGE_NAME_2="example2"
3
+ python main.py --name=$IMAGE_NAME \
4
+ --name_2=$IMAGE_NAME_2 \
5
+ --dpm="sdxl" \
6
+ --image \
7
+ --load_trained \
8
+ --load_edited_mask \
9
+ --resolution=1024 \
10
+ --guidance_scale=2.8 \
11
+ --num_imgs=2 \
12
+ --seed=2023 \
13
+ --strength=0.5 \
14
+ --edge_thickness=20 \
15
+ --src_index=2 --tgt_index=0 \
16
+ --tgt_name=$IMAGE_NAME
scripts/sdxl/run_move_resize.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ CUDA_VISIBLE_DEVICES=1 python main.py --name=$IMAGE_NAME \
3
+ --dpm="sdxl" \
4
+ --resolution=1024 \
5
+ --num_tokens=5 \
6
+ --load_edited_mask \
7
+ --load_trained \
8
+ --move_resize \
9
+ --seed=2023 \
10
+ --num_sampling_step=20 \
11
+ --strength=0.5 \
12
+ --edge_thickness=20 \
13
+ --guidance_scale=2.8 \
14
+ --num_imgs=2 \
15
+ --tgt_indices_list 0 \
16
+ --active_mask_list 2 \
17
+ --delta_x 200 --delta_y 140 \
18
+ --resize_list 0.5 \
19
+ --priority_list 1
20
+
21
+ # --load_edited_processed_mask
scripts/sdxl/run_recon.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ python main.py --name=$IMAGE_NAME \
3
+ --dpm="sdxl" \
4
+ --resolution=1024 \
5
+ --num_tokens=5 \
6
+ --load_trained \
7
+ --recon \
8
+ --seed=20 \
9
+ --guidance_scale=3 \
10
+ --num_sampling_step=20 \
11
+ --num_imgs=2 \
scripts/sdxl/run_recon_item_todo.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # export IMAGE_NAME="example1"
2
+ # python main.py --name=$IMAGE_NAME \
3
+ # --dpm="sdxl" \
4
+ # --resolution=1024 \
5
+ # --load_trained \
6
+ # --recon \
7
+ # --recon_an_item \
8
+ # --seed=23 \
9
+ # --guidance_scale=6 \
10
+ # --num_sampling_step=20 \
11
+ # --num_imgs=2 \
12
+ # --tgt_index=0 \
13
+ # --recon_prompt="a photo of a * handbag on a table"
14
+
15
+ export IMAGE_NAME="example1"
16
+ python main.py --name=$IMAGE_NAME \
17
+ --dpm="sdxl" \
18
+ --resolution=1024 \
19
+ --load_trained \
20
+ --recon \
21
+ --recon_an_item \
22
+ --seed=23 \
23
+ --guidance_scale=6 \
24
+ --num_sampling_step=20 \
25
+ --num_imgs=2 \
26
+ --tgt_index=2 \
27
+ --recon_prompt="a photo of a * model on a chair"
scripts/sdxl/run_remove.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ python main.py --name=$IMAGE_NAME \
3
+ --dpm="sdxl" \
4
+ --resolution=1024 \
5
+ --num_tokens=5 \
6
+ --load_edited_mask \
7
+ --load_trained \
8
+ --remove \
9
+ --seed=0 \
10
+ --num_sampling_step=20 \
11
+ --strength=0.4 \
12
+ --edge_thickness=20 \
13
+ --guidance_scale=3 \
14
+ --num_imgs=1 \
15
+ --tgt_index=0
16
+
17
+ # --load_edited_processed_mask
scripts/sdxl/run_text.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ python main.py --name=$IMAGE_NAME \
3
+ --dpm="sdxl" \
4
+ --resolution=1024 \
5
+ --load_trained \
6
+ --num_tokens=5 \
7
+ --text \
8
+ --seed=23 \
9
+ --num_sampling_step=20 \
10
+ --strength=0.6 \
11
+ --edge_thickness=30 \
12
+ --num_imgs=2 \
13
+ --tgt_prompt="a white handbag" \
14
+ --tgt_index=0
scripts/sdxl/run_text_w_edited_mask.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export IMAGE_NAME="example1"
2
+ python main.py --name=$IMAGE_NAME \
3
+ --dpm="sdxl" \
4
+ --resolution=1024 \
5
+ --num_tokens=5 \
6
+ --load_trained \
7
+ --load_edited_mask \
8
+ --text \
9
+ --seed=23 \
10
+ --num_sampling_step=50 \
11
+ --strength=0.7 \
12
+ --edge_thickness=30 \
13
+ --num_imgs=2 \
14
+ --tgt_prompt="a white handbag" \
15
+ --tgt_index=0
segment.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
3
+ from PIL import Image
4
+ import torch
5
+ from collections import defaultdict
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib import cm
8
+ import matplotlib.patches as mpatches
9
+ import os
10
+ import numpy as np
11
+ import argparse
12
+ import matplotlib
13
+
14
+ def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
15
+ if type(image_path) is str:
16
+ image = np.array(Image.open(image_path))[:, :, :3]
17
+ else:
18
+ image = image_path
19
+ h, w, c = image.shape
20
+ left = min(left, w-1)
21
+ right = min(right, w - left - 1)
22
+ top = min(top, h - left - 1)
23
+ bottom = min(bottom, h - top - 1)
24
+ image = image[top:h-bottom, left:w-right]
25
+ h, w, c = image.shape
26
+ if h < w:
27
+ offset = (w - h) // 2
28
+ image = image[:, offset:offset + h]
29
+ elif w < h:
30
+ offset = (h - w) // 2
31
+ image = image[offset:offset + w]
32
+ image = np.array(Image.fromarray(image).resize((size, size)))
33
+ return image
34
+
35
+ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, noseg = False):
36
+ if torch.max(segmentation)==torch.min(segmentation)==-1:
37
+ print("nothing is detected!")
38
+ noseg=True
39
+ viridis = matplotlib.colormaps['viridis'].resampled(1)
40
+ else:
41
+ viridis = matplotlib.colormaps['viridis'].resampled(torch.max(segmentation)-torch.min(segmentation)+1)
42
+ fig, ax = plt.subplots()
43
+ ax.imshow(segmentation)
44
+ instances_counter = defaultdict(int)
45
+ handles = []
46
+ label_list = []
47
+ if not noseg:
48
+ if torch.min(segmentation) == 0:
49
+ mask = segmentation==0
50
+ mask = mask.cpu().detach().numpy() # [512,512] bool
51
+ segment_label = "rest"
52
+ np.save( os.path.join(save_folder, "mask{}_{}.npy".format(0,"rest")) , mask)
53
+ color = viridis(0)
54
+ label = f"{segment_label}-{0}"
55
+ handles.append(mpatches.Patch(color=color, label=label))
56
+ label_list.append(label)
57
+
58
+ for segment in segments_info:
59
+ segment_id = segment['id']
60
+ mask = segmentation==segment_id
61
+ if torch.min(segmentation) != 0:
62
+ segment_id -= 1
63
+ mask = mask.cpu().detach().numpy() # [512,512] bool
64
+
65
+ segment_label = model.config.id2label[segment['label_id']]
66
+ instances_counter[segment['label_id']] += 1
67
+ np.save( os.path.join(save_folder, "mask{}_{}.npy".format(segment_id,segment_label)) , mask)
68
+ color = viridis(segment_id)
69
+
70
+ label = f"{segment_label}-{segment_id}"
71
+ handles.append(mpatches.Patch(color=color, label=label))
72
+ label_list.append(label)
73
+ else:
74
+ mask = np.full(segmentation.shape, True)
75
+ segment_label = "all"
76
+ np.save( os.path.join(save_folder, "mask{}_{}.npy".format(0,"all")) , mask)
77
+ color = viridis(0)
78
+ label = f"{segment_label}-{0}"
79
+ handles.append(mpatches.Patch(color=color, label=label))
80
+ label_list.append(label)
81
+
82
+ plt.xticks([])
83
+ plt.yticks([])
84
+ # plt.savefig(os.path.join(save_folder, 'mask_clear.png'), dpi=500)
85
+ ax.legend(handles=handles)
86
+ plt.savefig(os.path.join(save_folder, 'seg_init.png'), dpi=500 )
87
+ print("; ".join(label_list))
88
+
89
+
90
+
91
+ parser = argparse.ArgumentParser()
92
+ parser.add_argument("--name", type=str, default="obama")
93
+ parser.add_argument("--size", type=int, default=512)
94
+ parser.add_argument("--noseg", default=False, action="store_true" )
95
+ args = parser.parse_args()
96
+ base_folder_path = "."
97
+
98
+ processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
99
+ model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
100
+ input_folder = os.path.join(base_folder_path, args.name )
101
+ try:
102
+ image = load_image(os.path.join(input_folder, "img.png" ), size = args.size)
103
+ except:
104
+ image = load_image(os.path.join(input_folder, "img.jpg" ), size = args.size)
105
+
106
+ image =Image.fromarray(image)
107
+ image.save(os.path.join(input_folder,"img_{}.png".format(args.size)))
108
+ inputs = processor(image, return_tensors="pt")
109
+ with torch.no_grad():
110
+ outputs = model(**inputs)
111
+
112
+ panoptic_segmentation = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
113
+ save_folder = os.path.join(base_folder_path, args.name)
114
+ os.makedirs(save_folder, exist_ok=True)
115
+ draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = args.noseg)
segment_sam.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import copy
4
+ import shutil
5
+
6
+ import numpy as np
7
+ import json
8
+ import torch
9
+ from PIL import Image, ImageDraw, ImageFont
10
+
11
+ # Grounding DINO
12
+ import sys
13
+
14
+ sys.path.append("/path/to/Grounded-Segment-Anything")
15
+ # change to your "Grounded-Segment-Anything" installation folder!!!!!
16
+ import GroundingDINO.groundingdino.datasets.transforms as T
17
+ from GroundingDINO.groundingdino.models import build_model
18
+ from GroundingDINO.groundingdino.util import box_ops
19
+ from GroundingDINO.groundingdino.util.slconfig import SLConfig
20
+ from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
21
+
22
+ # segment anything
23
+ from segment_anything import (
24
+ sam_model_registry,
25
+ sam_hq_model_registry,
26
+ SamPredictor
27
+ )
28
+ import cv2
29
+ import numpy as np
30
+ import matplotlib.pyplot as plt
31
+ def load_image_to_resize(image_path, left=0, right=0, top=0, bottom=0, size = 512):
32
+ if type(image_path) is str:
33
+ image = np.array(Image.open(image_path))[:, :, :3]
34
+ else:
35
+ image = image_path
36
+ h, w, c = image.shape
37
+ left = min(left, w-1)
38
+ right = min(right, w - left - 1)
39
+ top = min(top, h - left - 1)
40
+ bottom = min(bottom, h - top - 1)
41
+ image = image[top:h-bottom, left:w-right]
42
+ h, w, c = image.shape
43
+ if h < w:
44
+ offset = (w - h) // 2
45
+ image = image[:, offset:offset + h]
46
+ elif w < h:
47
+ offset = (h - w) // 2
48
+ image = image[offset:offset + w]
49
+ image = np.array(Image.fromarray(image).resize((size, size)))
50
+ return image
51
+
52
+
53
+ def load_image(image_path):
54
+ # load image
55
+ image_pil = Image.open(image_path).convert("RGB") # load image
56
+
57
+ transform = T.Compose(
58
+ [
59
+ T.RandomResize([800], max_size=1333),
60
+ T.ToTensor(),
61
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
62
+ ]
63
+ )
64
+ image, _ = transform(image_pil, None) # 3, h, w
65
+ return image_pil, image
66
+
67
+
68
+ def load_model(model_config_path, model_checkpoint_path, device):
69
+ args = SLConfig.fromfile(model_config_path)
70
+ args.device = device
71
+ model = build_model(args)
72
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
73
+ load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
74
+ model.eval()
75
+ return model
76
+
77
+
78
+ def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
79
+ caption = caption.lower()
80
+ caption = caption.strip()
81
+ if not caption.endswith("."):
82
+ caption = caption + "."
83
+ model = model.to(device)
84
+ image = image.to(device)
85
+ with torch.no_grad():
86
+ outputs = model(image[None], captions=[caption])
87
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
88
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
89
+ logits.shape[0]
90
+
91
+ # filter output
92
+ logits_filt = logits.clone()
93
+ boxes_filt = boxes.clone()
94
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
95
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
96
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
97
+ logits_filt.shape[0]
98
+
99
+ # get phrase
100
+ tokenlizer = model.tokenizer
101
+ tokenized = tokenlizer(caption)
102
+ # build pred
103
+ pred_phrases = []
104
+ for logit, box in zip(logits_filt, boxes_filt):
105
+ pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
106
+ if with_logits:
107
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
108
+ else:
109
+ pred_phrases.append(pred_phrase)
110
+
111
+ return boxes_filt, pred_phrases
112
+
113
+ def show_mask(mask, ax, random_color=False):
114
+ if random_color:
115
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
116
+ else:
117
+ color = np.array([30/255, 144/255, 255/255, 0.6])
118
+ h, w = mask.shape[-2:]
119
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
120
+ ax.imshow(mask_image)
121
+
122
+
123
+ def show_box(box, ax, label):
124
+ x0, y0 = box[0], box[1]
125
+ w, h = box[2] - box[0], box[3] - box[1]
126
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
127
+ ax.text(x0, y0, label)
128
+
129
+
130
+ def save_mask_data(output_dir, mask_list, box_list, label_list):
131
+ value = 0 # 0 for background
132
+
133
+ mask_img = torch.zeros(mask_list.shape[-2:])
134
+ for idx, mask in enumerate(mask_list):
135
+ mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
136
+ plt.figure(figsize=(10, 10))
137
+ plt.imshow(mask_img.numpy())
138
+ plt.axis('off')
139
+ plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)
140
+
141
+ json_data = [{
142
+ 'value': value,
143
+ 'label': 'background'
144
+ }]
145
+ for label, box in zip(label_list, box_list):
146
+ value += 1
147
+ name, logit = label.split('(')
148
+ logit = logit[:-1] # the last is ')'
149
+ json_data.append({
150
+ 'value': value,
151
+ 'label': name,
152
+ 'logit': float(logit),
153
+ 'box': box.numpy().tolist(),
154
+ })
155
+ with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
156
+ json.dump(json_data, f)
157
+
158
+
159
+ if __name__ == "__main__":
160
+
161
+ parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
162
+ parser.add_argument("--sam_version", type=str, default="vit_h", required=False, help="SAM ViT version: vit_b / vit_l / vit_h")
163
+ parser.add_argument("--sam_checkpoint", type=str, required=False, help="path to sam checkpoint file")
164
+ parser.add_argument("--sam_hq_checkpoint", type=str, default=None, help="path to sam-hq checkpoint file")
165
+ parser.add_argument("--use_sam_hq", action="store_true", help="using sam-hq for prediction")
166
+ parser.add_argument("--text_prompt", type=str, required=True, help="text prompt")
167
+
168
+ parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
169
+ parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
170
+ parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
171
+ parser.add_argument("--name", type=str, default="", help="name of the input image folder")
172
+ parser.add_argument("--size", type=int, default=1024, help="image size")
173
+
174
+ args = parser.parse_args()
175
+ args.base_folder = "/path/to/Grounded-Segment-Anything"
176
+ # change to your "Grounded-Segment-Anything" installation folder!!!!!
177
+ input_folder = os.path.join(".", args.name)
178
+
179
+ args.config = os.path.join(args.base_folder,"GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py")
180
+ args.grounded_checkpoint = "groundingdino_swint_ogc.pth"
181
+ args.sam_checkpoint="sam_vit_h_4b8939.pth"
182
+ args.box_threshold = 0.3
183
+ args.text_threshold = 0.25
184
+ args.device = "cuda"
185
+ # cfg
186
+
187
+ config_file = args.config # change the path of the model config file
188
+ grounded_checkpoint = os.path.join(args.base_folder,args.grounded_checkpoint) # change the path of the model
189
+ sam_version = args.sam_version
190
+ sam_checkpoint = os.path.join(args.base_folder,args.sam_checkpoint)
191
+ if args.sam_hq_checkpoint is not None:
192
+ sam_hq_checkpoint = os.path.join(args.base_folder,args.sam_hq_checkpoint)
193
+ use_sam_hq = args.use_sam_hq
194
+ # image_path = args.input_image
195
+ text_prompt = args.text_prompt
196
+ # output_dir = args.output_dir
197
+ box_threshold = args.box_threshold
198
+ text_threshold = args.text_threshold
199
+ device = args.device
200
+
201
+ output_dir = input_folder
202
+ os.makedirs(output_dir, exist_ok=True)
203
+
204
+ # unify names
205
+
206
+ if len(os.listdir(input_folder)) == 1:
207
+ for filename in os.listdir(input_folder):
208
+ imgtype = "." + filename.split(".")[-1]
209
+ shutil.move(os.path.join(input_folder, filename), os.path.join(input_folder, "img"+imgtype))
210
+
211
+
212
+
213
+ ### resizing and save
214
+ if os.path.exists(os.path.join(input_folder, "img.jpg")):
215
+ image_path = os.path.join(input_folder, "img.jpg")
216
+ else:
217
+ image_path = os.path.join(input_folder, "img.png")
218
+ image = load_image_to_resize(image_path, size = args.size)
219
+ image =Image.fromarray(image)
220
+ resized_image_path = os.path.join(input_folder, "img_{}.png".format(args.size))
221
+ image.save(resized_image_path)
222
+
223
+ image_path = resized_image_path
224
+ # load image
225
+ image_pil, image = load_image(image_path)
226
+ # load model
227
+ model = load_model(config_file, grounded_checkpoint, device=device)
228
+
229
+ # # visualize raw image
230
+ # image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
231
+
232
+ # run grounding dino model
233
+ boxes_filt, pred_phrases = get_grounding_output(
234
+ model, image, text_prompt, box_threshold, text_threshold, device=device
235
+ )
236
+
237
+ # initialize SAM
238
+ if use_sam_hq:
239
+ predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device))
240
+ else:
241
+ predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device))
242
+ image = cv2.imread(image_path)
243
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
244
+ predictor.set_image(image)
245
+
246
+ size = image_pil.size
247
+ H, W = size[1], size[0]
248
+ for i in range(boxes_filt.size(0)):
249
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
250
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
251
+ boxes_filt[i][2:] += boxes_filt[i][:2]
252
+
253
+ boxes_filt = boxes_filt.cpu()
254
+ transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
255
+
256
+ masks, _, _ = predictor.predict_torch(
257
+ point_coords = None,
258
+ point_labels = None,
259
+ boxes = transformed_boxes.to(device),
260
+ multimask_output = False,
261
+ )
262
+
263
+ tot_detect = len(masks)
264
+ # draw output image
265
+ plt.figure(figsize=(10, 10))
266
+ plt.imshow(image)
267
+ for idx, (mask,label) in enumerate(zip(masks,pred_phrases)):
268
+ show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
269
+ np.save( os.path.join(output_dir, "maskSAM{}_{}.npy".format(idx, label)) ,mask[0].cpu().numpy())
270
+
271
+ for idx, (box, label) in enumerate(zip(boxes_filt, pred_phrases)):
272
+ label = label + "_{}".format(idx)
273
+ show_box(box.numpy(), plt.gca(), label)
274
+
275
+ rec_mask = np.zeros_like(mask[0].cpu().numpy()).astype(np.bool_)
276
+ for idx, box in enumerate(boxes_filt):
277
+ up = box[0].numpy().astype(np.int32)
278
+ down = box[2].numpy().astype(np.int32)
279
+ left = box[1].numpy().astype(np.int32)
280
+ right = box[3].numpy().astype(np.int32)
281
+ rec_mask[left:right, up:down] = True
282
+
283
+ plt.axis('off')
284
+ plt.savefig(
285
+ os.path.join(output_dir, "seg_init_SAM.png"),
286
+ bbox_inches="tight", dpi=300, pad_inches=0.0
287
+ )
288
+
289
+ mask_detected = np.logical_or.reduce([mask[0].cpu().numpy() for mask in masks ])
290
+ mask_undetected = np.logical_not(mask_detected)
291
+ np.save( os.path.join(output_dir, "SAM_detected.npy") ,mask_detected)
292
+ np.save( os.path.join(output_dir, "maskSAM{}_rest.npy".format(len(masks))) ,mask_undetected)
293
+ plt.imsave( os.path.join(output_dir,"mask_SAM-detected.png"), np.repeat(np.expand_dims( mask_detected.astype(float), axis=2), 3, axis = 2))
294
+
utils.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from PIL import Image
3
+ import torch
4
+ import numpy as np
5
+ import PIL
6
+ import os
7
+ from tqdm.auto import tqdm
8
+ from diffusers.models.attention_processor import (
9
+ AttnProcessor2_0,
10
+ LoRAAttnProcessor2_0,
11
+ LoRAXFormersAttnProcessor,
12
+ XFormersAttnProcessor,
13
+ )
14
+
15
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
16
+
17
+ def myroll2d(a, delta_x, delta_y):
18
+ h, w = a.shape[0], a.shape[1]
19
+ delta_x = -delta_x
20
+ delta_y = -delta_y
21
+ if isinstance(a, np.ndarray):
22
+ b = np.zeros ([h,w]).astype(np.uint8)
23
+ elif isinstance(a, torch.Tensor):
24
+ b = torch.zeros([h,w]).to(torch.uint8)
25
+ if delta_x > 0:
26
+ left_a = delta_x
27
+ right_a = w
28
+ left_b = 0
29
+ right_b = w - delta_x
30
+ else:
31
+ left_a = 0
32
+ right_a = w + delta_x
33
+ left_b = -delta_x
34
+ right_b = w
35
+ if delta_y > 0:
36
+ top_a = delta_y
37
+ bot_a = h
38
+ top_b = 0
39
+ bot_b = h-delta_y
40
+ else:
41
+ top_a = 0
42
+ bot_a = h + delta_y
43
+ top_b = -delta_y
44
+ bot_b = h
45
+ b[left_b: right_b, top_b: bot_b] = a[left_a: right_a, top_a: bot_a]
46
+ return b
47
+
48
+ def import_model_class_from_model_name_or_path(
49
+ pretrained_model_name_or_path: str, revision = None, subfolder: str = "text_encoder"
50
+ ):
51
+ text_encoder_config = PretrainedConfig.from_pretrained(
52
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
53
+ )
54
+ model_class = text_encoder_config.architectures[0]
55
+
56
+ if model_class == "CLIPTextModel":
57
+ from transformers import CLIPTextModel
58
+ return CLIPTextModel
59
+ elif model_class == "CLIPTextModelWithProjection":
60
+ from transformers import CLIPTextModelWithProjection
61
+ return CLIPTextModelWithProjection
62
+ else:
63
+ raise ValueError(f"{model_class} is not supported.")
64
+
65
+ @torch.no_grad()
66
+ def image2latent(image, vae = None, dtype=None):
67
+ with torch.no_grad():
68
+ if type(image) is Image or type(image) is PIL.PngImagePlugin.PngImageFile or type(image) is PIL.JpegImagePlugin.JpegImageFile:
69
+ image = np.array(image)
70
+ if type(image) is torch.Tensor and image.dim() == 4:
71
+ latents = image
72
+ else:
73
+ image = torch.from_numpy(image).float() / 127.5 - 1
74
+ image = image.permute(2, 0, 1).unsqueeze(0).to(device, dtype= dtype)
75
+ latents = vae.encode(image).latent_dist.sample()
76
+ latents = latents * vae.config.scaling_factor
77
+ return latents
78
+
79
+ @torch.no_grad()
80
+ def latent2image(latents, return_type = 'np', vae = None):
81
+ # needs_upcasting = vae.dtype == torch.float16 and vae.config.force_upcast
82
+ needs_upcasting = True
83
+ if needs_upcasting:
84
+ upcast_vae(vae)
85
+ latents = latents.to(next(iter(vae.post_quant_conv.parameters())).dtype)
86
+ image = vae.decode(latents /vae.config.scaling_factor, return_dict=False)[0]
87
+
88
+ if return_type == 'np':
89
+ image = (image / 2 + 0.5).clamp(0, 1)
90
+ image = image.cpu().permute(0, 2, 3, 1).numpy()#[0]
91
+ image = (image * 255).astype(np.uint8)
92
+ if needs_upcasting:
93
+ vae.to(dtype=torch.float16)
94
+ return image
95
+
96
+ def upcast_vae(vae):
97
+ dtype = vae.dtype
98
+ vae.to(dtype=torch.float32)
99
+ use_torch_2_0_or_xformers = isinstance(
100
+ vae.decoder.mid_block.attentions[0].processor,
101
+ (
102
+ AttnProcessor2_0,
103
+ XFormersAttnProcessor,
104
+ LoRAXFormersAttnProcessor,
105
+ LoRAAttnProcessor2_0,
106
+ ),
107
+ )
108
+ # if xformers or torch_2_0 is used attention block does not need
109
+ # to be in float32 which can save lots of memory
110
+ if use_torch_2_0_or_xformers:
111
+ vae.post_quant_conv.to(dtype)
112
+ vae.decoder.conv_in.to(dtype)
113
+ vae.decoder.mid_block.to(dtype)
114
+
115
+ def prompt_to_emb_length_sdxl(prompt, tokenizer, text_encoder, length = None):
116
+ text_input = tokenizer(
117
+ [prompt],
118
+ padding="max_length",
119
+ max_length=length,
120
+ truncation=True,
121
+ return_tensors="pt",
122
+ )
123
+ prompt_embeds = text_encoder(text_input.input_ids.to(device),output_hidden_states=True)
124
+ pooled_prompt_embeds = prompt_embeds[0]
125
+
126
+ prompt_embeds = prompt_embeds.hidden_states[-2]
127
+ bs_embed, seq_len, _ = prompt_embeds.shape
128
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
129
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
130
+
131
+ return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds}
132
+
133
+
134
+
135
+
136
+ def prompt_to_emb_length_sd(prompt, tokenizer, text_encoder, length = None):
137
+ text_input = tokenizer(
138
+ [prompt],
139
+ padding="max_length",
140
+ max_length=length,
141
+ truncation=True,
142
+ return_tensors="pt",
143
+ )
144
+ emb = text_encoder(text_input.input_ids.to(device))[0]
145
+ return emb
146
+
147
+ def sdxl_prepare_input_decom(
148
+ set_string_list,
149
+ tokenizer,
150
+ tokenizer_2,
151
+ text_encoder_1,
152
+ text_encoder_2,
153
+ length = 20,
154
+ bsz = 1,
155
+ weight_dtype = torch.float32,
156
+ resolution = 1024,
157
+ normal_token_id_list = []
158
+ ):
159
+ encoder_hidden_states_list = []
160
+ pooled_prompt_embeds = 0
161
+
162
+ for m_idx in range(len(set_string_list)):
163
+ prompt_embeds_list = []
164
+ if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list : ###
165
+ out = prompt_to_emb_length_sdxl(
166
+ set_string_list[m_idx], tokenizer, text_encoder_1, length = length
167
+ )
168
+ else:
169
+ out = prompt_to_emb_length_sdxl(
170
+ set_string_list[m_idx], tokenizer, text_encoder_1, length = 77
171
+ )
172
+ print(m_idx, set_string_list[m_idx])
173
+ prompt_embeds, _ = out["prompt_embeds"].to(dtype=weight_dtype), out["pooled_prompt_embeds"].to(dtype=weight_dtype)
174
+ prompt_embeds = prompt_embeds.repeat(bsz, 1, 1)
175
+ prompt_embeds_list.append(prompt_embeds)
176
+ if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list:
177
+ out = prompt_to_emb_length_sdxl(
178
+ set_string_list[m_idx], tokenizer_2, text_encoder_2, length = length
179
+ )
180
+ else:
181
+ out = prompt_to_emb_length_sdxl(
182
+ set_string_list[m_idx], tokenizer_2, text_encoder_2, length = 77
183
+ )
184
+ print(m_idx, set_string_list[m_idx])
185
+
186
+ prompt_embeds = out["prompt_embeds"].to(dtype=weight_dtype)
187
+ pooled_prompt_embeds += out["pooled_prompt_embeds"].to(dtype=weight_dtype)
188
+ prompt_embeds = prompt_embeds.repeat(bsz, 1, 1)
189
+ prompt_embeds_list.append(prompt_embeds)
190
+
191
+ encoder_hidden_states_list.append(torch.concat(prompt_embeds_list, dim=-1))
192
+
193
+ add_text_embeds = pooled_prompt_embeds /len(set_string_list)
194
+ target_size, original_size,crops_coords_top_left = (resolution,resolution),(resolution,resolution),(0,0)
195
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
196
+
197
+ add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype,device = pooled_prompt_embeds.device) #[B,6]
198
+ return encoder_hidden_states_list, add_text_embeds, add_time_ids
199
+
200
+ def sd_prepare_input_decom(
201
+ set_string_list,
202
+ tokenizer,
203
+ text_encoder_1,
204
+ length = 20,
205
+ bsz = 1,
206
+ weight_dtype = torch.float32,
207
+ normal_token_id_list = []
208
+ ):
209
+ encoder_hidden_states_list = []
210
+ for m_idx in range(len(set_string_list)):
211
+ if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list : ###
212
+ encoder_hidden_states = prompt_to_emb_length_sd(
213
+ set_string_list[m_idx], tokenizer, text_encoder_1, length = length
214
+ )
215
+ else:
216
+ encoder_hidden_states = prompt_to_emb_length_sd(
217
+ set_string_list[m_idx], tokenizer, text_encoder_1, length = 77
218
+ )
219
+ print(m_idx, set_string_list[m_idx])
220
+ encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1)
221
+ encoder_hidden_states_list.append(encoder_hidden_states.to(dtype=weight_dtype))
222
+ return encoder_hidden_states_list
223
+
224
+
225
+ def load_mask (input_folder):
226
+ np_mask_dtype = 'uint8'
227
+ mask_np_list = []
228
+ mask_label_list = []
229
+ files = [
230
+ file_name for file_name in os.listdir(input_folder) \
231
+ if "mask" in file_name and ".npy" in file_name \
232
+ and "_" in file_name and "Edited" not in file_name
233
+ ]
234
+ files = sorted(files, key = lambda x: int(x.split("_")[0][4:]))
235
+
236
+ for idx, file_name in enumerate(files):
237
+ if "mask" in file_name and ".npy" in file_name and "_" in file_name \
238
+ and "Edited" not in file_name:
239
+ mask_np = np.load(os.path.join(input_folder, file_name)).astype(np_mask_dtype)
240
+ mask_np_list.append(mask_np)
241
+ mask_label = file_name.split("_")[1][:-4]
242
+ mask_label_list.append(mask_label)
243
+ mask_list = []
244
+ for mask_np in mask_np_list:
245
+ mask = torch.from_numpy(mask_np)
246
+ mask_list.append(mask)
247
+ try:
248
+ assert torch.all(sum(mask_list)==1)
249
+ except:
250
+ print("please check mask")
251
+ # plt.imsave( "out_mask.png", mask_list_edit[0])
252
+ import pdb; pdb.set_trace()
253
+ return mask_list, mask_label_list
254
+
255
+ def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
256
+ if type(image_path) is str:
257
+ image = np.array(Image.open(image_path))[:, :, :3]
258
+ else:
259
+ image = image_path
260
+ h, w, c = image.shape
261
+ left = min(left, w-1)
262
+ right = min(right, w - left - 1)
263
+ top = min(top, h - left - 1)
264
+ bottom = min(bottom, h - top - 1)
265
+ image = image[top:h-bottom, left:w-right]
266
+ h, w, c = image.shape
267
+ if h < w:
268
+ offset = (w - h) // 2
269
+ image = image[:, offset:offset + h]
270
+ elif w < h:
271
+ offset = (h - w) // 2
272
+ image = image[offset:offset + w]
273
+ image = np.array(Image.fromarray(image).resize((size, size)))
274
+ return image
275
+
276
+ def mask_union_torch(*masks):
277
+ masks = [m.to(torch.float) for m in masks]
278
+ res = sum(masks)>0
279
+ return res
280
+
281
+ def load_mask_edit(input_folder):
282
+ np_mask_dtype = 'uint8'
283
+ mask_np_list = []
284
+ mask_label_list = []
285
+
286
+ files = [file_name for file_name in os.listdir(input_folder) if "mask" in file_name and ".npy" in file_name and "_" in file_name and "Edited" in file_name and "-1" not in file_name]
287
+ files = sorted(files, key = lambda x: int(x.split("_")[0][10:]))
288
+
289
+ for idx, file_name in enumerate(files):
290
+ if "mask" in file_name and ".npy" in file_name and "_" in file_name and "Edited" in file_name and "-1" not in file_name:
291
+ mask_np = np.load(os.path.join(input_folder, file_name)).astype(np_mask_dtype)
292
+ mask_np_list.append(mask_np)
293
+ mask_label = file_name.split("_")[1][:-4]
294
+ # mask_label = mask_label.split("-")[0]
295
+ mask_label_list.append(mask_label)
296
+ mask_list = []
297
+ for mask_np in mask_np_list:
298
+ mask = torch.from_numpy(mask_np)
299
+ mask_list.append(mask)
300
+ try:
301
+ assert torch.all(sum(mask_list)==1)
302
+ except:
303
+ print("Make sure maskEdited is in the folder, if not, generate using the UI")
304
+ import pdb; pdb.set_trace()
305
+ return mask_list, mask_label_list
306
+
307
+ def save_images(images,filename, num_rows=1, offset_ratio=0.02):
308
+ if type(images) is list:
309
+ num_empty = len(images) % num_rows
310
+ elif images.ndim == 4:
311
+ num_empty = images.shape[0] % num_rows
312
+ else:
313
+ images = [images]
314
+ num_empty = 0
315
+
316
+ empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
317
+ images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
318
+ num_items = len(images)
319
+
320
+ folder = os.path.dirname(filename)
321
+ for i, image in enumerate(images):
322
+ pil_img = Image.fromarray(image)
323
+ name = filename.split("/")[-1]
324
+ name = name.split(".")[-2]+"_{}".format(i) +"."+filename.split(".")[-1]
325
+ pil_img.save(os.path.join(folder, name))
326
+ print("saved to ", os.path.join(folder, name))
utils_mask.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from matplotlib import cm
4
+ import matplotlib.patches as mpatches
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+ from utils import myroll2d
8
+
9
+ def create_outer_edge_mask_torch(mask, edge_thickness = 20):
10
+ mask_down = myroll2d(mask, edge_thickness, 0 )
11
+ mask_edge_down = (mask_down.to(torch.float) -mask.to(torch.float))>0
12
+
13
+ mask_up = myroll2d(mask, -edge_thickness, 0)
14
+ mask_edge_up = (mask_up.to(torch.float) -mask.to(torch.float))>0
15
+
16
+ mask_left = myroll2d(mask, 0, -edge_thickness)
17
+ mask_edge_left = (mask_left.to(torch.float) -mask.to(torch.float))>0
18
+
19
+ mask_right = myroll2d(mask, 0, edge_thickness)
20
+ mask_edge_right = (mask_right.to(torch.float) -mask.to(torch.float))>0
21
+
22
+ mask_ur = myroll2d(mask, -edge_thickness,edge_thickness)
23
+ mask_edge_ur = (mask_ur.to(torch.float) -mask.to(torch.float))>0
24
+
25
+ mask_ul = myroll2d(mask, -edge_thickness,-edge_thickness)
26
+ mask_edge_ul = (mask_ul.to(torch.float) -mask.to(torch.float))>0
27
+
28
+ mask_dr = myroll2d(mask, edge_thickness,edge_thickness )
29
+ mask_edge_dr = (mask_dr.to(torch.float) -mask.to(torch.float))>0
30
+
31
+ mask_dl = myroll2d(mask, edge_thickness,-edge_thickness)
32
+ mask_edge_ul = (mask_dl.to(torch.float) -mask.to(torch.float))>0
33
+
34
+ mask_edge = mask_union_torch(mask_edge_down, mask_edge_up, mask_edge_left, mask_edge_right,
35
+ mask_edge_ur, mask_edge_ul, mask_edge_dr, mask_edge_ul)
36
+ return mask_edge
37
+
38
+ def mask_substract_torch(mask1, mask2):
39
+ return ((mask1.cpu().to(torch.float)-mask2.cpu().to(torch.float))>0).to(torch.uint8)
40
+
41
+ def check_mask_overlap_torch(*masks):
42
+ assert torch.any(sum([m.float() for m in masks])<=1 )
43
+
44
+ def check_mask_overlap_numpy(*masks):
45
+ assert np.all(sum([m.astype(float) for m in masks])<=1 )
46
+
47
+ def check_cover_all_torch (*masks):
48
+ assert torch.all(sum([m.cpu().float() for m in masks])==1)
49
+
50
+ def process_mask_to_follow_priority(mask_list, priority_list):
51
+ for idx1, (m1 , p1) in enumerate(zip(mask_list, priority_list)):
52
+ for idx2, (m2 , p2) in enumerate(zip(mask_list, priority_list)):
53
+ if p2 > p1:
54
+ mask_list[idx1] = ((m1.astype(float)-m2.astype(float))>0).astype(np.uint8)
55
+ return mask_list
56
+
57
+ def mask_union(*masks):
58
+ masks = [m.astype(float) for m in masks]
59
+ res = sum(masks)>0
60
+ return res.astype(np.uint8)
61
+
62
+ def mask_intersection(mask1, mask2):
63
+ mask_uni = mask_union(mask1, mask2)
64
+ mask_intersec = ((mask1.astype(float)-mask2.astype(float))==0) * mask_uni
65
+ return mask_intersec
66
+
67
+ def mask_union_torch(*masks):
68
+ masks = [m.float() for m in masks]
69
+ res = sum(masks)>0
70
+ return res.to(torch.uint8)
71
+
72
+ def mask_intersection_torch(mask1, mask2):
73
+ mask_uni = mask_union_torch(mask1, mask2)
74
+ mask_intersec = ((mask1.float()-mask2.float())==0) * mask_uni
75
+ return mask_intersec.cpu().to(torch.uint8)
76
+
77
+
78
+ def visualize_mask_list(mask_list, savepath):
79
+ mask = 0
80
+ for midx, m in enumerate(mask_list):
81
+ try:
82
+ mask += m.astype(float)* midx
83
+ except:
84
+ mask += m.float()*midx
85
+ viridis = cm.get_cmap('viridis', len(mask_list))
86
+ fig, ax = plt.subplots()
87
+ ax.imshow( mask)
88
+
89
+ handles = []
90
+ label_list = []
91
+ for idx , _ in enumerate(mask_list):
92
+ color = viridis(idx)
93
+ label = f"{idx}"
94
+ handles.append(mpatches.Patch(color=color, label=label))
95
+ label_list.append(label)
96
+ ax.legend(handles=handles)
97
+ plt.savefig(savepath)
98
+
99
+ def visualize_mask_list_clean(mask_list, savepath):
100
+ mask = 0
101
+ for midx, m in enumerate(mask_list):
102
+ try:
103
+ mask += m.astype(float)* midx
104
+ except:
105
+ mask += m.float()*midx
106
+ viridis = cm.get_cmap('viridis', len(mask_list))
107
+ fig, ax = plt.subplots()
108
+ ax.imshow( mask)
109
+
110
+ handles = []
111
+ label_list = []
112
+ for idx , _ in enumerate(mask_list):
113
+ color = viridis(idx)
114
+ label = f"{idx}"
115
+ handles.append(mpatches.Patch(color=color, label=label))
116
+ label_list.append(label)
117
+ # ax.legend(handles=handles)
118
+ plt.savefig(savepath, dpi=500)
119
+
120
+
121
+ def move_mask(mask_select, delta_x, delta_y):
122
+ mask_edit = myroll2d(mask_select, delta_y, delta_x)
123
+ return mask_edit
124
+
125
+ def stack_mask_with_priority (mask_list_np, priority_list, edit_idx_list):
126
+ mask_sel = mask_union(*[mask_list_np[eid] for eid in edit_idx_list])
127
+ for midx, mask in enumerate(mask_list_np):
128
+ if midx not in edit_idx_list:
129
+ if priority_list[edit_idx_list[0]] >= priority_list[midx]:
130
+ mask = mask.astype(float) - np.logical_and(mask.astype(bool) , mask_sel.astype(bool)).astype(float)
131
+ mask_list_np[midx] = mask.astype("uint8")
132
+ for midx in edit_idx_list:
133
+ for midx_1 in edit_idx_list:
134
+ if midx != midx_1:
135
+ if priority_list[midx] <= priority_list[midx_1]:
136
+ mask = mask_list_np[midx].astype(float) - np.logical_and(mask_list_np[midx].astype(bool), mask_list_np[midx_1].astype(bool)).astype(float)
137
+ mask_list_np[midx] = mask.astype("uint8")
138
+ return mask_list_np
139
+
140
+ def process_remain_mask(mask_list, edit_idx_list = None, force_mask_remain = None):
141
+ print("Start to process remaining mask using nearest neighbor")
142
+ width = mask_list[0].shape[0]
143
+ height = mask_list[0].shape[1]
144
+ pixel_ind = np.arange( width* height)
145
+
146
+ y_axis = np.arange(width)
147
+ ymesh = np.repeat(y_axis[:,np.newaxis], height, axis = 1) #N, N
148
+ ymesh_vec = ymesh.reshape(-1) #N *N
149
+
150
+ x_axis = np.arange(height)
151
+ xmesh = np.repeat(x_axis[np.newaxis, : ], width, axis = 0)
152
+ xmesh_vec = xmesh.reshape(-1)
153
+
154
+ mask_remain = (1 - sum([m.astype(float) for m in mask_list])).astype(np.uint8)
155
+ if force_mask_remain is not None:
156
+ mask_list[force_mask_remain] = (mask_list[force_mask_remain].astype(float) + mask_remain.astype(float)).astype(np.uint8)
157
+ else:
158
+ if edit_idx_list is not None:
159
+ a = [mask_list[eidx] for eidx in edit_idx_list]
160
+ mask_edit = mask_union(*a)
161
+ else:
162
+ mask_edit = np.zeros_like(mask_remain).astype(np.uint8)
163
+ mask_feasible = (1 - mask_remain.astype(float) - mask_edit.astype(float)).astype(np.uint8)
164
+
165
+ edge_width = 2
166
+
167
+ mask_feasible_down = myroll2d(mask_feasible, edge_width, 0)
168
+ mask_edge_down = (mask_feasible_down.astype(float) -mask_feasible.astype(float))<0
169
+
170
+ mask_feasible_up = myroll2d(mask_feasible, -edge_width, 0)
171
+ mask_edge_up = (mask_feasible_up.astype(float) -mask_feasible.astype(float))<0
172
+
173
+ mask_feasible_left = myroll2d(mask_feasible, 0, -edge_width)
174
+ mask_edge_left = (mask_feasible_left.astype(float) -mask_feasible.astype(float))<0
175
+
176
+ mask_feasible_right = myroll2d(mask_feasible, 0, edge_width)
177
+ mask_edge_right = (mask_feasible_right.astype(float) -mask_feasible.astype(float))<0
178
+
179
+ mask_feasible_ur = myroll2d(mask_feasible, -edge_width,edge_width)
180
+ mask_edge_ur = (mask_feasible_ur.astype(float) -mask_feasible.astype(float))<0
181
+
182
+ mask_feasible_ul = myroll2d(mask_feasible, -edge_width,-edge_width )
183
+ mask_edge_ul = (mask_feasible_ul.astype(float) -mask_feasible.astype(float))<0
184
+
185
+ mask_feasible_dr = myroll2d(mask_feasible, edge_width,edge_width )
186
+ mask_edge_dr = (mask_feasible_dr.astype(float) -mask_feasible.astype(float))<0
187
+
188
+ mask_feasible_dl = myroll2d(mask_feasible, edge_width,-edge_width)
189
+ mask_edge_ul = (mask_feasible_dl.astype(float) -mask_feasible.astype(float))<0
190
+
191
+ mask_edge = mask_union(
192
+ mask_edge_down, mask_edge_up, mask_edge_left, mask_edge_right, mask_edge_ur, mask_edge_ul, mask_edge_dr, mask_edge_ul
193
+ )
194
+
195
+ mask_feasible_edge = mask_intersection(mask_edge, mask_feasible)
196
+
197
+ vec_mask_feasible_edge = mask_feasible_edge.reshape(-1)
198
+ vec_mask_remain = mask_remain.reshape(-1)
199
+
200
+ indvec_all = np.arange(width*height)
201
+ vec_region_partition= 0
202
+ for mask_idx, mask in enumerate(mask_list):
203
+ vec_region_partition += mask.reshape(-1) * mask_idx
204
+ vec_region_partition += mask_remain.reshape(-1) * mask_idx
205
+ # assert 0 in vec_region_partition
206
+
207
+ vec_ind_remain = np.nonzero(vec_mask_remain)[0]
208
+ vec_ind_feasible_edge = np.nonzero(vec_mask_feasible_edge)[0]
209
+
210
+ vec_x_remain = xmesh_vec[vec_ind_remain]
211
+ vec_y_remain = ymesh_vec[vec_ind_remain]
212
+
213
+ vec_x_feasible_edge = xmesh_vec[vec_ind_feasible_edge]
214
+ vec_y_feasible_edge = ymesh_vec[vec_ind_feasible_edge]
215
+
216
+ x_dis = vec_x_remain[:,np.newaxis] - vec_x_feasible_edge[np.newaxis,:]
217
+ y_dis = vec_y_remain[:,np.newaxis] - vec_y_feasible_edge[np.newaxis,:]
218
+ dis = x_dis **2 + y_dis **2
219
+ pos = np.argmin(dis, axis = 1)
220
+ nearest_point = vec_ind_feasible_edge[pos] # closest point to target point
221
+
222
+ nearest_region = vec_region_partition[nearest_point]
223
+ nearest_region_set = set(nearest_region)
224
+ if edit_idx_list is not None:
225
+ for edit_idx in edit_idx_list:
226
+ assert edit_idx not in nearest_region
227
+
228
+ for midx, m in enumerate(mask_list):
229
+ if midx in nearest_region_set:
230
+ vec_newmask = np.zeros_like(indvec_all)
231
+ add_ind = vec_ind_remain [np.argwhere(nearest_region==midx)]
232
+ vec_newmask[add_ind] = 1
233
+
234
+ mask_list[midx] = mask_list[midx].astype(float)+ vec_newmask.reshape( mask_list[midx].shape).astype(float)
235
+ mask_list[midx] = mask_list[midx] > 0
236
+
237
+ print("Finish processing remaining mask, if you want to edit, launch the ui")
238
+ return mask_list, mask_remain
239
+
240
+ def resize_mask(mask_np, resize_ratio = 1):
241
+ w, h = mask_np.shape[0], mask_np.shape[1]
242
+ resized_w, resized_h = int(w*resize_ratio),int(h*resize_ratio)
243
+ mask_resized = torch.nn.functional.interpolate(torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0), (resized_w, resized_h)).squeeze()
244
+
245
+ mask = torch.zeros(w, h)
246
+ if w > resized_w:
247
+ mask[:resized_w, :resized_h] = mask_resized
248
+ else:
249
+ assert h <= resized_h
250
+ mask = mask_resized[resized_w//2-w//2: resized_w//2-w//2+w, resized_h//2-h//2: resized_h//2-h//2+h]
251
+ return mask.cpu().numpy().astype(np.uint8)
252
+
253
+ def process_mask_move_torch(
254
+ mask_list,
255
+ move_index_list,
256
+ delta_x_list = None,
257
+ delta_y_list = None,
258
+ edit_priority_list = None,
259
+ force_mask_remain = None,
260
+ resize_list = None
261
+ ):
262
+ mask_list_np = [m.cpu().numpy() for m in mask_list]
263
+ priority_list = [0 for _ in range(len(mask_list_np))]
264
+ for idx, (move_index, delta_x, delta_y, priority) in enumerate(zip(move_index_list, delta_x_list, delta_y_list, edit_priority_list)):
265
+ priority_list[move_index] = priority
266
+ if resize_list is not None:
267
+ mask = resize_mask (mask_list_np[move_index], resize_list[idx])
268
+ else:
269
+ mask = mask_list_np[move_index]
270
+ mask_list_np[move_index] = move_mask(mask, delta_x = delta_x, delta_y = delta_y)
271
+ mask_list_np = stack_mask_with_priority (mask_list_np, priority_list, move_index_list) # exists blank
272
+ check_mask_overlap_numpy(*mask_list_np)
273
+ mask_list_np, mask_remain = process_remain_mask(mask_list_np, move_index_list,force_mask_remain)
274
+ mask_list = [torch.from_numpy(m).to( dtype=torch.uint8) for m in mask_list_np]
275
+ mask_remain = torch.from_numpy(mask_remain).to(dtype=torch.uint8)
276
+ return mask_list, mask_remain
277
+
278
+ def process_mask_remove_torch(mask_list, remove_idx):
279
+ mask_list_np = [m.cpu().numpy() for m in mask_list]
280
+ mask_list_np[remove_idx] = np.zeros_like(mask_list_np[0])
281
+ mask_list_np, mask_remain = process_remain_mask(mask_list_np)
282
+ mask_list = [torch.from_numpy(m).to(dtype=torch.uint8) for m in mask_list_np]
283
+ mask_remain = torch.from_numpy(mask_remain).to(dtype=torch.uint8)
284
+ return mask_list, mask_remain
285
+
286
+ def get_mask_difference_torch(mask_list1, mask_list2):
287
+ assert len(mask_list1) == len(mask_list2)
288
+ mask_diff = torch.zeros_like(mask_list1[0])
289
+ for mask1 , mask2 in zip(mask_list1, mask_list2):
290
+ diff = ((mask1.float() - mask2.float())!=0).to(torch.uint8)
291
+ mask_diff = mask_union_torch(mask_diff, diff)
292
+ return mask_diff
293
+
294
+ def save_mask_list_to_npys(folder, mask_list, mask_label_list, name = "mask"):
295
+ for midx, (mask, mask_label) in enumerate(zip(mask_list, mask_label_list)):
296
+ np.save(os.path.join(folder, "{}{}_{}.npy".format(name, midx, mask_label)), mask)