CuriousDolphin dhkim2810 commited on
Commit
188a6cf
·
0 Parent(s):

Duplicate from dhkim2810/MobileSAM

Browse files

Co-authored-by: Donghoon Kim <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/sa_1309.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/sa_192.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/sa_414.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/sa_862.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MobileSAM
3
+ emoji: 🐠
4
+ colorFrom: indigo
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ python_version: 3.8.10
8
+ sdk_version: 3.35.2
9
+ app_file: app.py
10
+ pinned: false
11
+ license: apache-2.0
12
+ duplicated_from: dhkim2810/MobileSAM
13
+ ---
14
+
15
+ # Faster Segment Anything(MobileSAM)
16
+
17
+ Official PyTorch Implementation of the <a href="https://github.com/ChaoningZhang/MobileSAM">.
18
+
19
+
20
+ **MobileSAM** performs on par with the original SAM (at least visually) and keeps exactly the same pipeline as the original SAM except for a change on the image encoder.
21
+ Specifically, we replace the original heavyweight ViT-H encoder (632M) with a much smaller Tiny-ViT (5M). On a single GPU, MobileSAM runs around 12ms per image: 8ms on the image encoder and 4ms on the mask decoder.
22
+
23
+
24
+ ## License
25
+
26
+ The model is licensed under the [Apache 2.0 license](LICENSE).
27
+
28
+
29
+ ## Acknowledgement
30
+
31
+ - [Segment Anything](https://segment-anything.com/) provides the SA-1B dataset and the base codes.
32
+ - [TinyViT](https://github.com/microsoft/Cream/tree/main/TinyViT) provides codes and pre-trained models.
33
+
34
+ ## Citing MobileSAM
35
+
36
+ If you find this project useful for your research, please consider citing the following BibTeX entry.
37
+
38
+ ```bibtex
39
+ @article{mobile_sam,
40
+ title={Faster Segment Anything: Towards Lightweight SAM for Mobile Applications},
41
+ author={Zhang, Chaoning and Han, Dongshen and Qiao, Yu and Kim, Jung Uk and Bae, Sung Ho and Lee, Seungkyu and Hong, Choong Seon},
42
+ journal={arXiv preprint arXiv:2306.14289},
43
+ year={2023}
44
+ }
45
+ ```
app.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ from mobile_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
6
+ from PIL import ImageDraw
7
+ from utils.tools import box_prompt, format_results, point_prompt
8
+ from utils.tools_gradio import fast_process
9
+
10
+ # Most of our demo code is from [FastSAM Demo](https://huggingface.co/spaces/An-619/FastSAM). Huge thanks for AN-619.
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ # Load the pre-trained model
15
+ sam_checkpoint = "./mobile_sam.pt"
16
+ model_type = "vit_t"
17
+
18
+ mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
19
+ mobile_sam = mobile_sam.to(device=device)
20
+ mobile_sam.eval()
21
+
22
+ mask_generator = SamAutomaticMaskGenerator(mobile_sam)
23
+ predictor = SamPredictor(mobile_sam)
24
+
25
+ # Description
26
+ title = "<center><strong><font size='8'>Faster Segment Anything(MobileSAM)<font></strong></center>"
27
+
28
+ description_e = """This is a demo of [Faster Segment Anything(MobileSAM) Model](https://github.com/ChaoningZhang/MobileSAM).
29
+
30
+ We will provide box mode soon.
31
+
32
+ Enjoy!
33
+
34
+ """
35
+
36
+ description_p = """ # Instructions for point mode
37
+
38
+ 0. Restart by click the Restart button
39
+ 1. Select a point with Add Mask for the foreground (Must)
40
+ 2. Select a point with Remove Area for the background (Optional)
41
+ 3. Click the Start Segmenting.
42
+
43
+ """
44
+
45
+ examples = [
46
+ ["assets/picture3.jpg"],
47
+ ["assets/picture4.jpg"],
48
+ ["assets/picture5.jpg"],
49
+ ["assets/picture6.jpg"],
50
+ ["assets/picture1.jpg"],
51
+ ["assets/picture2.jpg"],
52
+ ]
53
+
54
+ default_example = examples[0]
55
+
56
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
57
+
58
+
59
+ @torch.no_grad()
60
+ def segment_everything(
61
+ image,
62
+ input_size=1024,
63
+ better_quality=False,
64
+ withContours=True,
65
+ use_retina=True,
66
+ mask_random_color=True,
67
+ ):
68
+ global mask_generator
69
+
70
+ input_size = int(input_size)
71
+ w, h = image.size
72
+ scale = input_size / max(w, h)
73
+ new_w = int(w * scale)
74
+ new_h = int(h * scale)
75
+ image = image.resize((new_w, new_h))
76
+
77
+ nd_image = np.array(image)
78
+ annotations = mask_generator.generate(nd_image)
79
+
80
+ fig = fast_process(
81
+ annotations=annotations,
82
+ image=image,
83
+ device=device,
84
+ scale=(1024 // input_size),
85
+ better_quality=better_quality,
86
+ mask_random_color=mask_random_color,
87
+ bbox=None,
88
+ use_retina=use_retina,
89
+ withContours=withContours,
90
+ )
91
+ return fig
92
+
93
+
94
+ def segment_with_points(
95
+ image,
96
+ input_size=1024,
97
+ better_quality=False,
98
+ withContours=True,
99
+ use_retina=True,
100
+ mask_random_color=True,
101
+ ):
102
+ global global_points
103
+ global global_point_label
104
+
105
+ input_size = int(input_size)
106
+ w, h = image.size
107
+ scale = input_size / max(w, h)
108
+ new_w = int(w * scale)
109
+ new_h = int(h * scale)
110
+ image = image.resize((new_w, new_h))
111
+
112
+ scaled_points = np.array([[int(x * scale) for x in point] for point in global_points])
113
+ scaled_point_label = np.array(global_point_label)
114
+
115
+ nd_image = np.array(image)
116
+ predictor.set_image(nd_image)
117
+ masks, scores, logits = predictor.predict(
118
+ point_coords=scaled_points,
119
+ point_labels=scaled_point_label,
120
+ multimask_output=True,
121
+ )
122
+
123
+ results = format_results(masks, scores, logits, 0)
124
+
125
+ annotations, _ = point_prompt(
126
+ results, scaled_points, scaled_point_label, new_h, new_w
127
+ )
128
+ annotations = np.array([annotations])
129
+
130
+ fig = fast_process(
131
+ annotations=annotations,
132
+ image=image,
133
+ device=device,
134
+ scale=(1024 // input_size),
135
+ better_quality=better_quality,
136
+ mask_random_color=mask_random_color,
137
+ bbox=None,
138
+ use_retina=use_retina,
139
+ withContours=withContours,
140
+ )
141
+
142
+ global_points = []
143
+ global_point_label = []
144
+ # return fig, None
145
+ return fig, image
146
+
147
+
148
+ def get_points_with_draw(image, label, evt: gr.SelectData):
149
+ global global_points
150
+ global global_point_label
151
+
152
+ x, y = evt.index[0], evt.index[1]
153
+ point_radius, point_color = 15, (255, 255, 0) if label == "Add Mask" else (
154
+ 255,
155
+ 0,
156
+ 255,
157
+ )
158
+ global_points.append([x, y])
159
+ global_point_label.append(1 if label == "Add Mask" else 0)
160
+
161
+ print(x, y, label == "Add Mask")
162
+
163
+ # 创建一个可以在图像上绘图的对象
164
+ draw = ImageDraw.Draw(image)
165
+ draw.ellipse(
166
+ [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
167
+ fill=point_color,
168
+ )
169
+ return image
170
+
171
+
172
+ cond_img_e = gr.Image(label="Input", value=default_example[0], type="pil")
173
+ cond_img_p = gr.Image(label="Input with points", value=default_example[0], type="pil")
174
+
175
+ segm_img_e = gr.Image(label="Segmented Image", interactive=False, type="pil")
176
+ segm_img_p = gr.Image(
177
+ label="Segmented Image with points", interactive=False, type="pil"
178
+ )
179
+
180
+ global_points = []
181
+ global_point_label = []
182
+
183
+ input_size_slider = gr.components.Slider(
184
+ minimum=512,
185
+ maximum=1024,
186
+ value=1024,
187
+ step=64,
188
+ label="Input_size",
189
+ info="Our model was trained on a size of 1024",
190
+ )
191
+
192
+ with gr.Blocks(css=css, title="Faster Segment Anything(MobileSAM)") as demo:
193
+ with gr.Row():
194
+ with gr.Column(scale=1):
195
+ # Title
196
+ gr.Markdown(title)
197
+
198
+ # with gr.Tab("Everything mode"):
199
+ # # Images
200
+ # with gr.Row(variant="panel"):
201
+ # with gr.Column(scale=1):
202
+ # cond_img_e.render()
203
+ #
204
+ # with gr.Column(scale=1):
205
+ # segm_img_e.render()
206
+ #
207
+ # # Submit & Clear
208
+ # with gr.Row():
209
+ # with gr.Column():
210
+ # input_size_slider.render()
211
+ #
212
+ # with gr.Row():
213
+ # contour_check = gr.Checkbox(
214
+ # value=True,
215
+ # label="withContours",
216
+ # info="draw the edges of the masks",
217
+ # )
218
+ #
219
+ # with gr.Column():
220
+ # segment_btn_e = gr.Button(
221
+ # "Segment Everything", variant="primary"
222
+ # )
223
+ # clear_btn_e = gr.Button("Clear", variant="secondary")
224
+ #
225
+ # gr.Markdown("Try some of the examples below ⬇️")
226
+ # gr.Examples(
227
+ # examples=examples,
228
+ # inputs=[cond_img_e],
229
+ # outputs=segm_img_e,
230
+ # fn=segment_everything,
231
+ # cache_examples=True,
232
+ # examples_per_page=4,
233
+ # )
234
+ #
235
+ # with gr.Column():
236
+ # with gr.Accordion("Advanced options", open=False):
237
+ # # text_box = gr.Textbox(label="text prompt")
238
+ # with gr.Row():
239
+ # mor_check = gr.Checkbox(
240
+ # value=False,
241
+ # label="better_visual_quality",
242
+ # info="better quality using morphologyEx",
243
+ # )
244
+ # with gr.Column():
245
+ # retina_check = gr.Checkbox(
246
+ # value=True,
247
+ # label="use_retina",
248
+ # info="draw high-resolution segmentation masks",
249
+ # )
250
+ # # Description
251
+ # gr.Markdown(description_e)
252
+ #
253
+ with gr.Tab("Point mode"):
254
+ # Images
255
+ with gr.Row(variant="panel"):
256
+ with gr.Column(scale=1):
257
+ cond_img_p.render()
258
+
259
+ with gr.Column(scale=1):
260
+ segm_img_p.render()
261
+
262
+ # Submit & Clear
263
+ with gr.Row():
264
+ with gr.Column():
265
+ with gr.Row():
266
+ add_or_remove = gr.Radio(
267
+ ["Add Mask", "Remove Area"],
268
+ value="Add Mask",
269
+ )
270
+
271
+ with gr.Column():
272
+ segment_btn_p = gr.Button(
273
+ "Start segmenting!", variant="primary"
274
+ )
275
+ clear_btn_p = gr.Button("Restart", variant="secondary")
276
+
277
+ gr.Markdown("Try some of the examples below ⬇️")
278
+ gr.Examples(
279
+ examples=examples,
280
+ inputs=[cond_img_p],
281
+ # outputs=segm_img_p,
282
+ # fn=segment_with_points,
283
+ # cache_examples=True,
284
+ examples_per_page=4,
285
+ )
286
+
287
+ with gr.Column():
288
+ # Description
289
+ gr.Markdown(description_p)
290
+
291
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
292
+
293
+ # segment_btn_e.click(
294
+ # segment_everything,
295
+ # inputs=[
296
+ # cond_img_e,
297
+ # input_size_slider,
298
+ # mor_check,
299
+ # contour_check,
300
+ # retina_check,
301
+ # ],
302
+ # outputs=segm_img_e,
303
+ # )
304
+
305
+ segment_btn_p.click(
306
+ segment_with_points, inputs=[cond_img_p], outputs=[segm_img_p, cond_img_p]
307
+ )
308
+
309
+ def clear():
310
+ return None, None
311
+
312
+ def clear_text():
313
+ return None, None, None
314
+
315
+ # clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
316
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
317
+
318
+ demo.queue()
319
+ demo.launch()
assets/picture1.jpg ADDED
assets/picture2.jpg ADDED
assets/picture3.jpg ADDED
assets/picture4.jpg ADDED
assets/picture5.jpg ADDED
assets/picture6.jpg ADDED
mobile_sam.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f
3
+ size 40728226
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ opencv-python
5
+ git+https://github.com/dhkim2810/MobileSAM.git
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (170 Bytes). View file
 
utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (163 Bytes). View file
 
utils/__pycache__/tools.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
utils/__pycache__/tools.cpython-38.pyc ADDED
Binary file (10.8 kB). View file
 
utils/__pycache__/tools_gradio.cpython-310.pyc ADDED
Binary file (4.17 kB). View file
 
utils/__pycache__/tools_gradio.cpython-38.pyc ADDED
Binary file (4.22 kB). View file
 
utils/tools.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import cv2
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+
10
+
11
+ def convert_box_xywh_to_xyxy(box):
12
+ x1 = box[0]
13
+ y1 = box[1]
14
+ x2 = box[0] + box[2]
15
+ y2 = box[1] + box[3]
16
+ return [x1, y1, x2, y2]
17
+
18
+
19
+ def segment_image(image, bbox):
20
+ image_array = np.array(image)
21
+ segmented_image_array = np.zeros_like(image_array)
22
+ x1, y1, x2, y2 = bbox
23
+ segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
24
+ segmented_image = Image.fromarray(segmented_image_array)
25
+ black_image = Image.new("RGB", image.size, (255, 255, 255))
26
+ # transparency_mask = np.zeros_like((), dtype=np.uint8)
27
+ transparency_mask = np.zeros(
28
+ (image_array.shape[0], image_array.shape[1]), dtype=np.uint8
29
+ )
30
+ transparency_mask[y1:y2, x1:x2] = 255
31
+ transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
32
+ black_image.paste(segmented_image, mask=transparency_mask_image)
33
+ return black_image
34
+
35
+
36
+ def format_results(masks, scores, logits, filter=0):
37
+ annotations = []
38
+ n = len(scores)
39
+ for i in range(n):
40
+ annotation = {}
41
+
42
+ mask = masks[i]
43
+ tmp = np.where(mask != 0)
44
+ if np.sum(mask) < filter:
45
+ continue
46
+ annotation["id"] = i
47
+ annotation["segmentation"] = mask
48
+ annotation["bbox"] = [
49
+ np.min(tmp[0]),
50
+ np.min(tmp[1]),
51
+ np.max(tmp[1]),
52
+ np.max(tmp[0]),
53
+ ]
54
+ annotation["score"] = scores[i]
55
+ annotation["area"] = annotation["segmentation"].sum()
56
+ annotations.append(annotation)
57
+ return annotations
58
+
59
+
60
+ def filter_masks(annotations): # filter the overlap mask
61
+ annotations.sort(key=lambda x: x["area"], reverse=True)
62
+ to_remove = set()
63
+ for i in range(0, len(annotations)):
64
+ a = annotations[i]
65
+ for j in range(i + 1, len(annotations)):
66
+ b = annotations[j]
67
+ if i != j and j not in to_remove:
68
+ # check if
69
+ if b["area"] < a["area"]:
70
+ if (a["segmentation"] & b["segmentation"]).sum() / b[
71
+ "segmentation"
72
+ ].sum() > 0.8:
73
+ to_remove.add(j)
74
+
75
+ return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
76
+
77
+
78
+ def get_bbox_from_mask(mask):
79
+ mask = mask.astype(np.uint8)
80
+ contours, hierarchy = cv2.findContours(
81
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
82
+ )
83
+ x1, y1, w, h = cv2.boundingRect(contours[0])
84
+ x2, y2 = x1 + w, y1 + h
85
+ if len(contours) > 1:
86
+ for b in contours:
87
+ x_t, y_t, w_t, h_t = cv2.boundingRect(b)
88
+ # 将多个bbox合并成一个
89
+ x1 = min(x1, x_t)
90
+ y1 = min(y1, y_t)
91
+ x2 = max(x2, x_t + w_t)
92
+ y2 = max(y2, y_t + h_t)
93
+ h = y2 - y1
94
+ w = x2 - x1
95
+ return [x1, y1, x2, y2]
96
+
97
+
98
+ def fast_process(
99
+ annotations, args, mask_random_color, bbox=None, points=None, edges=False
100
+ ):
101
+ if isinstance(annotations[0], dict):
102
+ annotations = [annotation["segmentation"] for annotation in annotations]
103
+ result_name = os.path.basename(args.img_path)
104
+ image = cv2.imread(args.img_path)
105
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
106
+ original_h = image.shape[0]
107
+ original_w = image.shape[1]
108
+ if sys.platform == "darwin":
109
+ plt.switch_backend("TkAgg")
110
+ plt.figure(figsize=(original_w / 100, original_h / 100))
111
+ # Add subplot with no margin.
112
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
113
+ plt.margins(0, 0)
114
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
115
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
116
+ plt.imshow(image)
117
+ if args.better_quality == True:
118
+ if isinstance(annotations[0], torch.Tensor):
119
+ annotations = np.array(annotations.cpu())
120
+ for i, mask in enumerate(annotations):
121
+ mask = cv2.morphologyEx(
122
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
123
+ )
124
+ annotations[i] = cv2.morphologyEx(
125
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
126
+ )
127
+ if args.device == "cpu":
128
+ annotations = np.array(annotations)
129
+ fast_show_mask(
130
+ annotations,
131
+ plt.gca(),
132
+ random_color=mask_random_color,
133
+ bbox=bbox,
134
+ points=points,
135
+ point_label=args.point_label,
136
+ retinamask=args.retina,
137
+ target_height=original_h,
138
+ target_width=original_w,
139
+ )
140
+ else:
141
+ if isinstance(annotations[0], np.ndarray):
142
+ annotations = torch.from_numpy(annotations)
143
+ fast_show_mask_gpu(
144
+ annotations,
145
+ plt.gca(),
146
+ random_color=args.randomcolor,
147
+ bbox=bbox,
148
+ points=points,
149
+ point_label=args.point_label,
150
+ retinamask=args.retina,
151
+ target_height=original_h,
152
+ target_width=original_w,
153
+ )
154
+ if isinstance(annotations, torch.Tensor):
155
+ annotations = annotations.cpu().numpy()
156
+ if args.withContours == True:
157
+ contour_all = []
158
+ temp = np.zeros((original_h, original_w, 1))
159
+ for i, mask in enumerate(annotations):
160
+ if type(mask) == dict:
161
+ mask = mask["segmentation"]
162
+ annotation = mask.astype(np.uint8)
163
+ if args.retina == False:
164
+ annotation = cv2.resize(
165
+ annotation,
166
+ (original_w, original_h),
167
+ interpolation=cv2.INTER_NEAREST,
168
+ )
169
+ contours, hierarchy = cv2.findContours(
170
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
171
+ )
172
+ for contour in contours:
173
+ contour_all.append(contour)
174
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
175
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
176
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
177
+ plt.imshow(contour_mask)
178
+
179
+ save_path = args.output
180
+ if not os.path.exists(save_path):
181
+ os.makedirs(save_path)
182
+ plt.axis("off")
183
+ fig = plt.gcf()
184
+ plt.draw()
185
+
186
+ try:
187
+ buf = fig.canvas.tostring_rgb()
188
+ except AttributeError:
189
+ fig.canvas.draw()
190
+ buf = fig.canvas.tostring_rgb()
191
+
192
+ cols, rows = fig.canvas.get_width_height()
193
+ img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
194
+ cv2.imwrite(
195
+ os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
196
+ )
197
+
198
+
199
+ # CPU post process
200
+ def fast_show_mask(
201
+ annotation,
202
+ ax,
203
+ random_color=False,
204
+ bbox=None,
205
+ points=None,
206
+ point_label=None,
207
+ retinamask=True,
208
+ target_height=960,
209
+ target_width=960,
210
+ ):
211
+ msak_sum = annotation.shape[0]
212
+ height = annotation.shape[1]
213
+ weight = annotation.shape[2]
214
+ # 将annotation 按照面积 排序
215
+ areas = np.sum(annotation, axis=(1, 2))
216
+ sorted_indices = np.argsort(areas)
217
+ annotation = annotation[sorted_indices]
218
+
219
+ index = (annotation != 0).argmax(axis=0)
220
+ if random_color == True:
221
+ color = np.random.random((msak_sum, 1, 1, 3))
222
+ else:
223
+ color = np.ones((msak_sum, 1, 1, 3)) * np.array(
224
+ [30 / 255, 144 / 255, 255 / 255]
225
+ )
226
+ transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
227
+ visual = np.concatenate([color, transparency], axis=-1)
228
+ mask_image = np.expand_dims(annotation, -1) * visual
229
+
230
+ show = np.zeros((height, weight, 4))
231
+ h_indices, w_indices = np.meshgrid(
232
+ np.arange(height), np.arange(weight), indexing="ij"
233
+ )
234
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
235
+ # 使用向量化索引更新show的值
236
+ show[h_indices, w_indices, :] = mask_image[indices]
237
+ if bbox is not None:
238
+ x1, y1, x2, y2 = bbox
239
+ ax.add_patch(
240
+ plt.Rectangle(
241
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
242
+ )
243
+ )
244
+ # draw point
245
+ if points is not None:
246
+ plt.scatter(
247
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
248
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
249
+ s=20,
250
+ c="y",
251
+ )
252
+ plt.scatter(
253
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
254
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
255
+ s=20,
256
+ c="m",
257
+ )
258
+
259
+ if retinamask == False:
260
+ show = cv2.resize(
261
+ show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
262
+ )
263
+ ax.imshow(show)
264
+
265
+
266
+ def fast_show_mask_gpu(
267
+ annotation,
268
+ ax,
269
+ random_color=False,
270
+ bbox=None,
271
+ points=None,
272
+ point_label=None,
273
+ retinamask=True,
274
+ target_height=960,
275
+ target_width=960,
276
+ ):
277
+ msak_sum = annotation.shape[0]
278
+ height = annotation.shape[1]
279
+ weight = annotation.shape[2]
280
+ areas = torch.sum(annotation, dim=(1, 2))
281
+ sorted_indices = torch.argsort(areas, descending=False)
282
+ annotation = annotation[sorted_indices]
283
+ # 找每个位置第一个非零值下标
284
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
285
+ if random_color == True:
286
+ color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
287
+ else:
288
+ color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
289
+ [30 / 255, 144 / 255, 255 / 255]
290
+ ).to(annotation.device)
291
+ transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
292
+ visual = torch.cat([color, transparency], dim=-1)
293
+ mask_image = torch.unsqueeze(annotation, -1) * visual
294
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
295
+ show = torch.zeros((height, weight, 4)).to(annotation.device)
296
+ h_indices, w_indices = torch.meshgrid(
297
+ torch.arange(height), torch.arange(weight), indexing="ij"
298
+ )
299
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
300
+ # 使用向量化索引更新show的值
301
+ show[h_indices, w_indices, :] = mask_image[indices]
302
+ show_cpu = show.cpu().numpy()
303
+ if bbox is not None:
304
+ x1, y1, x2, y2 = bbox
305
+ ax.add_patch(
306
+ plt.Rectangle(
307
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
308
+ )
309
+ )
310
+ # draw point
311
+ if points is not None:
312
+ plt.scatter(
313
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
314
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
315
+ s=20,
316
+ c="y",
317
+ )
318
+ plt.scatter(
319
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
320
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
321
+ s=20,
322
+ c="m",
323
+ )
324
+ if retinamask == False:
325
+ show_cpu = cv2.resize(
326
+ show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
327
+ )
328
+ ax.imshow(show_cpu)
329
+
330
+
331
+ def crop_image(annotations, image_like):
332
+ if isinstance(image_like, str):
333
+ image = Image.open(image_like)
334
+ else:
335
+ image = image_like
336
+ ori_w, ori_h = image.size
337
+ mask_h, mask_w = annotations[0]["segmentation"].shape
338
+ if ori_w != mask_w or ori_h != mask_h:
339
+ image = image.resize((mask_w, mask_h))
340
+ cropped_boxes = []
341
+ cropped_images = []
342
+ not_crop = []
343
+ filter_id = []
344
+ # annotations, _ = filter_masks(annotations)
345
+ # filter_id = list(_)
346
+ for _, mask in enumerate(annotations):
347
+ if np.sum(mask["segmentation"]) <= 100:
348
+ filter_id.append(_)
349
+ continue
350
+ bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
351
+ cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
352
+ # cropped_boxes.append(segment_image(image,mask["segmentation"]))
353
+ cropped_images.append(bbox) # 保存裁剪的图片的bbox
354
+
355
+ return cropped_boxes, cropped_images, not_crop, filter_id, annotations
356
+
357
+
358
+ def box_prompt(masks, bbox, target_height, target_width):
359
+ h = masks.shape[1]
360
+ w = masks.shape[2]
361
+ if h != target_height or w != target_width:
362
+ bbox = [
363
+ int(bbox[0] * w / target_width),
364
+ int(bbox[1] * h / target_height),
365
+ int(bbox[2] * w / target_width),
366
+ int(bbox[3] * h / target_height),
367
+ ]
368
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
369
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
370
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
371
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
372
+
373
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
374
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
375
+
376
+ masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
377
+ orig_masks_area = torch.sum(masks, dim=(1, 2))
378
+
379
+ union = bbox_area + orig_masks_area - masks_area
380
+ IoUs = masks_area / union
381
+ max_iou_index = torch.argmax(IoUs)
382
+
383
+ return masks[max_iou_index].cpu().numpy(), max_iou_index
384
+
385
+
386
+ def point_prompt(masks, points, point_label, target_height, target_width): # numpy 处理
387
+ h = masks[0]["segmentation"].shape[0]
388
+ w = masks[0]["segmentation"].shape[1]
389
+ if h != target_height or w != target_width:
390
+ points = [
391
+ [int(point[0] * w / target_width), int(point[1] * h / target_height)]
392
+ for point in points
393
+ ]
394
+ onemask = np.zeros((h, w))
395
+ for i, annotation in enumerate(masks):
396
+ if type(annotation) == dict:
397
+ mask = annotation["segmentation"]
398
+ else:
399
+ mask = annotation
400
+ for i, point in enumerate(points):
401
+ if mask[point[1], point[0]] == 1 and point_label[i] == 1:
402
+ onemask += mask
403
+ if mask[point[1], point[0]] == 1 and point_label[i] == 0:
404
+ onemask -= mask
405
+ onemask = onemask >= 1
406
+ return onemask, 0
utils/tools_gradio.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+
7
+
8
+ def fast_process(
9
+ annotations,
10
+ image,
11
+ device,
12
+ scale,
13
+ better_quality=False,
14
+ mask_random_color=True,
15
+ bbox=None,
16
+ use_retina=True,
17
+ withContours=True,
18
+ ):
19
+ if isinstance(annotations[0], dict):
20
+ annotations = [annotation["segmentation"] for annotation in annotations]
21
+
22
+ original_h = image.height
23
+ original_w = image.width
24
+ if better_quality:
25
+ if isinstance(annotations[0], torch.Tensor):
26
+ annotations = np.array(annotations.cpu())
27
+ for i, mask in enumerate(annotations):
28
+ mask = cv2.morphologyEx(
29
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
30
+ )
31
+ annotations[i] = cv2.morphologyEx(
32
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
33
+ )
34
+ if device == "cpu":
35
+ annotations = np.array(annotations)
36
+ inner_mask = fast_show_mask(
37
+ annotations,
38
+ plt.gca(),
39
+ random_color=mask_random_color,
40
+ bbox=bbox,
41
+ retinamask=use_retina,
42
+ target_height=original_h,
43
+ target_width=original_w,
44
+ )
45
+ else:
46
+ if isinstance(annotations[0], np.ndarray):
47
+ annotations = np.array(annotations)
48
+ annotations = torch.from_numpy(annotations)
49
+ inner_mask = fast_show_mask_gpu(
50
+ annotations,
51
+ plt.gca(),
52
+ random_color=mask_random_color,
53
+ bbox=bbox,
54
+ retinamask=use_retina,
55
+ target_height=original_h,
56
+ target_width=original_w,
57
+ )
58
+ if isinstance(annotations, torch.Tensor):
59
+ annotations = annotations.cpu().numpy()
60
+
61
+ if withContours:
62
+ contour_all = []
63
+ temp = np.zeros((original_h, original_w, 1))
64
+ for i, mask in enumerate(annotations):
65
+ if type(mask) == dict:
66
+ mask = mask["segmentation"]
67
+ annotation = mask.astype(np.uint8)
68
+ if use_retina == False:
69
+ annotation = cv2.resize(
70
+ annotation,
71
+ (original_w, original_h),
72
+ interpolation=cv2.INTER_NEAREST,
73
+ )
74
+ contours, _ = cv2.findContours(
75
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
76
+ )
77
+ for contour in contours:
78
+ contour_all.append(contour)
79
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
80
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
81
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
82
+
83
+ image = image.convert("RGBA")
84
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), "RGBA")
85
+ image.paste(overlay_inner, (0, 0), overlay_inner)
86
+
87
+ if withContours:
88
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), "RGBA")
89
+ image.paste(overlay_contour, (0, 0), overlay_contour)
90
+
91
+ return image
92
+
93
+
94
+ # CPU post process
95
+ def fast_show_mask(
96
+ annotation,
97
+ ax,
98
+ random_color=False,
99
+ bbox=None,
100
+ retinamask=True,
101
+ target_height=960,
102
+ target_width=960,
103
+ ):
104
+ mask_sum = annotation.shape[0]
105
+ height = annotation.shape[1]
106
+ weight = annotation.shape[2]
107
+ # 将annotation 按照面积 排序
108
+ areas = np.sum(annotation, axis=(1, 2))
109
+ sorted_indices = np.argsort(areas)[::1]
110
+ annotation = annotation[sorted_indices]
111
+
112
+ index = (annotation != 0).argmax(axis=0)
113
+ if random_color == True:
114
+ color = np.random.random((mask_sum, 1, 1, 3))
115
+ else:
116
+ color = np.ones((mask_sum, 1, 1, 3)) * np.array(
117
+ [30 / 255, 144 / 255, 255 / 255]
118
+ )
119
+ transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
120
+ visual = np.concatenate([color, transparency], axis=-1)
121
+ mask_image = np.expand_dims(annotation, -1) * visual
122
+
123
+ mask = np.zeros((height, weight, 4))
124
+
125
+ h_indices, w_indices = np.meshgrid(
126
+ np.arange(height), np.arange(weight), indexing="ij"
127
+ )
128
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
129
+
130
+ mask[h_indices, w_indices, :] = mask_image[indices]
131
+ if bbox is not None:
132
+ x1, y1, x2, y2 = bbox
133
+ ax.add_patch(
134
+ plt.Rectangle(
135
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
136
+ )
137
+ )
138
+
139
+ if retinamask == False:
140
+ mask = cv2.resize(
141
+ mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST
142
+ )
143
+
144
+ return mask
145
+
146
+
147
+ def fast_show_mask_gpu(
148
+ annotation,
149
+ ax,
150
+ random_color=False,
151
+ bbox=None,
152
+ retinamask=True,
153
+ target_height=960,
154
+ target_width=960,
155
+ ):
156
+ device = annotation.device
157
+ mask_sum = annotation.shape[0]
158
+ height = annotation.shape[1]
159
+ weight = annotation.shape[2]
160
+ areas = torch.sum(annotation, dim=(1, 2))
161
+ sorted_indices = torch.argsort(areas, descending=False)
162
+ annotation = annotation[sorted_indices]
163
+ # 找每个位置第一个非零值下标
164
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
165
+ if random_color == True:
166
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
167
+ else:
168
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
169
+ [30 / 255, 144 / 255, 255 / 255]
170
+ ).to(device)
171
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
172
+ visual = torch.cat([color, transparency], dim=-1)
173
+ mask_image = torch.unsqueeze(annotation, -1) * visual
174
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
175
+ mask = torch.zeros((height, weight, 4)).to(device)
176
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
177
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
178
+ # 使用向量化索引更新show的值
179
+ mask[h_indices, w_indices, :] = mask_image[indices]
180
+ mask_cpu = mask.cpu().numpy()
181
+ if bbox is not None:
182
+ x1, y1, x2, y2 = bbox
183
+ ax.add_patch(
184
+ plt.Rectangle(
185
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
186
+ )
187
+ )
188
+ if retinamask == False:
189
+ mask_cpu = cv2.resize(
190
+ mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
191
+ )
192
+ return mask_cpu