David310 commited on
Commit
9b37148
·
verified ·
1 Parent(s): c132dda

Update app.py

Browse files

previous change invalid

Files changed (1) hide show
  1. app.py +35 -47
app.py CHANGED
@@ -2,7 +2,6 @@ import subprocess
2
  import re
3
  from typing import List, Tuple, Optional
4
  import pickle
5
- from datetime import datetime
6
 
7
  # Define the command to be executed
8
  command = ["python", "setup.py", "build_ext", "--inplace"]
@@ -40,7 +39,7 @@ def get_video_properties(video_path):
40
  if not cap.isOpened():
41
  print("Error: Could not open video.")
42
  return None
43
- p
44
  # Get the FPS of the video
45
  fps = cap.get(cv2.CAP_PROP_FPS)
46
 
@@ -81,27 +80,26 @@ def preprocess_video_in(video_path):
81
  if not cap.isOpened():
82
  print("Error: Could not open video.")
83
  return None
84
-
85
  # Get the frames per second (FPS) of the video
86
  fps = cap.get(cv2.CAP_PROP_FPS)
87
- """
88
  # Calculate the number of frames to process (10 seconds of video)
89
- # 需解除10s的视频限制
90
  max_frames = int(fps * 10)
91
- """
92
-
93
  frame_number = 0
94
  first_frame = None
95
 
96
  while True:
97
  ret, frame = cap.read()
98
  """
99
- 解除10s的视频限制
100
  if not ret or frame_number >= max_frames:
101
  break
102
  """
103
  if not ret:
104
  break
 
105
 
106
  # Format the frame filename as '00000.jpg'
107
  frame_filename = os.path.join(extracted_frames_output_dir, f'{frame_number:05d}.jpg')
@@ -351,6 +349,8 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
351
  masks_images = []
352
 
353
  # run propagation throughout the video and collect the results in a dict
 
 
354
  video_segments = {} # video_segments contains the per-frame segmentation results
355
  for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
356
  video_segments[out_frame_idx] = {
@@ -448,15 +448,8 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
448
  video_writer.release()
449
  print(f"Mask Video saved at {mask_video_filename}")
450
 
451
- # 在函数末尾添加导出mask数据
452
- mask_data_file = export_mask_data(video_segments, video_frames_dir, frame_names)
453
-
454
- if vis_frame_type == "check":
455
- return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True), None, mask_data_file
456
- elif vis_frame_type == "render":
457
- # ... 现有代码 ...
458
- return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True), mask_video_filename, mask_data_file
459
 
 
460
 
461
  def update_ui(vis_frame_type):
462
  if vis_frame_type == "check":
@@ -485,39 +478,25 @@ def reset_propagation(first_frame_path, predictor, stored_inference_state):
485
  # print(f"RESET State: {stored_inference_state} ")
486
  return first_frame_path, gr.State([]), gr.State([]), gr.update(value=None, visible=False), stored_inference_state, None, ["frame_0.jpg"], first_frame_path, "frame_0.jpg", gr.update(visible=False)
487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  css="""
489
  div#component-18, div#component-25, div#component-35, div#component-41{
490
  align-items: stretch!important;
491
  }
492
  """
493
 
494
- def export_mask_data(video_segments, video_frames_dir, scanned_frames):
495
- if not video_segments:
496
- return None
497
-
498
- # 准备导出数据
499
- export_data = {
500
- 'video_dir': video_frames_dir,
501
- 'frame_names': scanned_frames,
502
- 'masks': {}
503
- }
504
-
505
- # 将每一帧的mask数据添加到导出数据中
506
- for frame_idx, masks in video_segments.items():
507
- export_data['masks'][frame_idx] = {
508
- obj_id: mask.tolist() for obj_id, mask in masks.items()
509
- }
510
-
511
- # 生成唯一的文件名
512
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
513
- output_file = f'mask_data_{timestamp}.pkl'
514
-
515
- # 保存数据
516
- with open(output_file, 'wb') as f:
517
- pickle.dump(export_data, f)
518
-
519
- return output_file
520
-
521
  with gr.Blocks(css=css) as demo:
522
  first_frame_path = gr.State()
523
  tracking_points = gr.State([])
@@ -575,9 +554,6 @@ with gr.Blocks(css=css) as demo:
575
  <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg-dark.svg" alt="Duplicate this Space" />
576
  </a> to skip queue and avoid OOM errors from heavy public load
