Yw22 commited on
Commit
6444ed9
·
0 Parent(s):
Files changed (36) hide show
  1. .gitattributes +62 -0
  2. .gitignore +178 -0
  3. README.md +11 -0
  4. app/gpt4_o/brushedit_all_in_one_pipeline.py +80 -0
  5. app/gpt4_o/brushedit_app.py +914 -0
  6. app/gpt4_o/instructions.py +106 -0
  7. app/gpt4_o/requirements.txt +18 -0
  8. app/gpt4_o/run_app.sh +5 -0
  9. app/gpt4_o/vlm_pipeline.py +138 -0
  10. app/utils/utils.py +197 -0
  11. assets/hedgehog_rm_fg/hedgehog.png +3 -0
  12. assets/hedgehog_rm_fg/image_edit_82314e18-c64c-4003-9ef9-52cebf254b2f_2.png +3 -0
  13. assets/hedgehog_rm_fg/mask_82314e18-c64c-4003-9ef9-52cebf254b2f.png +3 -0
  14. assets/hedgehog_rm_fg/masked_image_82314e18-c64c-4003-9ef9-52cebf254b2f.png +3 -0
  15. assets/hedgehog_rm_fg/prompt.txt +1 -0
  16. assets/hedgehog_rp_bg/hedgehog.png +3 -0
  17. assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png +3 -0
  18. assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png +3 -0
  19. assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png +3 -0
  20. assets/hedgehog_rp_bg/prompt.txt +1 -0
  21. assets/hedgehog_rp_fg/hedgehog.png +3 -0
  22. assets/hedgehog_rp_fg/image_edit_5cab3448-5a3a-459c-9144-35cca3d34273_0.png +3 -0
  23. assets/hedgehog_rp_fg/mask_5cab3448-5a3a-459c-9144-35cca3d34273.png +3 -0
  24. assets/hedgehog_rp_fg/masked_image_5cab3448-5a3a-459c-9144-35cca3d34273.png +3 -0
  25. assets/hedgehog_rp_fg/prompt.txt +1 -0
  26. assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png +3 -0
  27. assets/mona_lisa/mask_aae09152-4495-4332-b691-a0c7bff524be.png +3 -0
  28. assets/mona_lisa/masked_image_aae09152-4495-4332-b691-a0c7bff524be.png +3 -0
  29. assets/mona_lisa/mona_lisa.png +3 -0
  30. assets/mona_lisa/prompt.txt +1 -0
  31. assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png +3 -0
  32. assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png +3 -0
  33. assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png +3 -0
  34. assets/sunflower_girl/prompt.txt +1 -0
  35. assets/sunflower_girl/sunflower_girl.png +3 -0
  36. requirements.txt +20 -0
