Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -299,7 +299,7 @@ def get_mask_sam_process( | |
| 299 | 
             
                    available_frames_to_check.append(working_frame)
         | 
| 300 | 
             
                    print(available_frames_to_check)
         | 
| 301 |  | 
| 302 | 
            -
                return "output_first_frame.jpg", frame_names, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
         | 
| 303 |  | 
| 304 | 
             
            def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, progress=gr.Progress(track_tqdm=True)):   
         | 
| 305 | 
             
                #### PROPAGATION ####
         | 
| @@ -392,12 +392,19 @@ def switch_working_frame(working_frame, scanned_frames, video_frames_dir): | |
| 392 | 
             
                        new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
         | 
| 393 | 
             
                        return new_working_frame, gr.State([]), gr.State([]), new_working_frame, new_working_frame, new_working_frame
         | 
| 394 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 395 | 
             
            with gr.Blocks() as demo:
         | 
| 396 | 
             
                first_frame_path = gr.State()
         | 
| 397 | 
             
                tracking_points = gr.State([])
         | 
| 398 | 
             
                trackings_input_label = gr.State([])
         | 
| 399 | 
             
                video_frames_dir = gr.State()
         | 
| 400 | 
             
                scanned_frames = gr.State()
         | 
|  | |
| 401 | 
             
                stored_inference_state = gr.State()
         | 
| 402 | 
             
                stored_frame_names = gr.State()
         | 
| 403 | 
             
                available_frames_to_check = gr.State([])
         | 
| @@ -442,6 +449,7 @@ with gr.Blocks() as demo: | |
| 442 | 
             
                            with gr.Row():
         | 
| 443 | 
             
                                vis_frame_type = gr.Radio(label="Propagation level", choices=["check", "render"], value="check", scale=2)
         | 
| 444 | 
             
                                propagate_btn = gr.Button("Propagate", scale=1)
         | 
|  | |
| 445 | 
             
                            output_propagated = gr.Gallery(label="Propagated Mask samples gallery", visible=False)
         | 
| 446 | 
             
                            output_video = gr.Video(visible=False)
         | 
| 447 | 
             
                            # output_result_mask = gr.Image()
         | 
| @@ -524,11 +532,19 @@ with gr.Blocks() as demo: | |
| 524 | 
             
                    outputs = [
         | 
| 525 | 
             
                        output_result, 
         | 
| 526 | 
             
                        stored_frame_names, 
         | 
|  | |
| 527 | 
             
                        stored_inference_state,
         | 
| 528 | 
             
                        working_frame,
         | 
| 529 | 
             
                    ]
         | 
| 530 | 
             
                )
         | 
| 531 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 532 | 
             
                propagate_btn.click(
         | 
| 533 | 
             
                    fn = update_ui,
         | 
| 534 | 
             
                    inputs = [vis_frame_type],
         | 
|  | |
| 299 | 
             
                    available_frames_to_check.append(working_frame)
         | 
| 300 | 
             
                    print(available_frames_to_check)
         | 
| 301 |  | 
| 302 | 
            +
                return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
         | 
| 303 |  | 
| 304 | 
             
            def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, progress=gr.Progress(track_tqdm=True)):   
         | 
| 305 | 
             
                #### PROPAGATION ####
         | 
|  | |
| 392 | 
             
                        new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
         | 
| 393 | 
             
                        return new_working_frame, gr.State([]), gr.State([]), new_working_frame, new_working_frame, new_working_frame
         | 
| 394 |  | 
| 395 | 
            +
            def reset_propagation(predictor, stored_inference_state):
         | 
| 396 | 
            +
                
         | 
| 397 | 
            +
                predictor.reset_state(stored_inference_state)
         | 
| 398 | 
            +
                print(f"RESET State: {stored_inference_state} ")
         | 
| 399 | 
            +
                return stored_inference_state
         | 
| 400 | 
            +
                
         | 
| 401 | 
             
            with gr.Blocks() as demo:
         | 
| 402 | 
             
                first_frame_path = gr.State()
         | 
| 403 | 
             
                tracking_points = gr.State([])
         | 
| 404 | 
             
                trackings_input_label = gr.State([])
         | 
| 405 | 
             
                video_frames_dir = gr.State()
         | 
| 406 | 
             
                scanned_frames = gr.State()
         | 
| 407 | 
            +
                loaded_predictor = gr.State()
         | 
| 408 | 
             
                stored_inference_state = gr.State()
         | 
| 409 | 
             
                stored_frame_names = gr.State()
         | 
| 410 | 
             
                available_frames_to_check = gr.State([])
         | 
|  | |
| 449 | 
             
                            with gr.Row():
         | 
| 450 | 
             
                                vis_frame_type = gr.Radio(label="Propagation level", choices=["check", "render"], value="check", scale=2)
         | 
| 451 | 
             
                                propagate_btn = gr.Button("Propagate", scale=1)
         | 
| 452 | 
            +
                                reset_prpgt_brn = gr.Button("Reset", scale=0.75)
         | 
| 453 | 
             
                            output_propagated = gr.Gallery(label="Propagated Mask samples gallery", visible=False)
         | 
| 454 | 
             
                            output_video = gr.Video(visible=False)
         | 
| 455 | 
             
                            # output_result_mask = gr.Image()
         | 
|  | |
| 532 | 
             
                    outputs = [
         | 
| 533 | 
             
                        output_result, 
         | 
| 534 | 
             
                        stored_frame_names, 
         | 
| 535 | 
            +
                        loaded_predictor,
         | 
| 536 | 
             
                        stored_inference_state,
         | 
| 537 | 
             
                        working_frame,
         | 
| 538 | 
             
                    ]
         | 
| 539 | 
             
                )
         | 
| 540 |  | 
| 541 | 
            +
                reset_prpgt_brn.click(
         | 
| 542 | 
            +
                    fn = reset_propagation,
         | 
| 543 | 
            +
                    inputs = [loaded_predictor, stored_inference_state],
         | 
| 544 | 
            +
                    outputs = [stored_inference_state],
         | 
| 545 | 
            +
                    queue=False
         | 
| 546 | 
            +
                )
         | 
| 547 | 
            +
             | 
| 548 | 
             
                propagate_btn.click(
         | 
| 549 | 
             
                    fn = update_ui,
         | 
| 550 | 
             
                    inputs = [vis_frame_type],
         | 
