David310 commited on
Commit
c132dda
·
verified ·
1 Parent(s): 4611a01

Update app.py

Browse files

disabled the 10s video restriction, added a button to download the mask pickle

Files changed (1) hide show
  1. app.py +52 -6
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import subprocess
2
  import re
3
  from typing import List, Tuple, Optional
 
 
4
 
5
  # Define the command to be executed
6
  command = ["python", "setup.py", "build_ext", "--inplace"]
@@ -38,7 +40,7 @@ def get_video_properties(video_path):
38
  if not cap.isOpened():
39
  print("Error: Could not open video.")
40
  return None
41
-
42
  # Get the FPS of the video
43
  fps = cap.get(cv2.CAP_PROP_FPS)
44
 
@@ -79,20 +81,27 @@ def preprocess_video_in(video_path):
79
  if not cap.isOpened():
80
  print("Error: Could not open video.")
81
  return None
82
-
83
  # Get the frames per second (FPS) of the video
84
  fps = cap.get(cv2.CAP_PROP_FPS)
85
-
86
  # Calculate the number of frames to process (10 seconds of video)
 
87
  max_frames = int(fps * 10)
88
-
 
89
  frame_number = 0
90
  first_frame = None
91
 
92
  while True:
93
  ret, frame = cap.read()
 
 
94
  if not ret or frame_number >= max_frames:
95
  break
 
 
 
96
 
97
  # Format the frame filename as '00000.jpg'
98
  frame_filename = os.path.join(extracted_frames_output_dir, f'{frame_number:05d}.jpg')
@@ -439,8 +448,15 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
439
  video_writer.release()
440
  print(f"Mask Video saved at {mask_video_filename}")
441
 
 
 
 
 
 
 
 
 
442
 
443
- return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True), mask_video_filename
444
 
445
  def update_ui(vis_frame_type):
446
  if vis_frame_type == "check":
@@ -475,6 +491,33 @@ div#component-18, div#component-25, div#component-35, div#component-41{
475
  }
476
  """
477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  with gr.Blocks(css=css) as demo:
479
  first_frame_path = gr.State()
480
  tracking_points = gr.State([])
@@ -532,6 +575,9 @@ with gr.Blocks(css=css) as demo:
532
  <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg-dark.svg" alt="Duplicate this Space" />
533
  </a> to skip queue and avoid OOM errors from heavy public load
534
  """)
 
 
 
535
 
536
  with gr.Column():
537
 
@@ -654,7 +700,7 @@ with gr.Blocks(css=css) as demo:
654
  ).then(
655
  fn = propagate_to_all,
656
  inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],
657
- outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn, mask_final_output]
658
  )
659
 
660
  demo.launch(show_api=False, show_error=True)
 
1
  import subprocess
2
  import re
3
  from typing import List, Tuple, Optional
4
+ import pickle
5
+ from datetime import datetime
6
 
7
  # Define the command to be executed
8
  command = ["python", "setup.py", "build_ext", "--inplace"]
 
40
  if not cap.isOpened():
41
  print("Error: Could not open video.")
42
  return None
43
+ p
44
  # Get the FPS of the video
45
  fps = cap.get(cv2.CAP_PROP_FPS)
46
 
 
81
  if not cap.isOpened():
82
  print("Error: Could not open video.")
83
  return None
84
+
85
  # Get the frames per second (FPS) of the video
86
  fps = cap.get(cv2.CAP_PROP_FPS)
87
+ """
88
  # Calculate the number of frames to process (10 seconds of video)
89
+ # 需解除10s的视频限制
90
  max_frames = int(fps * 10)
91
+ """
92
+
93
  frame_number = 0
94
  first_frame = None
95
 
96
  while True:
97
  ret, frame = cap.read()
98
+ """
99
+ 解除10s的视频限制
100
  if not ret or frame_number >= max_frames:
101
  break
102
+ """
103
+ if not ret:
104
+ break
105
 
106
  # Format the frame filename as '00000.jpg'
107
  frame_filename = os.path.join(extracted_frames_output_dir, f'{frame_number:05d}.jpg')
 
448
  video_writer.release()
449
  print(f"Mask Video saved at {mask_video_filename}")
450
 
451
+ # 在函数末尾添加导出mask数据
452
+ mask_data_file = export_mask_data(video_segments, video_frames_dir, frame_names)
453
+
454
+ if vis_frame_type == "check":
455
+ return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True), None, mask_data_file
456
+ elif vis_frame_type == "render":
457
+ # ... 现有代码 ...
458
+ return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True), mask_video_filename, mask_data_file
459
 
 
460
 
461
  def update_ui(vis_frame_type):
462
  if vis_frame_type == "check":
 
491
  }
492
  """
493
 
494
+ def export_mask_data(video_segments, video_frames_dir, scanned_frames):
495
+ if not video_segments:
496
+ return None
497
+
498
+ # 准备导出数据
499
+ export_data = {
500
+ 'video_dir': video_frames_dir,
501
+ 'frame_names': scanned_frames,
502
+ 'masks': {}
503
+ }
504
+
505
+ # 将每一帧的mask数据添加到导出数据中
506
+ for frame_idx, masks in video_segments.items():
507
+ export_data['masks'][frame_idx] = {
508
+ obj_id: mask.tolist() for obj_id, mask in masks.items()
509
+ }
510
+
511
+ # 生成唯一的文件名
512
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
513
+ output_file = f'mask_data_{timestamp}.pkl'
514
+
515
+ # 保存数据
516
+ with open(output_file, 'wb') as f:
517
+ pickle.dump(export_data, f)
518
+
519
+ return output_file
520
+
521
  with gr.Blocks(css=css) as demo:
522
  first_frame_path = gr.State()
523
  tracking_points = gr.State([])
 
575
  <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg-dark.svg" alt="Duplicate this Space" />
576
  </a> to skip queue and avoid OOM errors from heavy public load
577
  """)
578
+
579
+ # mask数据下载
580
+ mask_data_file = gr.File(label="Download Mask Data", visible=True)
581
 
582
  with gr.Column():
583
 
 
700
  ).then(
701
  fn = propagate_to_all,
702
  inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],
703
+ outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn, mask_final_output, mask_data_file]
704
  )
705
 
706
  demo.launch(show_api=False, show_error=True)