Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		AAAAAAAyq
		
	commited on
		
		
					Commit 
							
							·
						
						e4f0b1d
	
1
								Parent(s):
							
							5778432
								
Add a wider result
Browse files- app_gradio.py +50 -46
- utils/tools.py +29 -16
- utils/tools_gradio.py +4 -4
    	
        app_gradio.py
    CHANGED
    
    | @@ -21,6 +21,8 @@ device = torch.device( | |
| 21 | 
             
            title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
         | 
| 22 |  | 
| 23 | 
             
            news = """ # 📖 News
         | 
|  | |
|  | |
| 24 | 
             
                    🔥 2023/06/29: Support the text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/47)).
         | 
| 25 |  | 
| 26 | 
             
                    🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
         | 
| @@ -76,6 +78,7 @@ def segment_everything( | |
| 76 | 
             
                withContours=True,
         | 
| 77 | 
             
                use_retina=True,
         | 
| 78 | 
             
                text="",
         | 
|  | |
| 79 | 
             
                mask_random_color=True,
         | 
| 80 | 
             
            ):
         | 
| 81 | 
             
                input_size = int(input_size)  # 确保 imgsz 是整数
         | 
| @@ -95,7 +98,7 @@ def segment_everything( | |
| 95 |  | 
| 96 | 
             
                if len(text) > 0:
         | 
| 97 | 
             
                    results = format_results(results[0], 0)
         | 
| 98 | 
            -
                    annotations, _ = text_prompt(results, text, input, device=device)
         | 
| 99 | 
             
                    annotations = np.array([annotations])
         | 
| 100 | 
             
                else:
         | 
| 101 | 
             
                    annotations = results[0].masks.data
         | 
| @@ -189,7 +192,7 @@ segm_img_t = gr.Image(label="Segmented Image with text", interactive=False, type | |
| 189 | 
             
            global_points = []
         | 
| 190 | 
             
            global_point_label = []
         | 
| 191 |  | 
| 192 | 
            -
             | 
| 193 | 
             
                                                     maximum=1024,
         | 
| 194 | 
             
                                                     value=1024,
         | 
| 195 | 
             
                                                     step=64,
         | 
| @@ -218,10 +221,10 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo: | |
| 218 | 
             
                    # Submit & Clear
         | 
| 219 | 
             
                    with gr.Row():
         | 
| 220 | 
             
                        with gr.Column():
         | 
| 221 | 
            -
                             | 
| 222 |  | 
| 223 | 
             
                            with gr.Row():
         | 
| 224 | 
            -
                                 | 
| 225 |  | 
| 226 | 
             
                                with gr.Column():
         | 
| 227 | 
             
                                    segment_btn_e = gr.Button("Segment Everything", variant='primary')
         | 
| @@ -237,16 +240,28 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo: | |
| 237 |  | 
| 238 | 
             
                        with gr.Column():
         | 
| 239 | 
             
                            with gr.Accordion("Advanced options", open=False):
         | 
| 240 | 
            -
                                 | 
| 241 | 
            -
                                 | 
| 242 | 
            -
                                conf_threshold_e = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
         | 
| 243 | 
             
                                with gr.Row():
         | 
| 244 | 
            -
                                     | 
| 245 | 
             
                                    with gr.Column():
         | 
| 246 | 
            -
                                         | 
|  | |
| 247 | 
             
                            # Description
         | 
| 248 | 
             
                            gr.Markdown(description_e)
         | 
| 249 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 250 | 
             
                with gr.Tab("Points mode"):
         | 
| 251 | 
             
                    # Images
         | 
| 252 | 
             
                    with gr.Row(variant="panel"):
         | 
| @@ -277,7 +292,13 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo: | |
| 277 | 
             
                        with gr.Column():
         | 
| 278 | 
             
                            # Description
         | 
| 279 | 
             
                            gr.Markdown(description_p)
         | 
| 280 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 281 | 
             
                with gr.Tab("Text mode"):
         | 
| 282 | 
             
                    # Images
         | 
| 283 | 
             
                    with gr.Row(variant="panel"):
         | 
| @@ -291,14 +312,14 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo: | |
| 291 | 
             
                    with gr.Row():
         | 
| 292 | 
             
                        with gr.Column():
         | 
| 293 | 
             
                            input_size_slider_t = gr.components.Slider(minimum=512,
         | 
| 294 | 
            -
             | 
| 295 | 
            -
             | 
| 296 | 
            -
             | 
| 297 | 
            -
             | 
| 298 | 
            -
             | 
| 299 | 
             
                            with gr.Row():
         | 
| 300 | 
             
                                with gr.Column():
         | 
| 301 | 
            -
                                     | 
| 302 | 
             
                                    text_box = gr.Textbox(label="text prompt", value="a black dog")
         | 
| 303 |  | 
| 304 | 
             
                                with gr.Column():
         | 
| @@ -306,7 +327,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo: | |
| 306 | 
             
                                    clear_btn_t = gr.Button("Clear", variant="secondary")
         | 
| 307 |  | 
| 308 | 
             
                            gr.Markdown("Try some of the examples below ⬇️")
         | 
| 309 | 
            -
                            gr.Examples(examples=["examples/dogs.jpg"],
         | 
| 310 | 
             
                                        inputs=[cond_img_e],
         | 
| 311 | 
             
                                        # outputs=segm_img_e,
         | 
| 312 | 
             
                                        # fn=segment_everything,
         | 
| @@ -315,44 +336,27 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo: | |
| 315 |  | 
| 316 | 
             
                        with gr.Column():
         | 
| 317 | 
             
                            with gr.Accordion("Advanced options", open=False):
         | 
| 318 | 
            -
                                 | 
| 319 | 
            -
                                 | 
| 320 | 
             
                                with gr.Row():
         | 
| 321 | 
            -
                                     | 
| 322 | 
            -
                                     | 
| 323 | 
            -
             | 
| 324 |  | 
| 325 | 
             
                            # Description
         | 
| 326 | 
             
                            gr.Markdown(description_e)
         | 
| 327 | 
            -
                    
         | 
| 328 | 
            -
                cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
         | 
| 329 | 
            -
             | 
| 330 | 
            -
                segment_btn_e.click(segment_everything,
         | 
| 331 | 
            -
                                    inputs=[
         | 
| 332 | 
            -
                                        cond_img_e,
         | 
| 333 | 
            -
                                        input_size_slider_e,
         | 
| 334 | 
            -
                                        iou_threshold_e,
         | 
| 335 | 
            -
                                        conf_threshold_e,
         | 
| 336 | 
            -
                                        mor_check_e,
         | 
| 337 | 
            -
                                        contour_check_e,
         | 
| 338 | 
            -
                                        retina_check_e,
         | 
| 339 | 
            -
                                    ],
         | 
| 340 | 
            -
                                    outputs=segm_img_e)
         | 
| 341 | 
            -
             | 
| 342 | 
            -
                segment_btn_p.click(segment_with_points,
         | 
| 343 | 
            -
                                    inputs=[cond_img_p],
         | 
| 344 | 
            -
                                    outputs=[segm_img_p, cond_img_p])
         | 
| 345 |  | 
| 346 | 
             
                segment_btn_t.click(segment_everything,
         | 
| 347 | 
             
                                    inputs=[
         | 
| 348 | 
             
                                        cond_img_t,
         | 
| 349 | 
             
                                        input_size_slider_t,
         | 
| 350 | 
            -
                                         | 
| 351 | 
            -
                                         | 
| 352 | 
            -
                                         | 
| 353 | 
            -
                                         | 
| 354 | 
            -
                                         | 
| 355 | 
             
                                        text_box,
         | 
|  | |
| 356 | 
             
                                    ],
         | 
| 357 | 
             
                                    outputs=segm_img_t)
         | 
| 358 |  | 
| @@ -361,7 +365,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo: | |
| 361 |  | 
| 362 | 
             
                def clear_text():
         | 
| 363 | 
             
                    return None, None, None
         | 
| 364 | 
            -
             | 
| 365 | 
             
                clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
         | 
| 366 | 
             
                clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
         | 
| 367 | 
             
                clear_btn_t.click(clear_text, outputs=[cond_img_p, segm_img_p, text_box])
         | 
|  | |
| 21 | 
             
            title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
         | 
| 22 |  | 
| 23 | 
             
            news = """ # 📖 News
         | 
| 24 | 
            +
                    🔥 2023/07/14: Add a "wider result" button in text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/95)).
         | 
