Update app.py
Browse filesprevious change invalid
app.py
CHANGED
@@ -2,7 +2,6 @@ import subprocess
|
|
2 |
import re
|
3 |
from typing import List, Tuple, Optional
|
4 |
import pickle
|
5 |
-
from datetime import datetime
|
6 |
|
7 |
# Define the command to be executed
|
8 |
command = ["python", "setup.py", "build_ext", "--inplace"]
|
@@ -40,7 +39,7 @@ def get_video_properties(video_path):
|
|
40 |
if not cap.isOpened():
|
41 |
print("Error: Could not open video.")
|
42 |
return None
|
43 |
-
|
44 |
# Get the FPS of the video
|
45 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
46 |
|
@@ -81,27 +80,26 @@ def preprocess_video_in(video_path):
|
|
81 |
if not cap.isOpened():
|
82 |
print("Error: Could not open video.")
|
83 |
return None
|
84 |
-
|
85 |
# Get the frames per second (FPS) of the video
|
86 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
87 |
-
|
88 |
# Calculate the number of frames to process (10 seconds of video)
|
89 |
-
# 需解除10s的视频限制
|
90 |
max_frames = int(fps * 10)
|
91 |
-
|
92 |
-
|
93 |
frame_number = 0
|
94 |
first_frame = None
|
95 |
|
96 |
while True:
|
97 |
ret, frame = cap.read()
|
98 |
"""
|
99 |
-
|
100 |
if not ret or frame_number >= max_frames:
|
101 |
break
|
102 |
"""
|
103 |
if not ret:
|
104 |
break
|
|
|
105 |
|
106 |
# Format the frame filename as '00000.jpg'
|
107 |
frame_filename = os.path.join(extracted_frames_output_dir, f'{frame_number:05d}.jpg')
|
@@ -351,6 +349,8 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
351 |
masks_images = []
|
352 |
|
353 |
# run propagation throughout the video and collect the results in a dict
|
|
|
|
|
354 |
video_segments = {} # video_segments contains the per-frame segmentation results
|
355 |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
|
356 |
video_segments[out_frame_idx] = {
|
@@ -448,15 +448,8 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
448 |
video_writer.release()
|
449 |
print(f"Mask Video saved at {mask_video_filename}")
|
450 |
|
451 |
-
# 在函数末尾添加导出mask数据
|
452 |
-
mask_data_file = export_mask_data(video_segments, video_frames_dir, frame_names)
|
453 |
-
|
454 |
-
if vis_frame_type == "check":
|
455 |
-
return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True), None, mask_data_file
|
456 |
-
elif vis_frame_type == "render":
|
457 |
-
# ... 现有代码 ...
|
458 |
-
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True), mask_video_filename, mask_data_file
|
459 |
|
|
|
460 |
|
461 |
def update_ui(vis_frame_type):
|
462 |
if vis_frame_type == "check":
|
@@ -485,39 +478,25 @@ def reset_propagation(first_frame_path, predictor, stored_inference_state):
|
|
485 |
# print(f"RESET State: {stored_inference_state} ")
|
486 |
return first_frame_path, gr.State([]), gr.State([]), gr.update(value=None, visible=False), stored_inference_state, None, ["frame_0.jpg"], first_frame_path, "frame_0.jpg", gr.update(visible=False)
|
487 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
488 |
css="""
|
489 |
div#component-18, div#component-25, div#component-35, div#component-41{
|
490 |
align-items: stretch!important;
|
491 |
}
|
492 |
"""
|
493 |
|
494 |
-
def export_mask_data(video_segments, video_frames_dir, scanned_frames):
|
495 |
-
if not video_segments:
|
496 |
-
return None
|
497 |
-
|
498 |
-
# 准备导出数据
|
499 |
-
export_data = {
|
500 |
-
'video_dir': video_frames_dir,
|
501 |
-
'frame_names': scanned_frames,
|
502 |
-
'masks': {}
|
503 |
-
}
|
504 |
-
|
505 |
-
# 将每一帧的mask数据添加到导出数据中
|
506 |
-
for frame_idx, masks in video_segments.items():
|
507 |
-
export_data['masks'][frame_idx] = {
|
508 |
-
obj_id: mask.tolist() for obj_id, mask in masks.items()
|
509 |
-
}
|
510 |
-
|
511 |
-
# 生成唯一的文件名
|
512 |
-
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
513 |
-
output_file = f'mask_data_{timestamp}.pkl'
|
514 |
-
|
515 |
-
# 保存数据
|
516 |
-
with open(output_file, 'wb') as f:
|
517 |
-
pickle.dump(export_data, f)
|
518 |
-
|
519 |
-
return output_file
|
520 |
-
|
521 |
with gr.Blocks(css=css) as demo:
|
522 |
first_frame_path = gr.State()
|
523 |
tracking_points = gr.State([])
|
@@ -575,9 +554,6 @@ with gr.Blocks(css=css) as demo:
|
|
575 |
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg-dark.svg" alt="Duplicate this Space" />
|
576 |
</a> to skip queue and avoid OOM errors from heavy public load
|
577 |
""")
|
578 |
-
|
579 |
-
# mask数据下载
|
580 |
-
mask_data_file = gr.File(label="Download Mask Data", visible=True)
|
581 |
|
582 |
with gr.Column():
|
583 |
|
@@ -598,6 +574,9 @@ with gr.Blocks(css=css) as demo:
|
|
598 |
output_video = gr.Video(visible=False)
|
599 |
mask_final_output = gr.Video(label="Mask Video")
|
600 |
# output_result_mask = gr.Image()
|
|
|
|
|
|
|
601 |
|
602 |
|
603 |
|
@@ -700,7 +679,16 @@ with gr.Blocks(css=css) as demo:
|
|
700 |
).then(
|
701 |
fn = propagate_to_all,
|
702 |
inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],
|
703 |
-
outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn, mask_final_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
704 |
)
|
705 |
|
706 |
demo.launch(show_api=False, show_error=True)
|
|
|
2 |
import re
|
3 |
from typing import List, Tuple, Optional
|
4 |
import pickle
|
|
|
5 |
|
6 |
# Define the command to be executed
|
7 |
command = ["python", "setup.py", "build_ext", "--inplace"]
|
|
|
39 |
if not cap.isOpened():
|
40 |
print("Error: Could not open video.")
|
41 |
return None
|
42 |
+
|
43 |
# Get the FPS of the video
|
44 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
45 |
|
|
|
80 |
if not cap.isOpened():
|
81 |
print("Error: Could not open video.")
|
82 |
return None
|
83 |
+
|
84 |
# Get the frames per second (FPS) of the video
|
85 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
86 |
+
|
87 |
# Calculate the number of frames to process (10 seconds of video)
|
|
|
88 |
max_frames = int(fps * 10)
|
89 |
+
|
|
|
90 |
frame_number = 0
|
91 |
first_frame = None
|
92 |
|
93 |
while True:
|
94 |
ret, frame = cap.read()
|
95 |
"""
|
96 |
+
不再将视频裁剪至10s
|
97 |
if not ret or frame_number >= max_frames:
|
98 |
break
|
99 |
"""
|
100 |
if not ret:
|
101 |
break
|
102 |
+
|
103 |
|
104 |
# Format the frame filename as '00000.jpg'
|
105 |
frame_filename = os.path.join(extracted_frames_output_dir, f'{frame_number:05d}.jpg')
|
|
|
349 |
masks_images = []
|
350 |
|
351 |
# run propagation throughout the video and collect the results in a dict
|
352 |
+
# 添加全局变量保存mask数据
|
353 |
+
global video_segments
|
354 |
video_segments = {} # video_segments contains the per-frame segmentation results
|
355 |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
|
356 |
video_segments[out_frame_idx] = {
|
|
|
448 |
video_writer.release()
|
449 |
print(f"Mask Video saved at {mask_video_filename}")
|
450 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
|
452 |
+
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True), mask_video_filename
|
453 |
|
454 |
def update_ui(vis_frame_type):
|
455 |
if vis_frame_type == "check":
|
|
|
478 |
# print(f"RESET State: {stored_inference_state} ")
|
479 |
return first_frame_path, gr.State([]), gr.State([]), gr.update(value=None, visible=False), stored_inference_state, None, ["frame_0.jpg"], first_frame_path, "frame_0.jpg", gr.update(visible=False)
|
480 |
|
481 |
+
def export_masks():
|
482 |
+
# 导出mask数据
|
483 |
+
global video_segments
|
484 |
+
if not video_segments:
|
485 |
+
raise gr.Error("No mask data available. Please run propagation first!")
|
486 |
+
|
487 |
+
# 保存为pickle文件
|
488 |
+
filename = "mask_data.pkl"
|
489 |
+
with open(filename, 'wb') as f:
|
490 |
+
pickle.dump(video_segments, f)
|
491 |
+
|
492 |
+
return filename
|
493 |
+
|
494 |
css="""
|
495 |
div#component-18, div#component-25, div#component-35, div#component-41{
|
496 |
align-items: stretch!important;
|
497 |
}
|
498 |
"""
|
499 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
500 |
with gr.Blocks(css=css) as demo:
|
501 |
first_frame_path = gr.State()
|
502 |
tracking_points = gr.State([])
|
|
|
554 |
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg-dark.svg" alt="Duplicate this Space" />
|
555 |
</a> to skip queue and avoid OOM errors from heavy public load
|
556 |
""")
|
|
|
|
|
|
|
557 |
|
558 |
with gr.Column():
|
559 |
|
|
|
574 |
output_video = gr.Video(visible=False)
|
575 |
mask_final_output = gr.Video(label="Mask Video")
|
576 |
# output_result_mask = gr.Image()
|
577 |
+
# 在输出部分添加导出按钮
|
578 |
+
export_btn = gr.Button("Export Mask Data")
|
579 |
+
mask_download = gr.File(label="Download Mask Data", visible=False)
|
580 |
|
581 |
|
582 |
|
|
|
679 |
).then(
|
680 |
fn = propagate_to_all,
|
681 |
inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],
|
682 |
+
outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn, mask_final_output]
|
683 |
+
)
|
684 |
+
|
685 |
+
# 添加导出按钮的事件绑定
|
686 |
+
export_btn.click(
|
687 |
+
fn=export_masks,
|
688 |
+
outputs=[mask_download]
|
689 |
+
).then(
|
690 |
+
fn=lambda: gr.update(visible=True),
|
691 |
+
outputs=[mask_download]
|
692 |
)
|
693 |
|
694 |
demo.launch(show_api=False, show_error=True)
|