.gitattributes ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
39
+ *.webp filter=lfs diff=lfs merge=lfs -text
40
+ *.gif filter=lfs diff=lfs merge=lfs -text
41
+ *.bmp filter=lfs diff=lfs merge=lfs -text
42
+ *.tiff filter=lfs diff=lfs merge=lfs -text
43
+ assets/hedgehog_rm_fg/hedgehog.png filter=lfs diff=lfs merge=lfs -text
44
+ assets/hedgehog_rm_fg/image_edit_82314e18-c64c-4003-9ef9-52cebf254b2f_2.png filter=lfs diff=lfs merge=lfs -text
45
+ assets/hedgehog_rm_fg/mask_82314e18-c64c-4003-9ef9-52cebf254b2f.png filter=lfs diff=lfs merge=lfs -text
46
+ assets/hedgehog_rm_fg/masked_image_82314e18-c64c-4003-9ef9-52cebf254b2f.png filter=lfs diff=lfs merge=lfs -text
47
+ assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png filter=lfs diff=lfs merge=lfs -text
48
+ assets/hedgehog_rp_bg/hedgehog.png filter=lfs diff=lfs merge=lfs -text
49
+ assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png filter=lfs diff=lfs merge=lfs -text
50
+ assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png filter=lfs diff=lfs merge=lfs -text
51
+ assets/hedgehog_rp_fg/hedgehog.png filter=lfs diff=lfs merge=lfs -text
52
+ assets/hedgehog_rp_fg/image_edit_5cab3448-5a3a-459c-9144-35cca3d34273_0.png filter=lfs diff=lfs merge=lfs -text
53
+ assets/hedgehog_rp_fg/mask_5cab3448-5a3a-459c-9144-35cca3d34273.png filter=lfs diff=lfs merge=lfs -text
54
+ assets/hedgehog_rp_fg/masked_image_5cab3448-5a3a-459c-9144-35cca3d34273.png filter=lfs diff=lfs merge=lfs -text
55
+ assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png filter=lfs diff=lfs merge=lfs -text
56
+ assets/mona_lisa/mask_aae09152-4495-4332-b691-a0c7bff524be.png filter=lfs diff=lfs merge=lfs -text
57
+ assets/mona_lisa/masked_image_aae09152-4495-4332-b691-a0c7bff524be.png filter=lfs diff=lfs merge=lfs -text
58
+ assets/mona_lisa/mona_lisa.png filter=lfs diff=lfs merge=lfs -text
59
+ assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png filter=lfs diff=lfs merge=lfs -text
60
+ assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png filter=lfs diff=lfs merge=lfs -text
61
+ assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png filter=lfs diff=lfs merge=lfs -text
62
+ assets/sunflower_girl/sunflower_girl.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Initially taken from GitHub's Python gitignore file
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # tests and logs
12
+ tests/fixtures/cached_*_text.txt
13
+ logs/
14
+ lightning_logs/
15
+ lang_code_data/
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a Python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ .python-version
90
+
91
+ # celery beat schedule file
92
+ celerybeat-schedule
93
+
94
+ # SageMath parsed files
95
+ *.sage.py
96
+
97
+ # Environments
98
+ .env
99
+ .venv
100
+ env/
101
+ venv/
102
+ ENV/
103
+ env.bak/
104
+ venv.bak/
105
+
106
+ # Spyder project settings
107
+ .spyderproject
108
+ .spyproject
109
+
110
+ # Rope project settings
111
+ .ropeproject
112
+
113
+ # mkdocs documentation
114
+ /site
115
+
116
+ # mypy
117
+ .mypy_cache/
118
+ .dmypy.json
119
+ dmypy.json
120
+
121
+ # Pyre type checker
122
+ .pyre/
123
+
124
+ # vscode
125
+ .vs
126
+ .vscode
127
+
128
+ # Pycharm
129
+ .idea
130
+
131
+ # TF code
132
+ tensorflow_code
133
+
134
+ # Models
135
+ proc_data
136
+
137
+ # examples
138
+ runs
139
+ /runs_old
140
+ /wandb
141
+ /examples/runs
142
+ /examples/**/*.args
143
+ /examples/rag/sweep
144
+
145
+ # data
146
+ /data
147
+ serialization_dir
148
+
149
+ # emacs
150
+ *.*~
151
+ debug.env
152
+
153
+ # vim
154
+ .*.swp
155
+
156
+ # ctags
157
+ tags
158
+
159
+ # pre-commit
160
+ .pre-commit*
161
+
162
+ # .lock
163
+ *.lock
164
+
165
+ # DS_Store (MacOS)
166
+ .DS_Store
167
+
168
+ # RL pipelines may produce mp4 outputs
169
+ *.mp4
170
+
171
+ # dependencies
172
+ /transformers
173
+
174
+ # ruff
175
+ .ruff_cache
176
+
177
+ # wandb
178
+ wandb
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: BrushEdit
3
+ emoji: 🤠
4
+ colorFrom: indigo
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 4.38.1
8
+ app_file: app/gpt4_o/brushedit_app.py
9
+ pinned: false
10
+ python_version: 3.1
11
+ ---
app/gpt4_o/brushedit_all_in_one_pipeline.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageEnhance
2
+ from diffusers.image_processor import VaeImageProcessor
3
+
4
+ import numpy as np
5
+ import cv2
6
+
7
+
8
+
9
+ def BrushEdit_Pipeline(pipe,
10
+ prompts,
11
+ mask_np,
12
+ original_image,
13
+ generator,
14
+ num_inference_steps,
15
+ guidance_scale,
16
+ control_strength,
17
+ negative_prompt,
18
+ num_samples,
19
+ blending):
20
+ if mask_np.ndim != 3:
21
+ mask_np = mask_np[:, :, np.newaxis]
22
+
23
+ mask_np = mask_np / 255
24
+ height, width = mask_np.shape[0], mask_np.shape[1]
25
+ # back/foreground
26
+ # if mask_np[94:547,94:546].sum() < mask_np.sum() - mask_np[94:547,94:546].sum() and mask_np[0,:].sum()>0 and mask_np[-1,:].sum()>0 and mask_np[:,0].sum()>0 and mask_np[:,-1].sum()>0 and mask_np[1,:].sum()>0 and mask_np[-2,:].sum()>0 and mask_np[:,1].sum()>0 and mask_np[:,-2].sum()>0 :
27
+ # mask_np = 1 - mask_np
28
+
29
+ ## resize the mask and original image to the same size which is divisible by vae_scale_factor
30
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
31
+ height_new, width_new = image_processor.get_default_height_width(original_image, height, width)
32
+ mask_np = cv2.resize(mask_np, (width_new, height_new))[:,:,np.newaxis]
33
+ mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255
34
+ mask_blurred = mask_blurred[:, :, np.newaxis]
35
+
36
+ original_image = cv2.resize(original_image, (width_new, height_new))
37
+
38
+ init_image = original_image * (1 - mask_np)
39
+ init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB")
40
+ mask_image = Image.fromarray((mask_np.repeat(3, -1) * 255).astype(np.uint8)).convert("RGB")
41
+
42
+ brushnet_conditioning_scale = float(control_strength)
43
+
44
+ images = pipe(
45
+ [prompts] * num_samples,
46
+ init_image,
47
+ mask_image,
48
+ num_inference_steps=num_inference_steps,
49
+ guidance_scale=guidance_scale,
50
+ generator=generator,
51
+ brushnet_conditioning_scale=brushnet_conditioning_scale,
52
+ negative_prompt=[negative_prompt]*num_samples,
53
+ height=height_new,
54
+ width=width_new,
55
+ ).images
56
+
57
+ if blending:
58
+
59
+ mask_blurred = mask_blurred * 0.5 + 0.5
60
+
61
+ ## convert to vae shape format, must be divisible by 8
62
+ original_image_pil = Image.fromarray(original_image).convert("RGB")
63
+ init_image_np = np.array(image_processor.preprocess(original_image_pil, height=height_new, width=width_new).squeeze())
64
+ init_image_np = ((init_image_np.transpose(1,2,0) + 1.) / 2.) * 255
65
+ init_image_np = init_image_np.astype(np.uint8)
66
+ image_all = []
67
+ for image_i in images:
68
+ image_np = np.array(image_i)
69
+ ## blending
70
+ image_pasted = init_image_np * (1 - mask_blurred) + mask_blurred * image_np
71
+ image_pasted = image_pasted.astype(np.uint8)
72
+ image = Image.fromarray(image_pasted)
73
+ image_all.append(image)
74
+ else:
75
+ image_all = images
76
+
77
+
78
+ return image_all, mask_image
79
+
80
+
app/gpt4_o/brushedit_app.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os, random
4
+ import numpy as np
5
+ import torch
6
+
7
+ import gradio as gr
8
+ import spaces
9
+
10
+ from PIL import Image
11
+
12
+
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
16
+ from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
17
+ from scipy.ndimage import binary_dilation, binary_erosion
18
+
19
+ from app.gpt4_o.vlm_pipeline import (
20
+ vlm_response_editing_type,
21
+ vlm_response_object_wait_for_edit,
22
+ vlm_response_mask,
23
+ vlm_response_prompt_after_apply_instruction
24
+ )
25
+ from app.gpt4_o.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
26
+ from app.utils.utils import load_grounding_dino_model
27
+
28
+
29
+ #### Description ####
30
+ head = r"""
31
+ <div style="text-align: center;">
32
+ <h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
33
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
34
+ <a href='https://tencentarc.github.io/BrushNet/'><img src='https://img.shields.io/badge/Project_Page-BrushNet-green' alt='Project Page'></a>
35
+ <a href='https://github.com/TencentARC/BrushNet/blob/main/InstructionGuidedEditing/CVPR2024workshop_technique_report.pdf'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
36
+ <a href='https://github.com/TencentARC/BrushNet'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
37
+
38
+ </div>
39
+ </br>
40
+ </div>
41
+ """
42
+ descriptions = r"""
43
+ Official Gradio Demo for <a href='https://tencentarc.github.io/BrushNet/'><b>BrushEdit: All-In-One Image Inpainting and Editing</b></a><br>
44
+ 🧙 BrushEdit enables precise, user-friendly instruction-based image editing via a inpainting model.<br>
45
+ """
46
+
47
+ instructions = r"""
48
+ Currently, we support two modes: <b>fully automated command editing</b> and <b>interactive command editing</b>.
49
+
50
+ 🛠️ <b>Fully automated instruction-based editing</b>:
51
+ <ul>
52
+ <li> ⭐️ <b>step1:</b> Upload or select one image from Example. </li>
53
+ <li> ⭐️ <b>step2:</b> Input the instructions (supports addition, deletion, and modification), e.g. remove xxx .</li>
54
+ <li> ⭐️ <b>step3:</b> Click <b>Run</b> button to automatic edit image.</li>
55
+ </ul>
56
+
57
+ 🛠️ <b>Interactive instruction-based editing</b>:
58
+ <ul>
59
+ <li> ⭐️ <b>step1:</b> Upload or select one image from Example. </li>
60
+ <li> ⭐️ <b>step2:</b> Use a brush to outline the area you want to edit. </li>
61
+ <li> ⭐️ <b>step3:</b> Input the instructions. </li>
62
+ <li> ⭐️ <b>step4:</b> Click <b>Run</b> button to automatic edit image. </li>
63
+ </ul>
64
+
65
+ 💡 <b>Some tips</b>:
66
+ <ul>
67
+ <li> 🤠 After input the instructions, you can click the <b>Generate Mask</b> button. The mask generated by VLM will be displayed in the preview panel on the right side. </li>
68
+ <li> 🤠 After generating the mask or when you use the brush to draw the mask, you can perform operations such as <b>randomization</b>, <b>dilation</b>, <b>erosion</b>, and <b>movement</b>. </li>
69
+ <li> 🤠 After input the instructions, you can click the <b>Generate Target Prompt</b> button. The target prompt will be displayed in the text box, and you can modify it according to your ideas. </li>
70
+ </ul>
71
+
72
+ ☕️ Have fun!
73
+ """
74
+
75
+
76
+ # - - - - - examples - - - - - #
77
+ EXAMPLES = [
78
+ # [
79
+ # {"background": Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png").convert("RGBA"),
80
+ # "layers": [Image.new("RGBA", (Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png").width, Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png").height), (0, 0, 0, 0))],
81
+ # "composite": Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png").convert("RGBA")},
82
+ # # Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png").convert("RGBA"),
83
+ # "add a shining necklace",
84
+ # # [Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.jpg")],
85
+ # # [Image.open("assets/mona_lisa/mask_aae09152-4495-4332-b691-a0c7bff524be.png")],
86
+ # # [Image.open("assets/mona_lisa/masked_image_aae09152-4495-4332-b691-a0c7bff524be.png")]
87
+ # ],
88
+
89
+ [
90
+ # load_image_from_url("https://github.com/liyaowei-stu/BrushEdit/blob/main/assets/mona_lisa/mona_lisa.png"),
91
+ Image.open("assets/mona_lisa/mona_lisa.png").convert("RGBA"),
92
+ "add a shining necklace",
93
+ # [Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.jpg")],
94
+ # [Image.open("assets/mona_lisa/mask_aae09152-4495-4332-b691-a0c7bff524be.png")],
95
+ # [Image.open("assets/mona_lisa/masked_image_aae09152-4495-4332-b691-a0c7bff524be.png")]
96
+ ],
97
+
98
+
99
+
100
+
101
+ ]
102
+
103
+
104
+ ## init VLM
105
+ from openai import OpenAI
106
+
107
+ OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
108
+ os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
109
+ vlm = OpenAI(base_url="http://v2.open.venus.oa.com/llmproxy")
110
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
111
+
112
+
113
+
114
+ # download hf models
115
+ base_model_path = hf_hub_download(
116
+ repo_id="Yw22/BrushEdit",
117
+ subfolder="base_model/realisticVisionV60B1_v51VAE",
118
+ token=os.getenv("HF_TOKEN"),
119
+ )
120
+
121
+
122
+ brushnet_path = hf_hub_download(
123
+ repo_id="Yw22/BrushEdit",
124
+ subfolder="brushnetX",
125
+ token=os.getenv("HF_TOKEN"),
126
+ )
127
+
128
+ sam_path = hf_hub_download(
129
+ repo_id="Yw22/BrushEdit",
130
+ subfolder="sam",
131
+ filename="sam_vit_h_4b8939.pth",
132
+ token=os.getenv("HF_TOKEN"),
133
+ )
134
+
135
+ groundingdino_path = hf_hub_download(
136
+ repo_id="Yw22/BrushEdit",
137
+ subfolder="grounding_dino",
138
+ filename="groundingdino_swint_ogc.pth",
139
+ token=os.getenv("HF_TOKEN"),
140
+ )
141
+
142
+
143
+ # input brushnetX ckpt path
144
+ brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16)
145
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
146
+ base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False
147
+ )
148
+ # speed up diffusion process with faster scheduler and memory optimization
149
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
150
+ # remove following line if xformers is not installed or when using Torch 2.0.
151
+ # pipe.enable_xformers_memory_efficient_attention()
152
+ pipe.enable_model_cpu_offload()
153
+
154
+
155
+ ## init SAM
156
+ sam = build_sam(checkpoint=sam_path)
157
+ sam.to(device=device)
158
+ sam_predictor = SamPredictor(sam)
159
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
160
+
161
+ ## init groundingdino_model
162
+ config_file = 'third_party/Grounded-Segment-Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
163
+ groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
164
+
165
+ ## Ordinary function
166
+ def crop_and_resize(image: Image.Image,
167
+ target_width: int,
168
+ target_height: int) -> Image.Image:
169
+ """
170
+ Crops and resizes an image while preserving the aspect ratio.
171
+
172
+ Args:
173
+ image (Image.Image): Input PIL image to be cropped and resized.
174
+ target_width (int): Target width of the output image.
175
+ target_height (int): Target height of the output image.
176
+
177
+ Returns:
178
+ Image.Image: Cropped and resized image.
179
+ """
180
+ # Original dimensions
181
+ original_width, original_height = image.size
182
+ original_aspect = original_width / original_height
183
+ target_aspect = target_width / target_height
184
+
185
+ # Calculate crop box to maintain aspect ratio
186
+ if original_aspect > target_aspect:
187
+ # Crop horizontally
188
+ new_width = int(original_height * target_aspect)
189
+ new_height = original_height
190
+ left = (original_width - new_width) / 2
191
+ top = 0
192
+ right = left + new_width
193
+ bottom = original_height
194
+ else:
195
+ # Crop vertically
196
+ new_width = original_width
197
+ new_height = int(original_width / target_aspect)
198
+ left = 0
199
+ top = (original_height - new_height) / 2
200
+ right = original_width
201
+ bottom = top + new_height
202
+
203
+ # Crop and resize
204
+ cropped_image = image.crop((left, top, right, bottom))
205
+ resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
206
+
207
+ return resized_image
208
+
209
+
210
+ def move_mask_func(mask, direction, units):
211
+ binary_mask = mask.squeeze()>0
212
+ rows, cols = binary_mask.shape
213
+
214
+ moved_mask = np.zeros_like(binary_mask, dtype=bool)
215
+
216
+ if direction == 'down':
217
+ # move down
218
+ moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
219
+
220
+ elif direction == 'up':
221
+ # move up
222
+ moved_mask[:rows - units, :] = binary_mask[units:, :]
223
+
224
+ elif direction == 'right':
225
+ # move left
226
+ moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
227
+
228
+ elif direction == 'left':
229
+ # move right
230
+ moved_mask[:, :cols - units] = binary_mask[:, units:]
231
+
232
+ return moved_mask
233
+
234
+
235
+ def random_mask_func(mask, dilation_type='square'):
236
+ # Randomly select the size of dilation
237
+ dilation_size = np.random.randint(20, 40) # Randomly select the size of dilation
238
+ binary_mask = mask.squeeze()>0
239
+
240
+ if dilation_type == 'square_dilation':
241
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
242
+ dilated_mask = binary_dilation(binary_mask, structure=structure)
243
+ elif dilation_type == 'square_erosion':
244
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
245
+ dilated_mask = binary_erosion(binary_mask, structure=structure)
246
+ elif dilation_type == 'bounding_box':
247
+ # find the most left top and left bottom point
248
+ rows, cols = np.where(binary_mask)
249
+ if len(rows) == 0 or len(cols) == 0:
250
+ return mask # return original mask if no valid points
251
+
252
+ min_row = np.min(rows)
253
+ max_row = np.max(rows)
254
+ min_col = np.min(cols)
255
+ max_col = np.max(cols)
256
+
257
+ # create a bounding box
258
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
259
+ dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
260
+
261
+ elif dilation_type == 'bounding_ellipse':
262
+ # find the most left top and left bottom point
263
+ rows, cols = np.where(binary_mask)
264
+ if len(rows) == 0 or len(cols) == 0:
265
+ return mask # return original mask if no valid points
266
+
267
+ min_row = np.min(rows)
268
+ max_row = np.max(rows)
269
+ min_col = np.min(cols)
270
+ max_col = np.max(cols)
271
+
272
+ # calculate the center and axis length of the ellipse
273
+ center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
274
+ a = (max_col - min_col) // 2 # half long axis
275
+ b = (max_row - min_row) // 2 # half short axis
276
+
277
+ # create a bounding ellipse
278
+ y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
279
+ ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
280
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
281
+ dilated_mask[ellipse_mask] = True
282
+ else:
283
+ raise ValueError("dilation_type must be 'square' or 'ellipse'")
284
+
285
+ # use binary dilation
286
+ dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
287
+ return dilated_mask
288
+
289
+
290
+ ## Gradio component function
291
+ @spaces.GPU(duration=180)
292
+ def process(input_image,
293
+ original_image,
294
+ original_mask,
295
+ prompt,
296
+ negative_prompt,
297
+ control_strength,
298
+ seed,
299
+ randomize_seed,
300
+ guidance_scale,
301
+ num_inference_steps,
302
+ num_samples,
303
+ blending,
304
+ category,
305
+ target_prompt,
306
+ resize_and_crop):
307
+
308
+ import ipdb; ipdb.set_trace()
309
+ if original_image is None:
310
+ raise gr.Error('Please upload the input image')
311
+ if prompt is None:
312
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
313
+
314
+
315
+ alpha_mask = input_image["layers"][0].split()[3]
316
+ input_mask = np.asarray(alpha_mask)
317
+ if resize_and_crop:
318
+ original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
319
+ input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
320
+ original_image = np.array(original_image)
321
+ input_mask = np.array(input_mask)
322
+
323
+ if input_mask.max() == 0:
324
+ original_mask = original_mask
325
+ else:
326
+ original_mask = input_mask[:,:,None]
327
+
328
+ # load example image
329
+ # if isinstance(original_image, str):
330
+ # # image_name = image_examples[original_image][0]
331
+ # # original_image = cv2.imread(image_name)
332
+ # # original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
333
+ # original_image = input_image
334
+ # num_samples = 1
335
+ # blending = True
336
+
337
+ if category is not None:
338
+ pass
339
+ else:
340
+ category = vlm_response_editing_type(vlm, original_image, prompt)
341
+
342
+
343
+ if original_mask is not None:
344
+ original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
345
+ else:
346
+ object_wait_for_edit = vlm_response_object_wait_for_edit(vlm,
347
+ category,
348
+ prompt)
349
+ original_mask = vlm_response_mask(vlm,
350
+ category,
351
+ original_image,
352
+ prompt,
353
+ object_wait_for_edit,
354
+ sam,
355
+ sam_predictor,
356
+ sam_automask_generator,
357
+ groundingdino_model,
358
+ )[:,:,None]
359
+
360
+
361
+ if len(target_prompt) <= 1:
362
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(vlm,
363
+ original_image,
364
+ prompt)
365
+ else:
366
+ prompt_after_apply_instruction = target_prompt
367
+
368
+ generator = torch.Generator("cuda").manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
369
+
370
+
371
+
372
+ image, mask_image = BrushEdit_Pipeline(pipe,
373
+ prompt_after_apply_instruction,
374
+ original_mask,
375
+ original_image,
376
+ generator,
377
+ num_inference_steps,
378
+ guidance_scale,
379
+ control_strength,
380
+ negative_prompt,
381
+ num_samples,
382
+ blending)
383
+
384
+ masked_image = original_image * (1 - (original_mask>0))
385
+ masked_image = masked_image.astype(np.uint8)
386
+ masked_image = Image.fromarray(masked_image)
387
+
388
+ import uuid
389
+ uuid = str(uuid.uuid4())
390
+ image[0].save(f"outputs/image_edit_{uuid}_0.png")
391
+ image[1].save(f"outputs/image_edit_{uuid}_1.png")
392
+ image[2].save(f"outputs/image_edit_{uuid}_2.png")
393
+ image[3].save(f"outputs/image_edit_{uuid}_3.png")
394
+ mask_image.save(f"outputs/mask_{uuid}.png")
395
+ masked_image.save(f"outputs/masked_image_{uuid}.png")
396
+ return image, [mask_image], [masked_image], ''
397
+
398
+
399
+ def generate_target_prompt(input_image,
400
+ original_image,
401
+ prompt):
402
+ # load example image
403
+ if isinstance(original_image, str):
404
+ original_image = input_image
405
+
406
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(vlm,
407
+ original_image,
408
+ prompt)
409
+ return prompt_after_apply_instruction
410
+
411
+
412
+ def process_mask(input_image,
413
+ original_image,
414
+ prompt,
415
+ resize_and_crop):
416
+ if original_image is None:
417
+ raise gr.Error('Please upload the input image')
418
+ if prompt is None:
419
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
420
+
421
+ ## load mask
422
+ alpha_mask = input_image["layers"][0].split()[3]
423
+ input_mask = np.array(alpha_mask)
424
+
425
+ # load example image
426
+ if isinstance(original_image, str):
427
+ original_image = input_image["background"]
428
+
429
+ if resize_and_crop:
430
+ original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
431
+ input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
432
+ original_image = np.array(original_image)
433
+ input_mask = np.array(input_mask)
434
+
435
+
436
+ if input_mask.max() == 0:
437
+ category = vlm_response_editing_type(vlm, original_image, prompt)
438
+
439
+ object_wait_for_edit = vlm_response_object_wait_for_edit(vlm,
440
+ category,
441
+ prompt)
442
+ # original mask: h,w,1 [0, 255]
443
+ original_mask = vlm_response_mask(
444
+ vlm,
445
+ category,
446
+ original_image,
447
+ prompt,
448
+ object_wait_for_edit,
449
+ sam,
450
+ sam_predictor,
451
+ sam_automask_generator,
452
+ groundingdino_model,
453
+ )[:,:,None]
454
+ else:
455
+ original_mask = input_mask[:,:,None]
456
+ category = None
457
+
458
+
459
+ mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
460
+
461
+ masked_image = original_image * (1 - (original_mask>0))
462
+ masked_image = masked_image.astype(np.uint8)
463
+ masked_image = Image.fromarray(masked_image)
464
+
465
+ ## not work for image editor
466
+ # background = input_image["background"]
467
+ # mask_array = original_mask.squeeze()
468
+ # layer_rgba = np.array(input_image['layers'][0])
469
+ # layer_rgba[mask_array > 0] = [0, 0, 0, 255]
470
+ # layer_rgba = Image.fromarray(layer_rgba, 'RGBA')
471
+ # black_image = Image.new("RGBA", layer_rgba.size, (0, 0, 0, 255))
472
+ # composite = Image.composite(black_image, background, layer_rgba)
473
+ # output_base = {"layers": [layer_rgba], "background": background, "composite": composite}
474
+
475
+
476
+ return [masked_image], [mask_image], original_mask.astype(np.uint8), category
477
+
478
+
479
+ def process_random_mask(input_image, original_image, original_mask, resize_and_crop):
480
+
481
+ alpha_mask = input_image["layers"][0].split()[3]
482
+ input_mask = np.asarray(alpha_mask)
483
+ if resize_and_crop:
484
+ original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
485
+ input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
486
+ original_image = np.array(original_image)
487
+ input_mask = np.array(input_mask)
488
+
489
+
490
+ if input_mask.max() == 0:
491
+ if original_mask is None:
492
+ raise gr.Error('Please generate mask first')
493
+ original_mask = original_mask
494
+ else:
495
+ original_mask = input_mask[:,:,None]
496
+
497
+
498
+ dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
499
+ random_mask = random_mask_func(original_mask, dilation_type).squeeze()
500
+
501
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
502
+
503
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
504
+ masked_image = masked_image.astype(original_image.dtype)
505
+ masked_image = Image.fromarray(masked_image)
506
+
507
+
508
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
509
+
510
+
511
+ def process_dilation_mask(input_image, original_image, original_mask, resize_and_crop):
512
+
513
+ alpha_mask = input_image["layers"][0].split()[3]
514
+ input_mask = np.asarray(alpha_mask)
515
+ if resize_and_crop:
516
+ original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
517
+ input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
518
+ original_image = np.array(original_image)
519
+ input_mask = np.array(input_mask)
520
+
521
+ if input_mask.max() == 0:
522
+ if original_mask is None:
523
+ raise gr.Error('Please generate mask first')
524
+ original_mask = original_mask
525
+ else:
526
+ original_mask = input_mask[:,:,None]
527
+
528
+ dilation_type = np.random.choice(['square_dilation'])
529
+ random_mask = random_mask_func(original_mask, dilation_type).squeeze()
530
+
531
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
532
+
533
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
534
+ masked_image = masked_image.astype(original_image.dtype)
535
+ masked_image = Image.fromarray(masked_image)
536
+
537
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
538
+
539
+
540
+ def process_erosion_mask(input_image, original_image, original_mask, resize_and_crop):
541
+ alpha_mask = input_image["layers"][0].split()[3]
542
+ input_mask = np.asarray(alpha_mask)
543
+ if resize_and_crop:
544
+ original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
545
+ input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
546
+ original_image = np.array(original_image)
547
+ input_mask = np.array(input_mask)
548
+
549
+ if input_mask.max() == 0:
550
+ if original_mask is None:
551
+ raise gr.Error('Please generate mask first')
552
+ original_mask = original_mask
553
+ else:
554
+ original_mask = input_mask[:,:,None]
555
+
556
+ dilation_type = np.random.choice(['square_erosion'])
557
+ random_mask = random_mask_func(original_mask, dilation_type).squeeze()
558
+
559
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
560
+
561
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
562
+ masked_image = masked_image.astype(original_image.dtype)
563
+ masked_image = Image.fromarray(masked_image)
564
+
565
+
566
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
567
+
568
+
569
+ def move_mask_left(input_image, original_image, original_mask, moving_pixels, resize_and_crop):
570
+
571
+ alpha_mask = input_image["layers"][0].split()[3]
572
+ input_mask = np.asarray(alpha_mask)
573
+ if resize_and_crop:
574
+ original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
575
+ input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
576
+ original_image = np.array(original_image)
577
+ input_mask = np.array(input_mask)
578
+
579
+ if input_mask.max() == 0:
580
+ if original_mask is None:
581
+ raise gr.Error('Please generate mask first')
582
+ original_mask = original_mask
583
+ else:
584
+ original_mask = input_mask[:,:,None]
585
+
586
+ moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
587
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
588
+
589
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
590
+ masked_image = masked_image.astype(original_image.dtype)
591
+ masked_image = Image.fromarray(masked_image)
592
+
593
+ if moved_mask.max() <= 1:
594
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
595
+ original_mask = moved_mask
596
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
597
+
598
+
599
+ def move_mask_right(input_image, original_image, original_mask, moving_pixels, resize_and_crop):
600
+ alpha_mask = input_image["layers"][0].split()[3]
601
+ input_mask = np.asarray(alpha_mask)
602
+ if resize_and_crop:
603
+ original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
604
+ input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
605
+ original_image = np.array(original_image)
606
+ input_mask = np.array(input_mask)
607
+
608
+ if input_mask.max() == 0:
609
+ if original_mask is None:
610
+ raise gr.Error('Please generate mask first')
611
+ original_mask = original_mask
612
+ else:
613
+ original_mask = input_mask[:,:,None]
614
+
615
+ moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
616
+
617
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
618
+
619
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
620
+ masked_image = masked_image.astype(original_image.dtype)
621
+ masked_image = Image.fromarray(masked_image)
622
+
623
+
624
+ if moved_mask.max() <= 1:
625
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
626
+ original_mask = moved_mask
627
+
628
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
629
+
630
+
631
+ def move_mask_up(input_image, original_image, original_mask, moving_pixels, resize_and_crop):
632
+ alpha_mask = input_image["layers"][0].split()[3]
633
+ input_mask = np.asarray(alpha_mask)
634
+ if resize_and_crop:
635
+ original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
636
+ input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
637
+ original_image = np.array(original_image)
638
+ input_mask = np.array(input_mask)
639
+
640
+ if input_mask.max() == 0:
641
+ if original_mask is None:
642
+ raise gr.Error('Please generate mask first')
643
+ original_mask = original_mask
644
+ else:
645
+ original_mask = input_mask[:,:,None]
646
+
647
+ moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
648
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
649
+
650
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
651
+ masked_image = masked_image.astype(original_image.dtype)
652
+ masked_image = Image.fromarray(masked_image)
653
+
654
+ if moved_mask.max() <= 1:
655
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
656
+ original_mask = moved_mask
657
+
658
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
659
+
660
+
661
+ def move_mask_down(input_image, original_image, original_mask, moving_pixels, resize_and_crop):
662
+ alpha_mask = input_image["layers"][0].split()[3]
663
+ input_mask = np.asarray(alpha_mask)
664
+ if resize_and_crop:
665
+ original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
666
+ input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
667
+ original_image = np.array(original_image)
668
+ input_mask = np.array(input_mask)
669
+
670
+ if input_mask.max() == 0:
671
+ if original_mask is None:
672
+ raise gr.Error('Please generate mask first')
673
+ original_mask = original_mask
674
+ else:
675
+ original_mask = input_mask[:,:,None]
676
+
677
+ moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
678
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
679
+
680
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
681
+ masked_image = masked_image.astype(original_image.dtype)
682
+ masked_image = Image.fromarray(masked_image)
683
+
684
+ if moved_mask.max() <= 1:
685
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
686
+ original_mask = moved_mask
687
+
688
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
689
+
690
+
691
+ def store_img(base):
692
+ import ipdb; ipdb.set_trace()
693
+ image_pil = base["background"].convert("RGB")
694
+ original_image = np.array(image_pil)
695
+ # import ipdb; ipdb.set_trace()
696
+ if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
697
+ raise gr.Error('image aspect ratio cannot be larger than 2.0')
698
+ return base, original_image, None, "", None, None, None, None, None
699
+
700
+
701
+ def reset_func(input_image, original_image, original_mask, prompt, target_prompt):
702
+ input_image = None
703
+ original_image = None
704
+ original_mask = None
705
+ prompt = ''
706
+ mask_gallery = []
707
+ masked_gallery = []
708
+ result_gallery = []
709
+ target_prompt = ''
710
+ return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt
711
+
712
+
713
+ block = gr.Blocks(
714
+ theme=gr.themes.Soft(
715
+ radius_size=gr.themes.sizes.radius_none,
716
+ text_size=gr.themes.sizes.text_md
717
+ )
718
+ ).queue()
719
+ with block as demo:
720
+ with gr.Row():
721
+ with gr.Column():
722
+ gr.HTML(head)
723
+
724
+ gr.Markdown(descriptions)
725
+
726
+ with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
727
+ with gr.Row(equal_height=True):
728
+ gr.Markdown(instructions)
729
+
730
+ original_image = gr.State(value=None)
731
+ original_mask = gr.State(value=None)
732
+ category = gr.State(value=None)
733
+
734
+ with gr.Row():
735
+ with gr.Column():
736
+ with gr.Row():
737
+ input_image = gr.ImageEditor(
738
+ label="Input Image",
739
+ type="pil",
740
+ brush=gr.Brush(colors=["#000000"], default_size = 30, color_mode="fixed"),
741
+ layers = False,
742
+ interactive=True,
743
+ height=800,
744
+ # transforms=("crop"),
745
+ # crop_size=(640, 640),
746
+ )
747
+
748
+ prompt = gr.Textbox(label="Prompt", placeholder="Please input your instruction.",value='',lines=1)
749
+
750
+ with gr.Row():
751
+ mask_button = gr.Button("Generate Mask")
752
+ random_mask_button = gr.Button("Random Generated Mask")
753
+ with gr.Row():
754
+ dilation_mask_button = gr.Button("Dilation Generated Mask")
755
+ erosion_mask_button = gr.Button("Erosion Generated Mask")
756
+
757
+ with gr.Row():
758
+ generate_target_prompt_button = gr.Button("Generate Target Prompt")
759
+ run_button = gr.Button("Run")
760
+
761
+
762
+ target_prompt = gr.Text(
763
+ label="Target prompt",
764
+ max_lines=5,
765
+ placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
766
+ value='',
767
+ lines=2
768
+ )
769
+
770
+ resize_and_crop = gr.Checkbox(label="Resize and Crop (640 x 640)", value=False)
771
+
772
+ with gr.Accordion("More input params (highly-recommended)", open=False, elem_id="accordion1"):
773
+ negative_prompt = gr.Text(
774
+ label="Negative Prompt",
775
+ max_lines=5,
776
+ placeholder="Please input your negative prompt",
777
+ value='ugly, low quality',lines=1
778
+ )
779
+
780
+ control_strength = gr.Slider(
781
+ label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
782
+ )
783
+ with gr.Group():
784
+ seed = gr.Slider(
785
+ label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
786
+ )
787
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
788
+
789
+ blending = gr.Checkbox(label="Blending mode", value=True)
790
+
791
+
792
+ num_samples = gr.Slider(
793
+ label="Num samples", minimum=0, maximum=4, step=1, value=4
794
+ )
795
+
796
+ with gr.Group():
797
+ with gr.Row():
798
+ guidance_scale = gr.Slider(
799
+ label="Guidance scale",
800
+ minimum=1,
801
+ maximum=12,
802
+ step=0.1,
803
+ value=7.5,
804
+ )
805
+ num_inference_steps = gr.Slider(
806
+ label="Number of inference steps",
807
+ minimum=1,
808
+ maximum=50,
809
+ step=1,
810
+ value=50,
811
+ )
812
+
813
+
814
+ with gr.Column():
815
+ with gr.Row():
816
+ with gr.Tabs(elem_classes=["feedback"]):
817
+ with gr.TabItem("Mask"):
818
+ mask_gallery = gr.Gallery(label='Mask', show_label=False, elem_id="gallery", preview=True, height=360)
819
+ with gr.Tabs(elem_classes=["feedback"]):
820
+ with gr.TabItem("Masked Image"):
821
+ masked_gallery = gr.Gallery(label='Masked Image', show_label=False, elem_id="gallery", preview=True, height=360)
822
+
823
+ moving_pixels = gr.Slider(
824
+ label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
825
+ )
826
+ with gr.Row():
827
+ move_left_button = gr.Button("Move Left")
828
+ move_right_button = gr.Button("Move Right")
829
+ with gr.Row():
830
+ move_up_button = gr.Button("Move Up")
831
+ move_down_button = gr.Button("Move Down")
832
+
833
+ with gr.Tabs(elem_classes=["feedback"]):
834
+ with gr.TabItem("Outputs"):
835
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, height=360)
836
+
837
+ reset_button = gr.Button("Reset")
838
+
839
+
840
+ with gr.Row():
841
+ # # example = gr.Examples(
842
+ # # label="Quick Example",
843
+ # # examples=EXAMPLES,
844
+ # # inputs=[prompt, seed, result_gallery, mask_gallery, masked_gallery],
845
+ # # examples_per_page=10,
846
+ # # cache_examples=False,
847
+ # # )
848
+ example = gr.Examples(
849
+ label="Quick Example",
850
+ examples=EXAMPLES,
851
+ inputs=[input_image, prompt],
852
+ examples_per_page=10,
853
+ cache_examples=False,
854
+ )
855
+ # def process_example(prompt, seed, eg_output):
856
+ # import ipdb; ipdb.set_trace()
857
+ # eg_output_path = os.path.join("assets/", eg_output)
858
+ # return prompt, seed, [Image.open(eg_output_path)]
859
+ # example = gr.Examples(
860
+ # label="Quick Example",
861
+ # examples=EXAMPLES,
862
+ # inputs=[prompt, seed, eg_output],
863
+ # outputs=[prompt, seed, result_gallery],
864
+ # fn=process_example,
865
+ # examples_per_page=10,
866
+ # run_on_click=True,
867
+ # cache_examples=False,
868
+ # )
869
+
870
+ input_image.upload(
871
+ store_img,
872
+ [input_image],
873
+ [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt]
874
+ )
875
+
876
+
877
+ ips=[input_image,
878
+ original_image,
879
+ original_mask,
880
+ prompt,
881
+ negative_prompt,
882
+ control_strength,
883
+ seed,
884
+ randomize_seed,
885
+ guidance_scale,
886
+ num_inference_steps,
887
+ num_samples,
888
+ blending,
889
+ category,
890
+ target_prompt,
891
+ resize_and_crop]
892
+
893
+ ## run brushedit
894
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, target_prompt])
895
+
896
+ ## mask func
897
+ mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask, category])
898
+ random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask])
899
+ dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_and_crop], outputs=[ masked_gallery, mask_gallery, original_mask])
900
+ erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_and_crop], outputs=[ masked_gallery, mask_gallery, original_mask])
901
+
902
+ ## move mask func
903
+ move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask])
904
+ move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask])
905
+ move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask])
906
+ move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask])
907
+
908
+ ## prompt func
909
+ generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
910
+
911
+ ## reset func
912
+ reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt])
913
+
914
+ demo.launch(server_name="0.0.0.0")
app/gpt4_o/instructions.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def create_editing_category_messages(editing_prompt):
2
+ messages = [{
3
+ "role": "system",
4
+ "content": [
5
+ {
6
+ "type": "text",
7
+ "text": "I will give you an image and an editing instruction of the image. Please output which type of editing category it is in. You can choose from the following categories: \n\
8
+ 1. Addition: Adding new objects within the images, e.g., add a bird to the image \n\
9
+ 2. Remove: Removing objects, e.g., remove the mask \n\
10
+ 3. Local: Replace local parts of an object and later the object's attributes (e.g., make it smile) or alter an object's visual appearance without affecting its structure (e.g., change the cat to a dog) \n\
11
+ 4. Global: Edit the entire image, e.g., let's see it in winter \n\
12
+ 5. Background: Change the scene's background, e.g., have her walk on water, change the background to a beach, make the hedgehog in France, etc.",
13
+ },]
14
+ },
15
+ {
16
+ "role": "user",
17
+ "content": [
18
+ {
19
+ "type": "text",
20
+ "text": editing_prompt
21
+ },
22
+ ]
23
+ }]
24
+ return messages
25
+
26
+
27
+ def create_ori_object_messages(editing_prompt):
28
+
29
+ messages = [
30
+ {
31
+ "role": "system",
32
+ "content": [
33
+ {
34
+ "type": "text",
35
+ "text": "I will give you an editing instruction of the image. Please output the object needed to be edited. You only need to output the basic description of the object in no more than 5 words. The output should only contain one noun. \n \
36
+ For example, the editing instruction is 'Change the white cat to a black dog'. Then you need to output: 'white cat'. Only output the new content. Do not output anything else."
37
+ },]
38
+ },
39
+ {
40
+ "role": "user",
41
+ "content": [
42
+ {
43
+ "type": "text",
44
+ "text": editing_prompt
45
+ }
46
+ ]
47
+ }
48
+ ]
49
+ return messages
50
+
51
+
52
+ def create_add_object_messages(editing_prompt, base64_image, height=640, width=640):
53
+
54
+ size_str = f"The image size is height {height}px and width {width}px. The top - left corner is coordinate [0 , 0]. The bottom - right corner is coordinnate [{height} , {width}]. "
55
+
56
+ messages = [
57
+ {
58
+ "role": "user",
59
+ "content": [
60
+ {
61
+ "type": "text",
62
+ "text": "I need to add an object to the image following the instruction: " + editing_prompt + ". " + size_str + " \n \
63
+ Can you give me a possible bounding box of the location for the added object? Please output with the format of [top - left x coordinate , top - left y coordinate , box width , box height]. You should only output the bounding box position and nothing else. Please refer to the example below for the desired format.\n\
64
+ [Examples]\n \
65
+ [19, 101, 32, 153]\n \
66
+ [54, 12, 242, 96]"
67
+ },
68
+ {
69
+ "type": "image_url",
70
+ "image_url": {
71
+ "url":f"data:image/jpeg;base64,{base64_image}"
72
+ },
73
+ }
74
+ ]
75
+ }
76
+ ]
77
+ return messages
78
+
79
+
80
+ def create_apply_editing_messages(editing_prompt, base64_image):
81
+ messages = [
82
+ {
83
+ "role": "system",
84
+ "content": [
85
+ {
86
+ "type": "text",
87
+ "text": "I will provide an image along with an editing instruction. Please describe the new content that should be present in the image after applying the instruction. \n \
88
+ For example, if the original image content shows a grandmother wearing a mask and the instruction is 'remove the mask', your output should be: 'a grandmother'. The output should only include elements that remain in the image after the edit and should not mention elements that have been changed or removed, such as 'mask' in this example. Do not output 'sorry, xxx', even if it's a guess, directly output the answer you think is correct."
89
+ },]
90
+ },
91
+ {
92
+ "role": "user",
93
+ "content": [
94
+ {
95
+ "type": "text",
96
+ "text": editing_prompt
97
+ },
98
+ {"type": "image_url",
99
+ "image_url": {
100
+ "url":f"data:image/jpeg;base64,{base64_image}"
101
+ },
102
+ },
103
+ ]
104
+ }
105
+ ]
106
+ return messages
app/gpt4_o/requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchvision
2
+ transformers>=4.25.1
3
+ ftfy
4
+ tensorboard
5
+ datasets
6
+ Pillow==9.5.0
7
+ opencv-python
8
+ imgaug
9
+ accelerate==0.20.3
10
+ image-reward
11
+ hpsv2
12
+ torchmetrics
13
+ open-clip-torch
14
+ clip
15
+ # gradio==4.44.1
16
+ gradio==4.38.1
17
+ segment_anything
18
+ openai
app/gpt4_o/run_app.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ export PYTHONPATH=.:$PYTHONPATH
2
+
3
+ export CUDA_VISIBLE_DEVICES=0
4
+
5
+ python app/gpt4_o/brushedit_app.py
app/gpt4_o/vlm_pipeline.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import re
3
+ import torch
4
+
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import numpy as np
8
+ import gradio as gr
9
+
10
+
11
+ from app.gpt4_o.instructions import (
12
+ create_editing_category_messages,
13
+ create_ori_object_messages,
14
+ create_add_object_messages,
15
+ create_apply_editing_messages)
16
+
17
+ from app.utils.utils import run_grounded_sam
18
+
19
+
20
+ def encode_image(img):
21
+ img = Image.fromarray(img.astype('uint8'))
22
+ buffered = BytesIO()
23
+ img.save(buffered, format="PNG")
24
+ img_bytes = buffered.getvalue()
25
+ return base64.b64encode(img_bytes).decode('utf-8')
26
+
27
+
28
+ def run_gpt4o_vl_inference(vlm,
29
+ messages):
30
+ response = vlm.chat.completions.create(
31
+ model="gpt-4o-2024-08-06",
32
+ messages=messages
33
+ )
34
+ response_str = response.choices[0].message.content
35
+ return response_str
36
+
37
+
38
+ def vlm_response_editing_type(vlm,
39
+ image,
40
+ editing_prompt):
41
+
42
+ base64_image = encode_image(image)
43
+
44
+ messages = create_editing_category_messages(editing_prompt)
45
+
46
+ response_str = run_gpt4o_vl_inference(vlm, messages)
47
+
48
+ for category_name in ["Addition","Remove","Local","Global","Background"]:
49
+ if category_name.lower() in response_str.lower():
50
+ return category_name
51
+ raise ValueError("Please input correct commands, including add, delete, and modify commands.")
52
+
53
+
54
+ def vlm_response_object_wait_for_edit(vlm,
55
+ category,
56
+ editing_prompt):
57
+ if category in ["Background", "Global", "Addition"]:
58
+ edit_object = "nan"
59
+ return edit_object
60
+
61
+ messages = create_ori_object_messages(editing_prompt)
62
+
63
+ response_str = run_gpt4o_vl_inference(vlm, messages)
64
+ return response_str
65
+
66
+
67
+ def vlm_response_mask(vlm,
68
+ category,
69
+ image,
70
+ editing_prompt,
71
+ object_wait_for_edit,
72
+ sam=None,
73
+ sam_predictor=None,
74
+ sam_automask_generator=None,
75
+ groundingdino_model=None,
76
+ ):
77
+ mask = None
78
+ if editing_prompt is None or len(editing_prompt)==0:
79
+ raise gr.Error("Please input the editing instruction!")
80
+ height, width = image.shape[:2]
81
+ if category=="Addition":
82
+ base64_image = encode_image(image)
83
+ messages = create_add_object_messages(editing_prompt, base64_image, height=height, width=width)
84
+ try:
85
+ response_str = run_gpt4o_vl_inference(vlm, messages)
86
+ pattern = r'\[\d{1,3}(?:,\s*\d{1,3}){3}\]'
87
+ box = re.findall(pattern, response_str)
88
+ box = box[0][1:-1].split(",")
89
+ for i in range(len(box)):
90
+ box[i] = int(box[i])
91
+ cus_mask = np.zeros((height, width))
92
+ cus_mask[box[1]: box[1]+box[3], box[0]: box[0]+box[2]]=255
93
+ mask = cus_mask
94
+ except:
95
+ raise gr.Error("Please set the mask manually, MLLM cannot output the mask!")
96
+
97
+ elif category=="Background":
98
+ labels = "background"
99
+ elif category=="Global":
100
+ mask = 255 * np.zeros((height, width))
101
+ else:
102
+ labels = object_wait_for_edit
103
+
104
+ if mask is None:
105
+ for thresh in [0.3,0.25,0.2,0.15,0.1,0.05,0]:
106
+ try:
107
+ device = "cuda" if torch.cuda.is_available() else "cpu"
108
+ detections = run_grounded_sam(
109
+ input_image={"image":Image.fromarray(image.astype('uint8')),
110
+ "mask":None},
111
+ text_prompt=labels,
112
+ task_type="seg",
113
+ box_threshold=thresh,
114
+ text_threshold=0.25,
115
+ iou_threshold=0.5,
116
+ scribble_mode="split",
117
+ sam=sam,
118
+ sam_predictor=sam_predictor,
119
+ sam_automask_generator=sam_automask_generator,
120
+ groundingdino_model=groundingdino_model,
121
+ device=device,
122
+ )
123
+ mask = np.array(detections[0,0,...].cpu()) * 255
124
+ break
125
+ except:
126
+ print(f"wrong in threshhold: {thresh}, continue")
127
+ continue
128
+ return mask
129
+
130
+
131
+ def vlm_response_prompt_after_apply_instruction(vlm,
132
+ image,
133
+ editing_prompt):
134
+ base64_image = encode_image(image)
135
+ messages = create_apply_editing_messages(editing_prompt, base64_image)
136
+
137
+ response_str = run_gpt4o_vl_inference(vlm, messages)
138
+ return response_str
app/utils/utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchvision
4
+
5
+ from scipy import ndimage
6
+
7
+ # BLIP
8
+ from transformers import BlipProcessor, BlipForConditionalGeneration
9
+
10
+ # SAM
11
+ from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
12
+
13
+ # GroundingDINO
14
+ from groundingdino.datasets import transforms as T
15
+ from groundingdino.models import build_model
16
+ from groundingdino.util.slconfig import SLConfig
17
+ from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
18
+
19
+
20
+ def load_grounding_dino_model(model_config_path, model_checkpoint_path, device):
21
+ args = SLConfig.fromfile(model_config_path)
22
+ args.device = device
23
+ model = build_model(args)
24
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
25
+ load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
26
+ print(load_res)
27
+ _ = model.eval()
28
+ return model
29
+
30
+
31
+ def generate_caption(processor, blip_model, raw_image, device):
32
+ # unconditional image captioning
33
+ inputs = processor(raw_image, return_tensors="pt").to(device, torch.float16)
34
+ out = blip_model.generate(**inputs)
35
+ caption = processor.decode(out[0], skip_special_tokens=True)
36
+ return caption
37
+
38
+
39
+
40
+ def transform_image(image_pil):
41
+
42
+ transform = T.Compose(
43
+ [
44
+ T.RandomResize([800], max_size=1333),
45
+ T.ToTensor(),
46
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
47
+ ]
48
+ )
49
+ image, _ = transform(image_pil, None) # 3, h, w
50
+ return image
51
+
52
+
53
+ def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True):
54
+ caption = caption.lower()
55
+ caption = caption.strip()
56
+ if not caption.endswith("."):
57
+ caption = caption + "."
58
+
59
+ with torch.no_grad():
60
+ outputs = model(image[None], captions=[caption])
61
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
62
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
63
+ logits.shape[0]
64
+
65
+ # filter output
66
+ logits_filt = logits.clone()
67
+ boxes_filt = boxes.clone()
68
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
69
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
70
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
71
+ logits_filt.shape[0]
72
+
73
+ # get phrase
74
+ tokenlizer = model.tokenizer
75
+ tokenized = tokenlizer(caption)
76
+ # build pred
77
+ pred_phrases = []
78
+ scores = []
79
+ for logit, box in zip(logits_filt, boxes_filt):
80
+ pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
81
+ if with_logits:
82
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
83
+ else:
84
+ pred_phrases.append(pred_phrase)
85
+ scores.append(logit.max().item())
86
+
87
+ return boxes_filt, torch.Tensor(scores), pred_phrases
88
+
89
+
90
+
91
+ def run_grounded_sam(input_image,
92
+ text_prompt,
93
+ task_type,
94
+ box_threshold,
95
+ text_threshold,
96
+ iou_threshold,
97
+ scribble_mode,
98
+ sam,
99
+ groundingdino_model,
100
+ sam_predictor=None,
101
+ sam_automask_generator=None,
102
+ device="cuda"):
103
+
104
+ global blip_processor, blip_model, inpaint_pipeline
105
+
106
+ # load image
107
+ image = input_image["image"]
108
+ scribble = input_image["mask"]
109
+ size = image.size # w, h
110
+
111
+ if sam_predictor is None:
112
+ sam_predictor = SamPredictor(sam)
113
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
114
+
115
+ image_pil = image.convert("RGB")
116
+ image = np.array(image_pil)
117
+
118
+ if task_type == 'scribble':
119
+ sam_predictor.set_image(image)
120
+ scribble = scribble.convert("RGB")
121
+ scribble = np.array(scribble)
122
+ scribble = scribble.transpose(2, 1, 0)[0]
123
+
124
+ # 将连通域进行标记
125
+ labeled_array, num_features = ndimage.label(scribble >= 255)
126
+
127
+ # 计算每个连通域的质心
128
+ centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1))
129
+ centers = np.array(centers)
130
+
131
+ point_coords = torch.from_numpy(centers)
132
+ point_coords = sam_predictor.transform.apply_coords_torch(point_coords, image.shape[:2])
133
+ point_coords = point_coords.unsqueeze(0).to(device)
134
+ point_labels = torch.from_numpy(np.array([1] * len(centers))).unsqueeze(0).to(device)
135
+ if scribble_mode == 'split':
136
+ point_coords = point_coords.permute(1, 0, 2)
137
+ point_labels = point_labels.permute(1, 0)
138
+ masks, _, _ = sam_predictor.predict_torch(
139
+ point_coords=point_coords if len(point_coords) > 0 else None,
140
+ point_labels=point_labels if len(point_coords) > 0 else None,
141
+ mask_input = None,
142
+ boxes = None,
143
+ multimask_output = False,
144
+ )
145
+ elif task_type == 'automask':
146
+ masks = sam_automask_generator.generate(image)
147
+ else:
148
+ transformed_image = transform_image(image_pil)
149
+
150
+ if task_type == 'automatic':
151
+ # generate caption and tags
152
+ # use Tag2Text can generate better captions
153
+ # https://huggingface.co/spaces/xinyu1205/Tag2Text
154
+ # but there are some bugs...
155
+ blip_processor = blip_processor or BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
156
+ blip_model = blip_model or BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
157
+ text_prompt = generate_caption(blip_processor, blip_model, image_pil, device)
158
+ print(f"Caption: {text_prompt}")
159
+
160
+ # run grounding dino model
161
+ boxes_filt, scores, pred_phrases = get_grounding_output(
162
+ groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold
163
+ )
164
+
165
+ # process boxes
166
+ H, W = size[1], size[0]
167
+ for i in range(boxes_filt.size(0)):
168
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
169
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
170
+ boxes_filt[i][2:] += boxes_filt[i][:2]
171
+
172
+ boxes_filt = boxes_filt.cpu()
173
+
174
+
175
+ if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic':
176
+ sam_predictor.set_image(image)
177
+
178
+ if task_type == 'automatic':
179
+ # use NMS to handle overlapped boxes
180
+ print(f"Before NMS: {boxes_filt.shape[0]} boxes")
181
+ nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
182
+ boxes_filt = boxes_filt[nms_idx]
183
+ pred_phrases = [pred_phrases[idx] for idx in nms_idx]
184
+ print(f"After NMS: {boxes_filt.shape[0]} boxes")
185
+ print(f"Revise caption with number: {text_prompt}")
186
+
187
+ transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
188
+
189
+ masks, _, _ = sam_predictor.predict_torch(
190
+ point_coords = None,
191
+ point_labels = None,
192
+ boxes = transformed_boxes,
193
+ multimask_output = False,
194
+ )
195
+ return masks
196
+ else:
197
+ print("task_type:{} error!".format(task_type))
assets/hedgehog_rm_fg/hedgehog.png ADDED

