xiaoyuxi commited on
Commit
151b615
·
1 Parent(s): 9193cab

add online

Browse files
Files changed (1) hide show
  1. app.py +62 -26
app.py CHANGED
@@ -43,7 +43,9 @@ except ImportError as e:
43
  raise
44
 
45
  # Constants
46
- MAX_FRAMES = 80
 
 
47
  COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
48
  MARKERS = [1, 5] # Cross for negative, Star for positive
49
  MARKER_SIZE = 8
@@ -88,8 +90,10 @@ vggt4track_model = vggt4track_model.to("cuda")
88
 
89
  # Global model initialization
90
  print("🚀 Initializing local models...")
91
- tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
92
- tracker_model.eval()
 
 
93
  predictor = get_sam_predictor()
94
  print("✅ Models loaded successfully!")
95
 
@@ -128,9 +132,13 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
128
  if tracker_model_arg is None or tracker_viser_arg is None:
129
  print("Initializing tracker models inside GPU function...")
130
  out_dir = os.path.join(temp_dir, "results")
131
- os.makedirs(out_dir, exist_ok=True)
132
- tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
133
- tracker_model=tracker_model.cuda())
 
 
 
 
134
 
135
  # Setup paths
136
  video_path = os.path.join(temp_dir, f"{video_name}.mp4")
@@ -148,7 +156,10 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
148
  if scale < 1:
149
  new_h, new_w = int(h * scale), int(w * scale)
150
  video_tensor = T.Resize((new_h, new_w))(video_tensor)
151
- video_tensor = video_tensor[::fps].float()[:MAX_FRAMES]
 
 
 
152
 
153
  # Move to GPU
154
  video_tensor = video_tensor.cuda()
@@ -526,7 +537,7 @@ def reset_points(original_img: str, sel_pix):
526
  print(f"❌ Error in reset_points: {e}")
527
  return None, []
528
 
529
- def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
530
  """Launch visualization with user-specific temp directory"""
531
  if original_image_state is None:
532
  return None, None, None
@@ -538,7 +549,7 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
538
  video_name = frame_data.get('video_name', 'video')
539
 
540
  print(f"🚀 Starting tracking for video: {video_name}")
541
- print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}")
542
 
543
  # Check for mask files
544
  mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
@@ -552,11 +563,11 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
552
  mask_path = mask_files[0] if mask_files else None
553
 
554
  # Run tracker
555
- print("🎯 Running tracker...")
556
  out_dir = os.path.join(temp_dir, "results")
557
  os.makedirs(out_dir, exist_ok=True)
558
 
559
- gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=mode)
560
 
561
  # Process results
562
  npz_path = os.path.join(out_dir, "result.npz")
@@ -609,6 +620,7 @@ def clear_all_with_download():
609
  gr.update(value=50),
610
  gr.update(value=756),
611
  gr.update(value=3),
 
612
  None, # tracking_video_download
613
  None) # HTML download component
614
 
@@ -641,6 +653,13 @@ def get_video_settings(video_name):
641
 
642
  return video_settings.get(video_name, (50, 756, 3))
643
 
 
 
 
 
 
 
 
644
  # Create the Gradio interface
645
  print("🎨 Creating Gradio interface...")
646
 
