Spaces:
Runtime error
Runtime error
Commit
·
6444ed9
0
Parent(s):
init
Browse files- .gitattributes +62 -0
- .gitignore +178 -0
- README.md +11 -0
- app/gpt4_o/brushedit_all_in_one_pipeline.py +80 -0
- app/gpt4_o/brushedit_app.py +914 -0
- app/gpt4_o/instructions.py +106 -0
- app/gpt4_o/requirements.txt +18 -0
- app/gpt4_o/run_app.sh +5 -0
- app/gpt4_o/vlm_pipeline.py +138 -0
- app/utils/utils.py +197 -0
- assets/hedgehog_rm_fg/hedgehog.png +3 -0
- assets/hedgehog_rm_fg/image_edit_82314e18-c64c-4003-9ef9-52cebf254b2f_2.png +3 -0
- assets/hedgehog_rm_fg/mask_82314e18-c64c-4003-9ef9-52cebf254b2f.png +3 -0
- assets/hedgehog_rm_fg/masked_image_82314e18-c64c-4003-9ef9-52cebf254b2f.png +3 -0
- assets/hedgehog_rm_fg/prompt.txt +1 -0
- assets/hedgehog_rp_bg/hedgehog.png +3 -0
- assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png +3 -0
- assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png +3 -0
- assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png +3 -0
- assets/hedgehog_rp_bg/prompt.txt +1 -0
- assets/hedgehog_rp_fg/hedgehog.png +3 -0
- assets/hedgehog_rp_fg/image_edit_5cab3448-5a3a-459c-9144-35cca3d34273_0.png +3 -0
- assets/hedgehog_rp_fg/mask_5cab3448-5a3a-459c-9144-35cca3d34273.png +3 -0
- assets/hedgehog_rp_fg/masked_image_5cab3448-5a3a-459c-9144-35cca3d34273.png +3 -0
- assets/hedgehog_rp_fg/prompt.txt +1 -0
- assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png +3 -0
- assets/mona_lisa/mask_aae09152-4495-4332-b691-a0c7bff524be.png +3 -0
- assets/mona_lisa/masked_image_aae09152-4495-4332-b691-a0c7bff524be.png +3 -0
- assets/mona_lisa/mona_lisa.png +3 -0
- assets/mona_lisa/prompt.txt +1 -0
- assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png +3 -0
- assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png +3 -0
- assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png +3 -0
- assets/sunflower_girl/prompt.txt +1 -0
- assets/sunflower_girl/sunflower_girl.png +3 -0
- requirements.txt +20 -0
.gitattributes
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
40 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
+
*.bmp filter=lfs diff=lfs merge=lfs -text
|
42 |
+
*.tiff filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/hedgehog_rm_fg/hedgehog.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/hedgehog_rm_fg/image_edit_82314e18-c64c-4003-9ef9-52cebf254b2f_2.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
assets/hedgehog_rm_fg/mask_82314e18-c64c-4003-9ef9-52cebf254b2f.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
assets/hedgehog_rm_fg/masked_image_82314e18-c64c-4003-9ef9-52cebf254b2f.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
assets/hedgehog_rp_bg/hedgehog.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
assets/hedgehog_rp_fg/hedgehog.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
assets/hedgehog_rp_fg/image_edit_5cab3448-5a3a-459c-9144-35cca3d34273_0.png filter=lfs diff=lfs merge=lfs -text
|
53 |
+
assets/hedgehog_rp_fg/mask_5cab3448-5a3a-459c-9144-35cca3d34273.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
assets/hedgehog_rp_fg/masked_image_5cab3448-5a3a-459c-9144-35cca3d34273.png filter=lfs diff=lfs merge=lfs -text
|
55 |
+
assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png filter=lfs diff=lfs merge=lfs -text
|
56 |
+
assets/mona_lisa/mask_aae09152-4495-4332-b691-a0c7bff524be.png filter=lfs diff=lfs merge=lfs -text
|
57 |
+
assets/mona_lisa/masked_image_aae09152-4495-4332-b691-a0c7bff524be.png filter=lfs diff=lfs merge=lfs -text
|
58 |
+
assets/mona_lisa/mona_lisa.png filter=lfs diff=lfs merge=lfs -text
|
59 |
+
assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png filter=lfs diff=lfs merge=lfs -text
|
60 |
+
assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png filter=lfs diff=lfs merge=lfs -text
|
61 |
+
assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png filter=lfs diff=lfs merge=lfs -text
|
62 |
+
assets/sunflower_girl/sunflower_girl.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Initially taken from GitHub's Python gitignore file
|
2 |
+
|
3 |
+
# Byte-compiled / optimized / DLL files
|
4 |
+
__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
|
8 |
+
# C extensions
|
9 |
+
*.so
|
10 |
+
|
11 |
+
# tests and logs
|
12 |
+
tests/fixtures/cached_*_text.txt
|
13 |
+
logs/
|
14 |
+
lightning_logs/
|
15 |
+
lang_code_data/
|
16 |
+
|
17 |
+
# Distribution / packaging
|
18 |
+
.Python
|
19 |
+
build/
|
20 |
+
develop-eggs/
|
21 |
+
dist/
|
22 |
+
downloads/
|
23 |
+
eggs/
|
24 |
+
.eggs/
|
25 |
+
lib/
|
26 |
+
lib64/
|
27 |
+
parts/
|
28 |
+
sdist/
|
29 |
+
var/
|
30 |
+
wheels/
|
31 |
+
*.egg-info/
|
32 |
+
.installed.cfg
|
33 |
+
*.egg
|
34 |
+
MANIFEST
|
35 |
+
|
36 |
+
# PyInstaller
|
37 |
+
# Usually these files are written by a Python script from a template
|
38 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
39 |
+
*.manifest
|
40 |
+
*.spec
|
41 |
+
|
42 |
+
# Installer logs
|
43 |
+
pip-log.txt
|
44 |
+
pip-delete-this-directory.txt
|
45 |
+
|
46 |
+
# Unit test / coverage reports
|
47 |
+
htmlcov/
|
48 |
+
.tox/
|
49 |
+
.nox/
|
50 |
+
.coverage
|
51 |
+
.coverage.*
|
52 |
+
.cache
|
53 |
+
nosetests.xml
|
54 |
+
coverage.xml
|
55 |
+
*.cover
|
56 |
+
.hypothesis/
|
57 |
+
.pytest_cache/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
|
78 |
+
# PyBuilder
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
.python-version
|
90 |
+
|
91 |
+
# celery beat schedule file
|
92 |
+
celerybeat-schedule
|
93 |
+
|
94 |
+
# SageMath parsed files
|
95 |
+
*.sage.py
|
96 |
+
|
97 |
+
# Environments
|
98 |
+
.env
|
99 |
+
.venv
|
100 |
+
env/
|
101 |
+
venv/
|
102 |
+
ENV/
|
103 |
+
env.bak/
|
104 |
+
venv.bak/
|
105 |
+
|
106 |
+
# Spyder project settings
|
107 |
+
.spyderproject
|
108 |
+
.spyproject
|
109 |
+
|
110 |
+
# Rope project settings
|
111 |
+
.ropeproject
|
112 |
+
|
113 |
+
# mkdocs documentation
|
114 |
+
/site
|
115 |
+
|
116 |
+
# mypy
|
117 |
+
.mypy_cache/
|
118 |
+
.dmypy.json
|
119 |
+
dmypy.json
|
120 |
+
|
121 |
+
# Pyre type checker
|
122 |
+
.pyre/
|
123 |
+
|
124 |
+
# vscode
|
125 |
+
.vs
|
126 |
+
.vscode
|
127 |
+
|
128 |
+
# Pycharm
|
129 |
+
.idea
|
130 |
+
|
131 |
+
# TF code
|
132 |
+
tensorflow_code
|
133 |
+
|
134 |
+
# Models
|
135 |
+
proc_data
|
136 |
+
|
137 |
+
# examples
|
138 |
+
runs
|
139 |
+
/runs_old
|
140 |
+
/wandb
|
141 |
+
/examples/runs
|
142 |
+
/examples/**/*.args
|
143 |
+
/examples/rag/sweep
|
144 |
+
|
145 |
+
# data
|
146 |
+
/data
|
147 |
+
serialization_dir
|
148 |
+
|
149 |
+
# emacs
|
150 |
+
*.*~
|
151 |
+
debug.env
|
152 |
+
|
153 |
+
# vim
|
154 |
+
.*.swp
|
155 |
+
|
156 |
+
# ctags
|
157 |
+
tags
|
158 |
+
|
159 |
+
# pre-commit
|
160 |
+
.pre-commit*
|
161 |
+
|
162 |
+
# .lock
|
163 |
+
*.lock
|
164 |
+
|
165 |
+
# DS_Store (MacOS)
|
166 |
+
.DS_Store
|
167 |
+
|
168 |
+
# RL pipelines may produce mp4 outputs
|
169 |
+
*.mp4
|
170 |
+
|
171 |
+
# dependencies
|
172 |
+
/transformers
|
173 |
+
|
174 |
+
# ruff
|
175 |
+
.ruff_cache
|
176 |
+
|
177 |
+
# wandb
|
178 |
+
wandb
|
README.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: BrushEdit
|
3 |
+
emoji: 🤠
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: gray
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.38.1
|
8 |
+
app_file: app/gpt4_o/brushedit_app.py
|
9 |
+
pinned: false
|
10 |
+
python_version: 3.1
|
11 |
+
---
|
app/gpt4_o/brushedit_all_in_one_pipeline.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageEnhance
|
2 |
+
from diffusers.image_processor import VaeImageProcessor
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
def BrushEdit_Pipeline(pipe,
|
10 |
+
prompts,
|
11 |
+
mask_np,
|
12 |
+
original_image,
|
13 |
+
generator,
|
14 |
+
num_inference_steps,
|
15 |
+
guidance_scale,
|
16 |
+
control_strength,
|
17 |
+
negative_prompt,
|
18 |
+
num_samples,
|
19 |
+
blending):
|
20 |
+
if mask_np.ndim != 3:
|
21 |
+
mask_np = mask_np[:, :, np.newaxis]
|
22 |
+
|
23 |
+
mask_np = mask_np / 255
|
24 |
+
height, width = mask_np.shape[0], mask_np.shape[1]
|
25 |
+
# back/foreground
|
26 |
+
# if mask_np[94:547,94:546].sum() < mask_np.sum() - mask_np[94:547,94:546].sum() and mask_np[0,:].sum()>0 and mask_np[-1,:].sum()>0 and mask_np[:,0].sum()>0 and mask_np[:,-1].sum()>0 and mask_np[1,:].sum()>0 and mask_np[-2,:].sum()>0 and mask_np[:,1].sum()>0 and mask_np[:,-2].sum()>0 :
|
27 |
+
# mask_np = 1 - mask_np
|
28 |
+
|
29 |
+
## resize the mask and original image to the same size which is divisible by vae_scale_factor
|
30 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
31 |
+
height_new, width_new = image_processor.get_default_height_width(original_image, height, width)
|
32 |
+
mask_np = cv2.resize(mask_np, (width_new, height_new))[:,:,np.newaxis]
|
33 |
+
mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255
|
34 |
+
mask_blurred = mask_blurred[:, :, np.newaxis]
|
35 |
+
|
36 |
+
original_image = cv2.resize(original_image, (width_new, height_new))
|
37 |
+
|
38 |
+
init_image = original_image * (1 - mask_np)
|
39 |
+
init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB")
|
40 |
+
mask_image = Image.fromarray((mask_np.repeat(3, -1) * 255).astype(np.uint8)).convert("RGB")
|
41 |
+
|
42 |
+
brushnet_conditioning_scale = float(control_strength)
|
43 |
+
|
44 |
+
images = pipe(
|
45 |
+
[prompts] * num_samples,
|
46 |
+
init_image,
|
47 |
+
mask_image,
|
48 |
+
num_inference_steps=num_inference_steps,
|
49 |
+
guidance_scale=guidance_scale,
|
50 |
+
generator=generator,
|
51 |
+
brushnet_conditioning_scale=brushnet_conditioning_scale,
|
52 |
+
negative_prompt=[negative_prompt]*num_samples,
|
53 |
+
height=height_new,
|
54 |
+
width=width_new,
|
55 |
+
).images
|
56 |
+
|
57 |
+
if blending:
|
58 |
+
|
59 |
+
mask_blurred = mask_blurred * 0.5 + 0.5
|
60 |
+
|
61 |
+
## convert to vae shape format, must be divisible by 8
|
62 |
+
original_image_pil = Image.fromarray(original_image).convert("RGB")
|
63 |
+
init_image_np = np.array(image_processor.preprocess(original_image_pil, height=height_new, width=width_new).squeeze())
|
64 |
+
init_image_np = ((init_image_np.transpose(1,2,0) + 1.) / 2.) * 255
|
65 |
+
init_image_np = init_image_np.astype(np.uint8)
|
66 |
+
image_all = []
|
67 |
+
for image_i in images:
|
68 |
+
image_np = np.array(image_i)
|
69 |
+
## blending
|
70 |
+
image_pasted = init_image_np * (1 - mask_blurred) + mask_blurred * image_np
|
71 |
+
image_pasted = image_pasted.astype(np.uint8)
|
72 |
+
image = Image.fromarray(image_pasted)
|
73 |
+
image_all.append(image)
|
74 |
+
else:
|
75 |
+
image_all = images
|
76 |
+
|
77 |
+
|
78 |
+
return image_all, mask_image
|
79 |
+
|
80 |
+
|
app/gpt4_o/brushedit_app.py
ADDED
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os, random
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import spaces
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
+
|
15 |
+
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
|
16 |
+
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
|
17 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
18 |
+
|
19 |
+
from app.gpt4_o.vlm_pipeline import (
|
20 |
+
vlm_response_editing_type,
|
21 |
+
vlm_response_object_wait_for_edit,
|
22 |
+
vlm_response_mask,
|
23 |
+
vlm_response_prompt_after_apply_instruction
|
24 |
+
)
|
25 |
+
from app.gpt4_o.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
|
26 |
+
from app.utils.utils import load_grounding_dino_model
|
27 |
+
|
28 |
+
|
29 |
+
#### Description ####
|
30 |
+
head = r"""
|
31 |
+
<div style="text-align: center;">
|
32 |
+
<h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
|
33 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
34 |
+
<a href='https://tencentarc.github.io/BrushNet/'><img src='https://img.shields.io/badge/Project_Page-BrushNet-green' alt='Project Page'></a>
|
35 |
+
<a href='https://github.com/TencentARC/BrushNet/blob/main/InstructionGuidedEditing/CVPR2024workshop_technique_report.pdf'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
|
36 |
+
<a href='https://github.com/TencentARC/BrushNet'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
|
37 |
+
|
38 |
+
</div>
|
39 |
+
</br>
|
40 |
+
</div>
|
41 |
+
"""
|
42 |
+
descriptions = r"""
|
43 |
+
Official Gradio Demo for <a href='https://tencentarc.github.io/BrushNet/'><b>BrushEdit: All-In-One Image Inpainting and Editing</b></a><br>
|
44 |
+
🧙 BrushEdit enables precise, user-friendly instruction-based image editing via a inpainting model.<br>
|
45 |
+
"""
|
46 |
+
|
47 |
+
instructions = r"""
|
48 |
+
Currently, we support two modes: <b>fully automated command editing</b> and <b>interactive command editing</b>.
|
49 |
+
|
50 |
+
🛠️ <b>Fully automated instruction-based editing</b>:
|
51 |
+
<ul>
|
52 |
+
<li> ⭐️ <b>step1:</b> Upload or select one image from Example. </li>
|
53 |
+
<li> ⭐️ <b>step2:</b> Input the instructions (supports addition, deletion, and modification), e.g. remove xxx .</li>
|
54 |
+
<li> ⭐️ <b>step3:</b> Click <b>Run</b> button to automatic edit image.</li>
|
55 |
+
</ul>
|
56 |
+
|
57 |
+
🛠️ <b>Interactive instruction-based editing</b>:
|
58 |
+
<ul>
|
59 |
+
<li> ⭐️ <b>step1:</b> Upload or select one image from Example. </li>
|
60 |
+
<li> ⭐️ <b>step2:</b> Use a brush to outline the area you want to edit. </li>
|
61 |
+
<li> ⭐️ <b>step3:</b> Input the instructions. </li>
|
62 |
+
<li> ⭐️ <b>step4:</b> Click <b>Run</b> button to automatic edit image. </li>
|
63 |
+
</ul>
|
64 |
+
|
65 |
+
💡 <b>Some tips</b>:
|
66 |
+
<ul>
|
67 |
+
<li> 🤠 After input the instructions, you can click the <b>Generate Mask</b> button. The mask generated by VLM will be displayed in the preview panel on the right side. </li>
|
68 |
+
<li> 🤠 After generating the mask or when you use the brush to draw the mask, you can perform operations such as <b>randomization</b>, <b>dilation</b>, <b>erosion</b>, and <b>movement</b>. </li>
|
69 |
+
<li> 🤠 After input the instructions, you can click the <b>Generate Target Prompt</b> button. The target prompt will be displayed in the text box, and you can modify it according to your ideas. </li>
|
70 |
+
</ul>
|
71 |
+
|
72 |
+
☕️ Have fun!
|
73 |
+
"""
|
74 |
+
|
75 |
+
|
76 |
+
# - - - - - examples - - - - - #
|
77 |
+
EXAMPLES = [
|
78 |
+
# [
|
79 |
+
# {"background": Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png").convert("RGBA"),
|
80 |
+
# "layers": [Image.new("RGBA", (Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png").width, Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png").height), (0, 0, 0, 0))],
|
81 |
+
# "composite": Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png").convert("RGBA")},
|
82 |
+
# # Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png").convert("RGBA"),
|
83 |
+
# "add a shining necklace",
|
84 |
+
# # [Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.jpg")],
|
85 |
+
# # [Image.open("assets/mona_lisa/mask_aae09152-4495-4332-b691-a0c7bff524be.png")],
|
86 |
+
# # [Image.open("assets/mona_lisa/masked_image_aae09152-4495-4332-b691-a0c7bff524be.png")]
|
87 |
+
# ],
|
88 |
+
|
89 |
+
[
|
90 |
+
# load_image_from_url("https://github.com/liyaowei-stu/BrushEdit/blob/main/assets/mona_lisa/mona_lisa.png"),
|
91 |
+
Image.open("assets/mona_lisa/mona_lisa.png").convert("RGBA"),
|
92 |
+
"add a shining necklace",
|
93 |
+
# [Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.jpg")],
|
94 |
+
# [Image.open("assets/mona_lisa/mask_aae09152-4495-4332-b691-a0c7bff524be.png")],
|
95 |
+
# [Image.open("assets/mona_lisa/masked_image_aae09152-4495-4332-b691-a0c7bff524be.png")]
|
96 |
+
],
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
]
|
102 |
+
|
103 |
+
|
104 |
+
## init VLM
|
105 |
+
from openai import OpenAI
|
106 |
+
|
107 |
+
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
|
108 |
+
os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
|
109 |
+
vlm = OpenAI(base_url="http://v2.open.venus.oa.com/llmproxy")
|
110 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
# download hf models
|
115 |
+
base_model_path = hf_hub_download(
|
116 |
+
repo_id="Yw22/BrushEdit",
|
117 |
+
subfolder="base_model/realisticVisionV60B1_v51VAE",
|
118 |
+
token=os.getenv("HF_TOKEN"),
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
brushnet_path = hf_hub_download(
|
123 |
+
repo_id="Yw22/BrushEdit",
|
124 |
+
subfolder="brushnetX",
|
125 |
+
token=os.getenv("HF_TOKEN"),
|
126 |
+
)
|
127 |
+
|
128 |
+
sam_path = hf_hub_download(
|
129 |
+
repo_id="Yw22/BrushEdit",
|
130 |
+
subfolder="sam",
|
131 |
+
filename="sam_vit_h_4b8939.pth",
|
132 |
+
token=os.getenv("HF_TOKEN"),
|
133 |
+
)
|
134 |
+
|
135 |
+
groundingdino_path = hf_hub_download(
|
136 |
+
repo_id="Yw22/BrushEdit",
|
137 |
+
subfolder="grounding_dino",
|
138 |
+
filename="groundingdino_swint_ogc.pth",
|
139 |
+
token=os.getenv("HF_TOKEN"),
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
# input brushnetX ckpt path
|
144 |
+
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16)
|
145 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
146 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False
|
147 |
+
)
|
148 |
+
# speed up diffusion process with faster scheduler and memory optimization
|
149 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
150 |
+
# remove following line if xformers is not installed or when using Torch 2.0.
|
151 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
152 |
+
pipe.enable_model_cpu_offload()
|
153 |
+
|
154 |
+
|
155 |
+
## init SAM
|
156 |
+
sam = build_sam(checkpoint=sam_path)
|
157 |
+
sam.to(device=device)
|
158 |
+
sam_predictor = SamPredictor(sam)
|
159 |
+
sam_automask_generator = SamAutomaticMaskGenerator(sam)
|
160 |
+
|
161 |
+
## init groundingdino_model
|
162 |
+
config_file = 'third_party/Grounded-Segment-Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
|
163 |
+
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
|
164 |
+
|
165 |
+
## Ordinary function
|
166 |
+
def crop_and_resize(image: Image.Image,
|
167 |
+
target_width: int,
|
168 |
+
target_height: int) -> Image.Image:
|
169 |
+
"""
|
170 |
+
Crops and resizes an image while preserving the aspect ratio.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
174 |
+
target_width (int): Target width of the output image.
|
175 |
+
target_height (int): Target height of the output image.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
Image.Image: Cropped and resized image.
|
179 |
+
"""
|
180 |
+
# Original dimensions
|
181 |
+
original_width, original_height = image.size
|
182 |
+
original_aspect = original_width / original_height
|
183 |
+
target_aspect = target_width / target_height
|
184 |
+
|
185 |
+
# Calculate crop box to maintain aspect ratio
|
186 |
+
if original_aspect > target_aspect:
|
187 |
+
# Crop horizontally
|
188 |
+
new_width = int(original_height * target_aspect)
|
189 |
+
new_height = original_height
|
190 |
+
left = (original_width - new_width) / 2
|
191 |
+
top = 0
|
192 |
+
right = left + new_width
|
193 |
+
bottom = original_height
|
194 |
+
else:
|
195 |
+
# Crop vertically
|
196 |
+
new_width = original_width
|
197 |
+
new_height = int(original_width / target_aspect)
|
198 |
+
left = 0
|
199 |
+
top = (original_height - new_height) / 2
|
200 |
+
right = original_width
|
201 |
+
bottom = top + new_height
|
202 |
+
|
203 |
+
# Crop and resize
|
204 |
+
cropped_image = image.crop((left, top, right, bottom))
|
205 |
+
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
|
206 |
+
|
207 |
+
return resized_image
|
208 |
+
|
209 |
+
|
210 |
+
def move_mask_func(mask, direction, units):
|
211 |
+
binary_mask = mask.squeeze()>0
|
212 |
+
rows, cols = binary_mask.shape
|
213 |
+
|
214 |
+
moved_mask = np.zeros_like(binary_mask, dtype=bool)
|
215 |
+
|
216 |
+
if direction == 'down':
|
217 |
+
# move down
|
218 |
+
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
|
219 |
+
|
220 |
+
elif direction == 'up':
|
221 |
+
# move up
|
222 |
+
moved_mask[:rows - units, :] = binary_mask[units:, :]
|
223 |
+
|
224 |
+
elif direction == 'right':
|
225 |
+
# move left
|
226 |
+
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
|
227 |
+
|
228 |
+
elif direction == 'left':
|
229 |
+
# move right
|
230 |
+
moved_mask[:, :cols - units] = binary_mask[:, units:]
|
231 |
+
|
232 |
+
return moved_mask
|
233 |
+
|
234 |
+
|
235 |
+
def random_mask_func(mask, dilation_type='square'):
|
236 |
+
# Randomly select the size of dilation
|
237 |
+
dilation_size = np.random.randint(20, 40) # Randomly select the size of dilation
|
238 |
+
binary_mask = mask.squeeze()>0
|
239 |
+
|
240 |
+
if dilation_type == 'square_dilation':
|
241 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
242 |
+
dilated_mask = binary_dilation(binary_mask, structure=structure)
|
243 |
+
elif dilation_type == 'square_erosion':
|
244 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
245 |
+
dilated_mask = binary_erosion(binary_mask, structure=structure)
|
246 |
+
elif dilation_type == 'bounding_box':
|
247 |
+
# find the most left top and left bottom point
|
248 |
+
rows, cols = np.where(binary_mask)
|
249 |
+
if len(rows) == 0 or len(cols) == 0:
|
250 |
+
return mask # return original mask if no valid points
|
251 |
+
|
252 |
+
min_row = np.min(rows)
|
253 |
+
max_row = np.max(rows)
|
254 |
+
min_col = np.min(cols)
|
255 |
+
max_col = np.max(cols)
|
256 |
+
|
257 |
+
# create a bounding box
|
258 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
259 |
+
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
|
260 |
+
|
261 |
+
elif dilation_type == 'bounding_ellipse':
|
262 |
+
# find the most left top and left bottom point
|
263 |
+
rows, cols = np.where(binary_mask)
|
264 |
+
if len(rows) == 0 or len(cols) == 0:
|
265 |
+
return mask # return original mask if no valid points
|
266 |
+
|
267 |
+
min_row = np.min(rows)
|
268 |
+
max_row = np.max(rows)
|
269 |
+
min_col = np.min(cols)
|
270 |
+
max_col = np.max(cols)
|
271 |
+
|
272 |
+
# calculate the center and axis length of the ellipse
|
273 |
+
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
|
274 |
+
a = (max_col - min_col) // 2 # half long axis
|
275 |
+
b = (max_row - min_row) // 2 # half short axis
|
276 |
+
|
277 |
+
# create a bounding ellipse
|
278 |
+
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
|
279 |
+
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
|
280 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
281 |
+
dilated_mask[ellipse_mask] = True
|
282 |
+
else:
|
283 |
+
raise ValueError("dilation_type must be 'square' or 'ellipse'")
|
284 |
+
|
285 |
+
# use binary dilation
|
286 |
+
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
|
287 |
+
return dilated_mask
|
288 |
+
|
289 |
+
|
290 |
+
## Gradio component function
|
291 |
+
@spaces.GPU(duration=180)
|
292 |
+
def process(input_image,
|
293 |
+
original_image,
|
294 |
+
original_mask,
|
295 |
+
prompt,
|
296 |
+
negative_prompt,
|
297 |
+
control_strength,
|
298 |
+
seed,
|
299 |
+
randomize_seed,
|
300 |
+
guidance_scale,
|
301 |
+
num_inference_steps,
|
302 |
+
num_samples,
|
303 |
+
blending,
|
304 |
+
category,
|
305 |
+
target_prompt,
|
306 |
+
resize_and_crop):
|
307 |
+
|
308 |
+
import ipdb; ipdb.set_trace()
|
309 |
+
if original_image is None:
|
310 |
+
raise gr.Error('Please upload the input image')
|
311 |
+
if prompt is None:
|
312 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
313 |
+
|
314 |
+
|
315 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
316 |
+
input_mask = np.asarray(alpha_mask)
|
317 |
+
if resize_and_crop:
|
318 |
+
original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
|
319 |
+
input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
|
320 |
+
original_image = np.array(original_image)
|
321 |
+
input_mask = np.array(input_mask)
|
322 |
+
|
323 |
+
if input_mask.max() == 0:
|
324 |
+
original_mask = original_mask
|
325 |
+
else:
|
326 |
+
original_mask = input_mask[:,:,None]
|
327 |
+
|
328 |
+
# load example image
|
329 |
+
# if isinstance(original_image, str):
|
330 |
+
# # image_name = image_examples[original_image][0]
|
331 |
+
# # original_image = cv2.imread(image_name)
|
332 |
+
# # original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
|
333 |
+
# original_image = input_image
|
334 |
+
# num_samples = 1
|
335 |
+
# blending = True
|
336 |
+
|
337 |
+
if category is not None:
|
338 |
+
pass
|
339 |
+
else:
|
340 |
+
category = vlm_response_editing_type(vlm, original_image, prompt)
|
341 |
+
|
342 |
+
|
343 |
+
if original_mask is not None:
|
344 |
+
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
|
345 |
+
else:
|
346 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm,
|
347 |
+
category,
|
348 |
+
prompt)
|
349 |
+
original_mask = vlm_response_mask(vlm,
|
350 |
+
category,
|
351 |
+
original_image,
|
352 |
+
prompt,
|
353 |
+
object_wait_for_edit,
|
354 |
+
sam,
|
355 |
+
sam_predictor,
|
356 |
+
sam_automask_generator,
|
357 |
+
groundingdino_model,
|
358 |
+
)[:,:,None]
|
359 |
+
|
360 |
+
|
361 |
+
if len(target_prompt) <= 1:
|
362 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(vlm,
|
363 |
+
original_image,
|
364 |
+
prompt)
|
365 |
+
else:
|
366 |
+
prompt_after_apply_instruction = target_prompt
|
367 |
+
|
368 |
+
generator = torch.Generator("cuda").manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
|
369 |
+
|
370 |
+
|
371 |
+
|
372 |
+
image, mask_image = BrushEdit_Pipeline(pipe,
|
373 |
+
prompt_after_apply_instruction,
|
374 |
+
original_mask,
|
375 |
+
original_image,
|
376 |
+
generator,
|
377 |
+
num_inference_steps,
|
378 |
+
guidance_scale,
|
379 |
+
control_strength,
|
380 |
+
negative_prompt,
|
381 |
+
num_samples,
|
382 |
+
blending)
|
383 |
+
|
384 |
+
masked_image = original_image * (1 - (original_mask>0))
|
385 |
+
masked_image = masked_image.astype(np.uint8)
|
386 |
+
masked_image = Image.fromarray(masked_image)
|
387 |
+
|
388 |
+
import uuid
|
389 |
+
uuid = str(uuid.uuid4())
|
390 |
+
image[0].save(f"outputs/image_edit_{uuid}_0.png")
|
391 |
+
image[1].save(f"outputs/image_edit_{uuid}_1.png")
|
392 |
+
image[2].save(f"outputs/image_edit_{uuid}_2.png")
|
393 |
+
image[3].save(f"outputs/image_edit_{uuid}_3.png")
|
394 |
+
mask_image.save(f"outputs/mask_{uuid}.png")
|
395 |
+
masked_image.save(f"outputs/masked_image_{uuid}.png")
|
396 |
+
return image, [mask_image], [masked_image], ''
|
397 |
+
|
398 |
+
|
399 |
+
def generate_target_prompt(input_image,
|
400 |
+
original_image,
|
401 |
+
prompt):
|
402 |
+
# load example image
|
403 |
+
if isinstance(original_image, str):
|
404 |
+
original_image = input_image
|
405 |
+
|
406 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(vlm,
|
407 |
+
original_image,
|
408 |
+
prompt)
|
409 |
+
return prompt_after_apply_instruction
|
410 |
+
|
411 |
+
|
412 |
+
def process_mask(input_image,
|
413 |
+
original_image,
|
414 |
+
prompt,
|
415 |
+
resize_and_crop):
|
416 |
+
if original_image is None:
|
417 |
+
raise gr.Error('Please upload the input image')
|
418 |
+
if prompt is None:
|
419 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
420 |
+
|
421 |
+
## load mask
|
422 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
423 |
+
input_mask = np.array(alpha_mask)
|
424 |
+
|
425 |
+
# load example image
|
426 |
+
if isinstance(original_image, str):
|
427 |
+
original_image = input_image["background"]
|
428 |
+
|
429 |
+
if resize_and_crop:
|
430 |
+
original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
|
431 |
+
input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
|
432 |
+
original_image = np.array(original_image)
|
433 |
+
input_mask = np.array(input_mask)
|
434 |
+
|
435 |
+
|
436 |
+
if input_mask.max() == 0:
|
437 |
+
category = vlm_response_editing_type(vlm, original_image, prompt)
|
438 |
+
|
439 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm,
|
440 |
+
category,
|
441 |
+
prompt)
|
442 |
+
# original mask: h,w,1 [0, 255]
|
443 |
+
original_mask = vlm_response_mask(
|
444 |
+
vlm,
|
445 |
+
category,
|
446 |
+
original_image,
|
447 |
+
prompt,
|
448 |
+
object_wait_for_edit,
|
449 |
+
sam,
|
450 |
+
sam_predictor,
|
451 |
+
sam_automask_generator,
|
452 |
+
groundingdino_model,
|
453 |
+
)[:,:,None]
|
454 |
+
else:
|
455 |
+
original_mask = input_mask[:,:,None]
|
456 |
+
category = None
|
457 |
+
|
458 |
+
|
459 |
+
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
|
460 |
+
|
461 |
+
masked_image = original_image * (1 - (original_mask>0))
|
462 |
+
masked_image = masked_image.astype(np.uint8)
|
463 |
+
masked_image = Image.fromarray(masked_image)
|
464 |
+
|
465 |
+
## not work for image editor
|
466 |
+
# background = input_image["background"]
|
467 |
+
# mask_array = original_mask.squeeze()
|
468 |
+
# layer_rgba = np.array(input_image['layers'][0])
|
469 |
+
# layer_rgba[mask_array > 0] = [0, 0, 0, 255]
|
470 |
+
# layer_rgba = Image.fromarray(layer_rgba, 'RGBA')
|
471 |
+
# black_image = Image.new("RGBA", layer_rgba.size, (0, 0, 0, 255))
|
472 |
+
# composite = Image.composite(black_image, background, layer_rgba)
|
473 |
+
# output_base = {"layers": [layer_rgba], "background": background, "composite": composite}
|
474 |
+
|
475 |
+
|
476 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8), category
|
477 |
+
|
478 |
+
|
479 |
+
def process_random_mask(input_image, original_image, original_mask, resize_and_crop):
|
480 |
+
|
481 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
482 |
+
input_mask = np.asarray(alpha_mask)
|
483 |
+
if resize_and_crop:
|
484 |
+
original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
|
485 |
+
input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
|
486 |
+
original_image = np.array(original_image)
|
487 |
+
input_mask = np.array(input_mask)
|
488 |
+
|
489 |
+
|
490 |
+
if input_mask.max() == 0:
|
491 |
+
if original_mask is None:
|
492 |
+
raise gr.Error('Please generate mask first')
|
493 |
+
original_mask = original_mask
|
494 |
+
else:
|
495 |
+
original_mask = input_mask[:,:,None]
|
496 |
+
|
497 |
+
|
498 |
+
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
|
499 |
+
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
|
500 |
+
|
501 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
502 |
+
|
503 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
504 |
+
masked_image = masked_image.astype(original_image.dtype)
|
505 |
+
masked_image = Image.fromarray(masked_image)
|
506 |
+
|
507 |
+
|
508 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
509 |
+
|
510 |
+
|
511 |
+
def process_dilation_mask(input_image, original_image, original_mask, resize_and_crop):
|
512 |
+
|
513 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
514 |
+
input_mask = np.asarray(alpha_mask)
|
515 |
+
if resize_and_crop:
|
516 |
+
original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
|
517 |
+
input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
|
518 |
+
original_image = np.array(original_image)
|
519 |
+
input_mask = np.array(input_mask)
|
520 |
+
|
521 |
+
if input_mask.max() == 0:
|
522 |
+
if original_mask is None:
|
523 |
+
raise gr.Error('Please generate mask first')
|
524 |
+
original_mask = original_mask
|
525 |
+
else:
|
526 |
+
original_mask = input_mask[:,:,None]
|
527 |
+
|
528 |
+
dilation_type = np.random.choice(['square_dilation'])
|
529 |
+
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
|
530 |
+
|
531 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
532 |
+
|
533 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
534 |
+
masked_image = masked_image.astype(original_image.dtype)
|
535 |
+
masked_image = Image.fromarray(masked_image)
|
536 |
+
|
537 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
538 |
+
|
539 |
+
|
540 |
+
def process_erosion_mask(input_image, original_image, original_mask, resize_and_crop):
|
541 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
542 |
+
input_mask = np.asarray(alpha_mask)
|
543 |
+
if resize_and_crop:
|
544 |
+
original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
|
545 |
+
input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
|
546 |
+
original_image = np.array(original_image)
|
547 |
+
input_mask = np.array(input_mask)
|
548 |
+
|
549 |
+
if input_mask.max() == 0:
|
550 |
+
if original_mask is None:
|
551 |
+
raise gr.Error('Please generate mask first')
|
552 |
+
original_mask = original_mask
|
553 |
+
else:
|
554 |
+
original_mask = input_mask[:,:,None]
|
555 |
+
|
556 |
+
dilation_type = np.random.choice(['square_erosion'])
|
557 |
+
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
|
558 |
+
|
559 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
560 |
+
|
561 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
562 |
+
masked_image = masked_image.astype(original_image.dtype)
|
563 |
+
masked_image = Image.fromarray(masked_image)
|
564 |
+
|
565 |
+
|
566 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
567 |
+
|
568 |
+
|
569 |
+
def move_mask_left(input_image, original_image, original_mask, moving_pixels, resize_and_crop):
|
570 |
+
|
571 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
572 |
+
input_mask = np.asarray(alpha_mask)
|
573 |
+
if resize_and_crop:
|
574 |
+
original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
|
575 |
+
input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
|
576 |
+
original_image = np.array(original_image)
|
577 |
+
input_mask = np.array(input_mask)
|
578 |
+
|
579 |
+
if input_mask.max() == 0:
|
580 |
+
if original_mask is None:
|
581 |
+
raise gr.Error('Please generate mask first')
|
582 |
+
original_mask = original_mask
|
583 |
+
else:
|
584 |
+
original_mask = input_mask[:,:,None]
|
585 |
+
|
586 |
+
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
|
587 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
588 |
+
|
589 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
590 |
+
masked_image = masked_image.astype(original_image.dtype)
|
591 |
+
masked_image = Image.fromarray(masked_image)
|
592 |
+
|
593 |
+
if moved_mask.max() <= 1:
|
594 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
595 |
+
original_mask = moved_mask
|
596 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
597 |
+
|
598 |
+
|
599 |
+
def move_mask_right(input_image, original_image, original_mask, moving_pixels, resize_and_crop):
|
600 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
601 |
+
input_mask = np.asarray(alpha_mask)
|
602 |
+
if resize_and_crop:
|
603 |
+
original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
|
604 |
+
input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
|
605 |
+
original_image = np.array(original_image)
|
606 |
+
input_mask = np.array(input_mask)
|
607 |
+
|
608 |
+
if input_mask.max() == 0:
|
609 |
+
if original_mask is None:
|
610 |
+
raise gr.Error('Please generate mask first')
|
611 |
+
original_mask = original_mask
|
612 |
+
else:
|
613 |
+
original_mask = input_mask[:,:,None]
|
614 |
+
|
615 |
+
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
|
616 |
+
|
617 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
618 |
+
|
619 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
620 |
+
masked_image = masked_image.astype(original_image.dtype)
|
621 |
+
masked_image = Image.fromarray(masked_image)
|
622 |
+
|
623 |
+
|
624 |
+
if moved_mask.max() <= 1:
|
625 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
626 |
+
original_mask = moved_mask
|
627 |
+
|
628 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
629 |
+
|
630 |
+
|
631 |
+
def move_mask_up(input_image, original_image, original_mask, moving_pixels, resize_and_crop):
|
632 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
633 |
+
input_mask = np.asarray(alpha_mask)
|
634 |
+
if resize_and_crop:
|
635 |
+
original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
|
636 |
+
input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
|
637 |
+
original_image = np.array(original_image)
|
638 |
+
input_mask = np.array(input_mask)
|
639 |
+
|
640 |
+
if input_mask.max() == 0:
|
641 |
+
if original_mask is None:
|
642 |
+
raise gr.Error('Please generate mask first')
|
643 |
+
original_mask = original_mask
|
644 |
+
else:
|
645 |
+
original_mask = input_mask[:,:,None]
|
646 |
+
|
647 |
+
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
|
648 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
649 |
+
|
650 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
651 |
+
masked_image = masked_image.astype(original_image.dtype)
|
652 |
+
masked_image = Image.fromarray(masked_image)
|
653 |
+
|
654 |
+
if moved_mask.max() <= 1:
|
655 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
656 |
+
original_mask = moved_mask
|
657 |
+
|
658 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
659 |
+
|
660 |
+
|
661 |
+
def move_mask_down(input_image, original_image, original_mask, moving_pixels, resize_and_crop):
|
662 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
663 |
+
input_mask = np.asarray(alpha_mask)
|
664 |
+
if resize_and_crop:
|
665 |
+
original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
|
666 |
+
input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
|
667 |
+
original_image = np.array(original_image)
|
668 |
+
input_mask = np.array(input_mask)
|
669 |
+
|
670 |
+
if input_mask.max() == 0:
|
671 |
+
if original_mask is None:
|
672 |
+
raise gr.Error('Please generate mask first')
|
673 |
+
original_mask = original_mask
|
674 |
+
else:
|
675 |
+
original_mask = input_mask[:,:,None]
|
676 |
+
|
677 |
+
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
|
678 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
679 |
+
|
680 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
681 |
+
masked_image = masked_image.astype(original_image.dtype)
|
682 |
+
masked_image = Image.fromarray(masked_image)
|
683 |
+
|
684 |
+
if moved_mask.max() <= 1:
|
685 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
686 |
+
original_mask = moved_mask
|
687 |
+
|
688 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
689 |
+
|
690 |
+
|
691 |
+
def store_img(base):
|
692 |
+
import ipdb; ipdb.set_trace()
|
693 |
+
image_pil = base["background"].convert("RGB")
|
694 |
+
original_image = np.array(image_pil)
|
695 |
+
# import ipdb; ipdb.set_trace()
|
696 |
+
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
|
697 |
+
raise gr.Error('image aspect ratio cannot be larger than 2.0')
|
698 |
+
return base, original_image, None, "", None, None, None, None, None
|
699 |
+
|
700 |
+
|
701 |
+
def reset_func(input_image, original_image, original_mask, prompt, target_prompt):
|
702 |
+
input_image = None
|
703 |
+
original_image = None
|
704 |
+
original_mask = None
|
705 |
+
prompt = ''
|
706 |
+
mask_gallery = []
|
707 |
+
masked_gallery = []
|
708 |
+
result_gallery = []
|
709 |
+
target_prompt = ''
|
710 |
+
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt
|
711 |
+
|
712 |
+
|
713 |
+
block = gr.Blocks(
|
714 |
+
theme=gr.themes.Soft(
|
715 |
+
radius_size=gr.themes.sizes.radius_none,
|
716 |
+
text_size=gr.themes.sizes.text_md
|
717 |
+
)
|
718 |
+
).queue()
|
719 |
+
with block as demo:
|
720 |
+
with gr.Row():
|
721 |
+
with gr.Column():
|
722 |
+
gr.HTML(head)
|
723 |
+
|
724 |
+
gr.Markdown(descriptions)
|
725 |
+
|
726 |
+
with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
|
727 |
+
with gr.Row(equal_height=True):
|
728 |
+
gr.Markdown(instructions)
|
729 |
+
|
730 |
+
original_image = gr.State(value=None)
|
731 |
+
original_mask = gr.State(value=None)
|
732 |
+
category = gr.State(value=None)
|
733 |
+
|
734 |
+
with gr.Row():
|
735 |
+
with gr.Column():
|
736 |
+
with gr.Row():
|
737 |
+
input_image = gr.ImageEditor(
|
738 |
+
label="Input Image",
|
739 |
+
type="pil",
|
740 |
+
brush=gr.Brush(colors=["#000000"], default_size = 30, color_mode="fixed"),
|
741 |
+
layers = False,
|
742 |
+
interactive=True,
|
743 |
+
height=800,
|
744 |
+
# transforms=("crop"),
|
745 |
+
# crop_size=(640, 640),
|
746 |
+
)
|
747 |
+
|
748 |
+
prompt = gr.Textbox(label="Prompt", placeholder="Please input your instruction.",value='',lines=1)
|
749 |
+
|
750 |
+
with gr.Row():
|
751 |
+
mask_button = gr.Button("Generate Mask")
|
752 |
+
random_mask_button = gr.Button("Random Generated Mask")
|
753 |
+
with gr.Row():
|
754 |
+
dilation_mask_button = gr.Button("Dilation Generated Mask")
|
755 |
+
erosion_mask_button = gr.Button("Erosion Generated Mask")
|
756 |
+
|
757 |
+
with gr.Row():
|
758 |
+
generate_target_prompt_button = gr.Button("Generate Target Prompt")
|
759 |
+
run_button = gr.Button("Run")
|
760 |
+
|
761 |
+
|
762 |
+
target_prompt = gr.Text(
|
763 |
+
label="Target prompt",
|
764 |
+
max_lines=5,
|
765 |
+
placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
|
766 |
+
value='',
|
767 |
+
lines=2
|
768 |
+
)
|
769 |
+
|
770 |
+
resize_and_crop = gr.Checkbox(label="Resize and Crop (640 x 640)", value=False)
|
771 |
+
|
772 |
+
with gr.Accordion("More input params (highly-recommended)", open=False, elem_id="accordion1"):
|
773 |
+
negative_prompt = gr.Text(
|
774 |
+
label="Negative Prompt",
|
775 |
+
max_lines=5,
|
776 |
+
placeholder="Please input your negative prompt",
|
777 |
+
value='ugly, low quality',lines=1
|
778 |
+
)
|
779 |
+
|
780 |
+
control_strength = gr.Slider(
|
781 |
+
label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
|
782 |
+
)
|
783 |
+
with gr.Group():
|
784 |
+
seed = gr.Slider(
|
785 |
+
label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
|
786 |
+
)
|
787 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
|
788 |
+
|
789 |
+
blending = gr.Checkbox(label="Blending mode", value=True)
|
790 |
+
|
791 |
+
|
792 |
+
num_samples = gr.Slider(
|
793 |
+
label="Num samples", minimum=0, maximum=4, step=1, value=4
|
794 |
+
)
|
795 |
+
|
796 |
+
with gr.Group():
|
797 |
+
with gr.Row():
|
798 |
+
guidance_scale = gr.Slider(
|
799 |
+
label="Guidance scale",
|
800 |
+
minimum=1,
|
801 |
+
maximum=12,
|
802 |
+
step=0.1,
|
803 |
+
value=7.5,
|
804 |
+
)
|
805 |
+
num_inference_steps = gr.Slider(
|
806 |
+
label="Number of inference steps",
|
807 |
+
minimum=1,
|
808 |
+
maximum=50,
|
809 |
+
step=1,
|
810 |
+
value=50,
|
811 |
+
)
|
812 |
+
|
813 |
+
|
814 |
+
with gr.Column():
|
815 |
+
with gr.Row():
|
816 |
+
with gr.Tabs(elem_classes=["feedback"]):
|
817 |
+
with gr.TabItem("Mask"):
|
818 |
+
mask_gallery = gr.Gallery(label='Mask', show_label=False, elem_id="gallery", preview=True, height=360)
|
819 |
+
with gr.Tabs(elem_classes=["feedback"]):
|
820 |
+
with gr.TabItem("Masked Image"):
|
821 |
+
masked_gallery = gr.Gallery(label='Masked Image', show_label=False, elem_id="gallery", preview=True, height=360)
|
822 |
+
|
823 |
+
moving_pixels = gr.Slider(
|
824 |
+
label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
|
825 |
+
)
|
826 |
+
with gr.Row():
|
827 |
+
move_left_button = gr.Button("Move Left")
|
828 |
+
move_right_button = gr.Button("Move Right")
|
829 |
+
with gr.Row():
|
830 |
+
move_up_button = gr.Button("Move Up")
|
831 |
+
move_down_button = gr.Button("Move Down")
|
832 |
+
|
833 |
+
with gr.Tabs(elem_classes=["feedback"]):
|
834 |
+
with gr.TabItem("Outputs"):
|
835 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, height=360)
|
836 |
+
|
837 |
+
reset_button = gr.Button("Reset")
|
838 |
+
|
839 |
+
|
840 |
+
with gr.Row():
|
841 |
+
# # example = gr.Examples(
|
842 |
+
# # label="Quick Example",
|
843 |
+
# # examples=EXAMPLES,
|
844 |
+
# # inputs=[prompt, seed, result_gallery, mask_gallery, masked_gallery],
|
845 |
+
# # examples_per_page=10,
|
846 |
+
# # cache_examples=False,
|
847 |
+
# # )
|
848 |
+
example = gr.Examples(
|
849 |
+
label="Quick Example",
|
850 |
+
examples=EXAMPLES,
|
851 |
+
inputs=[input_image, prompt],
|
852 |
+
examples_per_page=10,
|
853 |
+
cache_examples=False,
|
854 |
+
)
|
855 |
+
# def process_example(prompt, seed, eg_output):
|
856 |
+
# import ipdb; ipdb.set_trace()
|
857 |
+
# eg_output_path = os.path.join("assets/", eg_output)
|
858 |
+
# return prompt, seed, [Image.open(eg_output_path)]
|
859 |
+
# example = gr.Examples(
|
860 |
+
# label="Quick Example",
|
861 |
+
# examples=EXAMPLES,
|
862 |
+
# inputs=[prompt, seed, eg_output],
|
863 |
+
# outputs=[prompt, seed, result_gallery],
|
864 |
+
# fn=process_example,
|
865 |
+
# examples_per_page=10,
|
866 |
+
# run_on_click=True,
|
867 |
+
# cache_examples=False,
|
868 |
+
# )
|
869 |
+
|
870 |
+
input_image.upload(
|
871 |
+
store_img,
|
872 |
+
[input_image],
|
873 |
+
[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt]
|
874 |
+
)
|
875 |
+
|
876 |
+
|
877 |
+
ips=[input_image,
|
878 |
+
original_image,
|
879 |
+
original_mask,
|
880 |
+
prompt,
|
881 |
+
negative_prompt,
|
882 |
+
control_strength,
|
883 |
+
seed,
|
884 |
+
randomize_seed,
|
885 |
+
guidance_scale,
|
886 |
+
num_inference_steps,
|
887 |
+
num_samples,
|
888 |
+
blending,
|
889 |
+
category,
|
890 |
+
target_prompt,
|
891 |
+
resize_and_crop]
|
892 |
+
|
893 |
+
## run brushedit
|
894 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, target_prompt])
|
895 |
+
|
896 |
+
## mask func
|
897 |
+
mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask, category])
|
898 |
+
random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask])
|
899 |
+
dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_and_crop], outputs=[ masked_gallery, mask_gallery, original_mask])
|
900 |
+
erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_and_crop], outputs=[ masked_gallery, mask_gallery, original_mask])
|
901 |
+
|
902 |
+
## move mask func
|
903 |
+
move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask])
|
904 |
+
move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask])
|
905 |
+
move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask])
|
906 |
+
move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask])
|
907 |
+
|
908 |
+
## prompt func
|
909 |
+
generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
|
910 |
+
|
911 |
+
## reset func
|
912 |
+
reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt])
|
913 |
+
|
914 |
+
demo.launch(server_name="0.0.0.0")
|
app/gpt4_o/instructions.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def create_editing_category_messages(editing_prompt):
|
2 |
+
messages = [{
|
3 |
+
"role": "system",
|
4 |
+
"content": [
|
5 |
+
{
|
6 |
+
"type": "text",
|
7 |
+
"text": "I will give you an image and an editing instruction of the image. Please output which type of editing category it is in. You can choose from the following categories: \n\
|
8 |
+
1. Addition: Adding new objects within the images, e.g., add a bird to the image \n\
|
9 |
+
2. Remove: Removing objects, e.g., remove the mask \n\
|
10 |
+
3. Local: Replace local parts of an object and later the object's attributes (e.g., make it smile) or alter an object's visual appearance without affecting its structure (e.g., change the cat to a dog) \n\
|
11 |
+
4. Global: Edit the entire image, e.g., let's see it in winter \n\
|
12 |
+
5. Background: Change the scene's background, e.g., have her walk on water, change the background to a beach, make the hedgehog in France, etc.",
|
13 |
+
},]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"role": "user",
|
17 |
+
"content": [
|
18 |
+
{
|
19 |
+
"type": "text",
|
20 |
+
"text": editing_prompt
|
21 |
+
},
|
22 |
+
]
|
23 |
+
}]
|
24 |
+
return messages
|
25 |
+
|
26 |
+
|
27 |
+
def create_ori_object_messages(editing_prompt):
|
28 |
+
|
29 |
+
messages = [
|
30 |
+
{
|
31 |
+
"role": "system",
|
32 |
+
"content": [
|
33 |
+
{
|
34 |
+
"type": "text",
|
35 |
+
"text": "I will give you an editing instruction of the image. Please output the object needed to be edited. You only need to output the basic description of the object in no more than 5 words. The output should only contain one noun. \n \
|
36 |
+
For example, the editing instruction is 'Change the white cat to a black dog'. Then you need to output: 'white cat'. Only output the new content. Do not output anything else."
|
37 |
+
},]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"role": "user",
|
41 |
+
"content": [
|
42 |
+
{
|
43 |
+
"type": "text",
|
44 |
+
"text": editing_prompt
|
45 |
+
}
|
46 |
+
]
|
47 |
+
}
|
48 |
+
]
|
49 |
+
return messages
|
50 |
+
|
51 |
+
|
52 |
+
def create_add_object_messages(editing_prompt, base64_image, height=640, width=640):
|
53 |
+
|
54 |
+
size_str = f"The image size is height {height}px and width {width}px. The top - left corner is coordinate [0 , 0]. The bottom - right corner is coordinnate [{height} , {width}]. "
|
55 |
+
|
56 |
+
messages = [
|
57 |
+
{
|
58 |
+
"role": "user",
|
59 |
+
"content": [
|
60 |
+
{
|
61 |
+
"type": "text",
|
62 |
+
"text": "I need to add an object to the image following the instruction: " + editing_prompt + ". " + size_str + " \n \
|
63 |
+
Can you give me a possible bounding box of the location for the added object? Please output with the format of [top - left x coordinate , top - left y coordinate , box width , box height]. You should only output the bounding box position and nothing else. Please refer to the example below for the desired format.\n\
|
64 |
+
[Examples]\n \
|
65 |
+
[19, 101, 32, 153]\n \
|
66 |
+
[54, 12, 242, 96]"
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"type": "image_url",
|
70 |
+
"image_url": {
|
71 |
+
"url":f"data:image/jpeg;base64,{base64_image}"
|
72 |
+
},
|
73 |
+
}
|
74 |
+
]
|
75 |
+
}
|
76 |
+
]
|
77 |
+
return messages
|
78 |
+
|
79 |
+
|
80 |
+
def create_apply_editing_messages(editing_prompt, base64_image):
|
81 |
+
messages = [
|
82 |
+
{
|
83 |
+
"role": "system",
|
84 |
+
"content": [
|
85 |
+
{
|
86 |
+
"type": "text",
|
87 |
+
"text": "I will provide an image along with an editing instruction. Please describe the new content that should be present in the image after applying the instruction. \n \
|
88 |
+
For example, if the original image content shows a grandmother wearing a mask and the instruction is 'remove the mask', your output should be: 'a grandmother'. The output should only include elements that remain in the image after the edit and should not mention elements that have been changed or removed, such as 'mask' in this example. Do not output 'sorry, xxx', even if it's a guess, directly output the answer you think is correct."
|
89 |
+
},]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"role": "user",
|
93 |
+
"content": [
|
94 |
+
{
|
95 |
+
"type": "text",
|
96 |
+
"text": editing_prompt
|
97 |
+
},
|
98 |
+
{"type": "image_url",
|
99 |
+
"image_url": {
|
100 |
+
"url":f"data:image/jpeg;base64,{base64_image}"
|
101 |
+
},
|
102 |
+
},
|
103 |
+
]
|
104 |
+
}
|
105 |
+
]
|
106 |
+
return messages
|
app/gpt4_o/requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchvision
|
2 |
+
transformers>=4.25.1
|
3 |
+
ftfy
|
4 |
+
tensorboard
|
5 |
+
datasets
|
6 |
+
Pillow==9.5.0
|
7 |
+
opencv-python
|
8 |
+
imgaug
|
9 |
+
accelerate==0.20.3
|
10 |
+
image-reward
|
11 |
+
hpsv2
|
12 |
+
torchmetrics
|
13 |
+
open-clip-torch
|
14 |
+
clip
|
15 |
+
# gradio==4.44.1
|
16 |
+
gradio==4.38.1
|
17 |
+
segment_anything
|
18 |
+
openai
|
app/gpt4_o/run_app.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export PYTHONPATH=.:$PYTHONPATH
|
2 |
+
|
3 |
+
export CUDA_VISIBLE_DEVICES=0
|
4 |
+
|
5 |
+
python app/gpt4_o/brushedit_app.py
|
app/gpt4_o/vlm_pipeline.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
from io import BytesIO
|
7 |
+
import numpy as np
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
|
11 |
+
from app.gpt4_o.instructions import (
|
12 |
+
create_editing_category_messages,
|
13 |
+
create_ori_object_messages,
|
14 |
+
create_add_object_messages,
|
15 |
+
create_apply_editing_messages)
|
16 |
+
|
17 |
+
from app.utils.utils import run_grounded_sam
|
18 |
+
|
19 |
+
|
20 |
+
def encode_image(img):
|
21 |
+
img = Image.fromarray(img.astype('uint8'))
|
22 |
+
buffered = BytesIO()
|
23 |
+
img.save(buffered, format="PNG")
|
24 |
+
img_bytes = buffered.getvalue()
|
25 |
+
return base64.b64encode(img_bytes).decode('utf-8')
|
26 |
+
|
27 |
+
|
28 |
+
def run_gpt4o_vl_inference(vlm,
|
29 |
+
messages):
|
30 |
+
response = vlm.chat.completions.create(
|
31 |
+
model="gpt-4o-2024-08-06",
|
32 |
+
messages=messages
|
33 |
+
)
|
34 |
+
response_str = response.choices[0].message.content
|
35 |
+
return response_str
|
36 |
+
|
37 |
+
|
38 |
+
def vlm_response_editing_type(vlm,
|
39 |
+
image,
|
40 |
+
editing_prompt):
|
41 |
+
|
42 |
+
base64_image = encode_image(image)
|
43 |
+
|
44 |
+
messages = create_editing_category_messages(editing_prompt)
|
45 |
+
|
46 |
+
response_str = run_gpt4o_vl_inference(vlm, messages)
|
47 |
+
|
48 |
+
for category_name in ["Addition","Remove","Local","Global","Background"]:
|
49 |
+
if category_name.lower() in response_str.lower():
|
50 |
+
return category_name
|
51 |
+
raise ValueError("Please input correct commands, including add, delete, and modify commands.")
|
52 |
+
|
53 |
+
|
54 |
+
def vlm_response_object_wait_for_edit(vlm,
|
55 |
+
category,
|
56 |
+
editing_prompt):
|
57 |
+
if category in ["Background", "Global", "Addition"]:
|
58 |
+
edit_object = "nan"
|
59 |
+
return edit_object
|
60 |
+
|
61 |
+
messages = create_ori_object_messages(editing_prompt)
|
62 |
+
|
63 |
+
response_str = run_gpt4o_vl_inference(vlm, messages)
|
64 |
+
return response_str
|
65 |
+
|
66 |
+
|
67 |
+
def vlm_response_mask(vlm,
|
68 |
+
category,
|
69 |
+
image,
|
70 |
+
editing_prompt,
|
71 |
+
object_wait_for_edit,
|
72 |
+
sam=None,
|
73 |
+
sam_predictor=None,
|
74 |
+
sam_automask_generator=None,
|
75 |
+
groundingdino_model=None,
|
76 |
+
):
|
77 |
+
mask = None
|
78 |
+
if editing_prompt is None or len(editing_prompt)==0:
|
79 |
+
raise gr.Error("Please input the editing instruction!")
|
80 |
+
height, width = image.shape[:2]
|
81 |
+
if category=="Addition":
|
82 |
+
base64_image = encode_image(image)
|
83 |
+
messages = create_add_object_messages(editing_prompt, base64_image, height=height, width=width)
|
84 |
+
try:
|
85 |
+
response_str = run_gpt4o_vl_inference(vlm, messages)
|
86 |
+
pattern = r'\[\d{1,3}(?:,\s*\d{1,3}){3}\]'
|
87 |
+
box = re.findall(pattern, response_str)
|
88 |
+
box = box[0][1:-1].split(",")
|
89 |
+
for i in range(len(box)):
|
90 |
+
box[i] = int(box[i])
|
91 |
+
cus_mask = np.zeros((height, width))
|
92 |
+
cus_mask[box[1]: box[1]+box[3], box[0]: box[0]+box[2]]=255
|
93 |
+
mask = cus_mask
|
94 |
+
except:
|
95 |
+
raise gr.Error("Please set the mask manually, MLLM cannot output the mask!")
|
96 |
+
|
97 |
+
elif category=="Background":
|
98 |
+
labels = "background"
|
99 |
+
elif category=="Global":
|
100 |
+
mask = 255 * np.zeros((height, width))
|
101 |
+
else:
|
102 |
+
labels = object_wait_for_edit
|
103 |
+
|
104 |
+
if mask is None:
|
105 |
+
for thresh in [0.3,0.25,0.2,0.15,0.1,0.05,0]:
|
106 |
+
try:
|
107 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
108 |
+
detections = run_grounded_sam(
|
109 |
+
input_image={"image":Image.fromarray(image.astype('uint8')),
|
110 |
+
"mask":None},
|
111 |
+
text_prompt=labels,
|
112 |
+
task_type="seg",
|
113 |
+
box_threshold=thresh,
|
114 |
+
text_threshold=0.25,
|
115 |
+
iou_threshold=0.5,
|
116 |
+
scribble_mode="split",
|
117 |
+
sam=sam,
|
118 |
+
sam_predictor=sam_predictor,
|
119 |
+
sam_automask_generator=sam_automask_generator,
|
120 |
+
groundingdino_model=groundingdino_model,
|
121 |
+
device=device,
|
122 |
+
)
|
123 |
+
mask = np.array(detections[0,0,...].cpu()) * 255
|
124 |
+
break
|
125 |
+
except:
|
126 |
+
print(f"wrong in threshhold: {thresh}, continue")
|
127 |
+
continue
|
128 |
+
return mask
|
129 |
+
|
130 |
+
|
131 |
+
def vlm_response_prompt_after_apply_instruction(vlm,
|
132 |
+
image,
|
133 |
+
editing_prompt):
|
134 |
+
base64_image = encode_image(image)
|
135 |
+
messages = create_apply_editing_messages(editing_prompt, base64_image)
|
136 |
+
|
137 |
+
response_str = run_gpt4o_vl_inference(vlm, messages)
|
138 |
+
return response_str
|
app/utils/utils.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torchvision
|
4 |
+
|
5 |
+
from scipy import ndimage
|
6 |
+
|
7 |
+
# BLIP
|
8 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
9 |
+
|
10 |
+
# SAM
|
11 |
+
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
|
12 |
+
|
13 |
+
# GroundingDINO
|
14 |
+
from groundingdino.datasets import transforms as T
|
15 |
+
from groundingdino.models import build_model
|
16 |
+
from groundingdino.util.slconfig import SLConfig
|
17 |
+
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
18 |
+
|
19 |
+
|
20 |
+
def load_grounding_dino_model(model_config_path, model_checkpoint_path, device):
|
21 |
+
args = SLConfig.fromfile(model_config_path)
|
22 |
+
args.device = device
|
23 |
+
model = build_model(args)
|
24 |
+
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
25 |
+
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
26 |
+
print(load_res)
|
27 |
+
_ = model.eval()
|
28 |
+
return model
|
29 |
+
|
30 |
+
|
31 |
+
def generate_caption(processor, blip_model, raw_image, device):
|
32 |
+
# unconditional image captioning
|
33 |
+
inputs = processor(raw_image, return_tensors="pt").to(device, torch.float16)
|
34 |
+
out = blip_model.generate(**inputs)
|
35 |
+
caption = processor.decode(out[0], skip_special_tokens=True)
|
36 |
+
return caption
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
def transform_image(image_pil):
|
41 |
+
|
42 |
+
transform = T.Compose(
|
43 |
+
[
|
44 |
+
T.RandomResize([800], max_size=1333),
|
45 |
+
T.ToTensor(),
|
46 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
47 |
+
]
|
48 |
+
)
|
49 |
+
image, _ = transform(image_pil, None) # 3, h, w
|
50 |
+
return image
|
51 |
+
|
52 |
+
|
53 |
+
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True):
|
54 |
+
caption = caption.lower()
|
55 |
+
caption = caption.strip()
|
56 |
+
if not caption.endswith("."):
|
57 |
+
caption = caption + "."
|
58 |
+
|
59 |
+
with torch.no_grad():
|
60 |
+
outputs = model(image[None], captions=[caption])
|
61 |
+
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
62 |
+
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
|
63 |
+
logits.shape[0]
|
64 |
+
|
65 |
+
# filter output
|
66 |
+
logits_filt = logits.clone()
|
67 |
+
boxes_filt = boxes.clone()
|
68 |
+
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
69 |
+
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
70 |
+
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
71 |
+
logits_filt.shape[0]
|
72 |
+
|
73 |
+
# get phrase
|
74 |
+
tokenlizer = model.tokenizer
|
75 |
+
tokenized = tokenlizer(caption)
|
76 |
+
# build pred
|
77 |
+
pred_phrases = []
|
78 |
+
scores = []
|
79 |
+
for logit, box in zip(logits_filt, boxes_filt):
|
80 |
+
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
81 |
+
if with_logits:
|
82 |
+
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
83 |
+
else:
|
84 |
+
pred_phrases.append(pred_phrase)
|
85 |
+
scores.append(logit.max().item())
|
86 |
+
|
87 |
+
return boxes_filt, torch.Tensor(scores), pred_phrases
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
def run_grounded_sam(input_image,
|
92 |
+
text_prompt,
|
93 |
+
task_type,
|
94 |
+
box_threshold,
|
95 |
+
text_threshold,
|
96 |
+
iou_threshold,
|
97 |
+
scribble_mode,
|
98 |
+
sam,
|
99 |
+
groundingdino_model,
|
100 |
+
sam_predictor=None,
|
101 |
+
sam_automask_generator=None,
|
102 |
+
device="cuda"):
|
103 |
+
|
104 |
+
global blip_processor, blip_model, inpaint_pipeline
|
105 |
+
|
106 |
+
# load image
|
107 |
+
image = input_image["image"]
|
108 |
+
scribble = input_image["mask"]
|
109 |
+
size = image.size # w, h
|
110 |
+
|
111 |
+
if sam_predictor is None:
|
112 |
+
sam_predictor = SamPredictor(sam)
|
113 |
+
sam_automask_generator = SamAutomaticMaskGenerator(sam)
|
114 |
+
|
115 |
+
image_pil = image.convert("RGB")
|
116 |
+
image = np.array(image_pil)
|
117 |
+
|
118 |
+
if task_type == 'scribble':
|
119 |
+
sam_predictor.set_image(image)
|
120 |
+
scribble = scribble.convert("RGB")
|
121 |
+
scribble = np.array(scribble)
|
122 |
+
scribble = scribble.transpose(2, 1, 0)[0]
|
123 |
+
|
124 |
+
# 将连通域进行标记
|
125 |
+
labeled_array, num_features = ndimage.label(scribble >= 255)
|
126 |
+
|
127 |
+
# 计算每个连通域的质心
|
128 |
+
centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1))
|
129 |
+
centers = np.array(centers)
|
130 |
+
|
131 |
+
point_coords = torch.from_numpy(centers)
|
132 |
+
point_coords = sam_predictor.transform.apply_coords_torch(point_coords, image.shape[:2])
|
133 |
+
point_coords = point_coords.unsqueeze(0).to(device)
|
134 |
+
point_labels = torch.from_numpy(np.array([1] * len(centers))).unsqueeze(0).to(device)
|
135 |
+
if scribble_mode == 'split':
|
136 |
+
point_coords = point_coords.permute(1, 0, 2)
|
137 |
+
point_labels = point_labels.permute(1, 0)
|
138 |
+
masks, _, _ = sam_predictor.predict_torch(
|
139 |
+
point_coords=point_coords if len(point_coords) > 0 else None,
|
140 |
+
point_labels=point_labels if len(point_coords) > 0 else None,
|
141 |
+
mask_input = None,
|
142 |
+
boxes = None,
|
143 |
+
multimask_output = False,
|
144 |
+
)
|
145 |
+
elif task_type == 'automask':
|
146 |
+
masks = sam_automask_generator.generate(image)
|
147 |
+
else:
|
148 |
+
transformed_image = transform_image(image_pil)
|
149 |
+
|
150 |
+
if task_type == 'automatic':
|
151 |
+
# generate caption and tags
|
152 |
+
# use Tag2Text can generate better captions
|
153 |
+
# https://huggingface.co/spaces/xinyu1205/Tag2Text
|
154 |
+
# but there are some bugs...
|
155 |
+
blip_processor = blip_processor or BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
156 |
+
blip_model = blip_model or BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
|
157 |
+
text_prompt = generate_caption(blip_processor, blip_model, image_pil, device)
|
158 |
+
print(f"Caption: {text_prompt}")
|
159 |
+
|
160 |
+
# run grounding dino model
|
161 |
+
boxes_filt, scores, pred_phrases = get_grounding_output(
|
162 |
+
groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold
|
163 |
+
)
|
164 |
+
|
165 |
+
# process boxes
|
166 |
+
H, W = size[1], size[0]
|
167 |
+
for i in range(boxes_filt.size(0)):
|
168 |
+
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
169 |
+
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
170 |
+
boxes_filt[i][2:] += boxes_filt[i][:2]
|
171 |
+
|
172 |
+
boxes_filt = boxes_filt.cpu()
|
173 |
+
|
174 |
+
|
175 |
+
if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic':
|
176 |
+
sam_predictor.set_image(image)
|
177 |
+
|
178 |
+
if task_type == 'automatic':
|
179 |
+
# use NMS to handle overlapped boxes
|
180 |
+
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
|
181 |
+
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
|
182 |
+
boxes_filt = boxes_filt[nms_idx]
|
183 |
+
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
|
184 |
+
print(f"After NMS: {boxes_filt.shape[0]} boxes")
|
185 |
+
print(f"Revise caption with number: {text_prompt}")
|
186 |
+
|
187 |
+
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
|
188 |
+
|
189 |
+
masks, _, _ = sam_predictor.predict_torch(
|
190 |
+
point_coords = None,
|
191 |
+
point_labels = None,
|
192 |
+
boxes = transformed_boxes,
|
193 |
+
multimask_output = False,
|
194 |
+
)
|
195 |
+
return masks
|
196 |
+
else:
|
197 |
+
print("task_type:{} error!".format(task_type))
|
assets/hedgehog_rm_fg/hedgehog.png
ADDED
![]() |
Git LFS Details
|
assets/hedgehog_rm_fg/image_edit_82314e18-c64c-4003-9ef9-52cebf254b2f_2.png
ADDED
![]() |
Git LFS Details
|
assets/hedgehog_rm_fg/mask_82314e18-c64c-4003-9ef9-52cebf254b2f.png
ADDED
![]() |
Git LFS Details
|
assets/hedgehog_rm_fg/masked_image_82314e18-c64c-4003-9ef9-52cebf254b2f.png
ADDED
![]() |
Git LFS Details
|
assets/hedgehog_rm_fg/prompt.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
648464818: remove the hedgehog.
|
assets/hedgehog_rp_bg/hedgehog.png
ADDED
![]() |
Git LFS Details
|
assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png
ADDED
![]() |
Git LFS Details
|
assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png
ADDED
![]() |
Git LFS Details
|
assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png
ADDED
![]() |
Git LFS Details
|
assets/hedgehog_rp_bg/prompt.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
648464818: make the hedgehog in Italy.
|
assets/hedgehog_rp_fg/hedgehog.png
ADDED
![]() |
Git LFS Details
|
assets/hedgehog_rp_fg/image_edit_5cab3448-5a3a-459c-9144-35cca3d34273_0.png
ADDED
![]() |
Git LFS Details
|
assets/hedgehog_rp_fg/mask_5cab3448-5a3a-459c-9144-35cca3d34273.png
ADDED
![]() |
Git LFS Details
|
assets/hedgehog_rp_fg/masked_image_5cab3448-5a3a-459c-9144-35cca3d34273.png
ADDED
![]() |
Git LFS Details
|
assets/hedgehog_rp_fg/prompt.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
648464818: replace the hedgehog to flamingo.
|
assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png
ADDED
![]() |
Git LFS Details
|
assets/mona_lisa/mask_aae09152-4495-4332-b691-a0c7bff524be.png
ADDED
![]() |
Git LFS Details
|
assets/mona_lisa/masked_image_aae09152-4495-4332-b691-a0c7bff524be.png
ADDED
![]() |
Git LFS Details
|
assets/mona_lisa/mona_lisa.png
ADDED
![]() |
Git LFS Details
|
assets/mona_lisa/prompt.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
648464818: add a shining necklace.
|
assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png
ADDED
![]() |
Git LFS Details
|
assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png
ADDED
![]() |
Git LFS Details
|
assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png
ADDED
![]() |
Git LFS Details
|
assets/sunflower_girl/prompt.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
648464818: add a wreath on head..
|
assets/sunflower_girl/sunflower_girl.png
ADDED
![]() |
Git LFS Details
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
torchaudio
|
4 |
+
transformers>=4.25.1
|
5 |
+
gradio==4.38.1
|
6 |
+
ftfy
|
7 |
+
tensorboard
|
8 |
+
datasets
|
9 |
+
Pillow==9.5.0
|
10 |
+
opencv-python
|
11 |
+
imgaug
|
12 |
+
accelerate==0.20.3
|
13 |
+
image-reward
|
14 |
+
hpsv2
|
15 |
+
torchmetrics
|
16 |
+
open-clip-torch
|
17 |
+
clip
|
18 |
+
segment_anything
|
19 |
+
git+https://github.com/liyaowei-stu/BrushEdit.git
|
20 |
+
git+https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/GroundingDINO
|