Git LFS Details

  • SHA256: e64da50164ce3136e269a7f02db37a375373ffa1133c4c1d0345762f725ad7b6
  • Pointer size: 131 Bytes
  • Size of remote file: 735 kB
assets/hedgehog_rm_fg/image_edit_82314e18-c64c-4003-9ef9-52cebf254b2f_2.png ADDED

Git LFS Details

  • SHA256: 691bc84fa24f8abec7efcb93fe5ca3d111288e2790e7182ca618894bd8026005
  • Pointer size: 131 Bytes
  • Size of remote file: 712 kB
assets/hedgehog_rm_fg/mask_82314e18-c64c-4003-9ef9-52cebf254b2f.png ADDED

Git LFS Details

  • SHA256: 2e466f5080433b6b2b3cb0cdbdd882a39b071e9c2eb8a118cbc836253280911f
  • Pointer size: 129 Bytes
  • Size of remote file: 3.03 kB
assets/hedgehog_rm_fg/masked_image_82314e18-c64c-4003-9ef9-52cebf254b2f.png ADDED

Git LFS Details

  • SHA256: a5d6a5e7a66ed91b581c58e59240a9b911193b8300d0c8a44ef353a3826cedb2
  • Pointer size: 131 Bytes
  • Size of remote file: 653 kB
