Spaces:
Runtime error
Runtime error
Commit
·
d807efd
0
Parent(s):
first
Browse files- .gitignore +171 -0
- README.md +158 -0
- app.py +349 -0
- assets/demo1.gif +0 -0
- assets/demo2.gif +0 -0
- assets/demo3.gif +0 -0
- assets/demo4.gif +0 -0
- assets/mask_def.png +0 -0
- controller.py +156 -0
- example2/img.png +0 -0
- main.py +424 -0
- pipeline_dedit_sd.py +813 -0
- pipeline_dedit_sdxl.py +875 -0
- requirements.txt +12 -0
- scripts/run_segment.sh +3 -0
- scripts/run_segmentSAM.sh +2 -0
- scripts/sd/run_ft_sd_512.sh +11 -0
- scripts/sd/run_ft_sd_512_2imgs.sh +14 -0
- scripts/sd/run_image.sh +15 -0
- scripts/sd/run_move_resize.sh +18 -0
- scripts/sd/run_recon.sh +11 -0
- scripts/sd/run_remove.sh +55 -0
- scripts/sd/run_text.sh +33 -0
- scripts/sdxl/run_ft_sdxl_1024.sh +12 -0
- scripts/sdxl/run_ft_sdxl_1024_2imgs.sh +14 -0
- scripts/sdxl/run_ft_sdxl_1024_auxin_todo.sh +29 -0
- scripts/sdxl/run_ft_sdxl_1024_fulllora.sh +12 -0
- scripts/sdxl/run_ft_sdxl_1024_fulllora_2imgs.sh +41 -0
- scripts/sdxl/run_image.sh +15 -0
- scripts/sdxl/run_image_w_edited_mask.sh +16 -0
- scripts/sdxl/run_move_resize.sh +21 -0
- scripts/sdxl/run_recon.sh +11 -0
- scripts/sdxl/run_recon_item_todo.sh +27 -0
- scripts/sdxl/run_remove.sh +17 -0
- scripts/sdxl/run_text.sh +14 -0
- scripts/sdxl/run_text_w_edited_mask.sh +15 -0
- segment.py +115 -0
- segment_sam.py +294 -0
- utils.py +326 -0
- utils_mask.py +296 -0
.gitignore
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
example1_512/
|
2 |
+
example1_1024/
|
3 |
+
example1_example2_512/
|
4 |
+
example1_example2_1024/
|
5 |
+
example1/
|
6 |
+
old/
|
7 |
+
|
8 |
+
out_active.png
|
9 |
+
out_mask.png
|
10 |
+
out_soft.png
|
11 |
+
|
12 |
+
# Byte-compiled / optimized / DLL files
|
13 |
+
__pycache__/
|
14 |
+
*.py[cod]
|
15 |
+
*$py.class
|
16 |
+
|
17 |
+
# C extensions
|
18 |
+
*.so
|
19 |
+
|
20 |
+
# Distribution / packaging
|
21 |
+
.Python
|
22 |
+
build/
|
23 |
+
develop-eggs/
|
24 |
+
dist/
|
25 |
+
downloads/
|
26 |
+
eggs/
|
27 |
+
.eggs/
|
28 |
+
lib/
|
29 |
+
lib64/
|
30 |
+
parts/
|
31 |
+
sdist/
|
32 |
+
var/
|
33 |
+
wheels/
|
34 |
+
share/python-wheels/
|
35 |
+
*.egg-info/
|
36 |
+
.installed.cfg
|
37 |
+
*.egg
|
38 |
+
MANIFEST
|
39 |
+
|
40 |
+
# PyInstaller
|
41 |
+
# Usually these files are written by a python script from a template
|
42 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
43 |
+
*.manifest
|
44 |
+
*.spec
|
45 |
+
|
46 |
+
# Installer logs
|
47 |
+
pip-log.txt
|
48 |
+
pip-delete-this-directory.txt
|
49 |
+
|
50 |
+
# Unit test / coverage reports
|
51 |
+
htmlcov/
|
52 |
+
.tox/
|
53 |
+
.nox/
|
54 |
+
.coverage
|
55 |
+
.coverage.*
|
56 |
+
.cache
|
57 |
+
nosetests.xml
|
58 |
+
coverage.xml
|
59 |
+
*.cover
|
60 |
+
*.py,cover
|
61 |
+
.hypothesis/
|
62 |
+
.pytest_cache/
|
63 |
+
cover/
|
64 |
+
|
65 |
+
# Translations
|
66 |
+
*.mo
|
67 |
+
*.pot
|
68 |
+
|
69 |
+
# Django stuff:
|
70 |
+
*.log
|
71 |
+
local_settings.py
|
72 |
+
db.sqlite3
|
73 |
+
db.sqlite3-journal
|
74 |
+
|
75 |
+
# Flask stuff:
|
76 |
+
instance/
|
77 |
+
.webassets-cache
|
78 |
+
|
79 |
+
# Scrapy stuff:
|
80 |
+
.scrapy
|
81 |
+
|
82 |
+
# Sphinx documentation
|
83 |
+
docs/_build/
|
84 |
+
|
85 |
+
# PyBuilder
|
86 |
+
.pybuilder/
|
87 |
+
target/
|
88 |
+
|
89 |
+
# Jupyter Notebook
|
90 |
+
.ipynb_checkpoints
|
91 |
+
|
92 |
+
# IPython
|
93 |
+
profile_default/
|
94 |
+
ipython_config.py
|
95 |
+
|
96 |
+
# pyenv
|
97 |
+
# For a library or package, you might want to ignore these files since the code is
|
98 |
+
# intended to run in multiple environments; otherwise, check them in:
|
99 |
+
# .python-version
|
100 |
+
|
101 |
+
# pipenv
|
102 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
103 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
104 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
105 |
+
# install all needed dependencies.
|
106 |
+
#Pipfile.lock
|
107 |
+
|
108 |
+
# poetry
|
109 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
110 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
111 |
+
# commonly ignored for libraries.
|
112 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
113 |
+
#poetry.lock
|
114 |
+
|
115 |
+
# pdm
|
116 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
117 |
+
#pdm.lock
|
118 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
119 |
+
# in version control.
|
120 |
+
# https://pdm.fming.dev/#use-with-ide
|
121 |
+
.pdm.toml
|
122 |
+
|
123 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
124 |
+
__pypackages__/
|
125 |
+
|
126 |
+
# Celery stuff
|
127 |
+
celerybeat-schedule
|
128 |
+
celerybeat.pid
|
129 |
+
|
130 |
+
# SageMath parsed files
|
131 |
+
*.sage.py
|
132 |
+
|
133 |
+
# Environments
|
134 |
+
.env
|
135 |
+
.venv
|
136 |
+
env/
|
137 |
+
venv/
|
138 |
+
ENV/
|
139 |
+
env.bak/
|
140 |
+
venv.bak/
|
141 |
+
|
142 |
+
# Spyder project settings
|
143 |
+
.spyderproject
|
144 |
+
.spyproject
|
145 |
+
|
146 |
+
# Rope project settings
|
147 |
+
.ropeproject
|
148 |
+
|
149 |
+
# mkdocs documentation
|
150 |
+
/site
|
151 |
+
|
152 |
+
# mypy
|
153 |
+
.mypy_cache/
|
154 |
+
.dmypy.json
|
155 |
+
dmypy.json
|
156 |
+
|
157 |
+
# Pyre type checker
|
158 |
+
.pyre/
|
159 |
+
|
160 |
+
# pytype static type analyzer
|
161 |
+
.pytype/
|
162 |
+
|
163 |
+
# Cython debug symbols
|
164 |
+
cython_debug/
|
165 |
+
|
166 |
+
# PyCharm
|
167 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
168 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
169 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
170 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
171 |
+
#.idea/
|
README.md
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<h1>An item is Worth a Prompt: Versatile Image Editing with Disentangled Control</h1>
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
<a href='https://arxiv.org/abs/2403.04880'><img src='https://img.shields.io/badge/Technique-Report-red'></a>
|
7 |
+
|
8 |
+
|
9 |
+
</div>
|
10 |
+
D-Edit is a versatile image editing framework based on diffusion models, supporting text, image, mask-based editing.
|
11 |
+
|
12 |
+
<!-- <img src='assets/applications.png'> -->
|
13 |
+
## Release
|
14 |
+
- [2024/03/12] 🔥 Code uploaded.
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
## 🔥 Examples
|
20 |
+
|
21 |
+
<p align="center">
|
22 |
+
<img alt="text" src="assets/demo1.gif" width="45%">
|
23 |
+
|
24 |
+
<img alt="image" src="assets/demo2.gif" width="45%">
|
25 |
+
</p>
|
26 |
+
|
27 |
+
1. **Text-Guided Editing**:Allows users to select an object within an image and replace or refine it based on a text description.
|
28 |
+
- Key features:
|
29 |
+
- Generates more realistic details and smoother transitions than alternative methods
|
30 |
+
- Focuses edits specifically on the targeted object
|
31 |
+
- Preserves unrelated parts of the image
|
32 |
+
|
33 |
+
2. **Image-Guided Editing**: Enables users to choose an object from a reference image and transplant it into another image while preserving its identity.
|
34 |
+
- Key features:
|
35 |
+
- Ensures seamless integration of the object into the new context
|
36 |
+
- Adapts the object's appearance to match the target image's style
|
37 |
+
- Works effectively even when the object's appearance differs significantly between reference and target images
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
<p align="center">
|
42 |
+
<img alt="mask" src="assets/demo3.gif" width="45%">
|
43 |
+
|
44 |
+
<img alt="remove" src="assets/demo4.gif" width="45%">
|
45 |
+
</p>
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
3. **Mask-Based Editing**: Involves manipulating objects by directly editing their masks.
|
50 |
+
- Key features:
|
51 |
+
- Allows for operations like moving, reshaping, resizing, and refining objects
|
52 |
+
- Fills in new details according to the object's associated prompt
|
53 |
+
- Produces natural-looking results that maintain consistency with the overall image
|
54 |
+
|
55 |
+
4. **Item Removal**: Enables users to remove objects from images by deleting the mask-object associations.
|
56 |
+
- Key features:
|
57 |
+
- Intelligently fills in the empty space left by removed objects
|
58 |
+
- Ensures a coherent final image
|
59 |
+
- Maintains the integrity of the surrounding image elements
|
60 |
+
|
61 |
+
## 🔧 Dependencies and Installation
|
62 |
+
- Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
|
63 |
+
- [PyTorch >= 2.1.0](https://pytorch.org/)
|
64 |
+
```bash
|
65 |
+
conda create --name dedit python=3.10
|
66 |
+
conda activate dedit
|
67 |
+
pip install -U pip
|
68 |
+
|
69 |
+
# Install requirements
|
70 |
+
pip install -r requirements.txt
|
71 |
+
```
|
72 |
+
|
73 |
+
|
74 |
+
## 💻 Run
|
75 |
+
|
76 |
+
### 1. Segmentation
|
77 |
+
Put the image (of any resolution) to be edited into the folder with a specified name, and rename the image as "img.png" or "img.jpg".
|
78 |
+
Then run the segmentation model
|
79 |
+
```
|
80 |
+
sh ./scripts/run_segment.sh
|
81 |
+
```
|
82 |
+
Alternatively, run [GroundedSAM](https://github.com/IDEA-Research/Grounded-Segment-Anything) to detect with text prompt
|
83 |
+
```
|
84 |
+
sh ./scripts/run_segmentSAM.sh
|
85 |
+
```
|
86 |
+
|
87 |
+
Optionally, if segmentation is not good, refine masks with GUI by locally running the mask editing web:
|
88 |
+
```
|
89 |
+
python ui_edit_mask.py
|
90 |
+
```
|
91 |
+
For image-based editing, repeat this step for both reference and target images.
|
92 |
+
|
93 |
+
### 2. Model Finetuning
|
94 |
+
Finetune UNet cross-attention layer of diffusion models by running
|
95 |
+
```
|
96 |
+
sh ./scripts/sdxl/run_ft_sdxl_1024.sh
|
97 |
+
```
|
98 |
+
or finetune full UNet with lora
|
99 |
+
```
|
100 |
+
sh ./scripts/sdxl/run_ft_sdxl_1024_fulllora.sh
|
101 |
+
```
|
102 |
+
If image-based editing is needed, finetune the model with both reference and target images using
|
103 |
+
|
104 |
+
```
|
105 |
+
sh ./scripts/sdxl/run_ft_sdxl_1024_fulllora_2imgs.sh
|
106 |
+
```
|
107 |
+
|
108 |
+
### 3. Edit \!
|
109 |
+
#### 3.1 Reconstruction
|
110 |
+
To see if the original image can be constructed
|
111 |
+
```
|
112 |
+
sh ./scripts/sdxl/run_recon.sh
|
113 |
+
```
|
114 |
+
#### 3.1 Text-based
|
115 |
+
Replace the target item (tgt_index) with the item described by the text prompt (tgt_prompt)
|
116 |
+
```
|
117 |
+
sh ./scripts/sdxl/run_text.sh
|
118 |
+
```
|
119 |
+
#### 3.2 Image-based
|
120 |
+
Replace the target item (tgt_index) in the target image (tgt_name) with the item (src_index) in the reference image
|
121 |
+
```
|
122 |
+
sh ./scripts/sdxl/run_image.sh
|
123 |
+
```
|
124 |
+
#### 3.3 Mask-based
|
125 |
+
For target items (tgt_indices_list), resize it (resize_list), move it (delta_x, delta_y) or reshape it by manually editing the mask shape (using UI).
|
126 |
+
|
127 |
+
The resulting new masks (processed by a simple algorithm) can be visualized in './example1/move_resize/seg_move_resize.png', if it is not reasonable, edit using the UI.
|
128 |
+
|
129 |
+
```
|
130 |
+
sh ./scripts/sdxl/run_move_resize.sh
|
131 |
+
```
|
132 |
+
#### 3.4 Remove
|
133 |
+
Remove the target item (tgt_index), the remaining region will be reassigned to the nearby regions with a simple algorithm.
|
134 |
+
The resulting new masks (processed by a simple algorithm) can be visualized in './example1/remove/seg_removed.png', if it is not reasonable, edit using the UI.
|
135 |
+
|
136 |
+
```
|
137 |
+
sh ./scripts/sdxl/run_move_resize.sh
|
138 |
+
```
|
139 |
+
|
140 |
+
#### 3.4 General editing parameters
|
141 |
+
- We partition the image into three regions as shown below. Regions with the hard mask are frozen, regions with the active mask are generated with diffusion model, and regions with soft mask keep the original content in the first "strength*N" sampling steps.
|
142 |
+
<p align="center">
|
143 |
+
<img src="assets/mask_def.png" height=200>
|
144 |
+
</p>
|
145 |
+
|
146 |
+
- During editing, if you use an edited segmentation that is different from finetuning, add --load_edited_mask; For mask-based and remove, if you edit the masks automatically processed by the algorithm as mentioned, add --load_edited_processed_mask.
|
147 |
+
|
148 |
+
### Cite
|
149 |
+
If you find D-Edit useful for your research and applications, please cite us using this BibTeX:
|
150 |
+
|
151 |
+
```bibtex
|
152 |
+
@article{feng2024dedit,
|
153 |
+
title={An item is Worth a Prompt: Versatile Image Editing with Disentangled Control},
|
154 |
+
author={Aosong Feng, Weikang Qiu, Jinbin Bai, Kaicheng Zhou, Zhen Dong, Xiao Zhang, Rex Ying, and Leandros Tassiulas},
|
155 |
+
journal={arXiv preprint arXiv:2403.04880},
|
156 |
+
year={2024}
|
157 |
+
}
|
158 |
+
```
|
app.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import copy
|
4 |
+
from PIL import Image
|
5 |
+
import matplotlib
|
6 |
+
import numpy as np
|
7 |
+
import gradio as gr
|
8 |
+
from utils import load_mask, load_mask_edit
|
9 |
+
from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean
|
10 |
+
from pathlib import Path
|
11 |
+
import subprocess
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
LENGTH=512 #length of the square area displaying/editing images
|
15 |
+
TRANSPARENCY = 150 # transparency of the mask in display
|
16 |
+
|
17 |
+
def add_mask(mask_np_list_updated, mask_label_list):
|
18 |
+
mask_new = np.zeros_like(mask_np_list_updated[0])
|
19 |
+
mask_np_list_updated.append(mask_new)
|
20 |
+
mask_label_list.append("new")
|
21 |
+
return mask_np_list_updated, mask_label_list
|
22 |
+
|
23 |
+
def create_segmentation(mask_np_list):
|
24 |
+
viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list))
|
25 |
+
segmentation = 0
|
26 |
+
for i, m in enumerate(mask_np_list):
|
27 |
+
color = matplotlib.colors.to_rgb(viridis(i))
|
28 |
+
color_mat = np.ones_like(m)
|
29 |
+
color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2)
|
30 |
+
color_mat = color_mat * m[:,:,np.newaxis]
|
31 |
+
segmentation += color_mat
|
32 |
+
segmentation = Image.fromarray(np.uint8(segmentation*255))
|
33 |
+
return segmentation
|
34 |
+
|
35 |
+
def load_mask_ui(input_folder,load_edit = False):
|
36 |
+
if not load_edit:
|
37 |
+
mask_list, mask_label_list = load_mask(input_folder)
|
38 |
+
else:
|
39 |
+
mask_list, mask_label_list = load_mask_edit(input_folder)
|
40 |
+
|
41 |
+
mask_np_list = []
|
42 |
+
for m in mask_list:
|
43 |
+
mask_np_list. append( m.cpu().numpy())
|
44 |
+
|
45 |
+
return mask_np_list, mask_label_list
|
46 |
+
|
47 |
+
def load_image_ui(input_folder, load_edit):
|
48 |
+
try:
|
49 |
+
for img_path in Path(input_folder).iterdir():
|
50 |
+
if img_path.name in ["img.png", "img_1024.png", "img_512.png"]:
|
51 |
+
image = Image.open(img_path)
|
52 |
+
mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit)
|
53 |
+
image = image.convert('RGB')
|
54 |
+
segmentation = create_segmentation(mask_np_list)
|
55 |
+
return image, segmentation, mask_np_list, mask_label_list, image
|
56 |
+
except:
|
57 |
+
print("Image folder invalid: The folder should contain image.png")
|
58 |
+
return None, None, None, None, None
|
59 |
+
|
60 |
+
def run_segmentation(input_folder):
|
61 |
+
subprocess.run(["python", "segment.py" , "--name={}".format(input_folder)])
|
62 |
+
return
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
def run_edit_text(
|
67 |
+
input_folder,
|
68 |
+
num_tokens,
|
69 |
+
num_sampling_steps,
|
70 |
+
strength,
|
71 |
+
edge_thickness,
|
72 |
+
tgt_prompt,
|
73 |
+
tgt_idx,
|
74 |
+
guidance_scale
|
75 |
+
):
|
76 |
+
subprocess.run(["python",
|
77 |
+
"main.py" ,
|
78 |
+
"--text",
|
79 |
+
"--name={}".format(input_folder),
|
80 |
+
"--dpm={}".format("sd"),
|
81 |
+
"--resolution={}".format(512),
|
82 |
+
"--load_trained",
|
83 |
+
"--num_tokens={}".format(num_tokens),
|
84 |
+
"--seed={}".format(2024),
|
85 |
+
"--guidance_scale={}".format(guidance_scale),
|
86 |
+
"--num_sampling_step={}".format(num_sampling_steps),
|
87 |
+
"--strength={}".format(strength),
|
88 |
+
"--edge_thickness={}".format(edge_thickness),
|
89 |
+
"--num_imgs={}".format(2),
|
90 |
+
"--tgt_prompt={}".format(tgt_prompt) ,
|
91 |
+
"--tgt_index={}".format(tgt_idx)
|
92 |
+
])
|
93 |
+
|
94 |
+
return Image.open(os.path.join(input_folder, "text", "out_text_0.png"))
|
95 |
+
|
96 |
+
|
97 |
+
def run_optimization(
|
98 |
+
input_folder,
|
99 |
+
num_tokens,
|
100 |
+
embedding_learning_rate,
|
101 |
+
max_emb_train_steps,
|
102 |
+
diffusion_model_learning_rate,
|
103 |
+
max_diffusion_train_steps,
|
104 |
+
train_batch_size,
|
105 |
+
gradient_accumulation_steps
|
106 |
+
):
|
107 |
+
subprocess.run(["python",
|
108 |
+
"main.py" ,
|
109 |
+
"--name={}".format(input_folder),
|
110 |
+
"--dpm={}".format("sd"),
|
111 |
+
"--resolution={}".format(512),
|
112 |
+
"--num_tokens={}".format(num_tokens),
|
113 |
+
"--embedding_learning_rate={}".format(embedding_learning_rate),
|
114 |
+
"--diffusion_model_learning_rate={}".format(diffusion_model_learning_rate),
|
115 |
+
"--max_emb_train_steps={}".format(max_emb_train_steps),
|
116 |
+
"--max_diffusion_train_steps={}".format(max_diffusion_train_steps),
|
117 |
+
"--train_batch_size={}".format(train_batch_size),
|
118 |
+
"--gradient_accumulation_steps={}".format(gradient_accumulation_steps)
|
119 |
+
|
120 |
+
])
|
121 |
+
return
|
122 |
+
|
123 |
+
|
124 |
+
def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
|
125 |
+
backimg_solid_np = np.array(backimg)
|
126 |
+
bimg = backimg.copy()
|
127 |
+
fimg = foreimg.copy()
|
128 |
+
fimg.putalpha(transparency)
|
129 |
+
bimg.paste(fimg, (0,0), fimg)
|
130 |
+
|
131 |
+
bimg_np = np.array(bimg)
|
132 |
+
mask_np = mask_np[:,:,np.newaxis]
|
133 |
+
try:
|
134 |
+
new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np
|
135 |
+
return Image.fromarray(new_img_np)
|
136 |
+
except:
|
137 |
+
import pdb; pdb.set_trace()
|
138 |
+
|
139 |
+
def show_segmentation(image, segmentation, flag):
|
140 |
+
if flag is False:
|
141 |
+
flag = True
|
142 |
+
mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8)
|
143 |
+
image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY)
|
144 |
+
return image_edit, flag
|
145 |
+
else:
|
146 |
+
flag = False
|
147 |
+
return image,flag
|
148 |
+
|
149 |
+
def edit_mask_add(canvas, image, idx, mask_np_list):
|
150 |
+
mask_sel = mask_np_list[idx]
|
151 |
+
mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.)
|
152 |
+
mask_np_list_updated = []
|
153 |
+
for midx, m in enumerate(mask_np_list):
|
154 |
+
if midx == idx:
|
155 |
+
mask_np_list_updated.append(mask_union(mask_sel, mask_new))
|
156 |
+
else:
|
157 |
+
mask_np_list_updated.append(m)
|
158 |
+
|
159 |
+
priority_list = [0 for _ in range(len(mask_np_list_updated))]
|
160 |
+
priority_list[idx] = 1
|
161 |
+
mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list)
|
162 |
+
mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8)
|
163 |
+
segmentation = create_segmentation(mask_np_list_updated)
|
164 |
+
image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY)
|
165 |
+
return mask_np_list_updated, image_edit
|
166 |
+
|
167 |
+
def slider_release(index, image, mask_np_list_updated, mask_label_list):
|
168 |
+
if index > len(mask_np_list_updated):
|
169 |
+
return image, "out of range"
|
170 |
+
else:
|
171 |
+
mask_np = mask_np_list_updated[index]
|
172 |
+
mask_label = mask_label_list[index]
|
173 |
+
segmentation = create_segmentation(mask_np_list_updated)
|
174 |
+
new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY)
|
175 |
+
return new_image, mask_label
|
176 |
+
|
177 |
+
def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder):
|
178 |
+
try:
|
179 |
+
assert np.all(sum(mask_np_list_updated)==1)
|
180 |
+
except:
|
181 |
+
print("please check mask")
|
182 |
+
# plt.imsave( "out_mask.png", mask_list_edit[0])
|
183 |
+
import pdb; pdb.set_trace()
|
184 |
+
|
185 |
+
for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
|
186 |
+
# np.save(os.path.join(input_folder, "maskEDIT{}_{}.npy".format(midx, mask_label)),mask )
|
187 |
+
np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask )
|
188 |
+
savepath = os.path.join(input_folder, "seg_current.png")
|
189 |
+
visualize_mask_list_clean(mask_np_list_updated, savepath)
|
190 |
+
|
191 |
+
def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder):
|
192 |
+
try:
|
193 |
+
assert np.all(sum(mask_np_list_updated)==1)
|
194 |
+
except:
|
195 |
+
print("please check mask")
|
196 |
+
# plt.imsave( "out_mask.png", mask_list_edit[0])
|
197 |
+
import pdb; pdb.set_trace()
|
198 |
+
for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
|
199 |
+
np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask)
|
200 |
+
savepath = os.path.join(input_folder, "seg_edited.png")
|
201 |
+
visualize_mask_list_clean(mask_np_list_updated, savepath)
|
202 |
+
|
203 |
+
with gr.Blocks() as demo:
|
204 |
+
image = gr.State() # store mask
|
205 |
+
image_loaded = gr.State()
|
206 |
+
segmentation = gr.State()
|
207 |
+
|
208 |
+
mask_np_list = gr.State([])
|
209 |
+
mask_label_list = gr.State([])
|
210 |
+
mask_np_list_updated = gr.State([])
|
211 |
+
true = gr.State(True)
|
212 |
+
false = gr.State(False)
|
213 |
+
|
214 |
+
|
215 |
+
with gr.Row():
|
216 |
+
gr.Markdown("""# D-Edit""")
|
217 |
+
|
218 |
+
with gr.Tab(label="1 Edit mask"):
|
219 |
+
with gr.Row():
|
220 |
+
with gr.Column():
|
221 |
+
canvas = gr.Image(value = None, type="numpy", tool="sketch", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
|
222 |
+
input_folder = gr.Textbox(value="example1", label="input folder", interactive= True, )
|
223 |
+
|
224 |
+
segment_button = gr.Button("1.1 Run segmentation")
|
225 |
+
segment_button.click(run_segmentation,
|
226 |
+
[input_folder] ,
|
227 |
+
[] )
|
228 |
+
|
229 |
+
|
230 |
+
text_button = gr.Button("1.2 Load original masks")
|
231 |
+
text_button.click(load_image_ui,
|
232 |
+
[input_folder, false] ,
|
233 |
+
[image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
|
234 |
+
|
235 |
+
load_edit_button = gr.Button("1.2 Load edited masks")
|
236 |
+
load_edit_button.click(load_image_ui,
|
237 |
+
[input_folder, true] ,
|
238 |
+
[image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
|
239 |
+
|
240 |
+
show_segment = gr.Checkbox(label = "Show Segmentation")
|
241 |
+
|
242 |
+
flag = gr.State(False)
|
243 |
+
show_segment.select(show_segmentation,
|
244 |
+
[image_loaded, segmentation, flag],
|
245 |
+
[canvas, flag])
|
246 |
+
|
247 |
+
mask_np_list_updated = copy.deepcopy(mask_np_list)
|
248 |
+
|
249 |
+
with gr.Column():
|
250 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
|
251 |
+
slider = gr.Slider(0, 20, step=1, interactive=True)
|
252 |
+
label = gr.Textbox()
|
253 |
+
slider.release(slider_release,
|
254 |
+
inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list],
|
255 |
+
outputs= [canvas, label]
|
256 |
+
)
|
257 |
+
add_button = gr.Button("Add")
|
258 |
+
add_button.click( edit_mask_add,
|
259 |
+
[canvas, image_loaded, slider, mask_np_list_updated] ,
|
260 |
+
[mask_np_list_updated, canvas]
|
261 |
+
)
|
262 |
+
|
263 |
+
save_button2 = gr.Button("Set and Save as edited masks")
|
264 |
+
save_button2.click( save_as_edit_mask,
|
265 |
+
[mask_np_list_updated, mask_label_list, input_folder] ,
|
266 |
+
[] )
|
267 |
+
|
268 |
+
save_button = gr.Button("Set and Save as original masks")
|
269 |
+
save_button.click( save_as_orig_mask,
|
270 |
+
[mask_np_list_updated, mask_label_list, input_folder] ,
|
271 |
+
[] )
|
272 |
+
|
273 |
+
back_button = gr.Button("Back to current seg")
|
274 |
+
back_button.click( load_mask_ui,
|
275 |
+
[input_folder] ,
|
276 |
+
[ mask_np_list_updated,mask_label_list] )
|
277 |
+
|
278 |
+
add_mask_button = gr.Button("Add new empty mask")
|
279 |
+
add_mask_button.click(add_mask,
|
280 |
+
[mask_np_list_updated, mask_label_list] ,
|
281 |
+
[mask_np_list_updated, mask_label_list] )
|
282 |
+
|
283 |
+
with gr.Tab(label="2 Optimization"):
|
284 |
+
with gr.Row():
|
285 |
+
with gr.Column():
|
286 |
+
canvas_opt = gr.Image(value = canvas.value, type="pil", tool="sketch", label="Loaded Image", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
|
287 |
+
|
288 |
+
with gr.Column():
|
289 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
|
290 |
+
num_tokens = gr.Textbox(value="5", label="num tokens to represent each object", interactive= True)
|
291 |
+
embedding_learning_rate = gr.Textbox(value="1e-4", label="Embedding optimization: Learning rate", interactive= True )
|
292 |
+
max_emb_train_steps = gr.Textbox(value="500", label="embedding optimization: Training steps", interactive= True )
|
293 |
+
|
294 |
+
diffusion_model_learning_rate = gr.Textbox(value="5e-5", label="UNet Optimization: Learning rate", interactive= True )
|
295 |
+
max_diffusion_train_steps = gr.Textbox(value="500", label="UNet Optimization: Learning rate: Training steps", interactive= True )
|
296 |
+
|
297 |
+
train_batch_size = gr.Textbox(value="5", label="Batch size", interactive= True )
|
298 |
+
gradient_accumulation_steps=gr.Textbox(value="5", label="Gradient accumulation", interactive= True )
|
299 |
+
|
300 |
+
add_button = gr.Button("Run optimization")
|
301 |
+
add_button.click(run_optimization,
|
302 |
+
inputs = [
|
303 |
+
input_folder,
|
304 |
+
num_tokens,
|
305 |
+
embedding_learning_rate,
|
306 |
+
max_emb_train_steps,
|
307 |
+
diffusion_model_learning_rate,
|
308 |
+
max_diffusion_train_steps,
|
309 |
+
train_batch_size,gradient_accumulation_steps
|
310 |
+
],
|
311 |
+
outputs = []
|
312 |
+
)
|
313 |
+
|
314 |
+
|
315 |
+
with gr.Tab(label="3 Editing"):
|
316 |
+
with gr.Tab(label="3.1 Text-based editing"):
|
317 |
+
canvas_text_edit = gr.State() # store mask
|
318 |
+
with gr.Row():
|
319 |
+
with gr.Column():
|
320 |
+
canvas_text_edit = gr.Image(value = None, label="Editing results", show_label=True, height=LENGTH, width=LENGTH)
|
321 |
+
# canvas_text_edit = gr.Gallery(label = "Edited results")
|
322 |
+
|
323 |
+
with gr.Column():
|
324 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""")
|
325 |
+
|
326 |
+
tgt_prompt = gr.Textbox(value="Dog", label="Editing: Text prompt", interactive= True )
|
327 |
+
tgt_idx = gr.Textbox(value="0", label="Editing: Object index", interactive= True )
|
328 |
+
guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True )
|
329 |
+
num_sampling_steps = gr.Textbox(value="50", label="Editing: Sampling steps", interactive= True )
|
330 |
+
edge_thickness = gr.Textbox(value="10", label="Editing: Edge thickness", interactive= True )
|
331 |
+
strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
|
332 |
+
|
333 |
+
add_button = gr.Button("Run Editing")
|
334 |
+
add_button.click(run_edit_text,
|
335 |
+
inputs = [
|
336 |
+
input_folder,
|
337 |
+
num_tokens,
|
338 |
+
num_sampling_steps,
|
339 |
+
strength,
|
340 |
+
edge_thickness,
|
341 |
+
tgt_prompt,
|
342 |
+
tgt_idx,
|
343 |
+
guidance_scale
|
344 |
+
],
|
345 |
+
outputs = [canvas_text_edit]
|
346 |
+
)
|
347 |
+
|
348 |
+
|
349 |
+
demo.queue().launch(share=True, debug=True)
|
assets/demo1.gif
ADDED
![]() |
assets/demo2.gif
ADDED
![]() |
assets/demo3.gif
ADDED
![]() |
assets/demo4.gif
ADDED
![]() |
assets/mask_def.png
ADDED
![]() |
controller.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import math
|
4 |
+
import xformers
|
5 |
+
|
6 |
+
class DummyController:
|
7 |
+
def __call__(self, *args):
|
8 |
+
return args[0]
|
9 |
+
def __init__(self):
|
10 |
+
self.num_att_layers = 0
|
11 |
+
|
12 |
+
class GroupedCAController:
|
13 |
+
def __init__(self, mask_list = None):
|
14 |
+
self.mask_list = mask_list
|
15 |
+
if self.mask_list is None:
|
16 |
+
self.is_decom = False
|
17 |
+
else:
|
18 |
+
self.is_decom = True
|
19 |
+
|
20 |
+
def mask_img_to_mask_vec(self, mask, length):
|
21 |
+
mask_vec = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), (length, length)).squeeze()
|
22 |
+
mask_vec = mask_vec.flatten()
|
23 |
+
return mask_vec
|
24 |
+
|
25 |
+
def ca_forward_decom(self, q, k_list, v_list, scale, place_in_unet):
|
26 |
+
# attn [Bh, N, d ]
|
27 |
+
# [8, 4096, 77]
|
28 |
+
# q [Bh, N, d] [8, 4096, 40] [8, 1024, 80] [8, 256,160] [8, 64, 160]
|
29 |
+
# k [Bh, P, d] [8, 77 , 40] [8, 77, 80] [8, 77, 160] [8, 77, 160]
|
30 |
+
# v [Bh, P, d] [8, 77 , 40] [8, 77, 80] [8, 77, 160] [8, 77, 160]
|
31 |
+
N = q.shape[1]
|
32 |
+
mask_vec_list = []
|
33 |
+
for mask in self.mask_list:
|
34 |
+
mask_vec = self.mask_img_to_mask_vec(mask, int(math.sqrt(N))) # [1,N,1]
|
35 |
+
mask_vec = mask_vec.unsqueeze(0).unsqueeze(-1)
|
36 |
+
mask_vec_list.append(mask_vec)
|
37 |
+
out = 0
|
38 |
+
for mask_vec, k, v in zip(mask_vec_list, k_list, v_list):
|
39 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * scale # [8, 4096, 20]
|
40 |
+
attn = sim.softmax(dim=-1) # [Bh,N,P] [8,4096,20]
|
41 |
+
attn = attn.masked_fill(mask_vec==0, 0)
|
42 |
+
masked_out = torch.einsum("b i j, b j d -> b i d", attn, v) # [Bh,N,d] [8,4096,320/h]
|
43 |
+
# mask_vec_inf = torch.where(mask_vec>0, 0, torch.finfo(k.dtype).min)
|
44 |
+
# masked_out1 = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask_vec_inf, op=None, scale=scale)
|
45 |
+
out += masked_out
|
46 |
+
return out
|
47 |
+
|
48 |
+
def reshape_heads_to_batch_dim(self):
|
49 |
+
def func(tensor):
|
50 |
+
batch_size, seq_len, dim = tensor.shape
|
51 |
+
head_size = self.num_heads
|
52 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
53 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
54 |
+
return func
|
55 |
+
|
56 |
+
def reshape_batch_dim_to_heads(self):
|
57 |
+
def func(tensor):
|
58 |
+
batch_size, seq_len, dim = tensor.shape
|
59 |
+
head_size = self.num_heads
|
60 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
61 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
62 |
+
return func
|
63 |
+
|
64 |
+
def register_attention_disentangled_control(unet, controller):
|
65 |
+
def ca_forward(self, place_in_unet):
|
66 |
+
to_out = self.to_out
|
67 |
+
if type(to_out) is torch.nn.modules.container.ModuleList:
|
68 |
+
to_out = self.to_out[0]
|
69 |
+
else:
|
70 |
+
to_out = self.to_out
|
71 |
+
def forward(x, encoder_hidden_states =None, attention_mask=None):
|
72 |
+
if isinstance(controller, DummyController): # SA CA full
|
73 |
+
q = self.to_q(x)
|
74 |
+
is_cross = encoder_hidden_states is not None
|
75 |
+
encoder_hidden_states = encoder_hidden_states if is_cross else x
|
76 |
+
k = self.to_k(encoder_hidden_states)
|
77 |
+
v = self.to_v(encoder_hidden_states)
|
78 |
+
q = self.head_to_batch_dim(q)
|
79 |
+
k = self.head_to_batch_dim(k)
|
80 |
+
v = self.head_to_batch_dim(v)
|
81 |
+
|
82 |
+
# sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
83 |
+
# attn = sim.softmax(dim=-1)
|
84 |
+
# attn = controller(attn, is_cross, place_in_unet)
|
85 |
+
# out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
86 |
+
out = xformers.ops.memory_efficient_attention(
|
87 |
+
q, k, v, attn_bias=None, op=None, scale=self.scale
|
88 |
+
)
|
89 |
+
out = self.batch_to_head_dim(out)
|
90 |
+
else: # decom: CA+SA
|
91 |
+
is_cross = encoder_hidden_states is not None
|
92 |
+
assert is_cross is not None
|
93 |
+
encoder_hidden_states_list = encoder_hidden_states if is_cross else x
|
94 |
+
q = self.to_q(x)
|
95 |
+
q = self.head_to_batch_dim(q) # [Bh, 4096, 320/h ] h: 8
|
96 |
+
if is_cross: #CA
|
97 |
+
k_list = []
|
98 |
+
v_list = []
|
99 |
+
assert type(encoder_hidden_states_list) is list
|
100 |
+
for encoder_hidden_states in encoder_hidden_states_list:
|
101 |
+
k = self.to_k(encoder_hidden_states)
|
102 |
+
k = self.head_to_batch_dim(k) # [Bh, 77, 320/h ]
|
103 |
+
k_list.append(k)
|
104 |
+
v = self.to_v(encoder_hidden_states)
|
105 |
+
v = self.head_to_batch_dim(v) # [Bh, 77, 320/h ]
|
106 |
+
v_list.append(v)
|
107 |
+
out = controller.ca_forward_decom(q, k_list, v_list, self.scale, place_in_unet) # [Bh,N,d]
|
108 |
+
out = self.batch_to_head_dim(out)
|
109 |
+
else: # SA
|
110 |
+
exit("decomposing SA!")
|
111 |
+
k = self.to_k(x)
|
112 |
+
v = self.to_v(x)
|
113 |
+
k = self.head_to_batch_dim(k) # [Bh, 77, 320/h ]
|
114 |
+
v = self.head_to_batch_dim(v) # [Bh, 77, 320/h ]
|
115 |
+
import pdb; pdb.set_trace()
|
116 |
+
if k.shape[1] <= 1024 ** 2:
|
117 |
+
out = controller.sa_forward(q, k, v, self.scale, place_in_unet) # [Bh,N,d]
|
118 |
+
else:
|
119 |
+
print("warining")
|
120 |
+
out = controller.sa_forward_decom(q, k, v, self.scale, place_in_unet) # [Bh,N,d]
|
121 |
+
# sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
122 |
+
# attn = sim.softmax(dim=-1) # [8,4096,4096] [Bh,N,N]
|
123 |
+
# out = torch.einsum("b i j, b j d -> b i d", attn, v) # [Bh,N,d] [8,4096,320/h]
|
124 |
+
|
125 |
+
out = self.batch_to_head_dim(out) # [B, H, N, D]
|
126 |
+
|
127 |
+
return to_out(out)
|
128 |
+
|
129 |
+
return forward
|
130 |
+
|
131 |
+
if controller is None:
|
132 |
+
controller = DummyController()
|
133 |
+
|
134 |
+
def register_recr(net_, count, place_in_unet):
|
135 |
+
if net_.__class__.__name__ == 'Attention' and net_.to_k.in_features == unet.ca_dim:
|
136 |
+
net_.forward = ca_forward(net_, place_in_unet)
|
137 |
+
return count + 1
|
138 |
+
elif hasattr(net_, 'children'):
|
139 |
+
for net__ in net_.children():
|
140 |
+
count = register_recr(net__, count, place_in_unet)
|
141 |
+
return count
|
142 |
+
|
143 |
+
cross_att_count = 0
|
144 |
+
sub_nets = unet.named_children()
|
145 |
+
|
146 |
+
for net in sub_nets:
|
147 |
+
if "down" in net[0]:
|
148 |
+
down_count = register_recr(net[1], 0, "down")#6
|
149 |
+
cross_att_count += down_count
|
150 |
+
elif "up" in net[0]:
|
151 |
+
up_count = register_recr(net[1], 0, "up") #9
|
152 |
+
cross_att_count += up_count
|
153 |
+
elif "mid" in net[0]:
|
154 |
+
mid_count = register_recr(net[1], 0, "mid") #1
|
155 |
+
cross_att_count += mid_count
|
156 |
+
controller.num_att_layers = cross_att_count
|
example2/img.png
ADDED
![]() |
main.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
from peft import LoraConfig
|
6 |
+
from pipeline_dedit_sdxl import DEditSDXLPipeline
|
7 |
+
from pipeline_dedit_sd import DEditSDPipeline
|
8 |
+
from utils import load_image, load_mask, load_mask_edit
|
9 |
+
from utils_mask import process_mask_move_torch, process_mask_remove_torch, mask_union_torch, mask_substract_torch, create_outer_edge_mask_torch
|
10 |
+
from utils_mask import check_mask_overlap_torch, check_cover_all_torch, visualize_mask_list, get_mask_difference_torch, save_mask_list_to_npys
|
11 |
+
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument("--name", type=str,required=True, default=None)
|
14 |
+
parser.add_argument("--name_2", type=str,required=False, default=None)
|
15 |
+
parser.add_argument("--dpm", type=str,required=True, default="sd")
|
16 |
+
parser.add_argument("--resolution", type=int, default=1024)
|
17 |
+
parser.add_argument("--seed", type=int, default=42)
|
18 |
+
parser.add_argument("--embedding_learning_rate", type=float, default=1e-4)
|
19 |
+
parser.add_argument("--max_emb_train_steps", type=int, default=200)
|
20 |
+
parser.add_argument("--diffusion_model_learning_rate", type=float, default=5e-5)
|
21 |
+
parser.add_argument("--max_diffusion_train_steps", type=int, default=200)
|
22 |
+
parser.add_argument("--train_batch_size", type=int, default=1)
|
23 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
24 |
+
parser.add_argument("--num_tokens", type=int, default=1)
|
25 |
+
|
26 |
+
|
27 |
+
parser.add_argument("--load_trained", default=False, action="store_true" )
|
28 |
+
parser.add_argument("--num_sampling_steps", type=int, default=50)
|
29 |
+
parser.add_argument("--guidance_scale", type=float, default = 3 )
|
30 |
+
parser.add_argument("--strength", type=float, default=0.8)
|
31 |
+
|
32 |
+
parser.add_argument("--train_full_lora", default=False, action="store_true" )
|
33 |
+
parser.add_argument("--lora_rank", type=int, default=4)
|
34 |
+
parser.add_argument("--lora_alpha", type=int, default=4)
|
35 |
+
|
36 |
+
parser.add_argument("--prompt_auxin_list", nargs="+", type=str, default = None)
|
37 |
+
parser.add_argument("--prompt_auxin_idx_list", nargs="+", type=int, default = None)
|
38 |
+
|
39 |
+
# general editing configs
|
40 |
+
parser.add_argument("--load_edited_mask", default=False, action="store_true")
|
41 |
+
parser.add_argument("--load_edited_processed_mask", default=False, action="store_true")
|
42 |
+
parser.add_argument("--edge_thickness", type=int, default=20)
|
43 |
+
parser.add_argument("--num_imgs", type=int, default = 1 )
|
44 |
+
parser.add_argument('--active_mask_list', nargs="+", type=int)
|
45 |
+
parser.add_argument("--tgt_index", type=int, default=None)
|
46 |
+
|
47 |
+
# recon
|
48 |
+
parser.add_argument("--recon", default=False, action="store_true" )
|
49 |
+
parser.add_argument("--recon_an_item", default=False, action="store_true" )
|
50 |
+
parser.add_argument("--recon_prompt", type=str, default=None)
|
51 |
+
|
52 |
+
# text-based editing
|
53 |
+
parser.add_argument("--text", default=False, action="store_true")
|
54 |
+
parser.add_argument("--tgt_prompt", type=str, default=None)
|
55 |
+
|
56 |
+
# image-based editing
|
57 |
+
parser.add_argument("--image", default=False, action="store_true" )
|
58 |
+
parser.add_argument("--src_index", type=int, default=None)
|
59 |
+
parser.add_argument("--tgt_name", type=str, default=None)
|
60 |
+
|
61 |
+
# mask-based move
|
62 |
+
parser.add_argument("--move_resize", default=False, action="store_true" )
|
63 |
+
parser.add_argument('--tgt_indices_list', nargs="+", type=int)
|
64 |
+
parser.add_argument("--delta_x_list", nargs="+", type=int)
|
65 |
+
parser.add_argument("--delta_y_list", nargs="+", type=int)
|
66 |
+
parser.add_argument("--priority_list", nargs="+", type=int)
|
67 |
+
parser.add_argument("--force_mask_remain", type=int, default=None)
|
68 |
+
parser.add_argument("--resize_list", nargs="+", type=float)
|
69 |
+
|
70 |
+
# remove
|
71 |
+
parser.add_argument("--remove", default=False, action="store_true" )
|
72 |
+
parser.add_argument("--load_edited_removemask", default=False, action="store_true")
|
73 |
+
|
74 |
+
args = parser.parse_args()
|
75 |
+
|
76 |
+
torch.cuda.manual_seed_all(args.seed)
|
77 |
+
torch.manual_seed(args.seed)
|
78 |
+
base_input_folder = "."
|
79 |
+
base_output_folder = "."
|
80 |
+
|
81 |
+
input_folder = os.path.join(base_input_folder, args.name)
|
82 |
+
|
83 |
+
|
84 |
+
mask_list, mask_label_list = load_mask(input_folder)
|
85 |
+
assert mask_list[0].shape[0] == args.resolution, "Segmentation should be done on size {}".format(args.resolution)
|
86 |
+
try:
|
87 |
+
image_gt = load_image(os.path.join(input_folder, "img_{}.png".format(args.resolution) ), size = args.resolution)
|
88 |
+
except:
|
89 |
+
image_gt = load_image(os.path.join(input_folder, "img_{}.jpg".format(args.resolution) ), size = args.resolution)
|
90 |
+
|
91 |
+
if args.image:
|
92 |
+
input_folder_2 = os.path.join(base_input_folder, args.name_2)
|
93 |
+
mask_list_2, mask_label_list_2 = load_mask(input_folder_2)
|
94 |
+
assert mask_list_2[0].shape[0] == args.resolution, "Segmentation should be done on size {}".format(args.resolution)
|
95 |
+
try:
|
96 |
+
image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.png".format(args.resolution) ), size = args.resolution)
|
97 |
+
except:
|
98 |
+
image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.jpg".format(args.resolution) ), size = args.resolution)
|
99 |
+
output_dir = os.path.join(base_output_folder, args.name + "_" + args.name_2)
|
100 |
+
os.makedirs(output_dir, exist_ok = True)
|
101 |
+
else:
|
102 |
+
output_dir = os.path.join(base_output_folder, args.name)
|
103 |
+
os.makedirs(output_dir, exist_ok = True)
|
104 |
+
|
105 |
+
if args.dpm == "sd":
|
106 |
+
if args.image:
|
107 |
+
pipe = DEditSDPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = args.resolution, num_tokens = args.num_tokens)
|
108 |
+
else:
|
109 |
+
pipe = DEditSDPipeline(mask_list, mask_label_list, resolution = args.resolution, num_tokens = args.num_tokens)
|
110 |
+
|
111 |
+
elif args.dpm == "sdxl":
|
112 |
+
if args.image:
|
113 |
+
pipe = DEditSDXLPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = args.resolution, num_tokens = args.num_tokens)
|
114 |
+
else:
|
115 |
+
pipe = DEditSDXLPipeline(mask_list, mask_label_list, resolution = args.resolution, num_tokens = args.num_tokens)
|
116 |
+
|
117 |
+
else:
|
118 |
+
raise NotImplementedError
|
119 |
+
|
120 |
+
set_string_list = pipe.set_string_list
|
121 |
+
if args.prompt_auxin_list is not None:
|
122 |
+
for auxin_idx, auxin_prompt in zip(args.prompt_auxin_idx_list, args.prompt_auxin_list):
|
123 |
+
set_string_list[auxin_idx] = auxin_prompt.replace("*", set_string_list[auxin_idx] )
|
124 |
+
print(set_string_list)
|
125 |
+
|
126 |
+
if args.image:
|
127 |
+
set_string_list_2 = pipe.set_string_list_2
|
128 |
+
print(set_string_list_2)
|
129 |
+
|
130 |
+
if args.load_trained:
|
131 |
+
unet_save_path = os.path.join(output_dir, "unet.pt")
|
132 |
+
unet_state_dict = torch.load(unet_save_path)
|
133 |
+
text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
|
134 |
+
text_encoder1_state_dict = torch.load(text_encoder1_save_path)
|
135 |
+
if args.dpm == "sdxl":
|
136 |
+
text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
|
137 |
+
text_encoder2_state_dict = torch.load(text_encoder2_save_path)
|
138 |
+
|
139 |
+
if 'lora' in ''.join(unet_state_dict.keys()):
|
140 |
+
unet_lora_config = LoraConfig(
|
141 |
+
r=args.lora_rank,
|
142 |
+
lora_alpha=args.lora_alpha,
|
143 |
+
init_lora_weights="gaussian",
|
144 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
145 |
+
)
|
146 |
+
pipe.unet.add_adapter(unet_lora_config)
|
147 |
+
|
148 |
+
pipe.unet.load_state_dict(unet_state_dict)
|
149 |
+
pipe.text_encoder.load_state_dict(text_encoder1_state_dict)
|
150 |
+
if args.dpm == "sdxl":
|
151 |
+
pipe.text_encoder_2.load_state_dict(text_encoder2_state_dict)
|
152 |
+
else:
|
153 |
+
if args.image:
|
154 |
+
pipe.mask_list = [m.cuda() for m in pipe.mask_list]
|
155 |
+
pipe.mask_list_2 = [m.cuda() for m in pipe.mask_list_2]
|
156 |
+
pipe.train_emb_2imgs(
|
157 |
+
image_gt,
|
158 |
+
image_gt_2,
|
159 |
+
set_string_list,
|
160 |
+
set_string_list_2,
|
161 |
+
gradient_accumulation_steps = args.gradient_accumulation_steps,
|
162 |
+
embedding_learning_rate = args.embedding_learning_rate,
|
163 |
+
max_emb_train_steps = args.max_emb_train_steps,
|
164 |
+
train_batch_size = args.train_batch_size,
|
165 |
+
)
|
166 |
+
|
167 |
+
pipe.train_model_2imgs(
|
168 |
+
image_gt,
|
169 |
+
image_gt_2,
|
170 |
+
set_string_list,
|
171 |
+
set_string_list_2,
|
172 |
+
gradient_accumulation_steps = args.gradient_accumulation_steps,
|
173 |
+
max_diffusion_train_steps = args.max_diffusion_train_steps,
|
174 |
+
diffusion_model_learning_rate = args.diffusion_model_learning_rate ,
|
175 |
+
train_batch_size =args.train_batch_size,
|
176 |
+
train_full_lora = args.train_full_lora,
|
177 |
+
lora_rank = args.lora_rank,
|
178 |
+
lora_alpha = args.lora_alpha
|
179 |
+
)
|
180 |
+
|
181 |
+
else:
|
182 |
+
pipe.mask_list = [m.cuda() for m in pipe.mask_list]
|
183 |
+
pipe.train_emb(
|
184 |
+
image_gt,
|
185 |
+
set_string_list,
|
186 |
+
gradient_accumulation_steps = args.gradient_accumulation_steps,
|
187 |
+
embedding_learning_rate = args.embedding_learning_rate,
|
188 |
+
max_emb_train_steps = args.max_emb_train_steps,
|
189 |
+
train_batch_size = args.train_batch_size,
|
190 |
+
)
|
191 |
+
|
192 |
+
pipe.train_model(
|
193 |
+
image_gt,
|
194 |
+
set_string_list,
|
195 |
+
gradient_accumulation_steps = args.gradient_accumulation_steps,
|
196 |
+
max_diffusion_train_steps = args.max_diffusion_train_steps,
|
197 |
+
diffusion_model_learning_rate = args.diffusion_model_learning_rate ,
|
198 |
+
train_batch_size = args.train_batch_size,
|
199 |
+
train_full_lora = args.train_full_lora,
|
200 |
+
lora_rank = args.lora_rank,
|
201 |
+
lora_alpha = args.lora_alpha
|
202 |
+
)
|
203 |
+
|
204 |
+
|
205 |
+
unet_save_path = os.path.join(output_dir, "unet.pt")
|
206 |
+
torch.save(pipe.unet.state_dict(),unet_save_path )
|
207 |
+
text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
|
208 |
+
torch.save(pipe.text_encoder.state_dict(), text_encoder1_save_path)
|
209 |
+
if args.dpm == "sdxl":
|
210 |
+
text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
|
211 |
+
torch.save(pipe.text_encoder_2.state_dict(), text_encoder2_save_path )
|
212 |
+
|
213 |
+
|
214 |
+
if args.recon:
|
215 |
+
output_dir = os.path.join(output_dir, "recon")
|
216 |
+
os.makedirs(output_dir, exist_ok = True)
|
217 |
+
if args.recon_an_item:
|
218 |
+
mask_list = [torch.from_numpy(np.ones_like(mask_list[0].numpy()))]
|
219 |
+
tgt_string = set_string_list[args.tgt_index]
|
220 |
+
tgt_string = args.recon_prompt.replace("*", tgt_string)
|
221 |
+
set_string_list = [tgt_string]
|
222 |
+
print(set_string_list)
|
223 |
+
save_path = os.path.join(output_dir, "out_recon.png")
|
224 |
+
x_np = pipe.inference_with_mask(
|
225 |
+
save_path,
|
226 |
+
guidance_scale = args.guidance_scale,
|
227 |
+
num_sampling_steps = args.num_sampling_steps,
|
228 |
+
seed = args.seed,
|
229 |
+
num_imgs = args.num_imgs,
|
230 |
+
set_string_list = set_string_list,
|
231 |
+
mask_list = mask_list
|
232 |
+
)
|
233 |
+
|
234 |
+
if args.text:
|
235 |
+
print("Text-guided editing ")
|
236 |
+
output_dir = os.path.join(output_dir, "text")
|
237 |
+
os.makedirs(output_dir, exist_ok = True)
|
238 |
+
save_path = os.path.join(output_dir, "out_text.png")
|
239 |
+
set_string_list[args.tgt_index] = args.tgt_prompt
|
240 |
+
mask_active = torch.zeros_like(mask_list[0])
|
241 |
+
mask_active = mask_union_torch(mask_active, mask_list[args.tgt_index])
|
242 |
+
|
243 |
+
if args.active_mask_list is not None:
|
244 |
+
for midx in args.active_mask_list:
|
245 |
+
mask_active = mask_union_torch(mask_active, mask_list[midx])
|
246 |
+
|
247 |
+
if args.load_edited_mask:
|
248 |
+
mask_list_edited, mask_label_list_edited = load_mask_edit(input_folder)
|
249 |
+
mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
|
250 |
+
mask_active = mask_union_torch(mask_active, mask_diff)
|
251 |
+
mask_list = mask_list_edited
|
252 |
+
save_path = os.path.join(output_dir, "out_textEdited.png")
|
253 |
+
|
254 |
+
mask_hard = mask_substract_torch(torch.ones_like(mask_list[0]), mask_active)
|
255 |
+
mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = args.edge_thickness)
|
256 |
+
mask_hard = mask_substract_torch(mask_hard, mask_soft)
|
257 |
+
|
258 |
+
pipe.inference_with_mask(
|
259 |
+
save_path,
|
260 |
+
orig_image = image_gt,
|
261 |
+
set_string_list = set_string_list,
|
262 |
+
guidance_scale = args.guidance_scale,
|
263 |
+
strength = args.strength,
|
264 |
+
num_imgs = args.num_imgs,
|
265 |
+
mask_hard= mask_hard,
|
266 |
+
mask_soft = mask_soft,
|
267 |
+
mask_list = mask_list,
|
268 |
+
seed = args.seed,
|
269 |
+
num_sampling_steps = args.num_sampling_steps
|
270 |
+
)
|
271 |
+
|
272 |
+
if args.remove:
|
273 |
+
output_dir = os.path.join(output_dir, "remove")
|
274 |
+
save_path = os.path.join(output_dir, "out_remove.png")
|
275 |
+
os.makedirs(output_dir, exist_ok = True)
|
276 |
+
mask_active = torch.zeros_like(mask_list[0])
|
277 |
+
|
278 |
+
if args.load_edited_mask:
|
279 |
+
mask_list_edited, _ = load_mask_edit(input_folder)
|
280 |
+
mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
|
281 |
+
mask_active = mask_union_torch(mask_active, mask_diff)
|
282 |
+
mask_list = mask_list_edited
|
283 |
+
|
284 |
+
if args.load_edited_processed_mask:
|
285 |
+
# manually edit or draw masks after removing one index, then load
|
286 |
+
mask_list_processed, _ = load_mask_edit(output_dir)
|
287 |
+
mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
|
288 |
+
else:
|
289 |
+
# generate masks after removing one index, using nearest neighbor algorithm
|
290 |
+
mask_list_processed, mask_remain = process_mask_remove_torch(mask_list, args.tgt_index)
|
291 |
+
save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
|
292 |
+
visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_removed.png"))
|
293 |
+
check_cover_all_torch(*mask_list_processed)
|
294 |
+
mask_active = mask_union_torch(mask_active, mask_remain)
|
295 |
+
|
296 |
+
if args.active_mask_list is not None:
|
297 |
+
for midx in args.active_mask_list:
|
298 |
+
mask_active = mask_union_torch(mask_active, mask_list[midx])
|
299 |
+
|
300 |
+
mask_hard = 1 - mask_active
|
301 |
+
mask_soft = create_outer_edge_mask_torch(mask_remain, edge_thickness = args.edge_thickness)
|
302 |
+
mask_hard = mask_substract_torch(mask_hard, mask_soft)
|
303 |
+
|
304 |
+
pipe.inference_with_mask(
|
305 |
+
save_path,
|
306 |
+
orig_image = image_gt,
|
307 |
+
guidance_scale = args.guidance_scale,
|
308 |
+
strength = args.strength,
|
309 |
+
num_imgs = args.num_imgs,
|
310 |
+
mask_hard= mask_hard,
|
311 |
+
mask_soft = mask_soft,
|
312 |
+
mask_list = mask_list_processed,
|
313 |
+
seed = args.seed,
|
314 |
+
num_sampling_steps = args.num_sampling_steps
|
315 |
+
)
|
316 |
+
|
317 |
+
if args.image:
|
318 |
+
output_dir = os.path.join(output_dir, "image")
|
319 |
+
save_path = os.path.join(output_dir, "out_image.png")
|
320 |
+
os.makedirs(output_dir, exist_ok = True)
|
321 |
+
mask_active = torch.zeros_like(mask_list[0])
|
322 |
+
|
323 |
+
if None not in (args.tgt_name, args.src_index, args.tgt_index):
|
324 |
+
if args.tgt_name == args.name:
|
325 |
+
set_string_list_tgt = set_string_list
|
326 |
+
set_string_list_src = set_string_list_2
|
327 |
+
image_tgt = image_gt
|
328 |
+
if args.load_edited_mask:
|
329 |
+
mask_list_edited, _ = load_mask_edit(input_folder)
|
330 |
+
mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
|
331 |
+
mask_active = mask_union_torch(mask_active, mask_diff)
|
332 |
+
mask_list = mask_list_edited
|
333 |
+
save_path = os.path.join(output_dir, "out_imageEdited.png")
|
334 |
+
mask_list_tgt = mask_list
|
335 |
+
|
336 |
+
elif args.tgt_name == args.name_2:
|
337 |
+
set_string_list_tgt = set_string_list_2
|
338 |
+
set_string_list_src = set_string_list
|
339 |
+
image_tgt = image_gt_2
|
340 |
+
if args.load_edited_mask:
|
341 |
+
mask_list_2_edited, _ = load_mask_edit(input_folder_2)
|
342 |
+
mask_diff = get_mask_difference_torch(mask_list_2_edited, mask_list_2)
|
343 |
+
mask_active = mask_union_torch(mask_active, mask_diff)
|
344 |
+
mask_list_2 = mask_list_2_edited
|
345 |
+
save_path = os.path.join(output_dir, "out_imageEdited.png")
|
346 |
+
mask_list_tgt = mask_list_2
|
347 |
+
else:
|
348 |
+
exit("tgt_name should be either name or name_2")
|
349 |
+
|
350 |
+
set_string_list_tgt[args.tgt_index] = set_string_list_src[args.src_index]
|
351 |
+
|
352 |
+
mask_active = mask_list_tgt[args.tgt_index]
|
353 |
+
mask_frozen = (1-mask_active.float()).to(mask_active.device)
|
354 |
+
mask_soft = create_outer_edge_mask_torch(mask_active.cpu(), edge_thickness = args.edge_thickness)
|
355 |
+
mask_hard = mask_substract_torch(mask_frozen.cpu(), mask_soft.cpu())
|
356 |
+
|
357 |
+
mask_list_tgt = [m.cuda() for m in mask_list_tgt]
|
358 |
+
|
359 |
+
pipe.inference_with_mask(
|
360 |
+
save_path,
|
361 |
+
set_string_list = set_string_list_tgt,
|
362 |
+
mask_list = mask_list_tgt,
|
363 |
+
guidance_scale = args.guidance_scale,
|
364 |
+
num_sampling_steps = args.num_sampling_steps,
|
365 |
+
mask_hard = mask_hard.cuda(),
|
366 |
+
mask_soft = mask_soft.cuda(),
|
367 |
+
num_imgs = args.num_imgs,
|
368 |
+
orig_image = image_tgt,
|
369 |
+
strength = args.strength,
|
370 |
+
)
|
371 |
+
|
372 |
+
if args.move_resize:
|
373 |
+
output_dir = os.path.join(output_dir, "move_resize")
|
374 |
+
os.makedirs(output_dir, exist_ok = True)
|
375 |
+
save_path = os.path.join(output_dir, "out_moveresize.png")
|
376 |
+
mask_active = torch.zeros_like(mask_list[0])
|
377 |
+
|
378 |
+
if args.load_edited_mask:
|
379 |
+
mask_list_edited, _ = load_mask_edit(input_folder)
|
380 |
+
mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
|
381 |
+
mask_active = mask_union_torch(mask_active, mask_diff)
|
382 |
+
mask_list = mask_list_edited
|
383 |
+
# save_path = os.path.join(output_dir, "out_moveresizeEdited.png")
|
384 |
+
|
385 |
+
if args.load_edited_processed_mask:
|
386 |
+
mask_list_processed, _ = load_mask_edit(output_dir)
|
387 |
+
mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
|
388 |
+
else:
|
389 |
+
mask_list_processed, mask_remain = process_mask_move_torch(
|
390 |
+
mask_list,
|
391 |
+
args.tgt_indices_list,
|
392 |
+
args.delta_x_list,
|
393 |
+
args.delta_y_list, args.priority_list,
|
394 |
+
force_mask_remain = args.force_mask_remain,
|
395 |
+
resize_list = args.resize_list
|
396 |
+
)
|
397 |
+
save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
|
398 |
+
visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_move_resize.png"))
|
399 |
+
active_idxs = args.tgt_indices_list
|
400 |
+
|
401 |
+
mask_active = mask_union_torch(mask_active, *[m for midx, m in enumerate(mask_list_processed) if midx in active_idxs])
|
402 |
+
mask_active = mask_union_torch(mask_remain, mask_active)
|
403 |
+
if args.active_mask_list is not None:
|
404 |
+
for midx in args.active_mask_list:
|
405 |
+
mask_active = mask_union_torch(mask_active, mask_list_processed[midx])
|
406 |
+
|
407 |
+
mask_frozen =(1 - mask_active.float())
|
408 |
+
mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = args.edge_thickness)
|
409 |
+
mask_hard = mask_substract_torch(mask_frozen, mask_soft)
|
410 |
+
|
411 |
+
check_mask_overlap_torch(mask_hard, mask_soft)
|
412 |
+
|
413 |
+
pipe.inference_with_mask(
|
414 |
+
save_path,
|
415 |
+
strength = args.strength,
|
416 |
+
orig_image = image_gt,
|
417 |
+
guidance_scale = args.guidance_scale,
|
418 |
+
num_sampling_steps = args.num_sampling_steps,
|
419 |
+
num_imgs = args.num_imgs,
|
420 |
+
mask_hard= mask_hard,
|
421 |
+
mask_soft = mask_soft,
|
422 |
+
mask_list = mask_list_processed,
|
423 |
+
seed = args.seed
|
424 |
+
)
|
pipeline_dedit_sd.py
ADDED
@@ -0,0 +1,813 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from utils import import_model_class_from_model_name_or_path
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
from diffusers import (
|
5 |
+
AutoencoderKL,
|
6 |
+
DDPMScheduler,
|
7 |
+
DDIMScheduler,
|
8 |
+
UNet2DConditionModel,
|
9 |
+
)
|
10 |
+
from accelerate import Accelerator
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
from utils import sd_prepare_input_decom, save_images
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import itertools
|
15 |
+
from peft import LoraConfig
|
16 |
+
from controller import GroupedCAController, register_attention_disentangled_control, DummyController
|
17 |
+
from utils import image2latent, latent2image
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
from utils_mask import check_mask_overlap_torch
|
20 |
+
|
21 |
+
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
22 |
+
|
23 |
+
class DEditSDPipeline:
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
mask_list,
|
27 |
+
mask_label_list,
|
28 |
+
mask_list_2 = None,
|
29 |
+
mask_label_list_2 = None,
|
30 |
+
resolution = 1024,
|
31 |
+
num_tokens = 1
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
model_id = "./stable-diffusion-v1-5"
|
35 |
+
self.model_id = model_id
|
36 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", use_fast=False)
|
37 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(model_id, subfolder = "text_encoder")
|
38 |
+
self.text_encoder = text_encoder_cls_one.from_pretrained(model_id, subfolder="text_encoder" ).to(device)
|
39 |
+
|
40 |
+
self.unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
|
41 |
+
self.unet.ca_dim = 768
|
42 |
+
self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
|
43 |
+
self.scheduler = DDPMScheduler.from_pretrained(model_id , subfolder="scheduler")
|
44 |
+
self.scheduler = DDIMScheduler(
|
45 |
+
beta_start=0.00085,
|
46 |
+
beta_end=0.012,
|
47 |
+
beta_schedule="scaled_linear",
|
48 |
+
clip_sample=False,
|
49 |
+
set_alpha_to_one=True,
|
50 |
+
rescale_betas_zero_snr = False,
|
51 |
+
)
|
52 |
+
self.mixed_precision = "fp16"
|
53 |
+
self.resolution = resolution
|
54 |
+
self.num_tokens = num_tokens
|
55 |
+
|
56 |
+
self.mask_list = mask_list
|
57 |
+
self.mask_label_list = mask_label_list
|
58 |
+
notation_token_list = [phrase.split(" ")[-1] for phrase in mask_label_list]
|
59 |
+
placeholder_token_list = ["#"+word+"{}".format(widx) for widx, word in enumerate(notation_token_list)]
|
60 |
+
self.set_string_list, placeholder_token_ids = self.add_tokens(placeholder_token_list)
|
61 |
+
self.min_added_id = min(placeholder_token_ids)
|
62 |
+
self.max_added_id = max(placeholder_token_ids)
|
63 |
+
|
64 |
+
if mask_list_2 is not None:
|
65 |
+
self.mask_list_2 = mask_list_2
|
66 |
+
self.mask_label_list_2 = mask_label_list_2
|
67 |
+
notation_token_list_2 = [phrase.split(" ")[-1] for phrase in mask_label_list_2]
|
68 |
+
|
69 |
+
placeholder_token_list_2 = ["$"+word+"{}".format(widx) for widx, word in enumerate(notation_token_list_2)]
|
70 |
+
self.set_string_list_2, placeholder_token_ids_2 = self.add_tokens(placeholder_token_list_2)
|
71 |
+
self.max_added_id = max(placeholder_token_ids_2)
|
72 |
+
|
73 |
+
def add_tokens_text_encoder_random_init(self, placeholder_token, num_tokens=1):
|
74 |
+
# Add the placeholder token in tokenizer
|
75 |
+
placeholder_tokens = [placeholder_token]
|
76 |
+
# add dummy tokens for multi-vector
|
77 |
+
additional_tokens = []
|
78 |
+
for i in range(1, num_tokens):
|
79 |
+
additional_tokens.append(f"{placeholder_token}_{i}")
|
80 |
+
placeholder_tokens += additional_tokens
|
81 |
+
num_added_tokens = self.tokenizer.add_tokens(placeholder_tokens) # 49408
|
82 |
+
|
83 |
+
if num_added_tokens != num_tokens:
|
84 |
+
raise ValueError(
|
85 |
+
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
|
86 |
+
" `placeholder_token` that is not already in the tokenizer."
|
87 |
+
)
|
88 |
+
placeholder_token_ids = self.tokenizer.convert_tokens_to_ids(placeholder_tokens)
|
89 |
+
|
90 |
+
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
|
91 |
+
token_embeds = self.text_encoder.get_input_embeddings().weight.data
|
92 |
+
std, mean = torch.std_mean(token_embeds)
|
93 |
+
with torch.no_grad():
|
94 |
+
for token_id in placeholder_token_ids:
|
95 |
+
token_embeds[token_id] = torch.randn_like(token_embeds[token_id])*std + mean
|
96 |
+
|
97 |
+
set_string = " ".join(self.tokenizer.convert_ids_to_tokens(placeholder_token_ids))
|
98 |
+
|
99 |
+
return set_string, placeholder_token_ids
|
100 |
+
|
101 |
+
def add_tokens(self, placeholder_token_list):
|
102 |
+
set_string_list = []
|
103 |
+
placeholder_token_ids_list = []
|
104 |
+
for str_idx in range(len(placeholder_token_list)):
|
105 |
+
placeholder_token = placeholder_token_list[str_idx]
|
106 |
+
set_string, placeholder_token_ids = self.add_tokens_text_encoder_random_init(placeholder_token, num_tokens=self.num_tokens)
|
107 |
+
set_string_list.append(set_string)
|
108 |
+
placeholder_token_ids_list.append(placeholder_token_ids)
|
109 |
+
placeholder_token_ids = list(itertools.chain(*placeholder_token_ids_list))
|
110 |
+
return set_string_list, placeholder_token_ids
|
111 |
+
|
112 |
+
def train_emb(
|
113 |
+
self,
|
114 |
+
image_gt,
|
115 |
+
set_string_list,
|
116 |
+
gradient_accumulation_steps = 5,
|
117 |
+
embedding_learning_rate = 1e-4,
|
118 |
+
max_emb_train_steps = 100,
|
119 |
+
train_batch_size = 1,
|
120 |
+
):
|
121 |
+
decom_controller = GroupedCAController(mask_list = self.mask_list)
|
122 |
+
register_attention_disentangled_control(self.unet, decom_controller)
|
123 |
+
|
124 |
+
accelerator = Accelerator(mixed_precision=self.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps)
|
125 |
+
self.vae.requires_grad_(False)
|
126 |
+
self.unet.requires_grad_(False)
|
127 |
+
|
128 |
+
self.text_encoder.requires_grad_(True)
|
129 |
+
|
130 |
+
self.text_encoder.text_model.encoder.requires_grad_(False)
|
131 |
+
self.text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
132 |
+
self.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
133 |
+
|
134 |
+
weight_dtype = torch.float32
|
135 |
+
if accelerator.mixed_precision == "fp16":
|
136 |
+
weight_dtype = torch.float16
|
137 |
+
elif accelerator.mixed_precision == "bf16":
|
138 |
+
weight_dtype = torch.bfloat16
|
139 |
+
|
140 |
+
self.unet.to(device, dtype=weight_dtype)
|
141 |
+
self.vae.to(device, dtype=weight_dtype)
|
142 |
+
|
143 |
+
trainable_embmat_list_1 = [param for param in self.text_encoder.get_input_embeddings().parameters()]
|
144 |
+
optimizer = torch.optim.AdamW(trainable_embmat_list_1, lr=embedding_learning_rate)
|
145 |
+
|
146 |
+
self.text_encoder, optimizer = accelerator.prepare(self.text_encoder, optimizer)
|
147 |
+
|
148 |
+
orig_embeds_params_1 = accelerator.unwrap_model(self.text_encoder).get_input_embeddings().weight.data.clone()
|
149 |
+
|
150 |
+
self.text_encoder.train()
|
151 |
+
|
152 |
+
effective_emb_train_steps = max_emb_train_steps//gradient_accumulation_steps
|
153 |
+
|
154 |
+
if accelerator.is_main_process:
|
155 |
+
accelerator.init_trackers("DEdit EmbSteps", config={
|
156 |
+
"embedding_learning_rate": embedding_learning_rate,
|
157 |
+
"text_embedding_optimization_steps": effective_emb_train_steps,
|
158 |
+
})
|
159 |
+
global_step = 0
|
160 |
+
noise_scheduler = self.scheduler
|
161 |
+
progress_bar = tqdm(range(0, effective_emb_train_steps), initial = global_step, desc="EmbSteps")
|
162 |
+
latents0 = image2latent(image_gt, vae = self.vae, dtype = weight_dtype)
|
163 |
+
latents0 = latents0.repeat(train_batch_size, 1, 1, 1)
|
164 |
+
|
165 |
+
for _ in range(max_emb_train_steps):
|
166 |
+
with accelerator.accumulate(self.text_encoder):
|
167 |
+
latents = latents0.clone().detach()
|
168 |
+
noise = torch.randn_like(latents)
|
169 |
+
bsz = latents.shape[0]
|
170 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
171 |
+
timesteps = timesteps.long()
|
172 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
173 |
+
encoder_hidden_states_list = sd_prepare_input_decom(
|
174 |
+
set_string_list,
|
175 |
+
self.tokenizer,
|
176 |
+
self.text_encoder,
|
177 |
+
length = 40,
|
178 |
+
bsz = train_batch_size,
|
179 |
+
weight_dtype = weight_dtype
|
180 |
+
)
|
181 |
+
|
182 |
+
model_pred = self.unet(
|
183 |
+
noisy_latents,
|
184 |
+
timesteps,
|
185 |
+
encoder_hidden_states = encoder_hidden_states_list,
|
186 |
+
).sample
|
187 |
+
|
188 |
+
loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
|
189 |
+
accelerator.backward(loss)
|
190 |
+
optimizer.step()
|
191 |
+
optimizer.zero_grad()
|
192 |
+
|
193 |
+
index_no_updates = torch.ones((len(self.tokenizer),), dtype=torch.bool)
|
194 |
+
index_no_updates[self.min_added_id : self.max_added_id + 1] = False
|
195 |
+
with torch.no_grad():
|
196 |
+
accelerator.unwrap_model(self.text_encoder).get_input_embeddings().weight[
|
197 |
+
index_no_updates] = orig_embeds_params_1[index_no_updates]
|
198 |
+
|
199 |
+
logs = {"loss": loss.detach().item(), "lr": embedding_learning_rate}
|
200 |
+
progress_bar.set_postfix(**logs)
|
201 |
+
accelerator.log(logs, step=global_step)
|
202 |
+
if accelerator.sync_gradients:
|
203 |
+
progress_bar.update(1)
|
204 |
+
global_step += 1
|
205 |
+
|
206 |
+
if global_step >= max_emb_train_steps:
|
207 |
+
break
|
208 |
+
|
209 |
+
accelerator.wait_for_everyone()
|
210 |
+
accelerator.end_training()
|
211 |
+
self.text_encoder = accelerator.unwrap_model(self.text_encoder).to(dtype = weight_dtype)
|
212 |
+
|
213 |
+
def train_model(
|
214 |
+
self,
|
215 |
+
image_gt,
|
216 |
+
set_string_list,
|
217 |
+
gradient_accumulation_steps = 5,
|
218 |
+
max_diffusion_train_steps = 100,
|
219 |
+
diffusion_model_learning_rate = 1e-5,
|
220 |
+
train_batch_size = 1,
|
221 |
+
train_full_lora = False,
|
222 |
+
lora_rank = 4,
|
223 |
+
lora_alpha = 4
|
224 |
+
):
|
225 |
+
self.unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet").to(device)
|
226 |
+
self.unet.ca_dim = 768
|
227 |
+
decom_controller = GroupedCAController(mask_list = self.mask_list)
|
228 |
+
register_attention_disentangled_control(self.unet, decom_controller)
|
229 |
+
|
230 |
+
mixed_precision = "fp16"
|
231 |
+
accelerator = Accelerator(gradient_accumulation_steps = gradient_accumulation_steps, mixed_precision = mixed_precision)
|
232 |
+
|
233 |
+
weight_dtype = torch.float32
|
234 |
+
if accelerator.mixed_precision == "fp16":
|
235 |
+
weight_dtype = torch.float16
|
236 |
+
elif accelerator.mixed_precision == "bf16":
|
237 |
+
weight_dtype = torch.bfloat16
|
238 |
+
|
239 |
+
self.vae.requires_grad_(False)
|
240 |
+
self.vae.to(device, dtype=weight_dtype)
|
241 |
+
|
242 |
+
self.unet.requires_grad_(False)
|
243 |
+
self.unet.train()
|
244 |
+
|
245 |
+
self.text_encoder.requires_grad_(False)
|
246 |
+
|
247 |
+
if not train_full_lora:
|
248 |
+
trainable_params_list = []
|
249 |
+
for _, module in self.unet.named_modules():
|
250 |
+
module_name = type(module).__name__
|
251 |
+
if module_name == "Attention":
|
252 |
+
if module.to_k.in_features == self.unet.ca_dim: # this is cross attention:
|
253 |
+
module.to_k.weight.requires_grad = True
|
254 |
+
trainable_params_list.append(module.to_k.weight)
|
255 |
+
if module.to_k.bias is not None:
|
256 |
+
module.to_k.bias.requires_grad = True
|
257 |
+
trainable_params_list.append(module.to_k.bias)
|
258 |
+
module.to_v.weight.requires_grad = True
|
259 |
+
trainable_params_list.append(module.to_v.weight)
|
260 |
+
if module.to_v.bias is not None:
|
261 |
+
module.to_v.bias.requires_grad = True
|
262 |
+
trainable_params_list.append(module.to_v.bias)
|
263 |
+
module.to_q.weight.requires_grad = True
|
264 |
+
trainable_params_list.append(module.to_q.weight)
|
265 |
+
if module.to_q.bias is not None:
|
266 |
+
module.to_q.bias.requires_grad = True
|
267 |
+
trainable_params_list.append(module.to_q.bias)
|
268 |
+
else:
|
269 |
+
unet_lora_config = LoraConfig(
|
270 |
+
r=lora_rank,
|
271 |
+
lora_alpha=lora_alpha,
|
272 |
+
init_lora_weights="gaussian",
|
273 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
274 |
+
)
|
275 |
+
self.unet.add_adapter(unet_lora_config)
|
276 |
+
print("training full parameters using lora!")
|
277 |
+
trainable_params_list = list(filter(lambda p: p.requires_grad, self.unet.parameters()))
|
278 |
+
|
279 |
+
self.text_encoder.to(device, dtype=weight_dtype)
|
280 |
+
|
281 |
+
optimizer = torch.optim.AdamW(trainable_params_list, lr=diffusion_model_learning_rate)
|
282 |
+
self.unet, optimizer = accelerator.prepare(self.unet, optimizer)
|
283 |
+
psum2 = sum(p.numel() for p in trainable_params_list)
|
284 |
+
|
285 |
+
effective_diffusion_train_steps = max_diffusion_train_steps // gradient_accumulation_steps
|
286 |
+
if accelerator.is_main_process:
|
287 |
+
accelerator.init_trackers("textual_inversion", config={
|
288 |
+
"diffusion_model_learning_rate": diffusion_model_learning_rate,
|
289 |
+
"diffusion_model_optimization_steps": effective_diffusion_train_steps,
|
290 |
+
})
|
291 |
+
|
292 |
+
global_step = 0
|
293 |
+
progress_bar = tqdm( range(0, effective_diffusion_train_steps),initial=global_step, desc="ModelSteps")
|
294 |
+
|
295 |
+
noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0" , subfolder="scheduler")
|
296 |
+
|
297 |
+
latents0 = image2latent(image_gt, vae = self.vae, dtype=weight_dtype)
|
298 |
+
latents0 = latents0.repeat(train_batch_size, 1, 1, 1)
|
299 |
+
|
300 |
+
with torch.no_grad():
|
301 |
+
encoder_hidden_states_list = sd_prepare_input_decom(
|
302 |
+
set_string_list,
|
303 |
+
self.tokenizer,
|
304 |
+
self.text_encoder,
|
305 |
+
length = 40,
|
306 |
+
bsz = train_batch_size,
|
307 |
+
weight_dtype = weight_dtype
|
308 |
+
)
|
309 |
+
|
310 |
+
for _ in range(max_diffusion_train_steps):
|
311 |
+
with accelerator.accumulate(self.unet):
|
312 |
+
latents = latents0.clone().detach()
|
313 |
+
noise = torch.randn_like(latents)
|
314 |
+
bsz = latents.shape[0]
|
315 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
316 |
+
timesteps = timesteps.long()
|
317 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
318 |
+
model_pred = self.unet(
|
319 |
+
noisy_latents,
|
320 |
+
timesteps,
|
321 |
+
encoder_hidden_states=encoder_hidden_states_list,
|
322 |
+
).sample
|
323 |
+
loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
|
324 |
+
accelerator.backward(loss)
|
325 |
+
optimizer.step()
|
326 |
+
optimizer.zero_grad()
|
327 |
+
|
328 |
+
logs = {"loss": loss.detach().item(), "lr": diffusion_model_learning_rate}
|
329 |
+
progress_bar.set_postfix(**logs)
|
330 |
+
accelerator.log(logs, step=global_step)
|
331 |
+
if accelerator.sync_gradients:
|
332 |
+
progress_bar.update(1)
|
333 |
+
global_step += 1
|
334 |
+
if global_step >=max_diffusion_train_steps:
|
335 |
+
break
|
336 |
+
accelerator.wait_for_everyone()
|
337 |
+
accelerator.end_training()
|
338 |
+
self.unet = accelerator.unwrap_model(self.unet).to(dtype = weight_dtype)
|
339 |
+
|
340 |
+
def train_emb_2imgs(
|
341 |
+
self,
|
342 |
+
image_gt_1,
|
343 |
+
image_gt_2,
|
344 |
+
set_string_list_1,
|
345 |
+
set_string_list_2,
|
346 |
+
gradient_accumulation_steps = 5,
|
347 |
+
embedding_learning_rate = 1e-4,
|
348 |
+
max_emb_train_steps = 100,
|
349 |
+
train_batch_size = 1,
|
350 |
+
):
|
351 |
+
decom_controller_1 = GroupedCAController(mask_list = self.mask_list)
|
352 |
+
decom_controller_2 = GroupedCAController(mask_list = self.mask_list_2)
|
353 |
+
accelerator = Accelerator(mixed_precision=self.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps)
|
354 |
+
self.vae.requires_grad_(False)
|
355 |
+
self.unet.requires_grad_(False)
|
356 |
+
|
357 |
+
self.text_encoder.requires_grad_(True)
|
358 |
+
|
359 |
+
self.text_encoder.text_model.encoder.requires_grad_(False)
|
360 |
+
self.text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
361 |
+
self.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
362 |
+
|
363 |
+
|
364 |
+
weight_dtype = torch.float32
|
365 |
+
if accelerator.mixed_precision == "fp16":
|
366 |
+
weight_dtype = torch.float16
|
367 |
+
elif accelerator.mixed_precision == "bf16":
|
368 |
+
weight_dtype = torch.bfloat16
|
369 |
+
|
370 |
+
self.unet.to(device, dtype=weight_dtype)
|
371 |
+
self.vae.to(device, dtype=weight_dtype)
|
372 |
+
|
373 |
+
|
374 |
+
trainable_embmat_list_1 = [param for param in self.text_encoder.get_input_embeddings().parameters()]
|
375 |
+
|
376 |
+
optimizer = torch.optim.AdamW(trainable_embmat_list_1, lr=embedding_learning_rate)
|
377 |
+
self.text_encoder, optimizer= accelerator.prepare(self.text_encoder, optimizer) ###
|
378 |
+
orig_embeds_params_1 = accelerator.unwrap_model(self.text_encoder) .get_input_embeddings().weight.data.clone()
|
379 |
+
|
380 |
+
self.text_encoder.train()
|
381 |
+
|
382 |
+
effective_emb_train_steps = max_emb_train_steps//gradient_accumulation_steps
|
383 |
+
|
384 |
+
if accelerator.is_main_process:
|
385 |
+
accelerator.init_trackers("EmbFt", config={
|
386 |
+
"embedding_learning_rate": embedding_learning_rate,
|
387 |
+
"text_embedding_optimization_steps": effective_emb_train_steps,
|
388 |
+
})
|
389 |
+
|
390 |
+
global_step = 0
|
391 |
+
|
392 |
+
noise_scheduler = DDPMScheduler.from_pretrained(self.model_id , subfolder="scheduler")
|
393 |
+
progress_bar = tqdm(range(0, effective_emb_train_steps),initial=global_step,desc="EmbSteps")
|
394 |
+
latents0_1 = image2latent(image_gt_1, vae = self.vae, dtype=weight_dtype)
|
395 |
+
latents0_1 = latents0_1.repeat(train_batch_size,1,1,1)
|
396 |
+
|
397 |
+
latents0_2 = image2latent(image_gt_2, vae = self.vae, dtype=weight_dtype)
|
398 |
+
latents0_2 = latents0_2.repeat(train_batch_size,1,1,1)
|
399 |
+
|
400 |
+
for step in range(max_emb_train_steps):
|
401 |
+
with accelerator.accumulate(self.text_encoder):
|
402 |
+
latents_1 = latents0_1.clone().detach()
|
403 |
+
noise_1 = torch.randn_like(latents_1)
|
404 |
+
|
405 |
+
latents_2 = latents0_2.clone().detach()
|
406 |
+
noise_2 = torch.randn_like(latents_2)
|
407 |
+
|
408 |
+
bsz = latents_1.shape[0]
|
409 |
+
|
410 |
+
timesteps_1 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_1.device)
|
411 |
+
timesteps_1 = timesteps_1.long()
|
412 |
+
noisy_latents_1 = noise_scheduler.add_noise(latents_1, noise_1, timesteps_1)
|
413 |
+
|
414 |
+
timesteps_2 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_2.device)
|
415 |
+
timesteps_2 = timesteps_2.long()
|
416 |
+
noisy_latents_2 = noise_scheduler.add_noise(latents_2, noise_2, timesteps_2)
|
417 |
+
|
418 |
+
register_attention_disentangled_control(self.unet, decom_controller_1)
|
419 |
+
encoder_hidden_states_list_1 = sd_prepare_input_decom(
|
420 |
+
set_string_list_1,
|
421 |
+
self.tokenizer,
|
422 |
+
self.text_encoder,
|
423 |
+
length = 40,
|
424 |
+
bsz = train_batch_size,
|
425 |
+
weight_dtype = weight_dtype
|
426 |
+
)
|
427 |
+
|
428 |
+
model_pred_1 = self.unet(
|
429 |
+
noisy_latents_1,
|
430 |
+
timesteps_1,
|
431 |
+
encoder_hidden_states=encoder_hidden_states_list_1,
|
432 |
+
).sample
|
433 |
+
|
434 |
+
register_attention_disentangled_control(self.unet, decom_controller_2)
|
435 |
+
# import pdb; pdb.set_trace()
|
436 |
+
encoder_hidden_states_list_2= sd_prepare_input_decom(
|
437 |
+
set_string_list_2,
|
438 |
+
self.tokenizer,
|
439 |
+
self.text_encoder,
|
440 |
+
length = 40,
|
441 |
+
bsz = train_batch_size,
|
442 |
+
weight_dtype = weight_dtype
|
443 |
+
)
|
444 |
+
|
445 |
+
model_pred_2 = self.unet(
|
446 |
+
noisy_latents_2,
|
447 |
+
timesteps_2,
|
448 |
+
encoder_hidden_states = encoder_hidden_states_list_2,
|
449 |
+
).sample
|
450 |
+
|
451 |
+
loss_1 = F.mse_loss(model_pred_1.float(), noise_1.float(), reduction="mean") /2
|
452 |
+
loss_2 = F.mse_loss(model_pred_2.float(), noise_2.float(), reduction="mean") /2
|
453 |
+
loss = loss_1 + loss_2
|
454 |
+
accelerator.backward(loss)
|
455 |
+
optimizer.step()
|
456 |
+
optimizer.zero_grad()
|
457 |
+
|
458 |
+
index_no_updates = torch.ones((len(self.tokenizer),), dtype=torch.bool)
|
459 |
+
index_no_updates[self.min_added_id : self.max_added_id + 1] = False
|
460 |
+
with torch.no_grad():
|
461 |
+
accelerator.unwrap_model(self.text_encoder).get_input_embeddings().weight[
|
462 |
+
index_no_updates] = orig_embeds_params_1[index_no_updates]
|
463 |
+
|
464 |
+
logs = {"loss": loss.detach().item(), "lr": embedding_learning_rate}
|
465 |
+
progress_bar.set_postfix(**logs)
|
466 |
+
accelerator.log(logs, step=global_step)
|
467 |
+
if accelerator.sync_gradients:
|
468 |
+
progress_bar.update(1)
|
469 |
+
global_step += 1
|
470 |
+
|
471 |
+
if global_step >= max_emb_train_steps:
|
472 |
+
break
|
473 |
+
accelerator.wait_for_everyone()
|
474 |
+
accelerator.end_training()
|
475 |
+
self.text_encoder = accelerator.unwrap_model(self.text_encoder) .to(dtype = weight_dtype)
|
476 |
+
|
477 |
+
def train_model_2imgs(
|
478 |
+
self,
|
479 |
+
image_gt_1,
|
480 |
+
image_gt_2,
|
481 |
+
set_string_list_1,
|
482 |
+
set_string_list_2,
|
483 |
+
gradient_accumulation_steps = 5,
|
484 |
+
max_diffusion_train_steps = 100,
|
485 |
+
diffusion_model_learning_rate = 1e-5,
|
486 |
+
train_batch_size = 1,
|
487 |
+
train_full_lora = False,
|
488 |
+
lora_rank = 4,
|
489 |
+
lora_alpha = 4
|
490 |
+
):
|
491 |
+
self.unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet").to(device)
|
492 |
+
self.unet.ca_dim = 768
|
493 |
+
decom_controller_1 = GroupedCAController(mask_list = self.mask_list)
|
494 |
+
decom_controller_2 = GroupedCAController(mask_list = self.mask_list_2)
|
495 |
+
|
496 |
+
mixed_precision = "fp16"
|
497 |
+
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps,mixed_precision=mixed_precision)
|
498 |
+
|
499 |
+
weight_dtype = torch.float32
|
500 |
+
if accelerator.mixed_precision == "fp16":
|
501 |
+
weight_dtype = torch.float16
|
502 |
+
elif accelerator.mixed_precision == "bf16":
|
503 |
+
weight_dtype = torch.bfloat16
|
504 |
+
|
505 |
+
|
506 |
+
self.vae.requires_grad_(False)
|
507 |
+
self.vae.to(device, dtype=weight_dtype)
|
508 |
+
self.unet.requires_grad_(False)
|
509 |
+
self.unet.train()
|
510 |
+
|
511 |
+
self.text_encoder.requires_grad_(False)
|
512 |
+
|
513 |
+
if not train_full_lora:
|
514 |
+
trainable_params_list = []
|
515 |
+
for name, module in self.unet.named_modules():
|
516 |
+
module_name = type(module).__name__
|
517 |
+
if module_name == "Attention":
|
518 |
+
if module.to_k.in_features == self.unet.ca_dim: # this is cross attention:
|
519 |
+
module.to_k.weight.requires_grad = True
|
520 |
+
trainable_params_list.append(module.to_k.weight)
|
521 |
+
if module.to_k.bias is not None:
|
522 |
+
module.to_k.bias.requires_grad = True
|
523 |
+
trainable_params_list.append(module.to_k.bias)
|
524 |
+
|
525 |
+
module.to_v.weight.requires_grad = True
|
526 |
+
trainable_params_list.append(module.to_v.weight)
|
527 |
+
if module.to_v.bias is not None:
|
528 |
+
module.to_v.bias.requires_grad = True
|
529 |
+
trainable_params_list.append(module.to_v.bias)
|
530 |
+
module.to_q.weight.requires_grad = True
|
531 |
+
trainable_params_list.append(module.to_q.weight)
|
532 |
+
if module.to_q.bias is not None:
|
533 |
+
module.to_q.bias.requires_grad = True
|
534 |
+
trainable_params_list.append(module.to_q.bias)
|
535 |
+
else:
|
536 |
+
unet_lora_config = LoraConfig(
|
537 |
+
r = lora_rank,
|
538 |
+
lora_alpha = lora_alpha,
|
539 |
+
init_lora_weights="gaussian",
|
540 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
541 |
+
)
|
542 |
+
self.unet.add_adapter(unet_lora_config)
|
543 |
+
print("training full parameters using lora!")
|
544 |
+
trainable_params_list = list(filter(lambda p: p.requires_grad, self.unet.parameters()))
|
545 |
+
|
546 |
+
self.text_encoder.to(device, dtype=weight_dtype)
|
547 |
+
optimizer = torch.optim.AdamW(trainable_params_list, lr=diffusion_model_learning_rate)
|
548 |
+
self.unet, optimizer = accelerator.prepare(self.unet, optimizer)
|
549 |
+
psum2 = sum(p.numel() for p in trainable_params_list)
|
550 |
+
|
551 |
+
effective_diffusion_train_steps = max_diffusion_train_steps // gradient_accumulation_steps
|
552 |
+
if accelerator.is_main_process:
|
553 |
+
accelerator.init_trackers("ModelFt", config={
|
554 |
+
"diffusion_model_learning_rate": diffusion_model_learning_rate,
|
555 |
+
"diffusion_model_optimization_steps": effective_diffusion_train_steps,
|
556 |
+
})
|
557 |
+
|
558 |
+
global_step = 0
|
559 |
+
progress_bar = tqdm(range(0, effective_diffusion_train_steps),initial=global_step, desc="ModelSteps")
|
560 |
+
noise_scheduler = DDPMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
|
561 |
+
|
562 |
+
latents0_1 = image2latent(image_gt_1, vae = self.vae, dtype=weight_dtype)
|
563 |
+
latents0_1 = latents0_1.repeat(train_batch_size, 1, 1, 1)
|
564 |
+
|
565 |
+
latents0_2 = image2latent(image_gt_2, vae = self.vae, dtype=weight_dtype)
|
566 |
+
latents0_2 = latents0_2.repeat(train_batch_size,1, 1, 1)
|
567 |
+
|
568 |
+
with torch.no_grad():
|
569 |
+
encoder_hidden_states_list_1 = sd_prepare_input_decom(
|
570 |
+
set_string_list_1,
|
571 |
+
self.tokenizer,
|
572 |
+
self.text_encoder,
|
573 |
+
length = 40,
|
574 |
+
bsz = train_batch_size,
|
575 |
+
weight_dtype = weight_dtype
|
576 |
+
)
|
577 |
+
encoder_hidden_states_list_2 = sd_prepare_input_decom(
|
578 |
+
set_string_list_2,
|
579 |
+
self.tokenizer,
|
580 |
+
self.text_encoder,
|
581 |
+
length = 40,
|
582 |
+
bsz = train_batch_size,
|
583 |
+
weight_dtype = weight_dtype
|
584 |
+
)
|
585 |
+
|
586 |
+
for _ in range(max_diffusion_train_steps):
|
587 |
+
with accelerator.accumulate(self.unet):
|
588 |
+
latents_1 = latents0_1.clone().detach()
|
589 |
+
noise_1 = torch.randn_like(latents_1)
|
590 |
+
bsz = latents_1.shape[0]
|
591 |
+
timesteps_1 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_1.device)
|
592 |
+
timesteps_1 = timesteps_1.long()
|
593 |
+
noisy_latents_1 = noise_scheduler.add_noise(latents_1, noise_1, timesteps_1)
|
594 |
+
|
595 |
+
latents_2 = latents0_2.clone().detach()
|
596 |
+
noise_2 = torch.randn_like(latents_2)
|
597 |
+
bsz = latents_2.shape[0]
|
598 |
+
timesteps_2 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_2.device)
|
599 |
+
timesteps_2 = timesteps_2.long()
|
600 |
+
noisy_latents_2 = noise_scheduler.add_noise(latents_2, noise_2, timesteps_2)
|
601 |
+
|
602 |
+
register_attention_disentangled_control(self.unet, decom_controller_1)
|
603 |
+
model_pred_1 = self.unet(
|
604 |
+
noisy_latents_1,
|
605 |
+
timesteps_1,
|
606 |
+
encoder_hidden_states = encoder_hidden_states_list_1,
|
607 |
+
).sample
|
608 |
+
|
609 |
+
register_attention_disentangled_control(self.unet, decom_controller_2)
|
610 |
+
model_pred_2 = self.unet(
|
611 |
+
noisy_latents_2,
|
612 |
+
timesteps_2,
|
613 |
+
encoder_hidden_states = encoder_hidden_states_list_2,
|
614 |
+
).sample
|
615 |
+
|
616 |
+
loss_1 = F.mse_loss(model_pred_1.float(), noise_1.float(), reduction="mean")
|
617 |
+
loss_2 = F.mse_loss(model_pred_2.float(), noise_2.float(), reduction="mean")
|
618 |
+
loss = loss_1 + loss_2
|
619 |
+
accelerator.backward(loss)
|
620 |
+
optimizer.step()
|
621 |
+
optimizer.zero_grad()
|
622 |
+
|
623 |
+
|
624 |
+
logs = {"loss": loss.detach().item(), "lr": diffusion_model_learning_rate}
|
625 |
+
progress_bar.set_postfix(**logs)
|
626 |
+
accelerator.log(logs, step=global_step)
|
627 |
+
if accelerator.sync_gradients:
|
628 |
+
progress_bar.update(1)
|
629 |
+
global_step += 1
|
630 |
+
|
631 |
+
if global_step >=max_diffusion_train_steps:
|
632 |
+
break
|
633 |
+
accelerator.wait_for_everyone()
|
634 |
+
accelerator.end_training()
|
635 |
+
self.unet = accelerator.unwrap_model(self.unet).to(dtype = weight_dtype)
|
636 |
+
|
637 |
+
@torch.no_grad()
|
638 |
+
def backward_zT_to_z0_euler_decom(
|
639 |
+
self,
|
640 |
+
zT,
|
641 |
+
cond_emb_list,
|
642 |
+
uncond_emb=None,
|
643 |
+
guidance_scale = 1,
|
644 |
+
num_sampling_steps = 20,
|
645 |
+
cond_controller = None,
|
646 |
+
uncond_controller = None,
|
647 |
+
mask_hard = None,
|
648 |
+
mask_soft = None,
|
649 |
+
orig_image = None,
|
650 |
+
return_intermediate = False,
|
651 |
+
strength = 1
|
652 |
+
):
|
653 |
+
latent_cur = zT
|
654 |
+
if uncond_emb is None:
|
655 |
+
uncond_emb = torch.zeros(zT.shape[0], 77, self.unet.ca_dim).to(dtype = zT.dtype, device = zT.device)
|
656 |
+
|
657 |
+
if mask_soft is not None:
|
658 |
+
init_latents_orig = image2latent(orig_image, self.vae, dtype=self.vae.dtype)
|
659 |
+
length = init_latents_orig.shape[-1]
|
660 |
+
noise = torch.randn_like(init_latents_orig)
|
661 |
+
mask_soft = torch.nn.functional.interpolate(mask_soft.float().unsqueeze(0).unsqueeze(0), (length, length)).to(self.vae.dtype) ###
|
662 |
+
|
663 |
+
if mask_hard is not None:
|
664 |
+
init_latents_orig = image2latent(orig_image, self.vae, dtype=self.vae.dtype)
|
665 |
+
length = init_latents_orig.shape[-1]
|
666 |
+
noise = torch.randn_like(init_latents_orig)
|
667 |
+
mask_hard = torch.nn.functional.interpolate(mask_hard.float().unsqueeze(0).unsqueeze(0), (length, length)).to(self.vae.dtype) ###
|
668 |
+
|
669 |
+
intermediate_list = [latent_cur.detach()]
|
670 |
+
for i in tqdm(range(num_sampling_steps)):
|
671 |
+
t = self.scheduler.timesteps[i]
|
672 |
+
latent_input = self.scheduler.scale_model_input(latent_cur, t)
|
673 |
+
|
674 |
+
register_attention_disentangled_control(self.unet, uncond_controller)
|
675 |
+
noise_pred_uncond = self.unet(
|
676 |
+
latent_input,
|
677 |
+
t,
|
678 |
+
encoder_hidden_states=uncond_emb,
|
679 |
+
).sample
|
680 |
+
|
681 |
+
register_attention_disentangled_control(self.unet, cond_controller)
|
682 |
+
noise_pred_cond = self.unet(
|
683 |
+
latent_input,
|
684 |
+
t,
|
685 |
+
encoder_hidden_states=cond_emb_list,
|
686 |
+
).sample
|
687 |
+
|
688 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
689 |
+
latent_cur = self.scheduler.step(noise_pred, t, latent_cur, generator = None, return_dict=False)[0]
|
690 |
+
|
691 |
+
if return_intermediate is True:
|
692 |
+
intermediate_list.append(latent_cur)
|
693 |
+
|
694 |
+
if mask_hard is not None and mask_soft is not None and i <= strength *num_sampling_steps:
|
695 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
696 |
+
mask = mask_soft.to(latent_cur.device, latent_cur.dtype) + mask_hard.to(latent_cur.device, latent_cur.dtype)
|
697 |
+
latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
|
698 |
+
|
699 |
+
elif mask_hard is not None and mask_soft is not None and i > strength *num_sampling_steps:
|
700 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
701 |
+
mask = mask_hard.to(latent_cur.device, latent_cur.dtype)
|
702 |
+
latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
|
703 |
+
|
704 |
+
elif mask_hard is None and mask_soft is not None and i <= strength *num_sampling_steps:
|
705 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
706 |
+
mask = mask_soft.to(latent_cur.device, latent_cur.dtype)
|
707 |
+
latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
|
708 |
+
|
709 |
+
elif mask_hard is None and mask_soft is not None and i > strength *num_sampling_steps:
|
710 |
+
pass
|
711 |
+
|
712 |
+
elif mask_hard is not None and mask_soft is None:
|
713 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
714 |
+
mask = mask_hard.to(latent_cur.dtype)
|
715 |
+
latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
|
716 |
+
|
717 |
+
else: # hard and soft are both none
|
718 |
+
pass
|
719 |
+
|
720 |
+
if return_intermediate is True:
|
721 |
+
return latent_cur, intermediate_list
|
722 |
+
else:
|
723 |
+
return latent_cur
|
724 |
+
|
725 |
+
@torch.no_grad()
|
726 |
+
def sampling(
|
727 |
+
self,
|
728 |
+
set_string_list,
|
729 |
+
cond_controller = None,
|
730 |
+
uncond_controller = None,
|
731 |
+
guidance_scale = 7,
|
732 |
+
num_sampling_steps = 20,
|
733 |
+
mask_hard = None,
|
734 |
+
mask_soft = None,
|
735 |
+
orig_image = None,
|
736 |
+
strength = 1.,
|
737 |
+
num_imgs = 1,
|
738 |
+
normal_token_id_list = [],
|
739 |
+
seed = 1
|
740 |
+
):
|
741 |
+
weight_dtype = torch.float16
|
742 |
+
self.scheduler.set_timesteps(num_sampling_steps)
|
743 |
+
self.unet.to(device, dtype=weight_dtype)
|
744 |
+
self.vae.to(device, dtype=weight_dtype)
|
745 |
+
self.text_encoder.to(device, dtype=weight_dtype)
|
746 |
+
|
747 |
+
torch.manual_seed(seed)
|
748 |
+
torch.cuda.manual_seed(seed)
|
749 |
+
|
750 |
+
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
751 |
+
zT = torch.randn(num_imgs, 4, self.resolution//vae_scale_factor,self.resolution//vae_scale_factor).to(device,dtype=weight_dtype)
|
752 |
+
zT = zT * self.scheduler.init_noise_sigma
|
753 |
+
|
754 |
+
cond_emb_list = sd_prepare_input_decom(
|
755 |
+
set_string_list,
|
756 |
+
self.tokenizer,
|
757 |
+
self.text_encoder,
|
758 |
+
length = 40,
|
759 |
+
bsz = num_imgs,
|
760 |
+
weight_dtype = weight_dtype,
|
761 |
+
normal_token_id_list = normal_token_id_list
|
762 |
+
)
|
763 |
+
|
764 |
+
z0 = self.backward_zT_to_z0_euler_decom(zT, cond_emb_list,
|
765 |
+
guidance_scale = guidance_scale, num_sampling_steps = num_sampling_steps,
|
766 |
+
cond_controller = cond_controller, uncond_controller = uncond_controller,
|
767 |
+
mask_hard = mask_hard, mask_soft = mask_soft, orig_image = orig_image, strength = strength
|
768 |
+
)
|
769 |
+
x0 = latent2image(z0, vae = self.vae)
|
770 |
+
return x0
|
771 |
+
|
772 |
+
@torch.no_grad()
|
773 |
+
def inference_with_mask(
|
774 |
+
self,
|
775 |
+
save_path,
|
776 |
+
guidance_scale = 3,
|
777 |
+
num_sampling_steps = 50,
|
778 |
+
strength = 1,
|
779 |
+
mask_soft = None,
|
780 |
+
mask_hard= None,
|
781 |
+
orig_image=None,
|
782 |
+
mask_list = None,
|
783 |
+
num_imgs = 1,
|
784 |
+
seed = 1,
|
785 |
+
set_string_list = None
|
786 |
+
):
|
787 |
+
if mask_list is not None:
|
788 |
+
mask_list = [m.to(device) for m in mask_list]
|
789 |
+
else:
|
790 |
+
mask_list = self.mask_list
|
791 |
+
if set_string_list is not None:
|
792 |
+
self.set_string_list = set_string_list
|
793 |
+
|
794 |
+
if mask_hard is not None and mask_soft is not None:
|
795 |
+
check_mask_overlap_torch(mask_hard, mask_soft)
|
796 |
+
null_controller = DummyController()
|
797 |
+
decom_controller = GroupedCAController(mask_list = mask_list)
|
798 |
+
|
799 |
+
x0 = self.sampling(
|
800 |
+
self.set_string_list,
|
801 |
+
guidance_scale = guidance_scale,
|
802 |
+
num_sampling_steps = num_sampling_steps,
|
803 |
+
strength = strength,
|
804 |
+
cond_controller = decom_controller,
|
805 |
+
uncond_controller = null_controller,
|
806 |
+
mask_soft = mask_soft,
|
807 |
+
mask_hard = mask_hard,
|
808 |
+
orig_image = orig_image,
|
809 |
+
num_imgs = num_imgs,
|
810 |
+
seed = seed
|
811 |
+
)
|
812 |
+
save_images(x0, save_path)
|
813 |
+
return x0
|
pipeline_dedit_sdxl.py
ADDED
@@ -0,0 +1,875 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from utils import import_model_class_from_model_name_or_path
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
from diffusers import (
|
5 |
+
AutoencoderKL,
|
6 |
+
DDPMScheduler,
|
7 |
+
StableDiffusionXLPipeline,
|
8 |
+
UNet2DConditionModel,
|
9 |
+
)
|
10 |
+
from accelerate import Accelerator
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
from utils import sdxl_prepare_input_decom, save_images
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import itertools
|
15 |
+
from peft import LoraConfig
|
16 |
+
from controller import GroupedCAController, register_attention_disentangled_control, DummyController
|
17 |
+
from utils import image2latent, latent2image
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
from utils_mask import check_mask_overlap_torch
|
20 |
+
|
21 |
+
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
22 |
+
max_length = 40
|
23 |
+
class DEditSDXLPipeline:
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
mask_list,
|
27 |
+
mask_label_list,
|
28 |
+
mask_list_2 = None,
|
29 |
+
mask_label_list_2 = None,
|
30 |
+
resolution = 1024,
|
31 |
+
num_tokens = 1
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
35 |
+
self.model_id = model_id
|
36 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", use_fast=False)
|
37 |
+
self.tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", use_fast=False)
|
38 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(model_id, subfolder = "text_encoder")
|
39 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(model_id, subfolder="text_encoder_2")
|
40 |
+
self.text_encoder = text_encoder_cls_one.from_pretrained(model_id, subfolder="text_encoder" ).to(device)
|
41 |
+
self.text_encoder_2 = text_encoder_cls_two.from_pretrained(model_id, subfolder="text_encoder_2").to(device)
|
42 |
+
self.unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet" )
|
43 |
+
self.unet.ca_dim = 2048
|
44 |
+
self.vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix")
|
45 |
+
self.scheduler = DDPMScheduler.from_pretrained(model_id , subfolder="scheduler")
|
46 |
+
|
47 |
+
self.mixed_precision = "fp16"
|
48 |
+
self.resolution = resolution
|
49 |
+
self.num_tokens = num_tokens
|
50 |
+
|
51 |
+
self.mask_list = mask_list
|
52 |
+
self.mask_label_list = mask_label_list
|
53 |
+
notation_token_list = [phrase.split(" ")[-1] for phrase in mask_label_list]
|
54 |
+
placeholder_token_list = ["#"+word+"{}".format(widx) for widx, word in enumerate(notation_token_list)]
|
55 |
+
self.set_string_list, placeholder_token_ids = self.add_tokens(placeholder_token_list)
|
56 |
+
self.min_added_id = min(placeholder_token_ids)
|
57 |
+
self.max_added_id = max(placeholder_token_ids)
|
58 |
+
|
59 |
+
if mask_list_2 is not None:
|
60 |
+
self.mask_list_2 = mask_list_2
|
61 |
+
self.mask_label_list_2 = mask_label_list_2
|
62 |
+
notation_token_list_2 = [phrase.split(" ")[-1] for phrase in mask_label_list_2]
|
63 |
+
|
64 |
+
placeholder_token_list_2 = ["$"+word+"{}".format(widx) for widx, word in enumerate(notation_token_list_2)]
|
65 |
+
self.set_string_list_2, placeholder_token_ids_2 = self.add_tokens(placeholder_token_list_2)
|
66 |
+
self.max_added_id = max(placeholder_token_ids_2)
|
67 |
+
|
68 |
+
def add_tokens_text_encoder_random_init(self, placeholder_token, num_tokens=1):
|
69 |
+
# Add the placeholder token in tokenizer
|
70 |
+
placeholder_tokens = [placeholder_token]
|
71 |
+
# add dummy tokens for multi-vector
|
72 |
+
additional_tokens = []
|
73 |
+
for i in range(1, num_tokens):
|
74 |
+
additional_tokens.append(f"{placeholder_token}_{i}")
|
75 |
+
placeholder_tokens += additional_tokens
|
76 |
+
num_added_tokens = self.tokenizer.add_tokens(placeholder_tokens) # 49408
|
77 |
+
num_added_tokens = self.tokenizer_2.add_tokens(placeholder_tokens) # 49408
|
78 |
+
|
79 |
+
if num_added_tokens != num_tokens:
|
80 |
+
raise ValueError(
|
81 |
+
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
|
82 |
+
" `placeholder_token` that is not already in the tokenizer."
|
83 |
+
)
|
84 |
+
placeholder_token_ids = self.tokenizer.convert_tokens_to_ids(placeholder_tokens)
|
85 |
+
placeholder_token_ids_2 = self.tokenizer_2.convert_tokens_to_ids(placeholder_tokens)
|
86 |
+
assert placeholder_token_ids == placeholder_token_ids_2, "Two text encoders are expected to have same vocabs"
|
87 |
+
|
88 |
+
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
|
89 |
+
token_embeds = self.text_encoder.get_input_embeddings().weight.data
|
90 |
+
std, mean = torch.std_mean(token_embeds)
|
91 |
+
with torch.no_grad():
|
92 |
+
for token_id in placeholder_token_ids:
|
93 |
+
token_embeds[token_id] = torch.randn_like(token_embeds[token_id])*std + mean
|
94 |
+
|
95 |
+
self.text_encoder_2.resize_token_embeddings(len(self.tokenizer))
|
96 |
+
token_embeds = self.text_encoder_2.get_input_embeddings().weight.data
|
97 |
+
std, mean = torch.std_mean(token_embeds)
|
98 |
+
with torch.no_grad():
|
99 |
+
for token_id in placeholder_token_ids:
|
100 |
+
token_embeds[token_id] = torch.randn_like(token_embeds[token_id])*std + mean
|
101 |
+
|
102 |
+
set_string = " ".join(self.tokenizer.convert_ids_to_tokens(placeholder_token_ids))
|
103 |
+
|
104 |
+
return set_string, placeholder_token_ids
|
105 |
+
|
106 |
+
def add_tokens(self, placeholder_token_list):
|
107 |
+
set_string_list = []
|
108 |
+
placeholder_token_ids_list = []
|
109 |
+
for str_idx in range(len(placeholder_token_list)):
|
110 |
+
placeholder_token = placeholder_token_list[str_idx]
|
111 |
+
set_string, placeholder_token_ids = self.add_tokens_text_encoder_random_init(placeholder_token, num_tokens=self.num_tokens)
|
112 |
+
set_string_list.append(set_string)
|
113 |
+
placeholder_token_ids_list.append(placeholder_token_ids)
|
114 |
+
placeholder_token_ids = list(itertools.chain(*placeholder_token_ids_list))
|
115 |
+
return set_string_list, placeholder_token_ids
|
116 |
+
|
117 |
+
def train_emb(
|
118 |
+
self,
|
119 |
+
image_gt,
|
120 |
+
set_string_list,
|
121 |
+
gradient_accumulation_steps = 5,
|
122 |
+
embedding_learning_rate = 1e-4,
|
123 |
+
max_emb_train_steps = 100,
|
124 |
+
train_batch_size = 1,
|
125 |
+
train_full_lora = False
|
126 |
+
):
|
127 |
+
decom_controller = GroupedCAController(mask_list = self.mask_list)
|
128 |
+
register_attention_disentangled_control(self.unet, decom_controller)
|
129 |
+
|
130 |
+
accelerator = Accelerator(mixed_precision=self.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps)
|
131 |
+
self.vae.requires_grad_(False)
|
132 |
+
self.unet.requires_grad_(False)
|
133 |
+
|
134 |
+
self.text_encoder.requires_grad_(True)
|
135 |
+
self.text_encoder_2.requires_grad_(True)
|
136 |
+
|
137 |
+
self.text_encoder.text_model.encoder.requires_grad_(False)
|
138 |
+
self.text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
139 |
+
self.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
140 |
+
|
141 |
+
self.text_encoder_2.text_model.encoder.requires_grad_(False)
|
142 |
+
self.text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
|
143 |
+
self.text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
|
144 |
+
|
145 |
+
weight_dtype = torch.float32
|
146 |
+
if accelerator.mixed_precision == "fp16":
|
147 |
+
weight_dtype = torch.float16
|
148 |
+
elif accelerator.mixed_precision == "bf16":
|
149 |
+
weight_dtype = torch.bfloat16
|
150 |
+
|
151 |
+
self.unet.to(device, dtype=weight_dtype)
|
152 |
+
self.vae.to(device, dtype=weight_dtype)
|
153 |
+
|
154 |
+
trainable_embmat_list_1 = [param for param in self.text_encoder.get_input_embeddings().parameters()]
|
155 |
+
trainable_embmat_list_2 = [param for param in self.text_encoder_2.get_input_embeddings().parameters()]
|
156 |
+
|
157 |
+
optimizer = torch.optim.AdamW(trainable_embmat_list_1 + trainable_embmat_list_2, lr=embedding_learning_rate)
|
158 |
+
|
159 |
+
self.text_encoder, self.text_encoder_2, optimizer = accelerator.prepare(self.text_encoder, self.text_encoder_2, optimizer)
|
160 |
+
|
161 |
+
orig_embeds_params_1 = accelerator.unwrap_model(self.text_encoder) .get_input_embeddings().weight.data.clone()
|
162 |
+
orig_embeds_params_2 = accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight.data.clone()
|
163 |
+
|
164 |
+
self.text_encoder.train()
|
165 |
+
self.text_encoder_2.train()
|
166 |
+
|
167 |
+
effective_emb_train_steps = max_emb_train_steps//gradient_accumulation_steps
|
168 |
+
|
169 |
+
if accelerator.is_main_process:
|
170 |
+
accelerator.init_trackers("DEdit EmbSteps", config={
|
171 |
+
"embedding_learning_rate": embedding_learning_rate,
|
172 |
+
"text_embedding_optimization_steps": effective_emb_train_steps,
|
173 |
+
})
|
174 |
+
global_step = 0
|
175 |
+
noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0" , subfolder="scheduler")
|
176 |
+
progress_bar = tqdm(range(0, effective_emb_train_steps), initial = global_step, desc="EmbSteps")
|
177 |
+
latents0 = image2latent(image_gt, vae = self.vae, dtype=weight_dtype)
|
178 |
+
latents0 = latents0.repeat(train_batch_size, 1, 1, 1)
|
179 |
+
|
180 |
+
for _ in range(max_emb_train_steps):
|
181 |
+
with accelerator.accumulate(self.text_encoder, self.text_encoder_2):
|
182 |
+
latents = latents0.clone().detach()
|
183 |
+
noise = torch.randn_like(latents)
|
184 |
+
bsz = latents.shape[0]
|
185 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
186 |
+
timesteps = timesteps.long()
|
187 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
188 |
+
encoder_hidden_states_list, add_text_embeds, add_time_ids = sdxl_prepare_input_decom(
|
189 |
+
set_string_list,
|
190 |
+
self.tokenizer,
|
191 |
+
self.tokenizer_2,
|
192 |
+
self.text_encoder,
|
193 |
+
self.text_encoder_2,
|
194 |
+
length = max_length,
|
195 |
+
bsz = train_batch_size,
|
196 |
+
weight_dtype = weight_dtype
|
197 |
+
)
|
198 |
+
|
199 |
+
model_pred = self.unet(
|
200 |
+
noisy_latents,
|
201 |
+
timesteps,
|
202 |
+
encoder_hidden_states = encoder_hidden_states_list,
|
203 |
+
cross_attention_kwargs = None,
|
204 |
+
added_cond_kwargs={"text_embeds": add_text_embeds, "time_ids": add_time_ids},
|
205 |
+
return_dict=False
|
206 |
+
)[0]
|
207 |
+
loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
|
208 |
+
accelerator.backward(loss)
|
209 |
+
optimizer.step()
|
210 |
+
optimizer.zero_grad()
|
211 |
+
|
212 |
+
index_no_updates = torch.ones((len(self.tokenizer),), dtype=torch.bool)
|
213 |
+
index_no_updates[self.min_added_id : self.max_added_id + 1] = False
|
214 |
+
with torch.no_grad():
|
215 |
+
accelerator.unwrap_model(self.text_encoder).get_input_embeddings().weight[
|
216 |
+
index_no_updates] = orig_embeds_params_1[index_no_updates]
|
217 |
+
|
218 |
+
index_no_updates = torch.ones((len(self.tokenizer_2),), dtype=torch.bool)
|
219 |
+
index_no_updates[self.min_added_id : self.max_added_id + 1] = False
|
220 |
+
with torch.no_grad():
|
221 |
+
accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight[
|
222 |
+
index_no_updates] = orig_embeds_params_2[index_no_updates]
|
223 |
+
|
224 |
+
logs = {"loss": loss.detach().item(), "lr": embedding_learning_rate}
|
225 |
+
progress_bar.set_postfix(**logs)
|
226 |
+
accelerator.log(logs, step=global_step)
|
227 |
+
if accelerator.sync_gradients:
|
228 |
+
progress_bar.update(1)
|
229 |
+
global_step += 1
|
230 |
+
|
231 |
+
if global_step >= max_emb_train_steps:
|
232 |
+
break
|
233 |
+
accelerator.wait_for_everyone()
|
234 |
+
accelerator.end_training()
|
235 |
+
self.text_encoder = accelerator.unwrap_model(self.text_encoder).to(dtype = weight_dtype)
|
236 |
+
self.text_encoder_2 = accelerator.unwrap_model(self.text_encoder_2).to(dtype = weight_dtype)
|
237 |
+
|
238 |
+
def train_model(
|
239 |
+
self,
|
240 |
+
image_gt,
|
241 |
+
set_string_list,
|
242 |
+
gradient_accumulation_steps = 5,
|
243 |
+
max_diffusion_train_steps = 100,
|
244 |
+
diffusion_model_learning_rate = 1e-5,
|
245 |
+
train_batch_size = 1,
|
246 |
+
train_full_lora = False,
|
247 |
+
lora_rank = 4,
|
248 |
+
lora_alpha = 4
|
249 |
+
):
|
250 |
+
self.unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet").to(device)
|
251 |
+
self.unet.ca_dim = 2048
|
252 |
+
decom_controller = GroupedCAController(mask_list = self.mask_list)
|
253 |
+
register_attention_disentangled_control(self.unet, decom_controller)
|
254 |
+
|
255 |
+
mixed_precision = "fp16"
|
256 |
+
accelerator = Accelerator(gradient_accumulation_steps = gradient_accumulation_steps, mixed_precision = mixed_precision)
|
257 |
+
|
258 |
+
weight_dtype = torch.float32
|
259 |
+
if accelerator.mixed_precision == "fp16":
|
260 |
+
weight_dtype = torch.float16
|
261 |
+
elif accelerator.mixed_precision == "bf16":
|
262 |
+
weight_dtype = torch.bfloat16
|
263 |
+
|
264 |
+
self.vae.requires_grad_(False)
|
265 |
+
self.vae.to(device, dtype=weight_dtype)
|
266 |
+
|
267 |
+
self.unet.requires_grad_(False)
|
268 |
+
self.unet.train()
|
269 |
+
|
270 |
+
self.text_encoder.requires_grad_(False)
|
271 |
+
self.text_encoder_2.requires_grad_(False)
|
272 |
+
|
273 |
+
if not train_full_lora:
|
274 |
+
trainable_params_list = []
|
275 |
+
for _, module in self.unet.named_modules():
|
276 |
+
module_name = type(module).__name__
|
277 |
+
if module_name == "Attention":
|
278 |
+
if module.to_k.in_features == 2048: # this is cross attention:
|
279 |
+
module.to_k.weight.requires_grad = True
|
280 |
+
trainable_params_list.append(module.to_k.weight)
|
281 |
+
if module.to_k.bias is not None:
|
282 |
+
module.to_k.bias.requires_grad = True
|
283 |
+
trainable_params_list.append(module.to_k.bias)
|
284 |
+
module.to_v.weight.requires_grad = True
|
285 |
+
trainable_params_list.append(module.to_v.weight)
|
286 |
+
if module.to_v.bias is not None:
|
287 |
+
module.to_v.bias.requires_grad = True
|
288 |
+
trainable_params_list.append(module.to_v.bias)
|
289 |
+
module.to_q.weight.requires_grad = True
|
290 |
+
trainable_params_list.append(module.to_q.weight)
|
291 |
+
if module.to_q.bias is not None:
|
292 |
+
module.to_q.bias.requires_grad = True
|
293 |
+
trainable_params_list.append(module.to_q.bias)
|
294 |
+
else:
|
295 |
+
unet_lora_config = LoraConfig(
|
296 |
+
r=lora_rank,
|
297 |
+
lora_alpha=lora_alpha,
|
298 |
+
init_lora_weights="gaussian",
|
299 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
300 |
+
)
|
301 |
+
self.unet.add_adapter(unet_lora_config)
|
302 |
+
print("training full parameters using lora!")
|
303 |
+
trainable_params_list = list(filter(lambda p: p.requires_grad, self.unet.parameters()))
|
304 |
+
|
305 |
+
self.text_encoder.to(device, dtype=weight_dtype)
|
306 |
+
self.text_encoder_2.to(device, dtype=weight_dtype)
|
307 |
+
optimizer = torch.optim.AdamW(trainable_params_list, lr=diffusion_model_learning_rate)
|
308 |
+
self.unet, optimizer = accelerator.prepare(self.unet, optimizer)
|
309 |
+
psum2 = sum(p.numel() for p in trainable_params_list)
|
310 |
+
|
311 |
+
effective_diffusion_train_steps = max_diffusion_train_steps // gradient_accumulation_steps
|
312 |
+
if accelerator.is_main_process:
|
313 |
+
accelerator.init_trackers("textual_inversion", config={
|
314 |
+
"diffusion_model_learning_rate": diffusion_model_learning_rate,
|
315 |
+
"diffusion_model_optimization_steps": effective_diffusion_train_steps,
|
316 |
+
})
|
317 |
+
|
318 |
+
global_step = 0
|
319 |
+
progress_bar = tqdm( range(0, effective_diffusion_train_steps),initial=global_step, desc="ModelSteps")
|
320 |
+
|
321 |
+
noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0" , subfolder="scheduler")
|
322 |
+
|
323 |
+
latents0 = image2latent(image_gt, vae = self.vae, dtype=weight_dtype)
|
324 |
+
latents0 = latents0.repeat(train_batch_size, 1, 1, 1)
|
325 |
+
|
326 |
+
with torch.no_grad():
|
327 |
+
encoder_hidden_states_list, add_text_embeds, add_time_ids = sdxl_prepare_input_decom(
|
328 |
+
set_string_list,
|
329 |
+
self.tokenizer,
|
330 |
+
self.tokenizer_2,
|
331 |
+
self.text_encoder,
|
332 |
+
self.text_encoder_2,
|
333 |
+
length = max_length,
|
334 |
+
bsz = train_batch_size,
|
335 |
+
weight_dtype = weight_dtype
|
336 |
+
)
|
337 |
+
|
338 |
+
for _ in range(max_diffusion_train_steps):
|
339 |
+
with accelerator.accumulate(self.unet):
|
340 |
+
latents = latents0.clone().detach()
|
341 |
+
noise = torch.randn_like(latents)
|
342 |
+
bsz = latents.shape[0]
|
343 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
344 |
+
timesteps = timesteps.long()
|
345 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
346 |
+
model_pred = self.unet(
|
347 |
+
noisy_latents,
|
348 |
+
timesteps,
|
349 |
+
encoder_hidden_states=encoder_hidden_states_list,
|
350 |
+
cross_attention_kwargs=None, return_dict=False,
|
351 |
+
added_cond_kwargs={"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
352 |
+
)[0]
|
353 |
+
loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
|
354 |
+
accelerator.backward(loss)
|
355 |
+
optimizer.step()
|
356 |
+
optimizer.zero_grad()
|
357 |
+
|
358 |
+
logs = {"loss": loss.detach().item(), "lr": diffusion_model_learning_rate}
|
359 |
+
progress_bar.set_postfix(**logs)
|
360 |
+
accelerator.log(logs, step=global_step)
|
361 |
+
if accelerator.sync_gradients:
|
362 |
+
progress_bar.update(1)
|
363 |
+
global_step += 1
|
364 |
+
if global_step >=max_diffusion_train_steps:
|
365 |
+
break
|
366 |
+
accelerator.wait_for_everyone()
|
367 |
+
accelerator.end_training()
|
368 |
+
self.unet = accelerator.unwrap_model(self.unet).to(dtype = weight_dtype)
|
369 |
+
|
370 |
+
def train_emb_2imgs(
|
371 |
+
self,
|
372 |
+
image_gt_1,
|
373 |
+
image_gt_2,
|
374 |
+
set_string_list_1,
|
375 |
+
set_string_list_2,
|
376 |
+
gradient_accumulation_steps = 5,
|
377 |
+
embedding_learning_rate = 1e-4,
|
378 |
+
max_emb_train_steps = 100,
|
379 |
+
train_batch_size = 1,
|
380 |
+
train_full_lora = False
|
381 |
+
):
|
382 |
+
decom_controller_1 = GroupedCAController(mask_list = self.mask_list)
|
383 |
+
decom_controller_2 = GroupedCAController(mask_list = self.mask_list_2)
|
384 |
+
accelerator = Accelerator(mixed_precision=self.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps)
|
385 |
+
self.vae.requires_grad_(False)
|
386 |
+
self.unet.requires_grad_(False)
|
387 |
+
|
388 |
+
self.text_encoder.requires_grad_(True)
|
389 |
+
self.text_encoder_2.requires_grad_(True)
|
390 |
+
|
391 |
+
self.text_encoder.text_model.encoder.requires_grad_(False)
|
392 |
+
self.text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
393 |
+
self.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
394 |
+
|
395 |
+
self.text_encoder_2.text_model.encoder.requires_grad_(False)
|
396 |
+
self.text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
|
397 |
+
self.text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
|
398 |
+
|
399 |
+
weight_dtype = torch.float32
|
400 |
+
if accelerator.mixed_precision == "fp16":
|
401 |
+
weight_dtype = torch.float16
|
402 |
+
elif accelerator.mixed_precision == "bf16":
|
403 |
+
weight_dtype = torch.bfloat16
|
404 |
+
|
405 |
+
self.unet.to(device, dtype=weight_dtype)
|
406 |
+
self.vae.to(device, dtype=weight_dtype)
|
407 |
+
|
408 |
+
|
409 |
+
trainable_embmat_list_1 = [param for param in self.text_encoder.get_input_embeddings().parameters()]
|
410 |
+
trainable_embmat_list_2 = [param for param in self.text_encoder_2.get_input_embeddings().parameters()]
|
411 |
+
|
412 |
+
optimizer = torch.optim.AdamW(trainable_embmat_list_1 + trainable_embmat_list_2, lr=embedding_learning_rate)
|
413 |
+
self.text_encoder, self.text_encoder_2, optimizer= accelerator.prepare(self.text_encoder, self.text_encoder_2, optimizer) ###
|
414 |
+
orig_embeds_params_1 = accelerator.unwrap_model(self.text_encoder) .get_input_embeddings().weight.data.clone()
|
415 |
+
orig_embeds_params_2 = accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight.data.clone()
|
416 |
+
|
417 |
+
self.text_encoder.train()
|
418 |
+
self.text_encoder_2.train()
|
419 |
+
|
420 |
+
effective_emb_train_steps = max_emb_train_steps//gradient_accumulation_steps
|
421 |
+
|
422 |
+
if accelerator.is_main_process:
|
423 |
+
accelerator.init_trackers("EmbFt", config={
|
424 |
+
"embedding_learning_rate": embedding_learning_rate,
|
425 |
+
"text_embedding_optimization_steps": effective_emb_train_steps,
|
426 |
+
})
|
427 |
+
|
428 |
+
global_step = 0
|
429 |
+
|
430 |
+
noise_scheduler = DDPMScheduler.from_pretrained(self.model_id , subfolder="scheduler")
|
431 |
+
progress_bar = tqdm(range(0, effective_emb_train_steps),initial=global_step,desc="EmbSteps")
|
432 |
+
latents0_1 = image2latent(image_gt_1, vae = self.vae, dtype=weight_dtype)
|
433 |
+
latents0_1 = latents0_1.repeat(train_batch_size,1,1,1)
|
434 |
+
|
435 |
+
latents0_2 = image2latent(image_gt_2, vae = self.vae, dtype=weight_dtype)
|
436 |
+
latents0_2 = latents0_2.repeat(train_batch_size,1,1,1)
|
437 |
+
|
438 |
+
for step in range(max_emb_train_steps):
|
439 |
+
with accelerator.accumulate(self.text_encoder, self.text_encoder_2):
|
440 |
+
latents_1 = latents0_1.clone().detach()
|
441 |
+
noise_1 = torch.randn_like(latents_1)
|
442 |
+
|
443 |
+
latents_2 = latents0_2.clone().detach()
|
444 |
+
noise_2 = torch.randn_like(latents_2)
|
445 |
+
|
446 |
+
bsz = latents_1.shape[0]
|
447 |
+
|
448 |
+
timesteps_1 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_1.device)
|
449 |
+
timesteps_1 = timesteps_1.long()
|
450 |
+
noisy_latents_1 = noise_scheduler.add_noise(latents_1, noise_1, timesteps_1)
|
451 |
+
|
452 |
+
timesteps_2 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_2.device)
|
453 |
+
timesteps_2 = timesteps_2.long()
|
454 |
+
noisy_latents_2 = noise_scheduler.add_noise(latents_2, noise_2, timesteps_2)
|
455 |
+
|
456 |
+
register_attention_disentangled_control(self.unet, decom_controller_1)
|
457 |
+
encoder_hidden_states_list_1, add_text_embeds_1, add_time_ids_1 = sdxl_prepare_input_decom(
|
458 |
+
set_string_list_1,
|
459 |
+
self.tokenizer,
|
460 |
+
self.tokenizer_2,
|
461 |
+
self.text_encoder,
|
462 |
+
self.text_encoder_2,
|
463 |
+
length = max_length,
|
464 |
+
bsz = train_batch_size,
|
465 |
+
weight_dtype = weight_dtype
|
466 |
+
)
|
467 |
+
|
468 |
+
model_pred_1 = self.unet(
|
469 |
+
noisy_latents_1,
|
470 |
+
timesteps_1,
|
471 |
+
encoder_hidden_states=encoder_hidden_states_list_1,
|
472 |
+
cross_attention_kwargs=None,
|
473 |
+
added_cond_kwargs={"text_embeds": add_text_embeds_1, "time_ids": add_time_ids_1},
|
474 |
+
return_dict=False
|
475 |
+
)[0]
|
476 |
+
|
477 |
+
register_attention_disentangled_control(self.unet, decom_controller_2)
|
478 |
+
# import pdb; pdb.set_trace()
|
479 |
+
encoder_hidden_states_list_2, add_text_embeds_2, add_time_ids_2 = sdxl_prepare_input_decom(
|
480 |
+
set_string_list_2,
|
481 |
+
self.tokenizer,
|
482 |
+
self.tokenizer_2,
|
483 |
+
self.text_encoder,
|
484 |
+
self.text_encoder_2,
|
485 |
+
length = max_length,
|
486 |
+
bsz = train_batch_size,
|
487 |
+
weight_dtype = weight_dtype
|
488 |
+
)
|
489 |
+
|
490 |
+
model_pred_2 = self.unet(
|
491 |
+
noisy_latents_2,
|
492 |
+
timesteps_2,
|
493 |
+
encoder_hidden_states = encoder_hidden_states_list_2,
|
494 |
+
cross_attention_kwargs=None,
|
495 |
+
added_cond_kwargs={"text_embeds": add_text_embeds_2, "time_ids": add_time_ids_2},
|
496 |
+
return_dict=False
|
497 |
+
)[0]
|
498 |
+
|
499 |
+
loss_1 = F.mse_loss(model_pred_1.float(), noise_1.float(), reduction="mean") /2
|
500 |
+
loss_2 = F.mse_loss(model_pred_2.float(), noise_2.float(), reduction="mean") /2
|
501 |
+
loss = loss_1 + loss_2
|
502 |
+
accelerator.backward(loss)
|
503 |
+
optimizer.step()
|
504 |
+
optimizer.zero_grad()
|
505 |
+
|
506 |
+
index_no_updates = torch.ones((len(self.tokenizer),), dtype=torch.bool)
|
507 |
+
index_no_updates[self.min_added_id : self.max_added_id + 1] = False
|
508 |
+
with torch.no_grad():
|
509 |
+
accelerator.unwrap_model(self.text_encoder).get_input_embeddings().weight[
|
510 |
+
index_no_updates] = orig_embeds_params_1[index_no_updates]
|
511 |
+
index_no_updates = torch.ones((len(self.tokenizer_2),), dtype=torch.bool)
|
512 |
+
index_no_updates[self.min_added_id : self.max_added_id + 1] = False
|
513 |
+
with torch.no_grad():
|
514 |
+
accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight[
|
515 |
+
index_no_updates] = orig_embeds_params_2[index_no_updates]
|
516 |
+
|
517 |
+
logs = {"loss": loss.detach().item(), "lr": embedding_learning_rate}
|
518 |
+
progress_bar.set_postfix(**logs)
|
519 |
+
accelerator.log(logs, step=global_step)
|
520 |
+
if accelerator.sync_gradients:
|
521 |
+
progress_bar.update(1)
|
522 |
+
global_step += 1
|
523 |
+
|
524 |
+
if global_step >= max_emb_train_steps:
|
525 |
+
break
|
526 |
+
accelerator.wait_for_everyone()
|
527 |
+
accelerator.end_training()
|
528 |
+
self.text_encoder = accelerator.unwrap_model(self.text_encoder) .to(dtype = weight_dtype)
|
529 |
+
self.text_encoder_2 = accelerator.unwrap_model(self.text_encoder_2).to(dtype = weight_dtype)
|
530 |
+
|
531 |
+
def train_model_2imgs(
|
532 |
+
self,
|
533 |
+
image_gt_1,
|
534 |
+
image_gt_2,
|
535 |
+
set_string_list_1,
|
536 |
+
set_string_list_2,
|
537 |
+
gradient_accumulation_steps = 5,
|
538 |
+
max_diffusion_train_steps = 100,
|
539 |
+
diffusion_model_learning_rate = 1e-5,
|
540 |
+
train_batch_size = 1,
|
541 |
+
train_full_lora = False,
|
542 |
+
lora_rank = 4,
|
543 |
+
lora_alpha = 4
|
544 |
+
):
|
545 |
+
self.unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet").to(device)
|
546 |
+
self.unet.ca_dim = 2048
|
547 |
+
decom_controller_1 = GroupedCAController(mask_list = self.mask_list)
|
548 |
+
decom_controller_2 = GroupedCAController(mask_list = self.mask_list_2)
|
549 |
+
|
550 |
+
mixed_precision = "fp16"
|
551 |
+
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps,mixed_precision=mixed_precision)
|
552 |
+
|
553 |
+
weight_dtype = torch.float32
|
554 |
+
if accelerator.mixed_precision == "fp16":
|
555 |
+
weight_dtype = torch.float16
|
556 |
+
elif accelerator.mixed_precision == "bf16":
|
557 |
+
weight_dtype = torch.bfloat16
|
558 |
+
|
559 |
+
|
560 |
+
self.vae.requires_grad_(False)
|
561 |
+
self.vae.to(device, dtype=weight_dtype)
|
562 |
+
self.unet.requires_grad_(False)
|
563 |
+
self.unet.train()
|
564 |
+
|
565 |
+
self.text_encoder.requires_grad_(False)
|
566 |
+
self.text_encoder_2.requires_grad_(False)
|
567 |
+
if not train_full_lora:
|
568 |
+
trainable_params_list = []
|
569 |
+
for name, module in self.unet.named_modules():
|
570 |
+
module_name = type(module).__name__
|
571 |
+
if module_name == "Attention":
|
572 |
+
if module.to_k.in_features == 2048: # this is cross attention:
|
573 |
+
module.to_k.weight.requires_grad = True
|
574 |
+
trainable_params_list.append(module.to_k.weight)
|
575 |
+
if module.to_k.bias is not None:
|
576 |
+
module.to_k.bias.requires_grad = True
|
577 |
+
trainable_params_list.append(module.to_k.bias)
|
578 |
+
|
579 |
+
module.to_v.weight.requires_grad = True
|
580 |
+
trainable_params_list.append(module.to_v.weight)
|
581 |
+
if module.to_v.bias is not None:
|
582 |
+
module.to_v.bias.requires_grad = True
|
583 |
+
trainable_params_list.append(module.to_v.bias)
|
584 |
+
module.to_q.weight.requires_grad = True
|
585 |
+
trainable_params_list.append(module.to_q.weight)
|
586 |
+
if module.to_q.bias is not None:
|
587 |
+
module.to_q.bias.requires_grad = True
|
588 |
+
trainable_params_list.append(module.to_q.bias)
|
589 |
+
else:
|
590 |
+
unet_lora_config = LoraConfig(
|
591 |
+
r = lora_rank,
|
592 |
+
lora_alpha = lora_alpha,
|
593 |
+
init_lora_weights="gaussian",
|
594 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
595 |
+
)
|
596 |
+
self.unet.add_adapter(unet_lora_config)
|
597 |
+
print("training full parameters using lora!")
|
598 |
+
trainable_params_list = list(filter(lambda p: p.requires_grad, self.unet.parameters()))
|
599 |
+
|
600 |
+
self.text_encoder.to(device, dtype=weight_dtype)
|
601 |
+
self.text_encoder_2.to(device, dtype=weight_dtype)
|
602 |
+
optimizer = torch.optim.AdamW(trainable_params_list, lr=diffusion_model_learning_rate)
|
603 |
+
self.unet, optimizer = accelerator.prepare(self.unet, optimizer)
|
604 |
+
psum2 = sum(p.numel() for p in trainable_params_list)
|
605 |
+
|
606 |
+
effective_diffusion_train_steps = max_diffusion_train_steps // gradient_accumulation_steps
|
607 |
+
if accelerator.is_main_process:
|
608 |
+
accelerator.init_trackers("ModelFt", config={
|
609 |
+
"diffusion_model_learning_rate": diffusion_model_learning_rate,
|
610 |
+
"diffusion_model_optimization_steps": effective_diffusion_train_steps,
|
611 |
+
})
|
612 |
+
|
613 |
+
global_step = 0
|
614 |
+
progress_bar = tqdm(range(0, effective_diffusion_train_steps),initial=global_step, desc="ModelSteps")
|
615 |
+
noise_scheduler = DDPMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
|
616 |
+
|
617 |
+
latents0_1 = image2latent(image_gt_1, vae = self.vae, dtype=weight_dtype)
|
618 |
+
latents0_1 = latents0_1.repeat(train_batch_size, 1, 1, 1)
|
619 |
+
|
620 |
+
latents0_2 = image2latent(image_gt_2, vae = self.vae, dtype=weight_dtype)
|
621 |
+
latents0_2 = latents0_2.repeat(train_batch_size,1, 1, 1)
|
622 |
+
|
623 |
+
with torch.no_grad():
|
624 |
+
encoder_hidden_states_list_1, add_text_embeds_1, add_time_ids_1 = sdxl_prepare_input_decom(
|
625 |
+
set_string_list_1,
|
626 |
+
self.tokenizer,
|
627 |
+
self.tokenizer_2,
|
628 |
+
self.text_encoder,
|
629 |
+
self.text_encoder_2,
|
630 |
+
length = max_length,
|
631 |
+
bsz = train_batch_size,
|
632 |
+
weight_dtype = weight_dtype
|
633 |
+
)
|
634 |
+
encoder_hidden_states_list_2, add_text_embeds_2, add_time_ids_2 = sdxl_prepare_input_decom(
|
635 |
+
set_string_list_2,
|
636 |
+
self.tokenizer,
|
637 |
+
self.tokenizer_2,
|
638 |
+
self.text_encoder,
|
639 |
+
self.text_encoder_2,
|
640 |
+
length = max_length,
|
641 |
+
bsz = train_batch_size,
|
642 |
+
weight_dtype = weight_dtype
|
643 |
+
)
|
644 |
+
|
645 |
+
for _ in range(max_diffusion_train_steps):
|
646 |
+
with accelerator.accumulate(self.unet):
|
647 |
+
latents_1 = latents0_1.clone().detach()
|
648 |
+
noise_1 = torch.randn_like(latents_1)
|
649 |
+
bsz = latents_1.shape[0]
|
650 |
+
timesteps_1 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_1.device)
|
651 |
+
timesteps_1 = timesteps_1.long()
|
652 |
+
noisy_latents_1 = noise_scheduler.add_noise(latents_1, noise_1, timesteps_1)
|
653 |
+
|
654 |
+
latents_2 = latents0_2.clone().detach()
|
655 |
+
noise_2 = torch.randn_like(latents_2)
|
656 |
+
bsz = latents_2.shape[0]
|
657 |
+
timesteps_2 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_2.device)
|
658 |
+
timesteps_2 = timesteps_2.long()
|
659 |
+
noisy_latents_2 = noise_scheduler.add_noise(latents_2, noise_2, timesteps_2)
|
660 |
+
|
661 |
+
register_attention_disentangled_control(self.unet, decom_controller_1)
|
662 |
+
model_pred_1 = self.unet(
|
663 |
+
noisy_latents_1,
|
664 |
+
timesteps_1,
|
665 |
+
encoder_hidden_states = encoder_hidden_states_list_1,
|
666 |
+
cross_attention_kwargs = None,
|
667 |
+
return_dict = False,
|
668 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds_1, "time_ids": add_time_ids_1}
|
669 |
+
)[0]
|
670 |
+
|
671 |
+
register_attention_disentangled_control(self.unet, decom_controller_2)
|
672 |
+
model_pred_2 = self.unet(
|
673 |
+
noisy_latents_2,
|
674 |
+
timesteps_2,
|
675 |
+
encoder_hidden_states = encoder_hidden_states_list_2,
|
676 |
+
cross_attention_kwargs = None,
|
677 |
+
return_dict=False,
|
678 |
+
added_cond_kwargs={"text_embeds": add_text_embeds_2, "time_ids": add_time_ids_2}
|
679 |
+
)[0]
|
680 |
+
|
681 |
+
loss_1 = F.mse_loss(model_pred_1.float(), noise_1.float(), reduction="mean")
|
682 |
+
loss_2 = F.mse_loss(model_pred_2.float(), noise_2.float(), reduction="mean")
|
683 |
+
loss = loss_1 + loss_2
|
684 |
+
accelerator.backward(loss)
|
685 |
+
optimizer.step()
|
686 |
+
optimizer.zero_grad()
|
687 |
+
|
688 |
+
|
689 |
+
logs = {"loss": loss.detach().item(), "lr": diffusion_model_learning_rate}
|
690 |
+
progress_bar.set_postfix(**logs)
|
691 |
+
accelerator.log(logs, step=global_step)
|
692 |
+
if accelerator.sync_gradients:
|
693 |
+
progress_bar.update(1)
|
694 |
+
global_step += 1
|
695 |
+
|
696 |
+
if global_step >=max_diffusion_train_steps:
|
697 |
+
break
|
698 |
+
accelerator.wait_for_everyone()
|
699 |
+
accelerator.end_training()
|
700 |
+
self.unet = accelerator.unwrap_model(self.unet).to(dtype = weight_dtype)
|
701 |
+
|
702 |
+
@torch.no_grad()
|
703 |
+
def backward_zT_to_z0_euler_decom(
|
704 |
+
self,
|
705 |
+
zT,
|
706 |
+
cond_emb_list,
|
707 |
+
cond_add_text_embeds,
|
708 |
+
add_time_ids,
|
709 |
+
uncond_emb=None,
|
710 |
+
guidance_scale = 1,
|
711 |
+
num_sampling_steps = 20,
|
712 |
+
cond_controller = None,
|
713 |
+
uncond_controller = None,
|
714 |
+
mask_hard = None,
|
715 |
+
mask_soft = None,
|
716 |
+
orig_image = None,
|
717 |
+
return_intermediate = False,
|
718 |
+
strength = 1
|
719 |
+
):
|
720 |
+
latent_cur = zT
|
721 |
+
if uncond_emb is None:
|
722 |
+
uncond_emb = torch.zeros(zT.shape[0], 77, 2048).to(dtype = zT.dtype, device = zT.device)
|
723 |
+
uncond_add_text_embeds = torch.zeros(1, 1280).to(dtype = zT.dtype, device = zT.device)
|
724 |
+
if mask_soft is not None:
|
725 |
+
init_latents_orig = image2latent(orig_image, self.vae, dtype=self.vae.dtype)
|
726 |
+
length = init_latents_orig.shape[-1]
|
727 |
+
noise = torch.randn_like(init_latents_orig)
|
728 |
+
mask_soft = torch.nn.functional.interpolate(mask_soft.float().unsqueeze(0).unsqueeze(0), (length, length)).to(self.vae.dtype) ###
|
729 |
+
if mask_hard is not None:
|
730 |
+
init_latents_orig = image2latent(orig_image, self.vae, dtype=self.vae.dtype)
|
731 |
+
length = init_latents_orig.shape[-1]
|
732 |
+
noise = torch.randn_like(init_latents_orig)
|
733 |
+
mask_hard = torch.nn.functional.interpolate(mask_hard.float().unsqueeze(0).unsqueeze(0), (length, length)).to(self.vae.dtype) ###
|
734 |
+
|
735 |
+
intermediate_list = [latent_cur.detach()]
|
736 |
+
for i in tqdm(range(num_sampling_steps)):
|
737 |
+
t = self.scheduler.timesteps[i]
|
738 |
+
latent_input = self.scheduler.scale_model_input(latent_cur, t)
|
739 |
+
|
740 |
+
register_attention_disentangled_control(self.unet, uncond_controller)
|
741 |
+
noise_pred_uncond = self.unet(latent_input, t,
|
742 |
+
encoder_hidden_states=uncond_emb,
|
743 |
+
added_cond_kwargs={"text_embeds": uncond_add_text_embeds, "time_ids": add_time_ids},
|
744 |
+
return_dict=False,)[0]
|
745 |
+
|
746 |
+
register_attention_disentangled_control(self.unet, cond_controller)
|
747 |
+
noise_pred_cond = self.unet(latent_input, t,
|
748 |
+
encoder_hidden_states=cond_emb_list,
|
749 |
+
added_cond_kwargs={"text_embeds": cond_add_text_embeds, "time_ids": add_time_ids},
|
750 |
+
return_dict=False,)[0]
|
751 |
+
|
752 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
753 |
+
latent_cur = self.scheduler.step(noise_pred, t, latent_cur, generator = None, return_dict=False)[0]
|
754 |
+
if return_intermediate is True:
|
755 |
+
intermediate_list.append(latent_cur)
|
756 |
+
if mask_hard is not None and mask_soft is not None and i <= strength *num_sampling_steps:
|
757 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
758 |
+
mask = mask_soft.to(latent_cur.device, latent_cur.dtype) + mask_hard.to(latent_cur.device, latent_cur.dtype)
|
759 |
+
latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
|
760 |
+
|
761 |
+
elif mask_hard is not None and mask_soft is not None and i > strength *num_sampling_steps:
|
762 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
763 |
+
mask = mask_hard.to(latent_cur.device, latent_cur.dtype)
|
764 |
+
latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
|
765 |
+
|
766 |
+
elif mask_hard is None and mask_soft is not None and i <= strength *num_sampling_steps:
|
767 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
768 |
+
mask = mask_soft.to(latent_cur.device, latent_cur.dtype)
|
769 |
+
latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
|
770 |
+
|
771 |
+
elif mask_hard is None and mask_soft is not None and i > strength *num_sampling_steps:
|
772 |
+
pass
|
773 |
+
|
774 |
+
elif mask_hard is not None and mask_soft is None:
|
775 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
776 |
+
mask = mask_hard.to(latent_cur.dtype)
|
777 |
+
latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
|
778 |
+
|
779 |
+
else: # hard and soft are both none
|
780 |
+
pass
|
781 |
+
|
782 |
+
if return_intermediate is True:
|
783 |
+
return latent_cur, intermediate_list
|
784 |
+
else:
|
785 |
+
return latent_cur
|
786 |
+
|
787 |
+
@torch.no_grad()
|
788 |
+
def sampling(
|
789 |
+
self,
|
790 |
+
set_string_list,
|
791 |
+
cond_controller = None,
|
792 |
+
uncond_controller = None,
|
793 |
+
guidance_scale = 7,
|
794 |
+
num_sampling_steps = 20,
|
795 |
+
mask_hard = None,
|
796 |
+
mask_soft = None,
|
797 |
+
orig_image = None,
|
798 |
+
strength = 1.,
|
799 |
+
num_imgs = 1,
|
800 |
+
normal_token_id_list = [],
|
801 |
+
seed = 1
|
802 |
+
):
|
803 |
+
weight_dtype = torch.float16
|
804 |
+
self.scheduler.set_timesteps(num_sampling_steps)
|
805 |
+
self.unet.to(device, dtype=weight_dtype)
|
806 |
+
self.vae.to(device, dtype=weight_dtype)
|
807 |
+
self.text_encoder.to(device, dtype=weight_dtype)
|
808 |
+
self.text_encoder_2.to(device, dtype=weight_dtype)
|
809 |
+
torch.manual_seed(seed)
|
810 |
+
torch.cuda.manual_seed(seed)
|
811 |
+
|
812 |
+
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
813 |
+
zT = torch.randn(num_imgs, 4, self.resolution//vae_scale_factor,self.resolution//vae_scale_factor).to(device,dtype=weight_dtype)
|
814 |
+
zT = zT * self.scheduler.init_noise_sigma
|
815 |
+
|
816 |
+
cond_emb_list, cond_add_text_embeds, add_time_ids = sdxl_prepare_input_decom(
|
817 |
+
set_string_list,
|
818 |
+
self.tokenizer,
|
819 |
+
self.tokenizer_2,
|
820 |
+
self.text_encoder,
|
821 |
+
self.text_encoder_2,
|
822 |
+
length = max_length,
|
823 |
+
bsz = num_imgs,
|
824 |
+
weight_dtype = weight_dtype,
|
825 |
+
normal_token_id_list = normal_token_id_list
|
826 |
+
)
|
827 |
+
|
828 |
+
z0 = self.backward_zT_to_z0_euler_decom(zT, cond_emb_list, cond_add_text_embeds, add_time_ids,
|
829 |
+
guidance_scale = guidance_scale, num_sampling_steps = num_sampling_steps,
|
830 |
+
cond_controller = cond_controller, uncond_controller = uncond_controller,
|
831 |
+
mask_hard = mask_hard, mask_soft = mask_soft, orig_image =orig_image, strength = strength
|
832 |
+
)
|
833 |
+
x0 = latent2image(z0, vae = self.vae)
|
834 |
+
return x0
|
835 |
+
|
836 |
+
@torch.no_grad()
|
837 |
+
def inference_with_mask(
|
838 |
+
self,
|
839 |
+
save_path,
|
840 |
+
guidance_scale = 3,
|
841 |
+
num_sampling_steps = 50,
|
842 |
+
strength = 1,
|
843 |
+
mask_soft = None,
|
844 |
+
mask_hard= None,
|
845 |
+
orig_image=None,
|
846 |
+
mask_list = None,
|
847 |
+
num_imgs = 1,
|
848 |
+
seed = 1,
|
849 |
+
set_string_list = None
|
850 |
+
):
|
851 |
+
if mask_list is not None:
|
852 |
+
mask_list = [m.to(device) for m in mask_list]
|
853 |
+
else:
|
854 |
+
mask_list = self.mask_list
|
855 |
+
if set_string_list is not None:
|
856 |
+
self.set_string_list = set_string_list
|
857 |
+
|
858 |
+
if mask_hard is not None and mask_soft is not None:
|
859 |
+
check_mask_overlap_torch(mask_hard, mask_soft)
|
860 |
+
null_controller = DummyController()
|
861 |
+
decom_controller = GroupedCAController(mask_list = mask_list)
|
862 |
+
x0 = self.sampling(
|
863 |
+
self.set_string_list,
|
864 |
+
guidance_scale = guidance_scale,
|
865 |
+
num_sampling_steps = num_sampling_steps,
|
866 |
+
strength = strength,
|
867 |
+
cond_controller = decom_controller,
|
868 |
+
uncond_controller = null_controller,
|
869 |
+
mask_soft = mask_soft,
|
870 |
+
mask_hard = mask_hard,
|
871 |
+
orig_image = orig_image,
|
872 |
+
num_imgs = num_imgs,
|
873 |
+
seed = seed
|
874 |
+
)
|
875 |
+
save_images(x0, save_path)
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.2.0
|
2 |
+
torchvision==0.17.0
|
3 |
+
transformers==4.37.2
|
4 |
+
accelerate==0.23.0
|
5 |
+
gradio==3.41.1
|
6 |
+
xformers==0.0.24
|
7 |
+
diffusers==0.26.3
|
8 |
+
scipy
|
9 |
+
tqdm
|
10 |
+
numpy
|
11 |
+
safetensors
|
12 |
+
peft
|
scripts/run_segment.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
# python segment.py --name=$IMAGE_NAME --size=512
|
3 |
+
python segment.py --name=$IMAGE_NAME --size=1024
|
scripts/run_segmentSAM.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
python segment_sam.py --name=$IMAGE_NAME --text_prompt="bag"
|
scripts/sd/run_ft_sd_512.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
python main.py --name=$IMAGE_NAME \
|
3 |
+
--dpm="sd" \
|
4 |
+
--resolution=512 \
|
5 |
+
--num_tokens=5 \
|
6 |
+
--embedding_learning_rate=1e-4 \
|
7 |
+
--diffusion_model_learning_rate=5e-5 \
|
8 |
+
--max_emb_train_steps=500 \
|
9 |
+
--max_diffusion_train_steps=500 \
|
10 |
+
--train_batch_size=5 \
|
11 |
+
--gradient_accumulation_steps=5
|
scripts/sd/run_ft_sd_512_2imgs.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
export IMAGE_NAME_2="example2"
|
3 |
+
python main.py --name=$IMAGE_NAME \
|
4 |
+
--dpm="sd" \
|
5 |
+
--resolution=512 \
|
6 |
+
--image \
|
7 |
+
--name_2=$IMAGE_NAME_2 \
|
8 |
+
--embedding_learning_rate=1e-4 \
|
9 |
+
--diffusion_model_learning_rate=5e-5 \
|
10 |
+
--max_emb_train_steps=500 \
|
11 |
+
--max_diffusion_train_steps=500 \
|
12 |
+
--train_batch_size=5 \
|
13 |
+
--gradient_accumulation_steps=5
|
14 |
+
|
scripts/sd/run_image.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
export IMAGE_NAME_2="example2"
|
3 |
+
python main.py --name=$IMAGE_NAME \
|
4 |
+
--name_2=$IMAGE_NAME_2 \
|
5 |
+
--dpm="sd" \
|
6 |
+
--resolution=512 \
|
7 |
+
--image \
|
8 |
+
--load_trained \
|
9 |
+
--guidance_scale=2 \
|
10 |
+
--num_imgs=2 \
|
11 |
+
--seed=2024 \
|
12 |
+
--strength=0.5 \
|
13 |
+
--edge_thickness=10 \
|
14 |
+
--src_index=1 --tgt_index=0 \
|
15 |
+
--tgt_name=$IMAGE_NAME
|
scripts/sd/run_move_resize.sh
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
CUDA_VISIBLE_DEVICES=1 python main.py --name=$IMAGE_NAME \
|
3 |
+
--dpm="sd" \
|
4 |
+
--resolution=512 \
|
5 |
+
--num_tokens=5 \
|
6 |
+
--load_trained \
|
7 |
+
--move_resize \
|
8 |
+
--seed=2023 \
|
9 |
+
--num_sampling_step=50 \
|
10 |
+
--strength=0.6 \
|
11 |
+
--edge_thickness=10 \
|
12 |
+
--guidance_scale=2 \
|
13 |
+
--num_imgs=1 \
|
14 |
+
--tgt_indices_list 0 \
|
15 |
+
--active_mask_list 2 \
|
16 |
+
--delta_x 100 --delta_y 60 \
|
17 |
+
--resize_list 0.6 \
|
18 |
+
--priority_list 1
|
scripts/sd/run_recon.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
python main.py --name=$IMAGE_NAME \
|
3 |
+
--dpm="sd" \
|
4 |
+
--resolution=512 \
|
5 |
+
--num_tokens=5 \
|
6 |
+
--load_trained \
|
7 |
+
--recon \
|
8 |
+
--seed=2024 \
|
9 |
+
--guidance_scale=2 \
|
10 |
+
--num_sampling_step=20 \
|
11 |
+
--num_imgs=1 \
|
scripts/sd/run_remove.sh
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# export IMAGE_NAME="example1"
|
2 |
+
# CUDA_VISIBLE_DEVICES=1 python main.py --name=$IMAGE_NAME \
|
3 |
+
# --dpm="sd" \
|
4 |
+
# --resolution=512 \
|
5 |
+
# --num_tokens=5 \
|
6 |
+
# --load_trained \
|
7 |
+
# --load_edited_mask \
|
8 |
+
# --remove \
|
9 |
+
# --seed=2023 \
|
10 |
+
# --num_sampling_step=50 \
|
11 |
+
# --strength=0.7 \
|
12 |
+
# --edge_thickness=10 \
|
13 |
+
# --guidance_scale=2 \
|
14 |
+
# --num_imgs=1 \
|
15 |
+
# --tgt_index=0
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
# export IMAGE_NAME="example1"
|
20 |
+
# CUDA_VISIBLE_DEVICES=1 python main.py --name=$IMAGE_NAME \
|
21 |
+
# --dpm="sd" \
|
22 |
+
# --resolution=512 \
|
23 |
+
# --num_tokens=5 \
|
24 |
+
# --load_trained \
|
25 |
+
# --load_edited_processed_mask \
|
26 |
+
# --remove \
|
27 |
+
# --seed=2024 \
|
28 |
+
# --num_sampling_step=50 \
|
29 |
+
# --strength=0.5 \
|
30 |
+
# --edge_thickness=10 \
|
31 |
+
# --guidance_scale=7 \
|
32 |
+
# --num_imgs=1 \
|
33 |
+
# --tgt_index=2
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
export IMAGE_NAME="example1"
|
39 |
+
CUDA_VISIBLE_DEVICES=1 python main.py --name=$IMAGE_NAME \
|
40 |
+
--dpm="sd" \
|
41 |
+
--resolution=512 \
|
42 |
+
--num_tokens=5 \
|
43 |
+
--load_trained \
|
44 |
+
--load_edited_processed_mask \
|
45 |
+
--remove \
|
46 |
+
--seed=1 \
|
47 |
+
--num_sampling_step=50 \
|
48 |
+
--strength=0.6 \
|
49 |
+
--edge_thickness=10 \
|
50 |
+
--guidance_scale=7 \
|
51 |
+
--num_imgs=1 \
|
52 |
+
--tgt_index=2
|
53 |
+
|
54 |
+
|
55 |
+
|
scripts/sd/run_text.sh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# export IMAGE_NAME="example1"
|
2 |
+
# python main.py --name=$IMAGE_NAME \
|
3 |
+
# --dpm="sd" \
|
4 |
+
# --resolution=512 \
|
5 |
+
# --load_trained \
|
6 |
+
# --text \
|
7 |
+
# --num_tokens=5 \
|
8 |
+
# --seed=2024 \
|
9 |
+
# --guidance_scale=7 \
|
10 |
+
# --num_sampling_step=50 \
|
11 |
+
# --strength=0.7 \
|
12 |
+
# --edge_thickness=15 \
|
13 |
+
# --num_imgs=1 \
|
14 |
+
# --tgt_prompt="a red bag" \
|
15 |
+
# --tgt_index=0
|
16 |
+
|
17 |
+
|
18 |
+
export IMAGE_NAME="example1"
|
19 |
+
python main.py --name=$IMAGE_NAME \
|
20 |
+
--dpm="sd" \
|
21 |
+
--resolution=512 \
|
22 |
+
--load_trained \
|
23 |
+
--text \
|
24 |
+
--num_tokens=5 \
|
25 |
+
--seed=2024 \
|
26 |
+
--guidance_scale=6 \
|
27 |
+
--num_sampling_step=50 \
|
28 |
+
--strength=0.5 \
|
29 |
+
--edge_thickness=15 \
|
30 |
+
--num_imgs=2 \
|
31 |
+
--tgt_prompt="a black bag" \
|
32 |
+
--tgt_index=0
|
33 |
+
|
scripts/sdxl/run_ft_sdxl_1024.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
python main.py --name=$IMAGE_NAME \
|
3 |
+
--dpm="sdxl" \
|
4 |
+
--resolution=1024 \
|
5 |
+
--num_tokens=5 \
|
6 |
+
--embedding_learning_rate=1e-4 \
|
7 |
+
--diffusion_model_learning_rate=5e-5 \
|
8 |
+
--max_emb_train_steps=500 \
|
9 |
+
--max_diffusion_train_steps=500 \
|
10 |
+
--train_batch_size=2 \
|
11 |
+
--gradient_accumulation_steps=5
|
12 |
+
|
scripts/sdxl/run_ft_sdxl_1024_2imgs.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
export IMAGE_NAME_2="example2"
|
3 |
+
python main.py --name=$IMAGE_NAME \
|
4 |
+
--dpm="sdxl" \
|
5 |
+
--image \
|
6 |
+
--name_2=$IMAGE_NAME_2 \
|
7 |
+
--resolution=1024 \
|
8 |
+
--embedding_learning_rate=1e-4 \
|
9 |
+
--diffusion_model_learning_rate=5e-5 \
|
10 |
+
--max_emb_train_steps=500 \
|
11 |
+
--max_diffusion_train_steps=500 \
|
12 |
+
--train_batch_size=1 \
|
13 |
+
--gradient_accumulation_steps=5
|
14 |
+
|
scripts/sdxl/run_ft_sdxl_1024_auxin_todo.sh
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
python main.py --name=$IMAGE_NAME \
|
3 |
+
--dpm="sdxl" \
|
4 |
+
--resolution=1024 \
|
5 |
+
--train_full_lora \
|
6 |
+
--embedding_learning_rate=1e-4 \
|
7 |
+
--diffusion_model_learning_rate=1e-3 \
|
8 |
+
--max_emb_train_steps=500 \
|
9 |
+
--max_diffusion_train_steps=500 \
|
10 |
+
--train_batch_size=2 \
|
11 |
+
--gradient_accumulation_steps=5 \
|
12 |
+
--prompt_auxin_idx_list 0 2 \
|
13 |
+
--prompt_auxin_list "a photo of * handbag" "a photo of * model"
|
14 |
+
|
15 |
+
export IMAGE_NAME="example1"
|
16 |
+
python main.py --name=$IMAGE_NAME \
|
17 |
+
--dpm="sdxl" \
|
18 |
+
--resolution=1024 \
|
19 |
+
--train_full_lora \
|
20 |
+
--load_trained \
|
21 |
+
--recon \
|
22 |
+
--seed=23 \
|
23 |
+
--guidance_scale=7 \
|
24 |
+
--num_sampling_step=20 \
|
25 |
+
--num_imgs=2 \
|
26 |
+
--prompt_auxin_idx_list 0 2 \
|
27 |
+
--prompt_auxin_list "a photo of * handbag" "a photo of * model"
|
28 |
+
|
29 |
+
|
scripts/sdxl/run_ft_sdxl_1024_fulllora.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
python main.py --name=$IMAGE_NAME \
|
3 |
+
--dpm="sdxl" \
|
4 |
+
--train_full_lora \
|
5 |
+
--resolution=1024 \
|
6 |
+
--embedding_learning_rate=1e-4 \
|
7 |
+
--diffusion_model_learning_rate=1e-4 \
|
8 |
+
--max_emb_train_steps=500 \
|
9 |
+
--max_diffusion_train_steps=500 \
|
10 |
+
--train_batch_size=2 \
|
11 |
+
--gradient_accumulation_steps=5
|
12 |
+
|
scripts/sdxl/run_ft_sdxl_1024_fulllora_2imgs.sh
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
export IMAGE_NAME_2="example2"
|
3 |
+
# python main.py --name=$IMAGE_NAME \
|
4 |
+
# --dpm="sdxl" \
|
5 |
+
# --image \
|
6 |
+
# --name_2=$IMAGE_NAME_2 \
|
7 |
+
# --resolution=1024 \
|
8 |
+
# --embedding_learning_rate=1e-4 \
|
9 |
+
# --diffusion_model_learning_rate=5e-5 \
|
10 |
+
# --max_emb_train_steps=500 \
|
11 |
+
# --max_diffusion_train_steps=500 \
|
12 |
+
# --train_batch_size=1 \
|
13 |
+
# --gradient_accumulation_steps=5
|
14 |
+
|
15 |
+
python main.py --name=$IMAGE_NAME \
|
16 |
+
--dpm="sdxl" \
|
17 |
+
--image \
|
18 |
+
--train_full_lora \
|
19 |
+
--name_2=$IMAGE_NAME_2 \
|
20 |
+
--resolution=1024 \
|
21 |
+
--embedding_learning_rate=1e-4 \
|
22 |
+
--diffusion_model_learning_rate=5e-4 \
|
23 |
+
--max_emb_train_steps=500 \
|
24 |
+
--max_diffusion_train_steps=500 \
|
25 |
+
--train_batch_size=1 \
|
26 |
+
--gradient_accumulation_steps=5
|
27 |
+
|
28 |
+
|
29 |
+
# python main.py --load_trained \
|
30 |
+
# --dpm="sdxl" \
|
31 |
+
# --image \
|
32 |
+
# --name=$IMAGE_NAME \
|
33 |
+
# --name_2=$IMAGE_NAME_2 \
|
34 |
+
# --tgt_name=$IMAGE_NAME \
|
35 |
+
# --guidance_scale 2.5 \
|
36 |
+
# --edge_thickness 40 \
|
37 |
+
# --strength 0.5 \
|
38 |
+
# --seed 29 \
|
39 |
+
# --num_imgs 4 \
|
40 |
+
# --tgt_index=0 \
|
41 |
+
# --src_index=2
|
scripts/sdxl/run_image.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
export IMAGE_NAME_2="example2"
|
3 |
+
python main.py --name=$IMAGE_NAME \
|
4 |
+
--name_2=$IMAGE_NAME_2 \
|
5 |
+
--dpm="sdxl" \
|
6 |
+
--image \
|
7 |
+
--load_trained \
|
8 |
+
--resolution=1024 \
|
9 |
+
--guidance_scale=2.8 \
|
10 |
+
--num_imgs=2 \
|
11 |
+
--seed=2023 \
|
12 |
+
--strength=0.5 \
|
13 |
+
--edge_thickness=20 \
|
14 |
+
--src_index=2 --tgt_index=0 \
|
15 |
+
--tgt_name=$IMAGE_NAME
|
scripts/sdxl/run_image_w_edited_mask.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
export IMAGE_NAME_2="example2"
|
3 |
+
python main.py --name=$IMAGE_NAME \
|
4 |
+
--name_2=$IMAGE_NAME_2 \
|
5 |
+
--dpm="sdxl" \
|
6 |
+
--image \
|
7 |
+
--load_trained \
|
8 |
+
--load_edited_mask \
|
9 |
+
--resolution=1024 \
|
10 |
+
--guidance_scale=2.8 \
|
11 |
+
--num_imgs=2 \
|
12 |
+
--seed=2023 \
|
13 |
+
--strength=0.5 \
|
14 |
+
--edge_thickness=20 \
|
15 |
+
--src_index=2 --tgt_index=0 \
|
16 |
+
--tgt_name=$IMAGE_NAME
|
scripts/sdxl/run_move_resize.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
CUDA_VISIBLE_DEVICES=1 python main.py --name=$IMAGE_NAME \
|
3 |
+
--dpm="sdxl" \
|
4 |
+
--resolution=1024 \
|
5 |
+
--num_tokens=5 \
|
6 |
+
--load_edited_mask \
|
7 |
+
--load_trained \
|
8 |
+
--move_resize \
|
9 |
+
--seed=2023 \
|
10 |
+
--num_sampling_step=20 \
|
11 |
+
--strength=0.5 \
|
12 |
+
--edge_thickness=20 \
|
13 |
+
--guidance_scale=2.8 \
|
14 |
+
--num_imgs=2 \
|
15 |
+
--tgt_indices_list 0 \
|
16 |
+
--active_mask_list 2 \
|
17 |
+
--delta_x 200 --delta_y 140 \
|
18 |
+
--resize_list 0.5 \
|
19 |
+
--priority_list 1
|
20 |
+
|
21 |
+
# --load_edited_processed_mask
|
scripts/sdxl/run_recon.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
python main.py --name=$IMAGE_NAME \
|
3 |
+
--dpm="sdxl" \
|
4 |
+
--resolution=1024 \
|
5 |
+
--num_tokens=5 \
|
6 |
+
--load_trained \
|
7 |
+
--recon \
|
8 |
+
--seed=20 \
|
9 |
+
--guidance_scale=3 \
|
10 |
+
--num_sampling_step=20 \
|
11 |
+
--num_imgs=2 \
|
scripts/sdxl/run_recon_item_todo.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# export IMAGE_NAME="example1"
|
2 |
+
# python main.py --name=$IMAGE_NAME \
|
3 |
+
# --dpm="sdxl" \
|
4 |
+
# --resolution=1024 \
|
5 |
+
# --load_trained \
|
6 |
+
# --recon \
|
7 |
+
# --recon_an_item \
|
8 |
+
# --seed=23 \
|
9 |
+
# --guidance_scale=6 \
|
10 |
+
# --num_sampling_step=20 \
|
11 |
+
# --num_imgs=2 \
|
12 |
+
# --tgt_index=0 \
|
13 |
+
# --recon_prompt="a photo of a * handbag on a table"
|
14 |
+
|
15 |
+
export IMAGE_NAME="example1"
|
16 |
+
python main.py --name=$IMAGE_NAME \
|
17 |
+
--dpm="sdxl" \
|
18 |
+
--resolution=1024 \
|
19 |
+
--load_trained \
|
20 |
+
--recon \
|
21 |
+
--recon_an_item \
|
22 |
+
--seed=23 \
|
23 |
+
--guidance_scale=6 \
|
24 |
+
--num_sampling_step=20 \
|
25 |
+
--num_imgs=2 \
|
26 |
+
--tgt_index=2 \
|
27 |
+
--recon_prompt="a photo of a * model on a chair"
|
scripts/sdxl/run_remove.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
python main.py --name=$IMAGE_NAME \
|
3 |
+
--dpm="sdxl" \
|
4 |
+
--resolution=1024 \
|
5 |
+
--num_tokens=5 \
|
6 |
+
--load_edited_mask \
|
7 |
+
--load_trained \
|
8 |
+
--remove \
|
9 |
+
--seed=0 \
|
10 |
+
--num_sampling_step=20 \
|
11 |
+
--strength=0.4 \
|
12 |
+
--edge_thickness=20 \
|
13 |
+
--guidance_scale=3 \
|
14 |
+
--num_imgs=1 \
|
15 |
+
--tgt_index=0
|
16 |
+
|
17 |
+
# --load_edited_processed_mask
|
scripts/sdxl/run_text.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
python main.py --name=$IMAGE_NAME \
|
3 |
+
--dpm="sdxl" \
|
4 |
+
--resolution=1024 \
|
5 |
+
--load_trained \
|
6 |
+
--num_tokens=5 \
|
7 |
+
--text \
|
8 |
+
--seed=23 \
|
9 |
+
--num_sampling_step=20 \
|
10 |
+
--strength=0.6 \
|
11 |
+
--edge_thickness=30 \
|
12 |
+
--num_imgs=2 \
|
13 |
+
--tgt_prompt="a white handbag" \
|
14 |
+
--tgt_index=0
|
scripts/sdxl/run_text_w_edited_mask.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export IMAGE_NAME="example1"
|
2 |
+
python main.py --name=$IMAGE_NAME \
|
3 |
+
--dpm="sdxl" \
|
4 |
+
--resolution=1024 \
|
5 |
+
--num_tokens=5 \
|
6 |
+
--load_trained \
|
7 |
+
--load_edited_mask \
|
8 |
+
--text \
|
9 |
+
--seed=23 \
|
10 |
+
--num_sampling_step=50 \
|
11 |
+
--strength=0.7 \
|
12 |
+
--edge_thickness=30 \
|
13 |
+
--num_imgs=2 \
|
14 |
+
--tgt_prompt="a white handbag" \
|
15 |
+
--tgt_index=0
|
segment.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
from collections import defaultdict
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from matplotlib import cm
|
8 |
+
import matplotlib.patches as mpatches
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import argparse
|
12 |
+
import matplotlib
|
13 |
+
|
14 |
+
def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
|
15 |
+
if type(image_path) is str:
|
16 |
+
image = np.array(Image.open(image_path))[:, :, :3]
|
17 |
+
else:
|
18 |
+
image = image_path
|
19 |
+
h, w, c = image.shape
|
20 |
+
left = min(left, w-1)
|
21 |
+
right = min(right, w - left - 1)
|
22 |
+
top = min(top, h - left - 1)
|
23 |
+
bottom = min(bottom, h - top - 1)
|
24 |
+
image = image[top:h-bottom, left:w-right]
|
25 |
+
h, w, c = image.shape
|
26 |
+
if h < w:
|
27 |
+
offset = (w - h) // 2
|
28 |
+
image = image[:, offset:offset + h]
|
29 |
+
elif w < h:
|
30 |
+
offset = (h - w) // 2
|
31 |
+
image = image[offset:offset + w]
|
32 |
+
image = np.array(Image.fromarray(image).resize((size, size)))
|
33 |
+
return image
|
34 |
+
|
35 |
+
def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, noseg = False):
|
36 |
+
if torch.max(segmentation)==torch.min(segmentation)==-1:
|
37 |
+
print("nothing is detected!")
|
38 |
+
noseg=True
|
39 |
+
viridis = matplotlib.colormaps['viridis'].resampled(1)
|
40 |
+
else:
|
41 |
+
viridis = matplotlib.colormaps['viridis'].resampled(torch.max(segmentation)-torch.min(segmentation)+1)
|
42 |
+
fig, ax = plt.subplots()
|
43 |
+
ax.imshow(segmentation)
|
44 |
+
instances_counter = defaultdict(int)
|
45 |
+
handles = []
|
46 |
+
label_list = []
|
47 |
+
if not noseg:
|
48 |
+
if torch.min(segmentation) == 0:
|
49 |
+
mask = segmentation==0
|
50 |
+
mask = mask.cpu().detach().numpy() # [512,512] bool
|
51 |
+
segment_label = "rest"
|
52 |
+
np.save( os.path.join(save_folder, "mask{}_{}.npy".format(0,"rest")) , mask)
|
53 |
+
color = viridis(0)
|
54 |
+
label = f"{segment_label}-{0}"
|
55 |
+
handles.append(mpatches.Patch(color=color, label=label))
|
56 |
+
label_list.append(label)
|
57 |
+
|
58 |
+
for segment in segments_info:
|
59 |
+
segment_id = segment['id']
|
60 |
+
mask = segmentation==segment_id
|
61 |
+
if torch.min(segmentation) != 0:
|
62 |
+
segment_id -= 1
|
63 |
+
mask = mask.cpu().detach().numpy() # [512,512] bool
|
64 |
+
|
65 |
+
segment_label = model.config.id2label[segment['label_id']]
|
66 |
+
instances_counter[segment['label_id']] += 1
|
67 |
+
np.save( os.path.join(save_folder, "mask{}_{}.npy".format(segment_id,segment_label)) , mask)
|
68 |
+
color = viridis(segment_id)
|
69 |
+
|
70 |
+
label = f"{segment_label}-{segment_id}"
|
71 |
+
handles.append(mpatches.Patch(color=color, label=label))
|
72 |
+
label_list.append(label)
|
73 |
+
else:
|
74 |
+
mask = np.full(segmentation.shape, True)
|
75 |
+
segment_label = "all"
|
76 |
+
np.save( os.path.join(save_folder, "mask{}_{}.npy".format(0,"all")) , mask)
|
77 |
+
color = viridis(0)
|
78 |
+
label = f"{segment_label}-{0}"
|
79 |
+
handles.append(mpatches.Patch(color=color, label=label))
|
80 |
+
label_list.append(label)
|
81 |
+
|
82 |
+
plt.xticks([])
|
83 |
+
plt.yticks([])
|
84 |
+
# plt.savefig(os.path.join(save_folder, 'mask_clear.png'), dpi=500)
|
85 |
+
ax.legend(handles=handles)
|
86 |
+
plt.savefig(os.path.join(save_folder, 'seg_init.png'), dpi=500 )
|
87 |
+
print("; ".join(label_list))
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
parser = argparse.ArgumentParser()
|
92 |
+
parser.add_argument("--name", type=str, default="obama")
|
93 |
+
parser.add_argument("--size", type=int, default=512)
|
94 |
+
parser.add_argument("--noseg", default=False, action="store_true" )
|
95 |
+
args = parser.parse_args()
|
96 |
+
base_folder_path = "."
|
97 |
+
|
98 |
+
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
|
99 |
+
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
|
100 |
+
input_folder = os.path.join(base_folder_path, args.name )
|
101 |
+
try:
|
102 |
+
image = load_image(os.path.join(input_folder, "img.png" ), size = args.size)
|
103 |
+
except:
|
104 |
+
image = load_image(os.path.join(input_folder, "img.jpg" ), size = args.size)
|
105 |
+
|
106 |
+
image =Image.fromarray(image)
|
107 |
+
image.save(os.path.join(input_folder,"img_{}.png".format(args.size)))
|
108 |
+
inputs = processor(image, return_tensors="pt")
|
109 |
+
with torch.no_grad():
|
110 |
+
outputs = model(**inputs)
|
111 |
+
|
112 |
+
panoptic_segmentation = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
113 |
+
save_folder = os.path.join(base_folder_path, args.name)
|
114 |
+
os.makedirs(save_folder, exist_ok=True)
|
115 |
+
draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = args.noseg)
|
segment_sam.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import copy
|
4 |
+
import shutil
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import json
|
8 |
+
import torch
|
9 |
+
from PIL import Image, ImageDraw, ImageFont
|
10 |
+
|
11 |
+
# Grounding DINO
|
12 |
+
import sys
|
13 |
+
|
14 |
+
sys.path.append("/path/to/Grounded-Segment-Anything")
|
15 |
+
# change to your "Grounded-Segment-Anything" installation folder!!!!!
|
16 |
+
import GroundingDINO.groundingdino.datasets.transforms as T
|
17 |
+
from GroundingDINO.groundingdino.models import build_model
|
18 |
+
from GroundingDINO.groundingdino.util import box_ops
|
19 |
+
from GroundingDINO.groundingdino.util.slconfig import SLConfig
|
20 |
+
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
21 |
+
|
22 |
+
# segment anything
|
23 |
+
from segment_anything import (
|
24 |
+
sam_model_registry,
|
25 |
+
sam_hq_model_registry,
|
26 |
+
SamPredictor
|
27 |
+
)
|
28 |
+
import cv2
|
29 |
+
import numpy as np
|
30 |
+
import matplotlib.pyplot as plt
|
31 |
+
def load_image_to_resize(image_path, left=0, right=0, top=0, bottom=0, size = 512):
|
32 |
+
if type(image_path) is str:
|
33 |
+
image = np.array(Image.open(image_path))[:, :, :3]
|
34 |
+
else:
|
35 |
+
image = image_path
|
36 |
+
h, w, c = image.shape
|
37 |
+
left = min(left, w-1)
|
38 |
+
right = min(right, w - left - 1)
|
39 |
+
top = min(top, h - left - 1)
|
40 |
+
bottom = min(bottom, h - top - 1)
|
41 |
+
image = image[top:h-bottom, left:w-right]
|
42 |
+
h, w, c = image.shape
|
43 |
+
if h < w:
|
44 |
+
offset = (w - h) // 2
|
45 |
+
image = image[:, offset:offset + h]
|
46 |
+
elif w < h:
|
47 |
+
offset = (h - w) // 2
|
48 |
+
image = image[offset:offset + w]
|
49 |
+
image = np.array(Image.fromarray(image).resize((size, size)))
|
50 |
+
return image
|
51 |
+
|
52 |
+
|
53 |
+
def load_image(image_path):
|
54 |
+
# load image
|
55 |
+
image_pil = Image.open(image_path).convert("RGB") # load image
|
56 |
+
|
57 |
+
transform = T.Compose(
|
58 |
+
[
|
59 |
+
T.RandomResize([800], max_size=1333),
|
60 |
+
T.ToTensor(),
|
61 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
62 |
+
]
|
63 |
+
)
|
64 |
+
image, _ = transform(image_pil, None) # 3, h, w
|
65 |
+
return image_pil, image
|
66 |
+
|
67 |
+
|
68 |
+
def load_model(model_config_path, model_checkpoint_path, device):
|
69 |
+
args = SLConfig.fromfile(model_config_path)
|
70 |
+
args.device = device
|
71 |
+
model = build_model(args)
|
72 |
+
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
73 |
+
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
74 |
+
model.eval()
|
75 |
+
return model
|
76 |
+
|
77 |
+
|
78 |
+
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
|
79 |
+
caption = caption.lower()
|
80 |
+
caption = caption.strip()
|
81 |
+
if not caption.endswith("."):
|
82 |
+
caption = caption + "."
|
83 |
+
model = model.to(device)
|
84 |
+
image = image.to(device)
|
85 |
+
with torch.no_grad():
|
86 |
+
outputs = model(image[None], captions=[caption])
|
87 |
+
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
88 |
+
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
|
89 |
+
logits.shape[0]
|
90 |
+
|
91 |
+
# filter output
|
92 |
+
logits_filt = logits.clone()
|
93 |
+
boxes_filt = boxes.clone()
|
94 |
+
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
95 |
+
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
96 |
+
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
97 |
+
logits_filt.shape[0]
|
98 |
+
|
99 |
+
# get phrase
|
100 |
+
tokenlizer = model.tokenizer
|
101 |
+
tokenized = tokenlizer(caption)
|
102 |
+
# build pred
|
103 |
+
pred_phrases = []
|
104 |
+
for logit, box in zip(logits_filt, boxes_filt):
|
105 |
+
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
106 |
+
if with_logits:
|
107 |
+
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
108 |
+
else:
|
109 |
+
pred_phrases.append(pred_phrase)
|
110 |
+
|
111 |
+
return boxes_filt, pred_phrases
|
112 |
+
|
113 |
+
def show_mask(mask, ax, random_color=False):
|
114 |
+
if random_color:
|
115 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
116 |
+
else:
|
117 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
118 |
+
h, w = mask.shape[-2:]
|
119 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
120 |
+
ax.imshow(mask_image)
|
121 |
+
|
122 |
+
|
123 |
+
def show_box(box, ax, label):
|
124 |
+
x0, y0 = box[0], box[1]
|
125 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
126 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
127 |
+
ax.text(x0, y0, label)
|
128 |
+
|
129 |
+
|
130 |
+
def save_mask_data(output_dir, mask_list, box_list, label_list):
|
131 |
+
value = 0 # 0 for background
|
132 |
+
|
133 |
+
mask_img = torch.zeros(mask_list.shape[-2:])
|
134 |
+
for idx, mask in enumerate(mask_list):
|
135 |
+
mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
|
136 |
+
plt.figure(figsize=(10, 10))
|
137 |
+
plt.imshow(mask_img.numpy())
|
138 |
+
plt.axis('off')
|
139 |
+
plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)
|
140 |
+
|
141 |
+
json_data = [{
|
142 |
+
'value': value,
|
143 |
+
'label': 'background'
|
144 |
+
}]
|
145 |
+
for label, box in zip(label_list, box_list):
|
146 |
+
value += 1
|
147 |
+
name, logit = label.split('(')
|
148 |
+
logit = logit[:-1] # the last is ')'
|
149 |
+
json_data.append({
|
150 |
+
'value': value,
|
151 |
+
'label': name,
|
152 |
+
'logit': float(logit),
|
153 |
+
'box': box.numpy().tolist(),
|
154 |
+
})
|
155 |
+
with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
|
156 |
+
json.dump(json_data, f)
|
157 |
+
|
158 |
+
|
159 |
+
if __name__ == "__main__":
|
160 |
+
|
161 |
+
parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
|
162 |
+
parser.add_argument("--sam_version", type=str, default="vit_h", required=False, help="SAM ViT version: vit_b / vit_l / vit_h")
|
163 |
+
parser.add_argument("--sam_checkpoint", type=str, required=False, help="path to sam checkpoint file")
|
164 |
+
parser.add_argument("--sam_hq_checkpoint", type=str, default=None, help="path to sam-hq checkpoint file")
|
165 |
+
parser.add_argument("--use_sam_hq", action="store_true", help="using sam-hq for prediction")
|
166 |
+
parser.add_argument("--text_prompt", type=str, required=True, help="text prompt")
|
167 |
+
|
168 |
+
parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
|
169 |
+
parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
|
170 |
+
parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
|
171 |
+
parser.add_argument("--name", type=str, default="", help="name of the input image folder")
|
172 |
+
parser.add_argument("--size", type=int, default=1024, help="image size")
|
173 |
+
|
174 |
+
args = parser.parse_args()
|
175 |
+
args.base_folder = "/path/to/Grounded-Segment-Anything"
|
176 |
+
# change to your "Grounded-Segment-Anything" installation folder!!!!!
|
177 |
+
input_folder = os.path.join(".", args.name)
|
178 |
+
|
179 |
+
args.config = os.path.join(args.base_folder,"GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py")
|
180 |
+
args.grounded_checkpoint = "groundingdino_swint_ogc.pth"
|
181 |
+
args.sam_checkpoint="sam_vit_h_4b8939.pth"
|
182 |
+
args.box_threshold = 0.3
|
183 |
+
args.text_threshold = 0.25
|
184 |
+
args.device = "cuda"
|
185 |
+
# cfg
|
186 |
+
|
187 |
+
config_file = args.config # change the path of the model config file
|
188 |
+
grounded_checkpoint = os.path.join(args.base_folder,args.grounded_checkpoint) # change the path of the model
|
189 |
+
sam_version = args.sam_version
|
190 |
+
sam_checkpoint = os.path.join(args.base_folder,args.sam_checkpoint)
|
191 |
+
if args.sam_hq_checkpoint is not None:
|
192 |
+
sam_hq_checkpoint = os.path.join(args.base_folder,args.sam_hq_checkpoint)
|
193 |
+
use_sam_hq = args.use_sam_hq
|
194 |
+
# image_path = args.input_image
|
195 |
+
text_prompt = args.text_prompt
|
196 |
+
# output_dir = args.output_dir
|
197 |
+
box_threshold = args.box_threshold
|
198 |
+
text_threshold = args.text_threshold
|
199 |
+
device = args.device
|
200 |
+
|
201 |
+
output_dir = input_folder
|
202 |
+
os.makedirs(output_dir, exist_ok=True)
|
203 |
+
|
204 |
+
# unify names
|
205 |
+
|
206 |
+
if len(os.listdir(input_folder)) == 1:
|
207 |
+
for filename in os.listdir(input_folder):
|
208 |
+
imgtype = "." + filename.split(".")[-1]
|
209 |
+
shutil.move(os.path.join(input_folder, filename), os.path.join(input_folder, "img"+imgtype))
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
### resizing and save
|
214 |
+
if os.path.exists(os.path.join(input_folder, "img.jpg")):
|
215 |
+
image_path = os.path.join(input_folder, "img.jpg")
|
216 |
+
else:
|
217 |
+
image_path = os.path.join(input_folder, "img.png")
|
218 |
+
image = load_image_to_resize(image_path, size = args.size)
|
219 |
+
image =Image.fromarray(image)
|
220 |
+
resized_image_path = os.path.join(input_folder, "img_{}.png".format(args.size))
|
221 |
+
image.save(resized_image_path)
|
222 |
+
|
223 |
+
image_path = resized_image_path
|
224 |
+
# load image
|
225 |
+
image_pil, image = load_image(image_path)
|
226 |
+
# load model
|
227 |
+
model = load_model(config_file, grounded_checkpoint, device=device)
|
228 |
+
|
229 |
+
# # visualize raw image
|
230 |
+
# image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
|
231 |
+
|
232 |
+
# run grounding dino model
|
233 |
+
boxes_filt, pred_phrases = get_grounding_output(
|
234 |
+
model, image, text_prompt, box_threshold, text_threshold, device=device
|
235 |
+
)
|
236 |
+
|
237 |
+
# initialize SAM
|
238 |
+
if use_sam_hq:
|
239 |
+
predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device))
|
240 |
+
else:
|
241 |
+
predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device))
|
242 |
+
image = cv2.imread(image_path)
|
243 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
244 |
+
predictor.set_image(image)
|
245 |
+
|
246 |
+
size = image_pil.size
|
247 |
+
H, W = size[1], size[0]
|
248 |
+
for i in range(boxes_filt.size(0)):
|
249 |
+
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
250 |
+
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
251 |
+
boxes_filt[i][2:] += boxes_filt[i][:2]
|
252 |
+
|
253 |
+
boxes_filt = boxes_filt.cpu()
|
254 |
+
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
|
255 |
+
|
256 |
+
masks, _, _ = predictor.predict_torch(
|
257 |
+
point_coords = None,
|
258 |
+
point_labels = None,
|
259 |
+
boxes = transformed_boxes.to(device),
|
260 |
+
multimask_output = False,
|
261 |
+
)
|
262 |
+
|
263 |
+
tot_detect = len(masks)
|
264 |
+
# draw output image
|
265 |
+
plt.figure(figsize=(10, 10))
|
266 |
+
plt.imshow(image)
|
267 |
+
for idx, (mask,label) in enumerate(zip(masks,pred_phrases)):
|
268 |
+
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
|
269 |
+
np.save( os.path.join(output_dir, "maskSAM{}_{}.npy".format(idx, label)) ,mask[0].cpu().numpy())
|
270 |
+
|
271 |
+
for idx, (box, label) in enumerate(zip(boxes_filt, pred_phrases)):
|
272 |
+
label = label + "_{}".format(idx)
|
273 |
+
show_box(box.numpy(), plt.gca(), label)
|
274 |
+
|
275 |
+
rec_mask = np.zeros_like(mask[0].cpu().numpy()).astype(np.bool_)
|
276 |
+
for idx, box in enumerate(boxes_filt):
|
277 |
+
up = box[0].numpy().astype(np.int32)
|
278 |
+
down = box[2].numpy().astype(np.int32)
|
279 |
+
left = box[1].numpy().astype(np.int32)
|
280 |
+
right = box[3].numpy().astype(np.int32)
|
281 |
+
rec_mask[left:right, up:down] = True
|
282 |
+
|
283 |
+
plt.axis('off')
|
284 |
+
plt.savefig(
|
285 |
+
os.path.join(output_dir, "seg_init_SAM.png"),
|
286 |
+
bbox_inches="tight", dpi=300, pad_inches=0.0
|
287 |
+
)
|
288 |
+
|
289 |
+
mask_detected = np.logical_or.reduce([mask[0].cpu().numpy() for mask in masks ])
|
290 |
+
mask_undetected = np.logical_not(mask_detected)
|
291 |
+
np.save( os.path.join(output_dir, "SAM_detected.npy") ,mask_detected)
|
292 |
+
np.save( os.path.join(output_dir, "maskSAM{}_rest.npy".format(len(masks))) ,mask_undetected)
|
293 |
+
plt.imsave( os.path.join(output_dir,"mask_SAM-detected.png"), np.repeat(np.expand_dims( mask_detected.astype(float), axis=2), 3, axis = 2))
|
294 |
+
|
utils.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import PIL
|
6 |
+
import os
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
from diffusers.models.attention_processor import (
|
9 |
+
AttnProcessor2_0,
|
10 |
+
LoRAAttnProcessor2_0,
|
11 |
+
LoRAXFormersAttnProcessor,
|
12 |
+
XFormersAttnProcessor,
|
13 |
+
)
|
14 |
+
|
15 |
+
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
16 |
+
|
17 |
+
def myroll2d(a, delta_x, delta_y):
|
18 |
+
h, w = a.shape[0], a.shape[1]
|
19 |
+
delta_x = -delta_x
|
20 |
+
delta_y = -delta_y
|
21 |
+
if isinstance(a, np.ndarray):
|
22 |
+
b = np.zeros ([h,w]).astype(np.uint8)
|
23 |
+
elif isinstance(a, torch.Tensor):
|
24 |
+
b = torch.zeros([h,w]).to(torch.uint8)
|
25 |
+
if delta_x > 0:
|
26 |
+
left_a = delta_x
|
27 |
+
right_a = w
|
28 |
+
left_b = 0
|
29 |
+
right_b = w - delta_x
|
30 |
+
else:
|
31 |
+
left_a = 0
|
32 |
+
right_a = w + delta_x
|
33 |
+
left_b = -delta_x
|
34 |
+
right_b = w
|
35 |
+
if delta_y > 0:
|
36 |
+
top_a = delta_y
|
37 |
+
bot_a = h
|
38 |
+
top_b = 0
|
39 |
+
bot_b = h-delta_y
|
40 |
+
else:
|
41 |
+
top_a = 0
|
42 |
+
bot_a = h + delta_y
|
43 |
+
top_b = -delta_y
|
44 |
+
bot_b = h
|
45 |
+
b[left_b: right_b, top_b: bot_b] = a[left_a: right_a, top_a: bot_a]
|
46 |
+
return b
|
47 |
+
|
48 |
+
def import_model_class_from_model_name_or_path(
|
49 |
+
pretrained_model_name_or_path: str, revision = None, subfolder: str = "text_encoder"
|
50 |
+
):
|
51 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
52 |
+
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
53 |
+
)
|
54 |
+
model_class = text_encoder_config.architectures[0]
|
55 |
+
|
56 |
+
if model_class == "CLIPTextModel":
|
57 |
+
from transformers import CLIPTextModel
|
58 |
+
return CLIPTextModel
|
59 |
+
elif model_class == "CLIPTextModelWithProjection":
|
60 |
+
from transformers import CLIPTextModelWithProjection
|
61 |
+
return CLIPTextModelWithProjection
|
62 |
+
else:
|
63 |
+
raise ValueError(f"{model_class} is not supported.")
|
64 |
+
|
65 |
+
@torch.no_grad()
|
66 |
+
def image2latent(image, vae = None, dtype=None):
|
67 |
+
with torch.no_grad():
|
68 |
+
if type(image) is Image or type(image) is PIL.PngImagePlugin.PngImageFile or type(image) is PIL.JpegImagePlugin.JpegImageFile:
|
69 |
+
image = np.array(image)
|
70 |
+
if type(image) is torch.Tensor and image.dim() == 4:
|
71 |
+
latents = image
|
72 |
+
else:
|
73 |
+
image = torch.from_numpy(image).float() / 127.5 - 1
|
74 |
+
image = image.permute(2, 0, 1).unsqueeze(0).to(device, dtype= dtype)
|
75 |
+
latents = vae.encode(image).latent_dist.sample()
|
76 |
+
latents = latents * vae.config.scaling_factor
|
77 |
+
return latents
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def latent2image(latents, return_type = 'np', vae = None):
|
81 |
+
# needs_upcasting = vae.dtype == torch.float16 and vae.config.force_upcast
|
82 |
+
needs_upcasting = True
|
83 |
+
if needs_upcasting:
|
84 |
+
upcast_vae(vae)
|
85 |
+
latents = latents.to(next(iter(vae.post_quant_conv.parameters())).dtype)
|
86 |
+
image = vae.decode(latents /vae.config.scaling_factor, return_dict=False)[0]
|
87 |
+
|
88 |
+
if return_type == 'np':
|
89 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
90 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()#[0]
|
91 |
+
image = (image * 255).astype(np.uint8)
|
92 |
+
if needs_upcasting:
|
93 |
+
vae.to(dtype=torch.float16)
|
94 |
+
return image
|
95 |
+
|
96 |
+
def upcast_vae(vae):
|
97 |
+
dtype = vae.dtype
|
98 |
+
vae.to(dtype=torch.float32)
|
99 |
+
use_torch_2_0_or_xformers = isinstance(
|
100 |
+
vae.decoder.mid_block.attentions[0].processor,
|
101 |
+
(
|
102 |
+
AttnProcessor2_0,
|
103 |
+
XFormersAttnProcessor,
|
104 |
+
LoRAXFormersAttnProcessor,
|
105 |
+
LoRAAttnProcessor2_0,
|
106 |
+
),
|
107 |
+
)
|
108 |
+
# if xformers or torch_2_0 is used attention block does not need
|
109 |
+
# to be in float32 which can save lots of memory
|
110 |
+
if use_torch_2_0_or_xformers:
|
111 |
+
vae.post_quant_conv.to(dtype)
|
112 |
+
vae.decoder.conv_in.to(dtype)
|
113 |
+
vae.decoder.mid_block.to(dtype)
|
114 |
+
|
115 |
+
def prompt_to_emb_length_sdxl(prompt, tokenizer, text_encoder, length = None):
|
116 |
+
text_input = tokenizer(
|
117 |
+
[prompt],
|
118 |
+
padding="max_length",
|
119 |
+
max_length=length,
|
120 |
+
truncation=True,
|
121 |
+
return_tensors="pt",
|
122 |
+
)
|
123 |
+
prompt_embeds = text_encoder(text_input.input_ids.to(device),output_hidden_states=True)
|
124 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
125 |
+
|
126 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
127 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
128 |
+
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
129 |
+
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
130 |
+
|
131 |
+
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds}
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
def prompt_to_emb_length_sd(prompt, tokenizer, text_encoder, length = None):
|
137 |
+
text_input = tokenizer(
|
138 |
+
[prompt],
|
139 |
+
padding="max_length",
|
140 |
+
max_length=length,
|
141 |
+
truncation=True,
|
142 |
+
return_tensors="pt",
|
143 |
+
)
|
144 |
+
emb = text_encoder(text_input.input_ids.to(device))[0]
|
145 |
+
return emb
|
146 |
+
|
147 |
+
def sdxl_prepare_input_decom(
|
148 |
+
set_string_list,
|
149 |
+
tokenizer,
|
150 |
+
tokenizer_2,
|
151 |
+
text_encoder_1,
|
152 |
+
text_encoder_2,
|
153 |
+
length = 20,
|
154 |
+
bsz = 1,
|
155 |
+
weight_dtype = torch.float32,
|
156 |
+
resolution = 1024,
|
157 |
+
normal_token_id_list = []
|
158 |
+
):
|
159 |
+
encoder_hidden_states_list = []
|
160 |
+
pooled_prompt_embeds = 0
|
161 |
+
|
162 |
+
for m_idx in range(len(set_string_list)):
|
163 |
+
prompt_embeds_list = []
|
164 |
+
if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list : ###
|
165 |
+
out = prompt_to_emb_length_sdxl(
|
166 |
+
set_string_list[m_idx], tokenizer, text_encoder_1, length = length
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
out = prompt_to_emb_length_sdxl(
|
170 |
+
set_string_list[m_idx], tokenizer, text_encoder_1, length = 77
|
171 |
+
)
|
172 |
+
print(m_idx, set_string_list[m_idx])
|
173 |
+
prompt_embeds, _ = out["prompt_embeds"].to(dtype=weight_dtype), out["pooled_prompt_embeds"].to(dtype=weight_dtype)
|
174 |
+
prompt_embeds = prompt_embeds.repeat(bsz, 1, 1)
|
175 |
+
prompt_embeds_list.append(prompt_embeds)
|
176 |
+
if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list:
|
177 |
+
out = prompt_to_emb_length_sdxl(
|
178 |
+
set_string_list[m_idx], tokenizer_2, text_encoder_2, length = length
|
179 |
+
)
|
180 |
+
else:
|
181 |
+
out = prompt_to_emb_length_sdxl(
|
182 |
+
set_string_list[m_idx], tokenizer_2, text_encoder_2, length = 77
|
183 |
+
)
|
184 |
+
print(m_idx, set_string_list[m_idx])
|
185 |
+
|
186 |
+
prompt_embeds = out["prompt_embeds"].to(dtype=weight_dtype)
|
187 |
+
pooled_prompt_embeds += out["pooled_prompt_embeds"].to(dtype=weight_dtype)
|
188 |
+
prompt_embeds = prompt_embeds.repeat(bsz, 1, 1)
|
189 |
+
prompt_embeds_list.append(prompt_embeds)
|
190 |
+
|
191 |
+
encoder_hidden_states_list.append(torch.concat(prompt_embeds_list, dim=-1))
|
192 |
+
|
193 |
+
add_text_embeds = pooled_prompt_embeds /len(set_string_list)
|
194 |
+
target_size, original_size,crops_coords_top_left = (resolution,resolution),(resolution,resolution),(0,0)
|
195 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
196 |
+
|
197 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype,device = pooled_prompt_embeds.device) #[B,6]
|
198 |
+
return encoder_hidden_states_list, add_text_embeds, add_time_ids
|
199 |
+
|
200 |
+
def sd_prepare_input_decom(
|
201 |
+
set_string_list,
|
202 |
+
tokenizer,
|
203 |
+
text_encoder_1,
|
204 |
+
length = 20,
|
205 |
+
bsz = 1,
|
206 |
+
weight_dtype = torch.float32,
|
207 |
+
normal_token_id_list = []
|
208 |
+
):
|
209 |
+
encoder_hidden_states_list = []
|
210 |
+
for m_idx in range(len(set_string_list)):
|
211 |
+
if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list : ###
|
212 |
+
encoder_hidden_states = prompt_to_emb_length_sd(
|
213 |
+
set_string_list[m_idx], tokenizer, text_encoder_1, length = length
|
214 |
+
)
|
215 |
+
else:
|
216 |
+
encoder_hidden_states = prompt_to_emb_length_sd(
|
217 |
+
set_string_list[m_idx], tokenizer, text_encoder_1, length = 77
|
218 |
+
)
|
219 |
+
print(m_idx, set_string_list[m_idx])
|
220 |
+
encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1)
|
221 |
+
encoder_hidden_states_list.append(encoder_hidden_states.to(dtype=weight_dtype))
|
222 |
+
return encoder_hidden_states_list
|
223 |
+
|
224 |
+
|
225 |
+
def load_mask (input_folder):
|
226 |
+
np_mask_dtype = 'uint8'
|
227 |
+
mask_np_list = []
|
228 |
+
mask_label_list = []
|
229 |
+
files = [
|
230 |
+
file_name for file_name in os.listdir(input_folder) \
|
231 |
+
if "mask" in file_name and ".npy" in file_name \
|
232 |
+
and "_" in file_name and "Edited" not in file_name
|
233 |
+
]
|
234 |
+
files = sorted(files, key = lambda x: int(x.split("_")[0][4:]))
|
235 |
+
|
236 |
+
for idx, file_name in enumerate(files):
|
237 |
+
if "mask" in file_name and ".npy" in file_name and "_" in file_name \
|
238 |
+
and "Edited" not in file_name:
|
239 |
+
mask_np = np.load(os.path.join(input_folder, file_name)).astype(np_mask_dtype)
|
240 |
+
mask_np_list.append(mask_np)
|
241 |
+
mask_label = file_name.split("_")[1][:-4]
|
242 |
+
mask_label_list.append(mask_label)
|
243 |
+
mask_list = []
|
244 |
+
for mask_np in mask_np_list:
|
245 |
+
mask = torch.from_numpy(mask_np)
|
246 |
+
mask_list.append(mask)
|
247 |
+
try:
|
248 |
+
assert torch.all(sum(mask_list)==1)
|
249 |
+
except:
|
250 |
+
print("please check mask")
|
251 |
+
# plt.imsave( "out_mask.png", mask_list_edit[0])
|
252 |
+
import pdb; pdb.set_trace()
|
253 |
+
return mask_list, mask_label_list
|
254 |
+
|
255 |
+
def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
|
256 |
+
if type(image_path) is str:
|
257 |
+
image = np.array(Image.open(image_path))[:, :, :3]
|
258 |
+
else:
|
259 |
+
image = image_path
|
260 |
+
h, w, c = image.shape
|
261 |
+
left = min(left, w-1)
|
262 |
+
right = min(right, w - left - 1)
|
263 |
+
top = min(top, h - left - 1)
|
264 |
+
bottom = min(bottom, h - top - 1)
|
265 |
+
image = image[top:h-bottom, left:w-right]
|
266 |
+
h, w, c = image.shape
|
267 |
+
if h < w:
|
268 |
+
offset = (w - h) // 2
|
269 |
+
image = image[:, offset:offset + h]
|
270 |
+
elif w < h:
|
271 |
+
offset = (h - w) // 2
|
272 |
+
image = image[offset:offset + w]
|
273 |
+
image = np.array(Image.fromarray(image).resize((size, size)))
|
274 |
+
return image
|
275 |
+
|
276 |
+
def mask_union_torch(*masks):
|
277 |
+
masks = [m.to(torch.float) for m in masks]
|
278 |
+
res = sum(masks)>0
|
279 |
+
return res
|
280 |
+
|
281 |
+
def load_mask_edit(input_folder):
|
282 |
+
np_mask_dtype = 'uint8'
|
283 |
+
mask_np_list = []
|
284 |
+
mask_label_list = []
|
285 |
+
|
286 |
+
files = [file_name for file_name in os.listdir(input_folder) if "mask" in file_name and ".npy" in file_name and "_" in file_name and "Edited" in file_name and "-1" not in file_name]
|
287 |
+
files = sorted(files, key = lambda x: int(x.split("_")[0][10:]))
|
288 |
+
|
289 |
+
for idx, file_name in enumerate(files):
|
290 |
+
if "mask" in file_name and ".npy" in file_name and "_" in file_name and "Edited" in file_name and "-1" not in file_name:
|
291 |
+
mask_np = np.load(os.path.join(input_folder, file_name)).astype(np_mask_dtype)
|
292 |
+
mask_np_list.append(mask_np)
|
293 |
+
mask_label = file_name.split("_")[1][:-4]
|
294 |
+
# mask_label = mask_label.split("-")[0]
|
295 |
+
mask_label_list.append(mask_label)
|
296 |
+
mask_list = []
|
297 |
+
for mask_np in mask_np_list:
|
298 |
+
mask = torch.from_numpy(mask_np)
|
299 |
+
mask_list.append(mask)
|
300 |
+
try:
|
301 |
+
assert torch.all(sum(mask_list)==1)
|
302 |
+
except:
|
303 |
+
print("Make sure maskEdited is in the folder, if not, generate using the UI")
|
304 |
+
import pdb; pdb.set_trace()
|
305 |
+
return mask_list, mask_label_list
|
306 |
+
|
307 |
+
def save_images(images,filename, num_rows=1, offset_ratio=0.02):
|
308 |
+
if type(images) is list:
|
309 |
+
num_empty = len(images) % num_rows
|
310 |
+
elif images.ndim == 4:
|
311 |
+
num_empty = images.shape[0] % num_rows
|
312 |
+
else:
|
313 |
+
images = [images]
|
314 |
+
num_empty = 0
|
315 |
+
|
316 |
+
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
|
317 |
+
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
|
318 |
+
num_items = len(images)
|
319 |
+
|
320 |
+
folder = os.path.dirname(filename)
|
321 |
+
for i, image in enumerate(images):
|
322 |
+
pil_img = Image.fromarray(image)
|
323 |
+
name = filename.split("/")[-1]
|
324 |
+
name = name.split(".")[-2]+"_{}".format(i) +"."+filename.split(".")[-1]
|
325 |
+
pil_img.save(os.path.join(folder, name))
|
326 |
+
print("saved to ", os.path.join(folder, name))
|
utils_mask.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from matplotlib import cm
|
4 |
+
import matplotlib.patches as mpatches
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import torch
|
7 |
+
from utils import myroll2d
|
8 |
+
|
9 |
+
def create_outer_edge_mask_torch(mask, edge_thickness = 20):
|
10 |
+
mask_down = myroll2d(mask, edge_thickness, 0 )
|
11 |
+
mask_edge_down = (mask_down.to(torch.float) -mask.to(torch.float))>0
|
12 |
+
|
13 |
+
mask_up = myroll2d(mask, -edge_thickness, 0)
|
14 |
+
mask_edge_up = (mask_up.to(torch.float) -mask.to(torch.float))>0
|
15 |
+
|
16 |
+
mask_left = myroll2d(mask, 0, -edge_thickness)
|
17 |
+
mask_edge_left = (mask_left.to(torch.float) -mask.to(torch.float))>0
|
18 |
+
|
19 |
+
mask_right = myroll2d(mask, 0, edge_thickness)
|
20 |
+
mask_edge_right = (mask_right.to(torch.float) -mask.to(torch.float))>0
|
21 |
+
|
22 |
+
mask_ur = myroll2d(mask, -edge_thickness,edge_thickness)
|
23 |
+
mask_edge_ur = (mask_ur.to(torch.float) -mask.to(torch.float))>0
|
24 |
+
|
25 |
+
mask_ul = myroll2d(mask, -edge_thickness,-edge_thickness)
|
26 |
+
mask_edge_ul = (mask_ul.to(torch.float) -mask.to(torch.float))>0
|
27 |
+
|
28 |
+
mask_dr = myroll2d(mask, edge_thickness,edge_thickness )
|
29 |
+
mask_edge_dr = (mask_dr.to(torch.float) -mask.to(torch.float))>0
|
30 |
+
|
31 |
+
mask_dl = myroll2d(mask, edge_thickness,-edge_thickness)
|
32 |
+
mask_edge_ul = (mask_dl.to(torch.float) -mask.to(torch.float))>0
|
33 |
+
|
34 |
+
mask_edge = mask_union_torch(mask_edge_down, mask_edge_up, mask_edge_left, mask_edge_right,
|
35 |
+
mask_edge_ur, mask_edge_ul, mask_edge_dr, mask_edge_ul)
|
36 |
+
return mask_edge
|
37 |
+
|
38 |
+
def mask_substract_torch(mask1, mask2):
|
39 |
+
return ((mask1.cpu().to(torch.float)-mask2.cpu().to(torch.float))>0).to(torch.uint8)
|
40 |
+
|
41 |
+
def check_mask_overlap_torch(*masks):
|
42 |
+
assert torch.any(sum([m.float() for m in masks])<=1 )
|
43 |
+
|
44 |
+
def check_mask_overlap_numpy(*masks):
|
45 |
+
assert np.all(sum([m.astype(float) for m in masks])<=1 )
|
46 |
+
|
47 |
+
def check_cover_all_torch (*masks):
|
48 |
+
assert torch.all(sum([m.cpu().float() for m in masks])==1)
|
49 |
+
|
50 |
+
def process_mask_to_follow_priority(mask_list, priority_list):
|
51 |
+
for idx1, (m1 , p1) in enumerate(zip(mask_list, priority_list)):
|
52 |
+
for idx2, (m2 , p2) in enumerate(zip(mask_list, priority_list)):
|
53 |
+
if p2 > p1:
|
54 |
+
mask_list[idx1] = ((m1.astype(float)-m2.astype(float))>0).astype(np.uint8)
|
55 |
+
return mask_list
|
56 |
+
|
57 |
+
def mask_union(*masks):
|
58 |
+
masks = [m.astype(float) for m in masks]
|
59 |
+
res = sum(masks)>0
|
60 |
+
return res.astype(np.uint8)
|
61 |
+
|
62 |
+
def mask_intersection(mask1, mask2):
|
63 |
+
mask_uni = mask_union(mask1, mask2)
|
64 |
+
mask_intersec = ((mask1.astype(float)-mask2.astype(float))==0) * mask_uni
|
65 |
+
return mask_intersec
|
66 |
+
|
67 |
+
def mask_union_torch(*masks):
|
68 |
+
masks = [m.float() for m in masks]
|
69 |
+
res = sum(masks)>0
|
70 |
+
return res.to(torch.uint8)
|
71 |
+
|
72 |
+
def mask_intersection_torch(mask1, mask2):
|
73 |
+
mask_uni = mask_union_torch(mask1, mask2)
|
74 |
+
mask_intersec = ((mask1.float()-mask2.float())==0) * mask_uni
|
75 |
+
return mask_intersec.cpu().to(torch.uint8)
|
76 |
+
|
77 |
+
|
78 |
+
def visualize_mask_list(mask_list, savepath):
|
79 |
+
mask = 0
|
80 |
+
for midx, m in enumerate(mask_list):
|
81 |
+
try:
|
82 |
+
mask += m.astype(float)* midx
|
83 |
+
except:
|
84 |
+
mask += m.float()*midx
|
85 |
+
viridis = cm.get_cmap('viridis', len(mask_list))
|
86 |
+
fig, ax = plt.subplots()
|
87 |
+
ax.imshow( mask)
|
88 |
+
|
89 |
+
handles = []
|
90 |
+
label_list = []
|
91 |
+
for idx , _ in enumerate(mask_list):
|
92 |
+
color = viridis(idx)
|
93 |
+
label = f"{idx}"
|
94 |
+
handles.append(mpatches.Patch(color=color, label=label))
|
95 |
+
label_list.append(label)
|
96 |
+
ax.legend(handles=handles)
|
97 |
+
plt.savefig(savepath)
|
98 |
+
|
99 |
+
def visualize_mask_list_clean(mask_list, savepath):
|
100 |
+
mask = 0
|
101 |
+
for midx, m in enumerate(mask_list):
|
102 |
+
try:
|
103 |
+
mask += m.astype(float)* midx
|
104 |
+
except:
|
105 |
+
mask += m.float()*midx
|
106 |
+
viridis = cm.get_cmap('viridis', len(mask_list))
|
107 |
+
fig, ax = plt.subplots()
|
108 |
+
ax.imshow( mask)
|
109 |
+
|
110 |
+
handles = []
|
111 |
+
label_list = []
|
112 |
+
for idx , _ in enumerate(mask_list):
|
113 |
+
color = viridis(idx)
|
114 |
+
label = f"{idx}"
|
115 |
+
handles.append(mpatches.Patch(color=color, label=label))
|
116 |
+
label_list.append(label)
|
117 |
+
# ax.legend(handles=handles)
|
118 |
+
plt.savefig(savepath, dpi=500)
|
119 |
+
|
120 |
+
|
121 |
+
def move_mask(mask_select, delta_x, delta_y):
|
122 |
+
mask_edit = myroll2d(mask_select, delta_y, delta_x)
|
123 |
+
return mask_edit
|
124 |
+
|
125 |
+
def stack_mask_with_priority (mask_list_np, priority_list, edit_idx_list):
|
126 |
+
mask_sel = mask_union(*[mask_list_np[eid] for eid in edit_idx_list])
|
127 |
+
for midx, mask in enumerate(mask_list_np):
|
128 |
+
if midx not in edit_idx_list:
|
129 |
+
if priority_list[edit_idx_list[0]] >= priority_list[midx]:
|
130 |
+
mask = mask.astype(float) - np.logical_and(mask.astype(bool) , mask_sel.astype(bool)).astype(float)
|
131 |
+
mask_list_np[midx] = mask.astype("uint8")
|
132 |
+
for midx in edit_idx_list:
|
133 |
+
for midx_1 in edit_idx_list:
|
134 |
+
if midx != midx_1:
|
135 |
+
if priority_list[midx] <= priority_list[midx_1]:
|
136 |
+
mask = mask_list_np[midx].astype(float) - np.logical_and(mask_list_np[midx].astype(bool), mask_list_np[midx_1].astype(bool)).astype(float)
|
137 |
+
mask_list_np[midx] = mask.astype("uint8")
|
138 |
+
return mask_list_np
|
139 |
+
|
140 |
+
def process_remain_mask(mask_list, edit_idx_list = None, force_mask_remain = None):
|
141 |
+
print("Start to process remaining mask using nearest neighbor")
|
142 |
+
width = mask_list[0].shape[0]
|
143 |
+
height = mask_list[0].shape[1]
|
144 |
+
pixel_ind = np.arange( width* height)
|
145 |
+
|
146 |
+
y_axis = np.arange(width)
|
147 |
+
ymesh = np.repeat(y_axis[:,np.newaxis], height, axis = 1) #N, N
|
148 |
+
ymesh_vec = ymesh.reshape(-1) #N *N
|
149 |
+
|
150 |
+
x_axis = np.arange(height)
|
151 |
+
xmesh = np.repeat(x_axis[np.newaxis, : ], width, axis = 0)
|
152 |
+
xmesh_vec = xmesh.reshape(-1)
|
153 |
+
|
154 |
+
mask_remain = (1 - sum([m.astype(float) for m in mask_list])).astype(np.uint8)
|
155 |
+
if force_mask_remain is not None:
|
156 |
+
mask_list[force_mask_remain] = (mask_list[force_mask_remain].astype(float) + mask_remain.astype(float)).astype(np.uint8)
|
157 |
+
else:
|
158 |
+
if edit_idx_list is not None:
|
159 |
+
a = [mask_list[eidx] for eidx in edit_idx_list]
|
160 |
+
mask_edit = mask_union(*a)
|
161 |
+
else:
|
162 |
+
mask_edit = np.zeros_like(mask_remain).astype(np.uint8)
|
163 |
+
mask_feasible = (1 - mask_remain.astype(float) - mask_edit.astype(float)).astype(np.uint8)
|
164 |
+
|
165 |
+
edge_width = 2
|
166 |
+
|
167 |
+
mask_feasible_down = myroll2d(mask_feasible, edge_width, 0)
|
168 |
+
mask_edge_down = (mask_feasible_down.astype(float) -mask_feasible.astype(float))<0
|
169 |
+
|
170 |
+
mask_feasible_up = myroll2d(mask_feasible, -edge_width, 0)
|
171 |
+
mask_edge_up = (mask_feasible_up.astype(float) -mask_feasible.astype(float))<0
|
172 |
+
|
173 |
+
mask_feasible_left = myroll2d(mask_feasible, 0, -edge_width)
|
174 |
+
mask_edge_left = (mask_feasible_left.astype(float) -mask_feasible.astype(float))<0
|
175 |
+
|
176 |
+
mask_feasible_right = myroll2d(mask_feasible, 0, edge_width)
|
177 |
+
mask_edge_right = (mask_feasible_right.astype(float) -mask_feasible.astype(float))<0
|
178 |
+
|
179 |
+
mask_feasible_ur = myroll2d(mask_feasible, -edge_width,edge_width)
|
180 |
+
mask_edge_ur = (mask_feasible_ur.astype(float) -mask_feasible.astype(float))<0
|
181 |
+
|
182 |
+
mask_feasible_ul = myroll2d(mask_feasible, -edge_width,-edge_width )
|
183 |
+
mask_edge_ul = (mask_feasible_ul.astype(float) -mask_feasible.astype(float))<0
|
184 |
+
|
185 |
+
mask_feasible_dr = myroll2d(mask_feasible, edge_width,edge_width )
|
186 |
+
mask_edge_dr = (mask_feasible_dr.astype(float) -mask_feasible.astype(float))<0
|
187 |
+
|
188 |
+
mask_feasible_dl = myroll2d(mask_feasible, edge_width,-edge_width)
|
189 |
+
mask_edge_ul = (mask_feasible_dl.astype(float) -mask_feasible.astype(float))<0
|
190 |
+
|
191 |
+
mask_edge = mask_union(
|
192 |
+
mask_edge_down, mask_edge_up, mask_edge_left, mask_edge_right, mask_edge_ur, mask_edge_ul, mask_edge_dr, mask_edge_ul
|
193 |
+
)
|
194 |
+
|
195 |
+
mask_feasible_edge = mask_intersection(mask_edge, mask_feasible)
|
196 |
+
|
197 |
+
vec_mask_feasible_edge = mask_feasible_edge.reshape(-1)
|
198 |
+
vec_mask_remain = mask_remain.reshape(-1)
|
199 |
+
|
200 |
+
indvec_all = np.arange(width*height)
|
201 |
+
vec_region_partition= 0
|
202 |
+
for mask_idx, mask in enumerate(mask_list):
|
203 |
+
vec_region_partition += mask.reshape(-1) * mask_idx
|
204 |
+
vec_region_partition += mask_remain.reshape(-1) * mask_idx
|
205 |
+
# assert 0 in vec_region_partition
|
206 |
+
|
207 |
+
vec_ind_remain = np.nonzero(vec_mask_remain)[0]
|
208 |
+
vec_ind_feasible_edge = np.nonzero(vec_mask_feasible_edge)[0]
|
209 |
+
|
210 |
+
vec_x_remain = xmesh_vec[vec_ind_remain]
|
211 |
+
vec_y_remain = ymesh_vec[vec_ind_remain]
|
212 |
+
|
213 |
+
vec_x_feasible_edge = xmesh_vec[vec_ind_feasible_edge]
|
214 |
+
vec_y_feasible_edge = ymesh_vec[vec_ind_feasible_edge]
|
215 |
+
|
216 |
+
x_dis = vec_x_remain[:,np.newaxis] - vec_x_feasible_edge[np.newaxis,:]
|
217 |
+
y_dis = vec_y_remain[:,np.newaxis] - vec_y_feasible_edge[np.newaxis,:]
|
218 |
+
dis = x_dis **2 + y_dis **2
|
219 |
+
pos = np.argmin(dis, axis = 1)
|
220 |
+
nearest_point = vec_ind_feasible_edge[pos] # closest point to target point
|
221 |
+
|
222 |
+
nearest_region = vec_region_partition[nearest_point]
|
223 |
+
nearest_region_set = set(nearest_region)
|
224 |
+
if edit_idx_list is not None:
|
225 |
+
for edit_idx in edit_idx_list:
|
226 |
+
assert edit_idx not in nearest_region
|
227 |
+
|
228 |
+
for midx, m in enumerate(mask_list):
|
229 |
+
if midx in nearest_region_set:
|
230 |
+
vec_newmask = np.zeros_like(indvec_all)
|
231 |
+
add_ind = vec_ind_remain [np.argwhere(nearest_region==midx)]
|
232 |
+
vec_newmask[add_ind] = 1
|
233 |
+
|
234 |
+
mask_list[midx] = mask_list[midx].astype(float)+ vec_newmask.reshape( mask_list[midx].shape).astype(float)
|
235 |
+
mask_list[midx] = mask_list[midx] > 0
|
236 |
+
|
237 |
+
print("Finish processing remaining mask, if you want to edit, launch the ui")
|
238 |
+
return mask_list, mask_remain
|
239 |
+
|
240 |
+
def resize_mask(mask_np, resize_ratio = 1):
|
241 |
+
w, h = mask_np.shape[0], mask_np.shape[1]
|
242 |
+
resized_w, resized_h = int(w*resize_ratio),int(h*resize_ratio)
|
243 |
+
mask_resized = torch.nn.functional.interpolate(torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0), (resized_w, resized_h)).squeeze()
|
244 |
+
|
245 |
+
mask = torch.zeros(w, h)
|
246 |
+
if w > resized_w:
|
247 |
+
mask[:resized_w, :resized_h] = mask_resized
|
248 |
+
else:
|
249 |
+
assert h <= resized_h
|
250 |
+
mask = mask_resized[resized_w//2-w//2: resized_w//2-w//2+w, resized_h//2-h//2: resized_h//2-h//2+h]
|
251 |
+
return mask.cpu().numpy().astype(np.uint8)
|
252 |
+
|
253 |
+
def process_mask_move_torch(
|
254 |
+
mask_list,
|
255 |
+
move_index_list,
|
256 |
+
delta_x_list = None,
|
257 |
+
delta_y_list = None,
|
258 |
+
edit_priority_list = None,
|
259 |
+
force_mask_remain = None,
|
260 |
+
resize_list = None
|
261 |
+
):
|
262 |
+
mask_list_np = [m.cpu().numpy() for m in mask_list]
|
263 |
+
priority_list = [0 for _ in range(len(mask_list_np))]
|
264 |
+
for idx, (move_index, delta_x, delta_y, priority) in enumerate(zip(move_index_list, delta_x_list, delta_y_list, edit_priority_list)):
|
265 |
+
priority_list[move_index] = priority
|
266 |
+
if resize_list is not None:
|
267 |
+
mask = resize_mask (mask_list_np[move_index], resize_list[idx])
|
268 |
+
else:
|
269 |
+
mask = mask_list_np[move_index]
|
270 |
+
mask_list_np[move_index] = move_mask(mask, delta_x = delta_x, delta_y = delta_y)
|
271 |
+
mask_list_np = stack_mask_with_priority (mask_list_np, priority_list, move_index_list) # exists blank
|
272 |
+
check_mask_overlap_numpy(*mask_list_np)
|
273 |
+
mask_list_np, mask_remain = process_remain_mask(mask_list_np, move_index_list,force_mask_remain)
|
274 |
+
mask_list = [torch.from_numpy(m).to( dtype=torch.uint8) for m in mask_list_np]
|
275 |
+
mask_remain = torch.from_numpy(mask_remain).to(dtype=torch.uint8)
|
276 |
+
return mask_list, mask_remain
|
277 |
+
|
278 |
+
def process_mask_remove_torch(mask_list, remove_idx):
|
279 |
+
mask_list_np = [m.cpu().numpy() for m in mask_list]
|
280 |
+
mask_list_np[remove_idx] = np.zeros_like(mask_list_np[0])
|
281 |
+
mask_list_np, mask_remain = process_remain_mask(mask_list_np)
|
282 |
+
mask_list = [torch.from_numpy(m).to(dtype=torch.uint8) for m in mask_list_np]
|
283 |
+
mask_remain = torch.from_numpy(mask_remain).to(dtype=torch.uint8)
|
284 |
+
return mask_list, mask_remain
|
285 |
+
|
286 |
+
def get_mask_difference_torch(mask_list1, mask_list2):
|
287 |
+
assert len(mask_list1) == len(mask_list2)
|
288 |
+
mask_diff = torch.zeros_like(mask_list1[0])
|
289 |
+
for mask1 , mask2 in zip(mask_list1, mask_list2):
|
290 |
+
diff = ((mask1.float() - mask2.float())!=0).to(torch.uint8)
|
291 |
+
mask_diff = mask_union_torch(mask_diff, diff)
|
292 |
+
return mask_diff
|
293 |
+
|
294 |
+
def save_mask_list_to_npys(folder, mask_list, mask_label_list, name = "mask"):
|
295 |
+
for midx, (mask, mask_label) in enumerate(zip(mask_list, mask_label_list)):
|
296 |
+
np.save(os.path.join(folder, "{}{}_{}.npy".format(name, midx, mask_label)), mask)
|