| 25 | 
            +
             | 
| 26 | 
             
                    🔥 2023/06/29: Support the text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/47)).
         | 
| 27 |  | 
| 28 | 
             
                    🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
         | 
|  | |
| 78 | 
             
                withContours=True,
         | 
| 79 | 
             
                use_retina=True,
         | 
| 80 | 
             
                text="",
         | 
| 81 | 
            +
                wider=False,
         | 
| 82 | 
             
                mask_random_color=True,
         | 
| 83 | 
             
            ):
         | 
| 84 | 
             
                input_size = int(input_size)  # 确保 imgsz 是整数
         | 
|  | |
| 98 |  | 
| 99 | 
             
                if len(text) > 0:
         | 
| 100 | 
             
                    results = format_results(results[0], 0)
         | 
| 101 | 
            +
                    annotations, _ = text_prompt(results, text, input, device=device, wider=wider)
         | 
| 102 | 
             
                    annotations = np.array([annotations])
         | 
| 103 | 
             
                else:
         | 
| 104 | 
             
                    annotations = results[0].masks.data
         | 
|  | |
| 192 | 
             
            global_points = []
         | 
| 193 | 
             
            global_point_label = []
         | 
| 194 |  | 
| 195 | 
            +
            input_size_slider = gr.components.Slider(minimum=512,
         | 
| 196 | 
             
                                                     maximum=1024,
         | 
| 197 | 
             
                                                     value=1024,
         | 
| 198 | 
             
                                                     step=64,
         | 
|  | |
| 221 | 
             
                    # Submit & Clear
         | 
| 222 | 
             
                    with gr.Row():
         | 
| 223 | 
             
                        with gr.Column():
         | 
| 224 | 
            +
                            input_size_slider.render()
         | 
| 225 |  | 
| 226 | 
             
                            with gr.Row():
         | 
| 227 | 
            +
                                contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
         | 
| 228 |  | 
| 229 | 
             
                                with gr.Column():
         | 
| 230 | 
             
                                    segment_btn_e = gr.Button("Segment Everything", variant='primary')
         | 
|  | |
| 240 |  | 
| 241 | 
             
                        with gr.Column():
         | 
| 242 | 
             
                            with gr.Accordion("Advanced options", open=False):
         | 
| 243 | 
            +
                                iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
         | 
| 244 | 
            +
                                conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
         | 
|  | |
| 245 | 
             
                                with gr.Row():
         | 
| 246 | 
            +
                                    mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
         | 
| 247 | 
             
                                    with gr.Column():
         | 
| 248 | 
            +
                                        retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
         | 
| 249 | 
            +
             | 
| 250 | 
             
                            # Description
         | 
| 251 | 
             
                            gr.Markdown(description_e)
         | 
| 252 |  | 
| 253 | 
            +
                segment_btn_e.click(segment_everything,
         | 
| 254 | 
            +
                                    inputs=[
         | 
| 255 | 
            +
                                        cond_img_e,
         | 
| 256 | 
            +
                                        input_size_slider,
         | 
| 257 | 
            +
                                        iou_threshold,
         | 
| 258 | 
            +
                                        conf_threshold,
         | 
| 259 | 
            +
                                        mor_check,
         | 
| 260 | 
            +
                                        contour_check,
         | 
| 261 | 
            +
                                        retina_check,
         | 
| 262 | 
            +
                                    ],
         | 
| 263 | 
            +
                                    outputs=segm_img_e)
         | 
| 264 | 
            +
             | 
| 265 | 
             
                with gr.Tab("Points mode"):
         | 
| 266 | 
             
                    # Images
         | 
| 267 | 
             
                    with gr.Row(variant="panel"):
         | 
|  | |
| 292 | 
             
                        with gr.Column():
         | 
| 293 | 
             
                            # Description
         | 
| 294 | 
             
                            gr.Markdown(description_p)
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                segment_btn_p.click(segment_with_points,
         | 
| 299 | 
            +
                                    inputs=[cond_img_p],
         | 
| 300 | 
            +
                                    outputs=[segm_img_p, cond_img_p])
         | 
| 301 | 
            +
             | 
| 302 | 
             
                with gr.Tab("Text mode"):
         | 
| 303 | 
             
                    # Images
         | 
| 304 | 
             
                    with gr.Row(variant="panel"):
         | 
|  | |
| 312 | 
             
                    with gr.Row():
         | 
| 313 | 
             
                        with gr.Column():
         | 
| 314 | 
             
                            input_size_slider_t = gr.components.Slider(minimum=512,
         | 
| 315 | 
            +
                                                                       maximum=1024,
         | 
| 316 | 
            +
                                                                       value=1024,
         | 
| 317 | 
            +
                                                                       step=64,
         | 
| 318 | 
            +
                                                                       label='Input_size',
         | 
| 319 | 
            +
                                                                       info='Our model was trained on a size of 1024')
         | 
| 320 | 
             
                            with gr.Row():
         | 
| 321 | 
             
                                with gr.Column():
         | 
| 322 | 
            +
                                    contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
         | 
| 323 | 
             
                                    text_box = gr.Textbox(label="text prompt", value="a black dog")
         | 
| 324 |  | 
| 325 | 
             
                                with gr.Column():
         | 
|  | |
| 327 | 
             
                                    clear_btn_t = gr.Button("Clear", variant="secondary")
         | 
| 328 |  | 
| 329 | 
             
                            gr.Markdown("Try some of the examples below ⬇️")
         | 
| 330 | 
            +
                            gr.Examples(examples=[["examples/dogs.jpg"]] + examples,
         | 
| 331 | 
             
                                        inputs=[cond_img_e],
         | 
| 332 | 
             
                                        # outputs=segm_img_e,
         | 
| 333 | 
             
                                        # fn=segment_everything,
         | 
|  | |
| 336 |  | 
| 337 | 
             
                        with gr.Column():
         | 
| 338 | 
             
                            with gr.Accordion("Advanced options", open=False):
         | 
| 339 | 
            +
                                iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
         | 
| 340 | 
            +
                                conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
         | 
| 341 | 
             
                                with gr.Row():
         | 
| 342 | 
            +
                                    mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
         | 
| 343 | 
            +
                                    retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
         | 
| 344 | 
            +
                                    wider_check = gr.Checkbox(value=False, label='wider', info='wider result')
         | 
| 345 |  | 
| 346 | 
             
                            # Description
         | 
| 347 | 
             
                            gr.Markdown(description_e)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 348 |  | 
| 349 | 
             
                segment_btn_t.click(segment_everything,
         | 
| 350 | 
             
                                    inputs=[
         | 
| 351 | 
             
                                        cond_img_t,
         | 
| 352 | 
             
                                        input_size_slider_t,
         | 
| 353 | 
            +
                                        iou_threshold,
         | 
| 354 | 
            +
                                        conf_threshold,
         | 
| 355 | 
            +
                                        mor_check,
         | 
| 356 | 
            +
                                        contour_check,
         | 
| 357 | 
            +
                                        retina_check,
         | 
| 358 | 
             
                                        text_box,
         | 
| 359 | 
            +
                                        wider_check,
         | 
| 360 | 
             
                                    ],
         | 
| 361 | 
             
                                    outputs=segm_img_t)
         | 
| 362 |  | 
|  | |
| 365 |  | 
| 366 | 
             
                def clear_text():
         | 
| 367 | 
             
                    return None, None, None
         | 
| 368 | 
            +
             | 
| 369 | 
             
                clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
         | 
| 370 | 
             
                clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
         | 
| 371 | 
             
                clear_btn_t.click(clear_text, outputs=[cond_img_p, segm_img_p, text_box])
         | 
    	
        utils/tools.py
    CHANGED
    
    | @@ -9,11 +9,14 @@ import clip | |
| 9 |  | 
| 10 |  | 
| 11 | 
             
            def convert_box_xywh_to_xyxy(box):
         | 
| 12 | 
            -
                 | 
| 13 | 
            -
             | 
| 14 | 
            -
                 | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
|  | |
|  | |
|  | |
| 17 |  | 
| 18 |  | 
| 19 | 
             
            def segment_image(image, bbox):
         | 
| @@ -323,8 +326,8 @@ def fast_show_mask_gpu( | |
| 323 | 
             
            # clip
         | 
| 324 | 
             
            @torch.no_grad()
         | 
| 325 | 
             
            def retriev(
         | 
| 326 | 
            -
                model, preprocess, elements, search_text: str, device
         | 
| 327 | 
            -
            ) | 
| 328 | 
             
                preprocessed_images = [preprocess(image).to(device) for image in elements]
         | 
| 329 | 
             
                tokenized_text = clip.tokenize([search_text]).to(device)
         | 
| 330 | 
             
                stacked_images = torch.stack(preprocessed_images)
         | 
| @@ -348,19 +351,16 @@ def crop_image(annotations, image_like): | |
| 348 | 
             
                cropped_boxes = []
         | 
| 349 | 
             
                cropped_images = []
         | 
| 350 | 
             
                not_crop = []
         | 
| 351 | 
            -
                 | 
| 352 | 
            -
                # annotations, _ = filter_masks(annotations)
         | 
| 353 | 
            -
                # filter_id = list(_)
         | 
| 354 | 
             
                for _, mask in enumerate(annotations):
         | 
| 355 | 
             
                    if np.sum(mask["segmentation"]) <= 100:
         | 
| 356 | 
            -
                        filter_id.append(_)
         | 
| 357 | 
             
                        continue
         | 
|  | |
| 358 | 
             
                    bbox = get_bbox_from_mask(mask["segmentation"])  # mask 的 bbox
         | 
| 359 | 
             
                    cropped_boxes.append(segment_image(image, bbox))  # 保存裁剪的图片
         | 
| 360 | 
             
                    # cropped_boxes.append(segment_image(image,mask["segmentation"]))
         | 
| 361 | 
             
                    cropped_images.append(bbox)  # 保存裁剪的图片的bbox
         | 
| 362 | 
            -
             | 
| 363 | 
            -
                return cropped_boxes, cropped_images, not_crop, filter_id, annotations
         | 
| 364 |  | 
| 365 |  | 
| 366 | 
             
            def box_prompt(masks, bbox, target_height, target_width):
         | 
| @@ -415,8 +415,8 @@ def point_prompt(masks, points, point_label, target_height, target_width):  # nu | |
| 415 | 
             
                return onemask, 0
         | 
| 416 |  | 
| 417 |  | 
| 418 | 
            -
            def text_prompt(annotations, text, img_path, device):
         | 
| 419 | 
            -
                cropped_boxes, cropped_images, not_crop,  | 
| 420 | 
             
                    annotations, img_path
         | 
| 421 | 
             
                )
         | 
| 422 | 
             
                clip_model, preprocess = clip.load("./weights/CLIP_ViT_B_32.pt", device=device)
         | 
| @@ -425,5 +425,18 @@ def text_prompt(annotations, text, img_path, device): | |
| 425 | 
             
                )
         | 
| 426 | 
             
                max_idx = scores.argsort()
         | 
| 427 | 
             
                max_idx = max_idx[-1]
         | 
| 428 | 
            -
                max_idx  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 429 | 
             
                return annotations_[max_idx]["segmentation"], max_idx
         | 
|  | |
| 9 |  | 
| 10 |  | 
| 11 | 
             
            def convert_box_xywh_to_xyxy(box):
         | 
| 12 | 
            +
                if len(box) == 4:
         | 
| 13 | 
            +
                    return [box[0], box[1], box[0] + box[2], box[1] + box[3]]
         | 
| 14 | 
            +
                else:
         | 
| 15 | 
            +
                    result = []
         | 
| 16 | 
            +
                    for b in box:
         | 
| 17 | 
            +
                        b = convert_box_xywh_to_xyxy(b)
         | 
| 18 | 
            +
                        result.append(b)               
         | 
| 19 | 
            +
                return result
         | 
| 20 |  | 
| 21 |  | 
| 22 | 
             
            def segment_image(image, bbox):
         | 
|  | |
| 326 | 
             
            # clip
         | 
| 327 | 
             
            @torch.no_grad()
         | 
| 328 | 
             
            def retriev(
         | 
| 329 | 
            +
                model, preprocess, elements: [Image.Image], search_text: str, device
         | 
| 330 | 
            +
            ):
         | 
| 331 | 
             
                preprocessed_images = [preprocess(image).to(device) for image in elements]
         | 
| 332 | 
             
                tokenized_text = clip.tokenize([search_text]).to(device)
         | 
| 333 | 
             
                stacked_images = torch.stack(preprocessed_images)
         | 
|  | |
| 351 | 
             
                cropped_boxes = []
         | 
| 352 | 
             
                cropped_images = []
         | 
| 353 | 
             
                not_crop = []
         | 
| 354 | 
            +
                origin_id = []
         | 
|  | |
|  | |
| 355 | 
             
                for _, mask in enumerate(annotations):
         | 
| 356 | 
             
                    if np.sum(mask["segmentation"]) <= 100:
         | 
|  | |
| 357 | 
             
                        continue
         | 
| 358 | 
            +
                    origin_id.append(_)
         | 
| 359 | 
             
                    bbox = get_bbox_from_mask(mask["segmentation"])  # mask 的 bbox
         | 
| 360 | 
             
                    cropped_boxes.append(segment_image(image, bbox))  # 保存裁剪的图片
         | 
| 361 | 
             
                    # cropped_boxes.append(segment_image(image,mask["segmentation"]))
         | 
| 362 | 
             
                    cropped_images.append(bbox)  # 保存裁剪的图片的bbox
         | 
| 363 | 
            +
                return cropped_boxes, cropped_images, not_crop, origin_id, annotations
         | 
|  | |
| 364 |  | 
| 365 |  | 
| 366 | 
             
            def box_prompt(masks, bbox, target_height, target_width):
         | 
|  | |
| 415 | 
             
                return onemask, 0
         | 
| 416 |  | 
| 417 |  | 
| 418 | 
            +
            def text_prompt(annotations, text, img_path, device, wider=False, threshold=0.9):
         | 
| 419 | 
            +
                cropped_boxes, cropped_images, not_crop, origin_id, annotations_ = crop_image(
         | 
| 420 | 
             
                    annotations, img_path
         | 
| 421 | 
             
                )
         | 
| 422 | 
             
                clip_model, preprocess = clip.load("./weights/CLIP_ViT_B_32.pt", device=device)
         | 
|  | |
| 425 | 
             
                )
         | 