577
  """)
578
-
579
- # mask数据下载
580
- mask_data_file = gr.File(label="Download Mask Data", visible=True)
581
 
582
  with gr.Column():
583
 
@@ -598,6 +574,9 @@ with gr.Blocks(css=css) as demo:
598
  output_video = gr.Video(visible=False)
599
  mask_final_output = gr.Video(label="Mask Video")
600
  # output_result_mask = gr.Image()
 
 
 
601
 
602
 
603
 
@@ -700,7 +679,16 @@ with gr.Blocks(css=css) as demo:
700
  ).then(
701
  fn = propagate_to_all,
702
  inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],
703
- outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn, mask_final_output, mask_data_file]
 
 
 
 
 
 
 
 
 
704
  )
705
 
706
  demo.launch(show_api=False, show_error=True)
 
2
  import re
3
  from typing import List, Tuple, Optional
4
  import pickle
 
5
 
6
  # Define the command to be executed
7
  command = ["python", "setup.py", "build_ext", "--inplace"]
 
39
  if not cap.isOpened():
40
  print("Error: Could not open video.")
41
  return None
42
+
43
  # Get the FPS of the video
44
  fps = cap.get(cv2.CAP_PROP_FPS)
45
 
 
80
  if not cap.isOpened():
81
  print("Error: Could not open video.")
82
  return None
83
+
84
  # Get the frames per second (FPS) of the video
85
  fps = cap.get(cv2.CAP_PROP_FPS)
86
+
87
  # Calculate the number of frames to process (10 seconds of video)
 
88
  max_frames = int(fps * 10)
89
+
 
90
  frame_number = 0
91
  first_frame = None
92
 
93
  while True:
94
  ret, frame = cap.read()
95
  """
96
+ 不再将视频裁剪至10s
97
  if not ret or frame_number >= max_frames:
98
  break
99
  """
100
  if not ret:
101
  break
102
+
103
 
104
  # Format the frame filename as '00000.jpg'
105
  frame_filename = os.path.join(extracted_frames_output_dir, f'{frame_number:05d}.jpg')
 
349
  masks_images = []
350
 
351
  # run propagation throughout the video and collect the results in a dict
352
+ # 添加全局变量保存mask数据
353
+ global video_segments
354
  video_segments = {} # video_segments contains the per-frame segmentation results
355
  for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
356
  video_segments[out_frame_idx] = {
 
448
  video_writer.release()
449
  print(f"Mask Video saved at {mask_video_filename}")
450
 
 
 
 
 
 
 
 
 
451
 
452
+ return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True), mask_video_filename
453
 
454
  def update_ui(vis_frame_type):
455
  if vis_frame_type == "check":
 
478
  # print(f"RESET State: {stored_inference_state} ")
479
  return first_frame_path, gr.State([]), gr.State([]), gr.update(value=None, visible=False), stored_inference_state, None, ["frame_0.jpg"], first_frame_path, "frame_0.jpg", gr.update(visible=False)
480
 
481
+ def export_masks():
482
+ # 导出mask数据
483
+ global video_segments
484
+ if not video_segments:
485
+ raise gr.Error("No mask data available. Please run propagation first!")
486
+
487
+ # 保存为pickle文件
488
+ filename = "mask_data.pkl"
489
+ with open(filename, 'wb') as f:
490
+ pickle.dump(video_segments, f)
491
+
492
+ return filename
493
+
494
  css="""
495
  div#component-18, div#component-25, div#component-35, div#component-41{
496
  align-items: stretch!important;
497
  }
498
  """
499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  with gr.Blocks(css=css) as demo:
501
  first_frame_path = gr.State()
502
  tracking_points = gr.State([])
 
554
  <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg-dark.svg" alt="Duplicate this Space" />
555
  </a> to skip queue and avoid OOM errors from heavy public load
556
  """)
 
 
 
557
 
558
  with gr.Column():
559
 
 
574
  output_video = gr.Video(visible=False)
575
  mask_final_output = gr.Video(label="Mask Video")
576
  # output_result_mask = gr.Image()
577
+ # 在输出部分添加导出按钮
578
+ export_btn = gr.Button("Export Mask Data")
579
+ mask_download = gr.File(label="Download Mask Data", visible=False)
580
 
581
 
582
 
 
679
  ).then(
680
  fn = propagate_to_all,
681
  inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],
682
+ outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn, mask_final_output]
683
+ )
684
+
685
+ # 添加导出按钮的事件绑定
686
+ export_btn.click(
687
+ fn=export_masks,
688
+ outputs=[mask_download]
689
+ ).then(
690
+ fn=lambda: gr.update(visible=True),
691
+ outputs=[mask_download]
692
  )
693
 
694
  demo.launch(show_api=False, show_error=True)