assets/hedgehog_rm_fg/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 648464818: remove the hedgehog.
assets/hedgehog_rp_bg/hedgehog.png ADDED

Git LFS Details

  • SHA256: e64da50164ce3136e269a7f02db37a375373ffa1133c4c1d0345762f725ad7b6
  • Pointer size: 131 Bytes
  • Size of remote file: 735 kB
assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png ADDED

Git LFS Details

  • SHA256: 42864ee8965bff065c457bef8f0764b95c5ac488d7c1d253adec9965b15e687e
  • Pointer size: 131 Bytes
  • Size of remote file: 816 kB
assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png ADDED

Git LFS Details

  • SHA256: 6744128de344ecbf3974fc759aab14f25ad41f63eeb977984b9d6c169a01df41
  • Pointer size: 129 Bytes
  • Size of remote file: 9.55 kB
assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png ADDED

Git LFS Details

  • SHA256: 71b3b37b57c2543450389a64605b6369c7b6d37c12bd01393c01488228db2c08
  • Pointer size: 131 Bytes
  • Size of remote file: 565 kB
assets/hedgehog_rp_bg/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 648464818: make the hedgehog in Italy.
assets/hedgehog_rp_fg/hedgehog.png ADDED

Git LFS Details

  • SHA256: e64da50164ce3136e269a7f02db37a375373ffa1133c4c1d0345762f725ad7b6
  • Pointer size: 131 Bytes
  • Size of remote file: 735 kB
