goryhon commited on
Commit
a8566ea
·
verified ·
1 Parent(s): f003ca6

Update web-demos/hugging_face/app.py

Browse files
Files changed (1) hide show
  1. web-demos/hugging_face/app.py +29 -7
web-demos/hugging_face/app.py CHANGED
@@ -155,6 +155,23 @@ def get_end_number(track_pause_number_slider, video_state, interactive_state):
155
 
156
  return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log, operation_log
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  # use sam to get the mask
159
  def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
160
  """
@@ -622,13 +639,17 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=css) as iface:
622
  step2_title = gr.Markdown("---\n## Step2: Add masks", visible=False)
623
  with gr.Row(equal_height=True):
624
  with gr.Column(scale=2):
625
- template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
626
  image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False)
627
  track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)
628
  with gr.Column(scale=2, elem_classes="jc_center"):
629
- run_status2 = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get video info button to get started! (Kindly ensure that the uploaded video consists of fewer than 500 frames in total)", "Normal")],
630
- color_map={"Normal": "green", "Error": "red", "Clear clicks": "gray", "Add mask": "green", "Remove mask": "red"},
631
- visible=False)
 
 
 
 
632
  with gr.Column():
633
  point_prompt = gr.Radio(
634
  choices=["Positive", "Negative"],
@@ -675,11 +696,12 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=css) as iface:
675
 
676
  # click select image to get mask using sam
677
  template_frame.select(
678
- fn=sam_refine,
679
- inputs=[video_state, point_prompt, click_state, interactive_state],
680
- outputs=[template_frame, video_state, interactive_state, run_status, run_status2]
681
  )
682
 
 
683
  # add different mask
684
  Add_mask_button.click(
685
  fn=add_multi_mask,
 
155
 
156
  return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log, operation_log
157
 
158
+ def draw_refine(video_state, template_frame):
159
+ if isinstance(template_frame, dict) and "mask" in template_frame:
160
+ mask = template_frame["mask"]
161
+ if mask is not None:
162
+ mask = (np.array(mask) > 127).astype(np.uint8)
163
+ painted_image = mask_painter(
164
+ video_state["origin_images"][video_state["select_frame_number"]].copy(),
165
+ mask,
166
+ mask_color=1
167
+ )
168
+ video_state["masks"][video_state["select_frame_number"]] = mask
169
+ video_state["painted_images"][video_state["select_frame_number"]] = painted_image
170
+ operation_log = [("Manual mask added by drawing", "Normal")]
171
+ return painted_image, video_state, operation_log, operation_log
172
+ operation_log = [("No mask drawn", "Error")]
173
+ return video_state["painted_images"][video_state["select_frame_number"]], video_state, operation_log, operation_log
174
+
175
  # use sam to get the mask
176
  def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
177
  """
 
639
  step2_title = gr.Markdown("---\n## Step2: Add masks", visible=False)
640
  with gr.Row(equal_height=True):
641
  with gr.Column(scale=2):
642
+ template_frame = gr.Image(type="numpy",tool="sketch",brush_radius=4,label="Draw mask manually or click to add",interactive=True,visible=False,elem_id="template_frame",elem_classes="image")
643
  image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False)
644
  track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)
645
  with gr.Column(scale=2, elem_classes="jc_center"):
646
+ mask_mode_selector = gr.Radio(
647
+ choices=["Click", "Draw"],
648
+ value="Click",
649
+ label="Mask input mode",
650
+ interactive=True,
651
+ visible=False
652
+ )
653
  with gr.Column():
654
  point_prompt = gr.Radio(
655
  choices=["Positive", "Negative"],
 
696
 
697
  # click select image to get mask using sam
698
  template_frame.select(
699
+ fn=lambda *args: sam_refine(*args) if args[4] == "Click" else draw_refine(args[1], args[0]),
700
+ inputs=[template_frame, video_state, point_prompt, click_state, mask_mode_selector, interactive_state],
701
+ outputs=[template_frame, video_state, run_status, run_status2]
702
  )
703
 
704
+
705
  # add different mask
706
  Add_mask_button.click(
707
  fn=add_multi_mask,