@@ -846,7 +865,7 @@ with gr.Blocks(
846
  """)
847
 
848
  # Status indicator
849
- gr.Markdown("**Status:** 🟢 Local Processing Mode")
850
 
851
  # Main content area - video upload left, 3D visualization right
852
  with gr.Row():
@@ -945,18 +964,29 @@ with gr.Blocks(
945
  with gr.Row():
946
  gr.Markdown("### ⚙️ Tracking Parameters")
947
  with gr.Row():
948
- grid_size = gr.Slider(
949
- minimum=10, maximum=100, step=10, value=50,
950
- label="Grid Size", info="Tracking detail level"
951
- )
952
- vo_points = gr.Slider(
953
- minimum=100, maximum=2000, step=50, value=756,
954
- label="VO Points", info="Motion accuracy"
955
- )
956
- fps = gr.Slider(
957
- minimum=1, maximum=20, step=1, value=3,
958
- label="FPS", info="Processing speed"
959
- )
 
 
 
 
 
 
 
 
 
 
 
960
 
961
  # Advanced Point Selection with SAM - Collapsed by default
962
  with gr.Row():
@@ -1082,6 +1112,12 @@ with gr.Blocks(
1082
  outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
1083
  )
1084
 
 
 
 
 
 
 
1085
  interactive_frame.select(
1086
  fn=select_point,
1087
  inputs=[original_image_state, selected_points, point_type],
@@ -1096,12 +1132,12 @@ with gr.Blocks(
1096
 
1097
  clear_all_btn.click(
1098
  fn=clear_all_with_download,
1099
- outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, tracking_video_download, html_download]
1100
  )
1101
 
1102
  launch_btn.click(
1103
  fn=launch_viz,
1104
- inputs=[grid_size, vo_points, fps, original_image_state],
1105
  outputs=[viz_html, tracking_video_download, html_download]
1106
  )
1107
 
 
43
  raise
44
 
45
  # Constants
46
+ MAX_FRAMES_OFFLINE = 80
47
+ MAX_FRAMES_ONLINE = 300
48
+
49
  COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
50
  MARKERS = [1, 5] # Cross for negative, Star for positive
51
  MARKER_SIZE = 8
 
90
 
91
  # Global model initialization
92
  print("🚀 Initializing local models...")
93
+ tracker_model_offline = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
94
+ tracker_model_offline.eval()
95
+ tracker_model_online = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online")
96
+ tracker_model_online.eval()
97
  predictor = get_sam_predictor()
98
  print("✅ Models loaded successfully!")
99
 
 
132
  if tracker_model_arg is None or tracker_viser_arg is None:
133
  print("Initializing tracker models inside GPU function...")
134
  out_dir = os.path.join(temp_dir, "results")
135
+ os.makedirs(out_dir, exist_ok=True)
136
+ if mode == "offline":
137
+ tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
138
+ tracker_model=tracker_model_offline.cuda())
139
+ else:
140
+ tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
141
+ tracker_model=tracker_model_online.cuda())
142
 
143
  # Setup paths
144
  video_path = os.path.join(temp_dir, f"{video_name}.mp4")
 
156
  if scale < 1:
157
  new_h, new_w = int(h * scale), int(w * scale)
158
  video_tensor = T.Resize((new_h, new_w))(video_tensor)
159
+ if mode == "offline":
160
+ video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_OFFLINE]
161
+ else:
162
+ video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_ONLINE]
163
 
164
  # Move to GPU
165
  video_tensor = video_tensor.cuda()
 
537
  print(f"❌ Error in reset_points: {e}")
538
  return None, []
539
 
540
+ def launch_viz(grid_size, vo_points, fps, original_image_state, processing_mode):
541
  """Launch visualization with user-specific temp directory"""
542
  if original_image_state is None:
543
  return None, None, None
 
549
  video_name = frame_data.get('video_name', 'video')
550
 
551
  print(f"🚀 Starting tracking for video: {video_name}")
552
+ print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}, mode={processing_mode}")
553
 
554
  # Check for mask files
555
  mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
 
563
  mask_path = mask_files[0] if mask_files else None
564
 
565
  # Run tracker
566
+ print(f"🎯 Running tracker in {processing_mode} mode...")
567
  out_dir = os.path.join(temp_dir, "results")
568
  os.makedirs(out_dir, exist_ok=True)
569
 
570
+ gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=processing_mode)
571
 
572
  # Process results
573
  npz_path = os.path.join(out_dir, "result.npz")
 
620
  gr.update(value=50),
621
  gr.update(value=756),
622
  gr.update(value=3),
623
+ gr.update(value="offline"), # processing_mode
624
  None, # tracking_video_download
625
  None) # HTML download component
626
 
 
653
 
654
  return video_settings.get(video_name, (50, 756, 3))
655
 
656
+ def update_status_indicator(processing_mode):
657
+ """Update status indicator based on processing mode"""
658
+ if processing_mode == "offline":
659
+ return "**Status:** 🟢 Local Processing Mode (Offline)"
660
+ else:
661
+ return "**Status:** 🔵 Cloud Processing Mode (Online)"
662
+
663
  # Create the Gradio interface
664
  print("🎨 Creating Gradio interface...")
665
 
 
865
  """)
866
 
867
  # Status indicator
868
+ status_indicator = gr.Markdown("**Status:** 🟢 Local Processing Mode (Offline)")
869
 
870
  # Main content area - video upload left, 3D visualization right
871
  with gr.Row():
 
964
  with gr.Row():
965
  gr.Markdown("### ⚙️ Tracking Parameters")
966
  with gr.Row():
967
+ # 添加模式选择器
968
+ with gr.Column(scale=1):
969
+ processing_mode = gr.Radio(
970
+ choices=["offline", "online"],
971
+ value="offline",
972
+ label="Processing Mode",
973
+ info="Offline: default mode | Online: Sliding Window Mode"
974
+ )
975
+ with gr.Column(scale=1):
976
+ grid_size = gr.Slider(
977
+ minimum=10, maximum=100, step=10, value=50,
978
+ label="Grid Size", info="Tracking detail level"
979
+ )
980
+ with gr.Column(scale=1):
981
+ vo_points = gr.Slider(
982
+ minimum=100, maximum=2000, step=50, value=756,
983
+ label="VO Points", info="Motion accuracy"
984
+ )
985
+ with gr.Column(scale=1):
986
+ fps = gr.Slider(
987
+ minimum=1, maximum=20, step=1, value=3,
988
+ label="FPS", info="Processing speed"
989
+ )
990
 
991
  # Advanced Point Selection with SAM - Collapsed by default
992
  with gr.Row():
 
1112
  outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
1113
  )
1114
 
1115
+ processing_mode.change(
1116
+ fn=update_status_indicator,
1117
+ inputs=[processing_mode],
1118
+ outputs=[status_indicator]
1119
+ )
1120
+
1121
  interactive_frame.select(
1122
  fn=select_point,
1123
  inputs=[original_image_state, selected_points, point_type],
 
1132
 
1133
  clear_all_btn.click(
1134
  fn=clear_all_with_download,
1135
+ outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, processing_mode, tracking_video_download, html_download]
1136
  )
1137
 
1138
  launch_btn.click(
1139
  fn=launch_viz,
1140
+ inputs=[grid_size, vo_points, fps, original_image_state, processing_mode],
1141
  outputs=[viz_html, tracking_video_download, html_download]
1142
  )
1143