assets/hedgehog_rp_fg/image_edit_5cab3448-5a3a-459c-9144-35cca3d34273_0.png ADDED

Git LFS Details

  • SHA256: 5eae70b41f84e33a6256cef1022487e5ade532e8d6e6042d0d2a66c7de8c5491
  • Pointer size: 131 Bytes
  • Size of remote file: 698 kB
assets/hedgehog_rp_fg/mask_5cab3448-5a3a-459c-9144-35cca3d34273.png ADDED

Git LFS Details

  • SHA256: 2e466f5080433b6b2b3cb0cdbdd882a39b071e9c2eb8a118cbc836253280911f
  • Pointer size: 129 Bytes
  • Size of remote file: 3.03 kB
assets/hedgehog_rp_fg/masked_image_5cab3448-5a3a-459c-9144-35cca3d34273.png ADDED

Git LFS Details

  • SHA256: a5d6a5e7a66ed91b581c58e59240a9b911193b8300d0c8a44ef353a3826cedb2
  • Pointer size: 131 Bytes
  • Size of remote file: 653 kB
assets/hedgehog_rp_fg/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 648464818: replace the hedgehog to flamingo.
assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png ADDED

Git LFS Details

  • SHA256: 165d04619b4b24b5261b3c016236b3539d13c744fb302ea72f649fb04d4378f9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.08 MB
