Shivdutta commited on
Commit
3eb9d86
·
verified ·
1 Parent(s): 1844d06

Upload 3 files

Browse files
Files changed (3) hide show
  1. FastSAM.pt +3 -0
  2. app.py +373 -0
  3. requirements.txt +6 -0
FastSAM.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0be4e7ddbe4c15333d15a859c676d053c486d0a746a3be6a7a9790d52a9b6d7
3
+ size 144943063
app.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard Library Imports
2
+ import sys
3
+
4
+ # Third-Party Imports
5
+ import torch
6
+ import numpy as np
7
+ import gradio as gr
8
+ from PIL import ImageDraw
9
+ from ultralytics import YOLO
10
+ from utils.tools_gradio import fast_process
11
+ from utils.tools import format_results, box_prompt, point_prompt, text_prompt
12
+
13
+
14
+ def segment_everything(
15
+ input,
16
+ input_size=1024,
17
+ iou_threshold=0.7,
18
+ conf_threshold=0.25,
19
+ better_quality=False,
20
+ withContours=True,
21
+ use_retina=True,
22
+ text="",
23
+ wider=False,
24
+ mask_random_color=True,
25
+ ):
26
+ input_size = int(input_size)
27
+ w, h = input.size
28
+ scale = input_size / max(w, h)
29
+ new_w = int(w * scale)
30
+ new_h = int(h * scale)
31
+ input = input.resize((new_w, new_h))
32
+
33
+ results = model(input,
34
+ device=device,
35
+ retina_masks=True,
36
+ iou=iou_threshold,
37
+ conf=conf_threshold,
38
+ imgsz=input_size, )
39
+
40
+ if len(text) > 0:
41
+ results = format_results(results[0], 0)
42
+ annotations, _ = text_prompt(results, text, input, device=device, wider=wider)
43
+ annotations = np.array([annotations])
44
+ else:
45
+ annotations = results[0].masks.data
46
+
47
+ fig = fast_process(annotations=annotations,
48
+ image=input,
49
+ device=device,
50
+ scale=(1024 // input_size),
51
+ better_quality=better_quality,
52
+ mask_random_color=mask_random_color,
53
+ bbox=None,
54
+ use_retina=use_retina,
55
+ withContours=withContours, )
56
+ return fig
57
+
58
+
59
+ def segment_with_points(
60
+ input,
61
+ input_size=1024,
62
+ iou_threshold=0.7,
63
+ conf_threshold=0.25,
64
+ better_quality=False,
65
+ withContours=True,
66
+ use_retina=True,
67
+ mask_random_color=True,
68
+ ):
69
+ global global_points
70
+ global global_point_label
71
+
72
+ input_size = int(input_size)
73
+ w, h = input.size
74
+ scale = input_size / max(w, h)
75
+ new_w = int(w * scale)
76
+ new_h = int(h * scale)
77
+ input = input.resize((new_w, new_h))
78
+
79
+ scaled_points = [[int(x * scale) for x in point] for point in global_points]
80
+
81
+ results = model(input,
82
+ device=device,
83
+ retina_masks=True,
84
+ iou=iou_threshold,
85
+ conf=conf_threshold,
86
+ imgsz=input_size, )
87
+
88
+ results = format_results(results[0], 0)
89
+ annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
90
+ annotations = np.array([annotations])
91
+
92
+ fig = fast_process(annotations=annotations,
93
+ image=input,
94
+ device=device,
95
+ scale=(1024 // input_size),
96
+ better_quality=better_quality,
97
+ mask_random_color=mask_random_color,
98
+ bbox=None,
99
+ use_retina=use_retina,
100
+ withContours=withContours, )
101
+
102
+ global_points = []
103
+ global_point_label = []
104
+ return fig, None
105
+
106
+
107
+ def get_points_with_draw(image, label, evt: gr.SelectData):
108
+ global global_points
109
+ global global_point_label
110
+
111
+ x, y = evt.index[0], evt.index[1]
112
+ point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
113
+ global_points.append([x, y])
114
+ global_point_label.append(1 if label == 'Add Mask' else 0)
115
+
116
+ print(x, y, label == 'Add Mask')
117
+
118
+ draw = ImageDraw.Draw(image)
119
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
120
+ return image
121
+
122
+
123
+ # Load the pre-trained model
124
+ model = YOLO('./FastSAM.pt')
125
+ example_dir = './'
126
+
127
+ # Select the device
128
+ device = torch.device(
129
+ "cuda"
130
+ if torch.cuda.is_available()
131
+ else "mps"
132
+ if torch.backends.mps.is_available()
133
+ else "cpu"
134
+ )
135
+
136
+ # Title of the App
137
+ title = "# Fast Segment Anything"
138
+
139
+ # Description for segmentation
140
+ description_e = """
141
+ This is a demo on Github project [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM)
142
+ """
143
+
144
+ # Description for points
145
+ description_p = """ # Instructions for points mode
146
+ This is a demo on Github project [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it.
147
+
148
+ 1. Upload an image or choose an example.
149
+
150
+ 2. Choose the point label ('Add mask' means a positive point. 'Remove' Area means a negative point that is not segmented).
151
+
152
+ 3. Add points one by one on the image.
153
+
154
+ 4. Click the 'Segment with points prompt' button to get the segmentation results.
155
+
156
+ **5. If you get Error, click the 'Clear points' button and try again may help.**
157
+
158
+ """
159
+
160
+ # Examples
161
+ examples = [[example_dir + "examples/sa_8776.jpg"], [example_dir + "examples/sa_414.jpg"],
162
+ [example_dir + "examples/sa_1309.jpg"], [example_dir + "examples/sa_11025.jpg"],
163
+ [example_dir + "examples/sa_561.jpg"], [example_dir + "examples/sa_192.jpg"],
164
+ [example_dir + "examples/sa_10039.jpg"], [example_dir + "examples/sa_862.jpg"]]
165
+
166
+ default_example = examples[0]
167
+
168
+ # CSS file for app display
169
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
170
+
171
+ # Input images for segmentation, points and text based segmentation
172
+ cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
173
+ cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil')
174
+ cond_img_t = gr.Image(label="Input with text", value=example_dir + "examples/dogs.jpg", type='pil')
175
+
176
+ # Output for each tab
177
+ segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
178
+ segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
179
+ segm_img_t = gr.Image(label="Segmented Image with text", interactive=False, type='pil')
180
+
181
+ # List to accumulate points and its points
182
+ global_points = []
183
+ global_point_label = []
184
+
185
+ input_size_slider = gr.components.Slider(minimum=512,
186
+ maximum=1024,
187
+ value=1024,
188
+ step=64,
189
+ label='Input_size',
190
+ info='The model was trained on a size of 1024')
191
+
192
+ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
193
+ with gr.Row():
194
+ with gr.Column(scale=1):
195
+ gr.Markdown(title) # Title
196
+
197
+ with gr.Tab("Everything mode"):
198
+ # Images
199
+ with gr.Row(variant="panel"):
200
+ with gr.Column(scale=1):
201
+ cond_img_e.render()
202
+
203
+ with gr.Column(scale=1):
204
+ segm_img_e.render()
205
+
206
+ # Submit & Clear
207
+ with gr.Row():
208
+ with gr.Column():
209
+ input_size_slider.render()
210
+
211
+ with gr.Row():
212
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
213
+
214
+ with gr.Column():
215
+ segment_btn_e = gr.Button("Segment Everything", variant='primary')
216
+ clear_btn_e = gr.Button("Clear", variant="secondary")
217
+
218
+ gr.Markdown("Try some of the examples below ⬇️")
219
+ gr.Examples(examples=examples,
220
+ inputs=[cond_img_e],
221
+ outputs=segm_img_e,
222
+ fn=segment_everything,
223
+ cache_examples=True,
224
+ examples_per_page=4)
225
+
226
+ with gr.Column():
227
+ with gr.Accordion("Advanced options", open=False):
228
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou',
229
+ info='iou threshold for filtering the annotations')
230
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf',
231
+ info='object confidence threshold')
232
+ with gr.Row():
233
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality',
234
+ info='better quality using morphologyEx')
235
+ with gr.Column():
236
+ retina_check = gr.Checkbox(value=True, label='use_retina',
237
+ info='draw high-resolution segmentation masks')
238
+
239
+ # Description
240
+ gr.Markdown(description_e)
241
+
242
+ segment_btn_e.click(segment_everything,
243
+ inputs=[
244
+ cond_img_e,
245
+ input_size_slider,
246
+ iou_threshold,
247
+ conf_threshold,
248
+ mor_check,
249
+ contour_check,
250
+ retina_check,
251
+ ],
252
+ outputs=segm_img_e)
253
+
254
+ with gr.Tab("Points mode"):
255
+ # Images
256
+ with gr.Row(variant="panel"):
257
+ with gr.Column(scale=1):
258
+ cond_img_p.render()
259
+
260
+ with gr.Column(scale=1):
261
+ segm_img_p.render()
262
+
263
+ # Submit & Clear
264
+ with gr.Row():
265
+ with gr.Column():
266
+ with gr.Row():
267
+ add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask",
268
+ label="Point_label (foreground/background)")
269
+
270
+ with gr.Column():
271
+ segment_btn_p = gr.Button("Segment with points prompt", variant='primary')
272
+ clear_btn_p = gr.Button("Clear points", variant='secondary')
273
+
274
+ gr.Markdown("Try some of the examples below ⬇️")
275
+ gr.Examples(examples=examples,
276
+ inputs=[cond_img_p],
277
+ # outputs=segm_img_p,
278
+ # fn=segment_with_points,
279
+ # cache_examples=True,
280
+ examples_per_page=4)
281
+
282
+ with gr.Column():
283
+ # Description
284
+ gr.Markdown(description_p)
285
+
286
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
287
+
288
+ segment_btn_p.click(segment_with_points,
289
+ inputs=[cond_img_p],
290
+ outputs=[segm_img_p, cond_img_p])
291
+
292
+ with gr.Tab("Text mode"):
293
+ # Images
294
+ with gr.Row(variant="panel"):
295
+ with gr.Column(scale=1):
296
+ cond_img_t.render()
297
+
298
+ with gr.Column(scale=1):
299
+ segm_img_t.render()
300
+
301
+ # Submit & Clear
302
+ with gr.Row():
303
+ with gr.Column():
304
+ input_size_slider_t = gr.components.Slider(minimum=512,
305
+ maximum=1024,
306
+ value=1024,
307
+ step=64,
308
+ label='Input_size',
309
+ info='Our model was trained on a size of 1024')
310
+ with gr.Row():
311
+ with gr.Column():
312
+ contour_check = gr.Checkbox(value=True, label='withContours',
313
+ info='draw the edges of the masks')
314
+ text_box = gr.Textbox(label="text prompt", value="a black dog")
315
+
316
+ with gr.Column():
317
+ segment_btn_t = gr.Button("Segment with text", variant='primary')
318
+ clear_btn_t = gr.Button("Clear", variant="secondary")
319
+
320
+ gr.Markdown("Try some of the examples below ⬇️")
321
+ gr.Examples(examples=[[example_dir + "examples/dogs.jpg"], [example_dir + "examples/fruits.jpg"],
322
+ [example_dir + "examples/flowers.jpg"]],
323
+ inputs=[cond_img_t],
324
+ # outputs=segm_img_e,
325
+ # fn=segment_everything,
326
+ # cache_examples=True,
327
+ examples_per_page=4)
328
+
329
+ with gr.Column():
330
+ with gr.Accordion("Advanced options", open=False):
331
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou',
332
+ info='iou threshold for filtering the annotations')
333
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf',
334
+ info='object confidence threshold')
335
+ with gr.Row():
336
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality',
337
+ info='better quality using morphologyEx')
338
+ retina_check = gr.Checkbox(value=True, label='use_retina',
339
+ info='draw high-resolution segmentation masks')
340
+ wider_check = gr.Checkbox(value=False, label='wider', info='wider result')
341
+
342
+ # Description
343
+ gr.Markdown(description_e)
344
+
345
+ segment_btn_t.click(segment_everything,
346
+ inputs=[
347
+ cond_img_t,
348
+ input_size_slider_t,
349
+ iou_threshold,
350
+ conf_threshold,
351
+ mor_check,
352
+ contour_check,
353
+ retina_check,
354
+ text_box,
355
+ wider_check,
356
+ ],
357
+ outputs=segm_img_t)
358
+
359
+
360
+ def clear():
361
+ return None, None
362
+
363
+
364
+ def clear_text():
365
+ return None, None, None
366
+
367
+
368
+ clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
369
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
370
+ clear_btn_t.click(clear_text, outputs=[cond_img_p, segm_img_p, text_box])
371
+
372
+ demo.queue()
373
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ matplotlib==3.2.2
2
+ numpy
3
+ opencv-python
4
+ ultralytics==8.0.121
5
+
6
+ git+https://github.com/openai/CLIP.git