padmanabhbosamia commited on
Commit
1bf7e5e
·
1 Parent(s): e7edd3d

app_gradio.py

Browse files
Files changed (1) hide show
  1. app_gradio.py +370 -0
app_gradio.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import gradio as gr
3
+ import torch
4
+ from utils.tools_gradio import fast_process
5
+ from utils.tools import format_results, box_prompt, point_prompt, text_prompt
6
+ from PIL import ImageDraw
7
+ import numpy as np
8
+
9
+ # Load the pre-trained model
10
+ model = YOLO('./FastSAM.pt')
11
+
12
+ device = torch.device(
13
+ "cuda"
14
+ if torch.cuda.is_available()
15
+ else "mps"
16
+ if torch.backends.mps.is_available()
17
+ else "cpu"
18
+ )
19
+
20
+ # Description
21
+ title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
22
+
23
+ news = """ # 📖 News
24
+ 🔥 2023/07/14: Add a "wider result" button in text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/95)).
25
+ 🔥 2023/06/29: Support the text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/47)).
26
+ 🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
27
+ 🔥 2023/06/24: Add the 'Advanced options" in Everything mode to get a more detailed adjustment.
28
+ """
29
+
30
+ description_e = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it.
31
+
32
+ 🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
33
+
34
+ ⌛️ It takes about 6~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
35
+
36
+ 🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
37
+
38
+ 📣 You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)
39
+
40
+ 😚 A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant.
41
+
42
+ 🏠 Check out our [Model Card 🏃](https://huggingface.co/An-619/FastSAM)
43
+
44
+ """
45
+
46
+ description_p = """ # 🎯 Instructions for points mode
47
+ This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it.
48
+
49
+ 1. Upload an image or choose an example.
50
+
51
+ 2. Choose the point label do('Add mask' means a positive point. 'Remove' Area means a negative point that is not segmented).
52
+
53
+ 3. Add points one by one on the image.
54
+
55
+ 4. Click the 'Segment with points prompt' button to get the segmentation results.
56
+
57
+ **5. If you get Error, click the 'Clear points' button and try again may help.**
58
+
59
+ """
60
+
61
+ examples = [["examples/dogs.jpg"], ["examples/fruits.jpg"], ["examples/flowers.jpg"],]
62
+
63
+ default_example = examples[0]
64
+
65
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
66
+
67
+
68
+ def segment_everything(
69
+ input,
70
+ input_size=1024,
71
+ iou_threshold=0.7,
72
+ conf_threshold=0.25,
73
+ better_quality=False,
74
+ withContours=True,
75
+ use_retina=True,
76
+ text="",
77
+ wider=False,
78
+ mask_random_color=True,
79
+ ):
80
+ input_size = int(input_size) # 确保 imgsz 是整数
81
+ # Thanks for the suggestion by hysts in HuggingFace.
82
+ w, h = input.size
83
+ scale = input_size / max(w, h)
84
+ new_w = int(w * scale)
85
+ new_h = int(h * scale)
86
+ input = input.resize((new_w, new_h))
87
+
88
+ results = model(input,
89
+ device=device,
90
+ retina_masks=True,
91
+ iou=iou_threshold,
92
+ conf=conf_threshold,
93
+ imgsz=input_size,)
94
+
95
+ if len(text) > 0:
96
+ results = format_results(results[0], 0)
97
+ annotations, _ = text_prompt(results, text, input, device=device, wider=wider)
98
+ annotations = np.array([annotations])
99
+ else:
100
+ annotations = results[0].masks.data
101
+
102
+ fig = fast_process(annotations=annotations,
103
+ image=input,
104
+ device=device,
105
+ scale=(1024 // input_size),
106
+ better_quality=better_quality,
107
+ mask_random_color=mask_random_color,
108
+ bbox=None,
109
+ use_retina=use_retina,
110
+ withContours=withContours,)
111
+ return fig
112
+
113
+
114
+ def segment_with_points(
115
+ input,
116
+ input_size=1024,
117
+ iou_threshold=0.7,
118
+ conf_threshold=0.25,
119
+ better_quality=False,
120
+ withContours=True,
121
+ use_retina=True,
122
+ mask_random_color=True,
123
+ ):
124
+ global global_points
125
+ global global_point_label
126
+
127
+ input_size = int(input_size) # 确保 imgsz 是整数
128
+ # Thanks for the suggestion by hysts in HuggingFace.
129
+ w, h = input.size
130
+ scale = input_size / max(w, h)
131
+ new_w = int(w * scale)
132
+ new_h = int(h * scale)
133
+ input = input.resize((new_w, new_h))
134
+
135
+ scaled_points = [[int(x * scale) for x in point] for point in global_points]
136
+
137
+ results = model(input,
138
+ device=device,
139
+ retina_masks=True,
140
+ iou=iou_threshold,
141
+ conf=conf_threshold,
142
+ imgsz=input_size,)
143
+
144
+ results = format_results(results[0], 0)
145
+ annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
146
+ annotations = np.array([annotations])
147
+
148
+ fig = fast_process(annotations=annotations,
149
+ image=input,
150
+ device=device,
151
+ scale=(1024 // input_size),
152
+ better_quality=better_quality,
153
+ mask_random_color=mask_random_color,
154
+ bbox=None,
155
+ use_retina=use_retina,
156
+ withContours=withContours,)
157
+
158
+ global_points = []
159
+ global_point_label = []
160
+ return fig, None
161
+
162
+
163
+ def get_points_with_draw(image, label, evt: gr.SelectData):
164
+ global global_points
165
+ global global_point_label
166
+
167
+ x, y = evt.index[0], evt.index[1]
168
+ point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
169
+ global_points.append([x, y])
170
+ global_point_label.append(1 if label == 'Add Mask' else 0)
171
+
172
+ print(x, y, label == 'Add Mask')
173
+
174
+ # 创建一个可以在图像上绘图的对象
175
+ draw = ImageDraw.Draw(image)
176
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
177
+ return image
178
+
179
+
180
+ cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
181
+ cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil')
182
+ cond_img_t = gr.Image(label="Input with text", value="examples/dogs.jpg", type='pil')
183
+
184
+ segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
185
+ segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
186
+ segm_img_t = gr.Image(label="Segmented Image with text", interactive=False, type='pil')
187
+
188
+ global_points = []
189
+ global_point_label = []
190
+
191
+ input_size_slider = gr.components.Slider(minimum=512,
192
+ maximum=1024,
193
+ value=1024,
194
+ step=64,
195
+ label='Input_size',
196
+ info='Our model was trained on a size of 1024')
197
+
198
+ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
199
+ with gr.Row():
200
+ with gr.Column(scale=1):
201
+ # Title
202
+ gr.Markdown(title)
203
+
204
+ with gr.Column(scale=1):
205
+ # News
206
+ gr.Markdown(news)
207
+
208
+ with gr.Tab("Everything mode"):
209
+ # Images
210
+ with gr.Row(variant="panel"):
211
+ with gr.Column(scale=1):
212
+ cond_img_e.render()
213
+
214
+ with gr.Column(scale=1):
215
+ segm_img_e.render()
216
+
217
+ # Submit & Clear
218
+ with gr.Row():
219
+ with gr.Column():
220
+ input_size_slider.render()
221
+
222
+ with gr.Row():
223
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
224
+
225
+ with gr.Column():
226
+ segment_btn_e = gr.Button("Segment Everything", variant='primary')
227
+ clear_btn_e = gr.Button("Clear", variant="secondary")
228
+
229
+ gr.Markdown("Try some of the examples below ⬇️")
230
+ gr.Examples(examples=examples,
231
+ inputs=[cond_img_e],
232
+ outputs=segm_img_e,
233
+ fn=segment_everything,
234
+ cache_examples=True,
235
+ examples_per_page=4)
236
+
237
+ with gr.Column():
238
+ with gr.Accordion("Advanced options", open=False):
239
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
240
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
241
+ with gr.Row():
242
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
243
+ with gr.Column():
244
+ retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
245
+
246
+ # Description
247
+ gr.Markdown(description_e)
248
+
249
+ segment_btn_e.click(segment_everything,
250
+ inputs=[
251
+ cond_img_e,
252
+ input_size_slider,
253
+ iou_threshold,
254
+ conf_threshold,
255
+ mor_check,
256
+ contour_check,
257
+ retina_check,
258
+ ],
259
+ outputs=segm_img_e)
260
+
261
+ with gr.Tab("Points mode"):
262
+ # Images
263
+ with gr.Row(variant="panel"):
264
+ with gr.Column(scale=1):
265
+ cond_img_p.render()
266
+
267
+ with gr.Column(scale=1):
268
+ segm_img_p.render()
269
+
270
+ # Submit & Clear
271
+ with gr.Row():
272
+ with gr.Column():
273
+ with gr.Row():
274
+ add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
275
+
276
+ with gr.Column():
277
+ segment_btn_p = gr.Button("Segment with points prompt", variant='primary')
278
+ clear_btn_p = gr.Button("Clear points", variant='secondary')
279
+
280
+ gr.Markdown("Try some of the examples below ⬇️")
281
+ gr.Examples(examples=examples,
282
+ inputs=[cond_img_p],
283
+ # outputs=segm_img_p,
284
+ # fn=segment_with_points,
285
+ # cache_examples=True,
286
+ examples_per_page=4)
287
+
288
+ with gr.Column():
289
+ # Description
290
+ gr.Markdown(description_p)
291
+
292
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
293
+
294
+ segment_btn_p.click(segment_with_points,
295
+ inputs=[cond_img_p],
296
+ outputs=[segm_img_p, cond_img_p])
297
+
298
+ with gr.Tab("Text mode"):
299
+ # Images
300
+ with gr.Row(variant="panel"):
301
+ with gr.Column(scale=1):
302
+ cond_img_t.render()
303
+
304
+ with gr.Column(scale=1):
305
+ segm_img_t.render()
306
+
307
+ # Submit & Clear
308
+ with gr.Row():
309
+ with gr.Column():
310
+ input_size_slider_t = gr.components.Slider(minimum=512,
311
+ maximum=1024,
312
+ value=1024,
313
+ step=64,
314
+ label='Input_size',
315
+ info='Our model was trained on a size of 1024')
316
+ with gr.Row():
317
+ with gr.Column():
318
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
319
+ text_box = gr.Textbox(label="text prompt", value="a black dog")
320
+
321
+ with gr.Column():
322
+ segment_btn_t = gr.Button("Segment with text", variant='primary')
323
+ clear_btn_t = gr.Button("Clear", variant="secondary")
324
+
325
+ gr.Markdown("Try some of the examples below ⬇️")
326
+ gr.Examples(examples=[["examples/dogs.jpg"], ["examples/fruits.jpg"], ["examples/flowers.jpg"]],
327
+ inputs=[cond_img_t],
328
+ # outputs=segm_img_e,
329
+ # fn=segment_everything,
330
+ # cache_examples=True,
331
+ examples_per_page=4)
332
+
333
+ with gr.Column():
334
+ with gr.Accordion("Advanced options", open=False):
335
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
336
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
337
+ with gr.Row():
338
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
339
+ retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
340
+ wider_check = gr.Checkbox(value=False, label='wider', info='wider result')
341
+
342
+ # Description
343
+ gr.Markdown(description_e)
344
+
345
+ segment_btn_t.click(segment_everything,
346
+ inputs=[
347
+ cond_img_t,
348
+ input_size_slider_t,
349
+ iou_threshold,
350
+ conf_threshold,
351
+ mor_check,
352
+ contour_check,
353
+ retina_check,
354
+ text_box,
355
+ wider_check,
356
+ ],
357
+ outputs=segm_img_t)
358
+
359
+ def clear():
360
+ return None, None
361
+
362
+ def clear_text():
363
+ return None, None, None
364
+
365
+ clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
366
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
367
+ clear_btn_t.click(clear_text, outputs=[cond_img_p, segm_img_p, text_box])
368
+
369
+ demo.queue()
370
+ demo.launch()