goryhon commited on
Commit
b1cd907
·
verified ·
1 Parent(s): 42dd46f

Update web-demos/hugging_face/app.py

Browse files
Files changed (1) hide show
  1. web-demos/hugging_face/app.py +31 -72
web-demos/hugging_face/app.py CHANGED
@@ -156,23 +156,6 @@ def get_end_number(track_pause_number_slider, video_state, interactive_state):
156
 
157
  return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log, operation_log
158
 
159
- def draw_refine(video_state, template_frame):
160
- if isinstance(template_frame, dict) and "mask" in template_frame:
161
- mask = template_frame["mask"]
162
- if mask is not None:
163
- mask = (np.array(mask) > 127).astype(np.uint8)
164
- painted_image = mask_painter(
165
- video_state["origin_images"][video_state["select_frame_number"]].copy(),
166
- mask,
167
- mask_color=1
168
- )
169
- video_state["masks"][video_state["select_frame_number"]] = mask
170
- video_state["painted_images"][video_state["select_frame_number"]] = painted_image
171
- operation_log = [("Manual mask added by drawing", "Normal")]
172
- return painted_image, video_state, operation_log, operation_log
173
- operation_log = [("No mask drawn", "Error")]
174
- return video_state["painted_images"][video_state["select_frame_number"]], video_state, operation_log, operation_log
175
-
176
  # use sam to get the mask
177
  def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
178
  """
@@ -209,26 +192,17 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
209
  ("[Optional]", "Click image"), (": Try to click the image shown in step2 if you want to generate more masks.\n", None)]
210
  return painted_image, video_state, interactive_state, operation_log, operation_log
211
 
212
- def add_multi_mask(video_state, interactive_state, mask_dropdown, mask_input_mode, template_frame):
213
  try:
214
- if mask_input_mode == "Click":
215
- # 🟢 Стандартная логика SAM
216
- mask = video_state["masks"][video_state["select_frame_number"]]
217
- else:
218
- # ⚪ Рисованная маска
219
- drawn_mask = template_frame # это уже numpy-массив из gr.Sketchpad
220
- if drawn_mask is None:
221
- raise ValueError("No mask was drawn.")
222
- mask = (np.array(drawn_mask)[:, :, 0] > 127).astype(np.uint8) # бинаризация
223
-
224
  interactive_state["multi_mask"]["masks"].append(mask)
225
  interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
226
  mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
227
 
228
  select_frame, _, _ = show_mask(video_state, interactive_state, mask_dropdown)
229
  operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")]
230
- except Exception as e:
231
- operation_log = [(f"Error: {str(e)}", "Error"), ("","")]
232
 
233
  return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log, operation_log
234
 
@@ -635,65 +609,50 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=css) as iface:
635
  value=10,)
636
 
637
  with gr.Column():
638
- # Step 1: Upload video
639
  gr.Markdown("## Step1: Upload video")
640
  with gr.Row(equal_height=True):
641
  with gr.Column(scale=2):
642
  video_input = gr.Video(elem_classes="video")
643
  extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
644
  with gr.Column(scale=2):
645
- run_status = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")],
646
  color_map={"Normal": "green", "Error": "red", "Clear clicks": "gray", "Add mask": "green", "Remove mask": "red"})
647
  video_info = gr.Textbox(label="Video Info")
648
-
649
- # Step 2: Add masks
 
650
  step2_title = gr.Markdown("---\n## Step2: Add masks", visible=False)
651
  with gr.Row(equal_height=True):
652
  with gr.Column(scale=2):
653
  brush_settings = gr.Brush(
654
- default_size=10,
655
- colors=["#FFFFFF"],
656
  default_color="#FFFFFF"
657
  )
658
- template_frame = gr.Sketchpad(label="Draw mask manually or click to add",canvas_size=(1920, 1080),type="numpy",interactive=True,visible=False,tool="brush", brush=4,elem_id="template_frame",elem_classes="image")
659
-
660
-
661
  image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False)
662
  track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)
663
  with gr.Column(scale=2, elem_classes="jc_center"):
664
- run_status2 = gr.HighlightedText(value=[("",""), ("Use click or draw to select mask, then track or inpaint.", "Normal")],
665
  color_map={"Normal": "green", "Error": "red", "Clear clicks": "gray", "Add mask": "green", "Remove mask": "red"},
666
  visible=False)
667
- mask_input_mode = gr.Radio(
668
- choices=["Click", "Draw"],
669
- value="Click",
670
- label="Select Mask Input Mode",
671
- interactive=True,
672
- visible=True
673
- )
674
- mask_mode_selector = gr.Radio(
675
- choices=["Click", "Draw"],
676
- value="Click",
677
- label="Mask input mode",
678
- interactive=True,
679
- visible=False
680
- )
681
- point_prompt = gr.Radio(
682
- choices=["Positive", "Negative"],
683
- value="Positive",
684
- label="Point prompt",
685
- interactive=True,
686
- visible=False,
687
- min_width=100,
688
- scale=1
689
- )
690
- with gr.Row(elem_classes="mask_button_group"):
691
- Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False, elem_classes="add_button")
692
- remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False, elem_classes="remove_button")
693
- clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False, elem_classes="clear_button")
694
  mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
695
-
696
- # Step 3: Tracking & Inpainting
697
  step3_title = gr.Markdown("---\n## Step3: Track masks and get the inpainting result", visible=False)
698
  with gr.Row(equal_height=True):
699
  with gr.Column(scale=2):
@@ -724,9 +683,9 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=css) as iface:
724
 
725
  # click select image to get mask using sam
726
  template_frame.select(
727
- fn=lambda *args: sam_refine(*args) if args[4] == "Click" else draw_refine(args[1], args[0]),
728
- inputs=[template_frame, video_state, point_prompt, click_state, mask_mode_selector, interactive_state],
729
- outputs=[template_frame, video_state, run_status, run_status2]
730
  )
731
 
732
 
 
156
 
157
  return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log, operation_log
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  # use sam to get the mask
160
  def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
161
  """
 