assets/mona_lisa/mask_aae09152-4495-4332-b691-a0c7bff524be.png ADDED

Git LFS Details

  • SHA256: 8f2c982378902f5f78aa9f58ed3739eb43474a3202b9b74688419ddf87c09600
  • Pointer size: 129 Bytes
  • Size of remote file: 3.33 kB
assets/mona_lisa/masked_image_aae09152-4495-4332-b691-a0c7bff524be.png ADDED

Git LFS Details

  • SHA256: 5718c352ebcd1a6526fc84f5190f54a71e6edb4f8b628c0e633b283a7ef964f7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.1 MB
assets/mona_lisa/mona_lisa.png ADDED

Git LFS Details

  • SHA256: f176865dad43ad3d4358b9bfcdcabdc17ac25b6744461420a8d0b13634d5b048
  • Pointer size: 132 Bytes
  • Size of remote file: 1.38 MB
assets/mona_lisa/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 648464818: add a shining necklace.
assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png ADDED

Git LFS Details

  • SHA256: 6fe94816353deaebf737b9a92cc3b25c15d78d13adbb1295138b5a87bc56126c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.79 MB
assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png ADDED

Git LFS Details

  • SHA256: a2f7070129faf02815c736915648fb63d3461022c5babbce11966ff925f65271
  • Pointer size: 129 Bytes
  • Size of remote file: 7.08 kB
assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png ADDED

Git LFS Details

  • SHA256: 16c08cfb84ccc64c3e32617772ce9b9822b9de152784a74076e1ef93806f7e0f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.76 MB
assets/sunflower_girl/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 648464818: add a wreath on head..
assets/sunflower_girl/sunflower_girl.png ADDED

Git LFS Details

  • SHA256: 67cfa6af6126d774a32a355266311acaa9088f7d799aa28ecae4b95d784ce936
  • Pointer size: 132 Bytes
  • Size of remote file: 1.8 MB
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ transformers>=4.25.1
5
+ gradio==4.38.1
6
+ ftfy
7
+ tensorboard
8
+ datasets
9
+ Pillow==9.5.0
10
+ opencv-python
11
+ imgaug
12
+ accelerate==0.20.3
13
+ image-reward
14
+ hpsv2
15
+ torchmetrics
16
+ open-clip-torch
17
+ clip
18
+ segment_anything
19
+ git+https://github.com/liyaowei-stu/BrushEdit.git
20
+ git+https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/GroundingDINO