Spaces:
Running
on
Zero
Running
on
Zero
xiaoyuxi
commited on
Commit
·
151b615
1
Parent(s):
9193cab
add online
Browse files
app.py
CHANGED
@@ -43,7 +43,9 @@ except ImportError as e:
|
|
43 |
raise
|
44 |
|
45 |
# Constants
|
46 |
-
|
|
|
|
|
47 |
COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
|
48 |
MARKERS = [1, 5] # Cross for negative, Star for positive
|
49 |
MARKER_SIZE = 8
|
@@ -88,8 +90,10 @@ vggt4track_model = vggt4track_model.to("cuda")
|
|
88 |
|
89 |
# Global model initialization
|
90 |
print("🚀 Initializing local models...")
|
91 |
-
|
92 |
-
|
|
|
|
|
93 |
predictor = get_sam_predictor()
|
94 |
print("✅ Models loaded successfully!")
|
95 |
|
@@ -128,9 +132,13 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
|
128 |
if tracker_model_arg is None or tracker_viser_arg is None:
|
129 |
print("Initializing tracker models inside GPU function...")
|
130 |
out_dir = os.path.join(temp_dir, "results")
|
131 |
-
os.makedirs(out_dir, exist_ok=True)
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
134 |
|
135 |
# Setup paths
|
136 |
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
@@ -148,7 +156,10 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
|
148 |
if scale < 1:
|
149 |
new_h, new_w = int(h * scale), int(w * scale)
|
150 |
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
151 |
-
|
|
|
|
|
|
|
152 |
|
153 |
# Move to GPU
|
154 |
video_tensor = video_tensor.cuda()
|
@@ -526,7 +537,7 @@ def reset_points(original_img: str, sel_pix):
|
|
526 |
print(f"❌ Error in reset_points: {e}")
|
527 |
return None, []
|
528 |
|
529 |
-
def launch_viz(grid_size, vo_points, fps, original_image_state,
|
530 |
"""Launch visualization with user-specific temp directory"""
|
531 |
if original_image_state is None:
|
532 |
return None, None, None
|
@@ -538,7 +549,7 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
|
|
538 |
video_name = frame_data.get('video_name', 'video')
|
539 |
|
540 |
print(f"🚀 Starting tracking for video: {video_name}")
|
541 |
-
print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}")
|
542 |
|
543 |
# Check for mask files
|
544 |
mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
|
@@ -552,11 +563,11 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
|
|
552 |
mask_path = mask_files[0] if mask_files else None
|
553 |
|
554 |
# Run tracker
|
555 |
-
print("🎯 Running tracker...")
|
556 |
out_dir = os.path.join(temp_dir, "results")
|
557 |
os.makedirs(out_dir, exist_ok=True)
|
558 |
|
559 |
-
gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=
|
560 |
|
561 |
# Process results
|
562 |
npz_path = os.path.join(out_dir, "result.npz")
|
@@ -609,6 +620,7 @@ def clear_all_with_download():
|
|
609 |
gr.update(value=50),
|
610 |
gr.update(value=756),
|
611 |
gr.update(value=3),
|
|
|
612 |
None, # tracking_video_download
|
613 |
None) # HTML download component
|
614 |
|
@@ -641,6 +653,13 @@ def get_video_settings(video_name):
|
|
641 |
|
642 |
return video_settings.get(video_name, (50, 756, 3))
|
643 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
644 |
# Create the Gradio interface
|
645 |
print("🎨 Creating Gradio interface...")
|
646 |
|
@@ -846,7 +865,7 @@ with gr.Blocks(
|
|
846 |
""")
|
847 |
|
848 |
# Status indicator
|
849 |
-
gr.Markdown("**Status:** 🟢 Local Processing Mode")
|
850 |
|
851 |
# Main content area - video upload left, 3D visualization right
|
852 |
with gr.Row():
|
@@ -945,18 +964,29 @@ with gr.Blocks(
|
|
945 |
with gr.Row():
|
946 |
gr.Markdown("### ⚙️ Tracking Parameters")
|
947 |
with gr.Row():
|
948 |
-
|
949 |
-
|
950 |
-
|
951 |
-
|
952 |
-
|
953 |
-
|
954 |
-
|
955 |
-
|
956 |
-
|
957 |
-
|
958 |
-
|
959 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
960 |
|
961 |
# Advanced Point Selection with SAM - Collapsed by default
|
962 |
with gr.Row():
|
@@ -1082,6 +1112,12 @@ with gr.Blocks(
|
|
1082 |
outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
|
1083 |
)
|
1084 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1085 |
interactive_frame.select(
|
1086 |
fn=select_point,
|
1087 |
inputs=[original_image_state, selected_points, point_type],
|
@@ -1096,12 +1132,12 @@ with gr.Blocks(
|
|
1096 |
|
1097 |
clear_all_btn.click(
|
1098 |
fn=clear_all_with_download,
|
1099 |
-
outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, tracking_video_download, html_download]
|
1100 |
)
|
1101 |
|
1102 |
launch_btn.click(
|
1103 |
fn=launch_viz,
|
1104 |
-
inputs=[grid_size, vo_points, fps, original_image_state],
|
1105 |
outputs=[viz_html, tracking_video_download, html_download]
|
1106 |
)
|
1107 |
|
|
|
43 |
raise
|
44 |
|
45 |
# Constants
|
46 |
+
MAX_FRAMES_OFFLINE = 80
|
47 |
+
MAX_FRAMES_ONLINE = 300
|
48 |
+
|
49 |
COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
|
50 |
MARKERS = [1, 5] # Cross for negative, Star for positive
|
51 |
MARKER_SIZE = 8
|
|
|
90 |
|
91 |
# Global model initialization
|
92 |
print("🚀 Initializing local models...")
|
93 |
+
tracker_model_offline = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
|
94 |
+
tracker_model_offline.eval()
|
95 |
+
tracker_model_online = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online")
|
96 |
+
tracker_model_online.eval()
|
97 |
predictor = get_sam_predictor()
|
98 |
print("✅ Models loaded successfully!")
|
99 |
|
|
|
132 |
if tracker_model_arg is None or tracker_viser_arg is None:
|
133 |
print("Initializing tracker models inside GPU function...")
|
134 |
out_dir = os.path.join(temp_dir, "results")
|
135 |
+
os.makedirs(out_dir, exist_ok=True)
|
136 |
+
if mode == "offline":
|
137 |
+
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
|
138 |
+
tracker_model=tracker_model_offline.cuda())
|
139 |
+
else:
|
140 |
+
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
|
141 |
+
tracker_model=tracker_model_online.cuda())
|
142 |
|
143 |
# Setup paths
|
144 |
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
|
|
156 |
if scale < 1:
|
157 |
new_h, new_w = int(h * scale), int(w * scale)
|
158 |
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
159 |
+
if mode == "offline":
|
160 |
+
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_OFFLINE]
|
161 |
+
else:
|
162 |
+
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_ONLINE]
|
163 |
|
164 |
# Move to GPU
|
165 |
video_tensor = video_tensor.cuda()
|
|
|
537 |
print(f"❌ Error in reset_points: {e}")
|
538 |
return None, []
|
539 |
|
540 |
+
def launch_viz(grid_size, vo_points, fps, original_image_state, processing_mode):
|
541 |
"""Launch visualization with user-specific temp directory"""
|
542 |
if original_image_state is None:
|
543 |
return None, None, None
|
|
|
549 |
video_name = frame_data.get('video_name', 'video')
|
550 |
|
551 |
print(f"🚀 Starting tracking for video: {video_name}")
|
552 |
+
print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}, mode={processing_mode}")
|
553 |
|
554 |
# Check for mask files
|
555 |
mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
|
|
|
563 |
mask_path = mask_files[0] if mask_files else None
|
564 |
|
565 |
# Run tracker
|
566 |
+
print(f"🎯 Running tracker in {processing_mode} mode...")
|
567 |
out_dir = os.path.join(temp_dir, "results")
|
568 |
os.makedirs(out_dir, exist_ok=True)
|
569 |
|
570 |
+
gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=processing_mode)
|
571 |
|
572 |
# Process results
|
573 |
npz_path = os.path.join(out_dir, "result.npz")
|
|
|
620 |
gr.update(value=50),
|
621 |
gr.update(value=756),
|
622 |
gr.update(value=3),
|
623 |
+
gr.update(value="offline"), # processing_mode
|
624 |
None, # tracking_video_download
|
625 |
None) # HTML download component
|
626 |
|
|
|
653 |
|
654 |
return video_settings.get(video_name, (50, 756, 3))
|
655 |
|
656 |
+
def update_status_indicator(processing_mode):
|
657 |
+
"""Update status indicator based on processing mode"""
|
658 |
+
if processing_mode == "offline":
|
659 |
+
return "**Status:** 🟢 Local Processing Mode (Offline)"
|
660 |
+
else:
|
661 |
+
return "**Status:** 🔵 Cloud Processing Mode (Online)"
|
662 |
+
|
663 |
# Create the Gradio interface
|
664 |
print("🎨 Creating Gradio interface...")
|
665 |
|
|
|
865 |
""")
|
866 |
|
867 |
# Status indicator
|
868 |
+
status_indicator = gr.Markdown("**Status:** 🟢 Local Processing Mode (Offline)")
|
869 |
|
870 |
# Main content area - video upload left, 3D visualization right
|
871 |
with gr.Row():
|
|
|
964 |
with gr.Row():
|
965 |
gr.Markdown("### ⚙️ Tracking Parameters")
|
966 |
with gr.Row():
|
967 |
+
# 添加模式选择器
|
968 |
+
with gr.Column(scale=1):
|
969 |
+
processing_mode = gr.Radio(
|
970 |
+
choices=["offline", "online"],
|
971 |
+
value="offline",
|
972 |
+
label="Processing Mode",
|
973 |
+
info="Offline: default mode | Online: Sliding Window Mode"
|
974 |
+
)
|
975 |
+
with gr.Column(scale=1):
|
976 |
+
grid_size = gr.Slider(
|
977 |
+
minimum=10, maximum=100, step=10, value=50,
|
978 |
+
label="Grid Size", info="Tracking detail level"
|
979 |
+
)
|
980 |
+
with gr.Column(scale=1):
|
981 |
+
vo_points = gr.Slider(
|
982 |
+
minimum=100, maximum=2000, step=50, value=756,
|
983 |
+
label="VO Points", info="Motion accuracy"
|
984 |
+
)
|
985 |
+
with gr.Column(scale=1):
|
986 |
+
fps = gr.Slider(
|
987 |
+
minimum=1, maximum=20, step=1, value=3,
|
988 |
+
label="FPS", info="Processing speed"
|
989 |
+
)
|
990 |
|
991 |
# Advanced Point Selection with SAM - Collapsed by default
|
992 |
with gr.Row():
|
|
|
1112 |
outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
|
1113 |
)
|
1114 |
|
1115 |
+
processing_mode.change(
|
1116 |
+
fn=update_status_indicator,
|
1117 |
+
inputs=[processing_mode],
|
1118 |
+
outputs=[status_indicator]
|
1119 |
+
)
|
1120 |
+
|
1121 |
interactive_frame.select(
|
1122 |
fn=select_point,
|
1123 |
inputs=[original_image_state, selected_points, point_type],
|
|
|
1132 |
|
1133 |
clear_all_btn.click(
|
1134 |
fn=clear_all_with_download,
|
1135 |
+
outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, processing_mode, tracking_video_download, html_download]
|
1136 |
)
|
1137 |
|
1138 |
launch_btn.click(
|
1139 |
fn=launch_viz,
|
1140 |
+
inputs=[grid_size, vo_points, fps, original_image_state, processing_mode],
|
1141 |
outputs=[viz_html, tracking_video_download, html_download]
|
1142 |
)
|
1143 |
|