| 426 | 
             
                max_idx = scores.argsort()
         | 
| 427 | 
             
                max_idx = max_idx[-1]
         | 
| 428 | 
            +
                max_idx = origin_id[int(max_idx)]
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                # find the biggest mask which contains the mask with max score
         | 
| 431 | 
            +
                if wider:
         | 
| 432 | 
            +
                    mask0 = annotations_[max_idx]["segmentation"]
         | 
| 433 | 
            +
                    area0 = np.sum(mask0)
         | 
| 434 | 
            +
                    areas = [(i, np.sum(mask["segmentation"])) for i, mask in enumerate(annotations_) if i in origin_id]
         | 
| 435 | 
            +
                    areas = sorted(areas, key=lambda area: area[1], reverse=True)
         | 
| 436 | 
            +
                    indices = [area[0] for area in areas]
         | 
| 437 | 
            +
                    for index in indices:
         | 
| 438 | 
            +
                        if index == max_idx or np.sum(annotations_[index]["segmentation"] & mask0) / area0 > threshold:
         | 
| 439 | 
            +
                            max_idx = index
         | 
| 440 | 
            +
                            break
         | 
| 441 | 
            +
             | 
| 442 | 
             
                return annotations_[max_idx]["segmentation"], max_idx
         | 
    	
        utils/tools_gradio.py
    CHANGED
    
    | @@ -103,7 +103,7 @@ def fast_show_mask( | |
| 103 | 
             
                annotation = annotation[sorted_indices]
         | 
| 104 |  | 
| 105 | 
             
                index = (annotation != 0).argmax(axis=0)
         | 
| 106 | 
            -
                if random_color | 
| 107 | 
             
                    color = np.random.random((mask_sum, 1, 1, 3))
         | 
| 108 | 
             
                else:
         | 
| 109 | 
             
                    color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
         | 
| @@ -121,7 +121,7 @@ def fast_show_mask( | |
| 121 | 
             
                    x1, y1, x2, y2 = bbox
         | 
| 122 | 
             
                    ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
         | 
| 123 |  | 
| 124 | 
            -
                if retinamask | 
| 125 | 
             
                    mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
         | 
| 126 |  | 
| 127 | 
             
                return mask
         | 
| @@ -145,7 +145,7 @@ def fast_show_mask_gpu( | |
| 145 | 
             
                annotation = annotation[sorted_indices]
         | 
| 146 | 
             
                # 找每个位置第一个非零值下标
         | 
| 147 | 
             
                index = (annotation != 0).to(torch.long).argmax(dim=0)
         | 
| 148 | 
            -
                if random_color | 
| 149 | 
             
                    color = torch.rand((mask_sum, 1, 1, 3)).to(device)
         | 
| 150 | 
             
                else:
         | 
| 151 | 
             
                    color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
         | 
| @@ -168,7 +168,7 @@ def fast_show_mask_gpu( | |
| 168 | 
             
                            (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
         | 
| 169 | 
             
                        )
         | 
| 170 | 
             
                    )
         | 
| 171 | 
            -
                if retinamask | 
| 172 | 
             
                    mask_cpu = cv2.resize(
         | 
| 173 | 
             
                        mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
         | 
| 174 | 
             
                    )
         | 
|  | |
| 103 | 
             
                annotation = annotation[sorted_indices]
         | 
| 104 |  | 
| 105 | 
             
                index = (annotation != 0).argmax(axis=0)
         | 
| 106 | 
            +
                if random_color:
         | 
| 107 | 
             
                    color = np.random.random((mask_sum, 1, 1, 3))
         | 
| 108 | 
             
                else:
         | 
| 109 | 
             
                    color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
         | 
|  | |
| 121 | 
             
                    x1, y1, x2, y2 = bbox
         | 
| 122 | 
             
                    ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
         | 
| 123 |  | 
| 124 | 
            +
                if not retinamask:
         | 
| 125 | 
             
                    mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
         | 
| 126 |  | 
| 127 | 
             
                return mask
         | 
|  | |
| 145 | 
             
                annotation = annotation[sorted_indices]
         | 
| 146 | 
             
                # 找每个位置第一个非零值下标
         | 
| 147 | 
             
                index = (annotation != 0).to(torch.long).argmax(dim=0)
         | 
| 148 | 
            +
                if random_color:
         | 
| 149 | 
             
                    color = torch.rand((mask_sum, 1, 1, 3)).to(device)
         | 
| 150 | 
             
                else:
         | 
| 151 | 
             
                    color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
         | 
|  | |
| 168 | 
             
                            (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
         | 
| 169 | 
             
                        )
         | 
| 170 | 
             
                    )
         | 
| 171 | 
            +
                if not retinamask:
         | 
| 172 | 
             
                    mask_cpu = cv2.resize(
         | 
| 173 | 
             
                        mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
         | 
| 174 | 
             
                    )
         | 