192
  ("[Optional]", "Click image"), (": Try to click the image shown in step2 if you want to generate more masks.\n", None)]
193
  return painted_image, video_state, interactive_state, operation_log, operation_log
194
 
195
+ def add_multi_mask(video_state, interactive_state, mask_dropdown):
196
  try:
197
+ mask = video_state["masks"][video_state["select_frame_number"]]
 
 
 
 
 
 
 
 
 
198
  interactive_state["multi_mask"]["masks"].append(mask)
199
  interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
200
  mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
201
 
202
  select_frame, _, _ = show_mask(video_state, interactive_state, mask_dropdown)
203
  operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")]
204
+ except:
205
+ operation_log = [("Please click the image in step2 to generate masks.", "Error"), ("","")]
206
 
207
  return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log, operation_log
208
 
 
609
  value=10,)
610
 
611
  with gr.Column():
612
+ # input video
613
  gr.Markdown("## Step1: Upload video")
614
  with gr.Row(equal_height=True):
615
  with gr.Column(scale=2):
616
  video_input = gr.Video(elem_classes="video")
617
  extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
618
  with gr.Column(scale=2):
619
+ run_status = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get video info button to get started! (Kindly ensure that the uploaded video consists of fewer than 500 frames in total)", "Normal")],
620
  color_map={"Normal": "green", "Error": "red", "Clear clicks": "gray", "Add mask": "green", "Remove mask": "red"})
621
  video_info = gr.Textbox(label="Video Info")
622
+
623
+
624
+ # add masks
625
  step2_title = gr.Markdown("---\n## Step2: Add masks", visible=False)
626
  with gr.Row(equal_height=True):
627
  with gr.Column(scale=2):
628
  brush_settings = gr.Brush(
629
+ default_size=2,
630
+ colors=["#FFFFFF"], # Белый цвет кисти
631
  default_color="#FFFFFF"
632
  )
633
+ template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
 
 
634
  image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False)
635
  track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)
636
  with gr.Column(scale=2, elem_classes="jc_center"):
637
+ run_status2 = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get video info button to get started! (Kindly ensure that the uploaded video consists of fewer than 500 frames in total)", "Normal")],
638
  color_map={"Normal": "green", "Error": "red", "Clear clicks": "gray", "Add mask": "green", "Remove mask": "red"},
639
  visible=False)
640
+ with gr.Column():
641
+ point_prompt = gr.Radio(
642
+ choices=["Positive", "Negative"],
643
+ value="Positive",
644
+ label="Point prompt",
645
+ interactive=True,
646
+ visible=False,
647
+ min_width=100,
648
+ scale=1,)
649
+ with gr.Row(elem_classes="mask_button_group"):
650
+ Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False, elem_classes="add_button")
651
+ remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False, elem_classes="remove_button")
652
+ clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False, elem_classes="clear_button")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
654
+
655
+ # output video
656
  step3_title = gr.Markdown("---\n## Step3: Track masks and get the inpainting result", visible=False)
657
  with gr.Row(equal_height=True):
658
  with gr.Column(scale=2):
 
683
 
684
  # click select image to get mask using sam
685
  template_frame.select(
686
+ fn=sam_refine,
687
+ inputs=[video_state, point_prompt, click_state, interactive_state],
688
+ outputs=[template_frame, video_state, interactive_state, run_status, run_status2]
689
  )
690
 
691