Junyi42 commited on
Commit
c4f90e9
·
1 Parent(s): 49842e0
Files changed (5) hide show
  1. Dockerfile +22 -0
  2. app.py +108 -0
  3. requirements.txt +15 -0
  4. vis_st4rtrack.py +780 -0
  5. viser_proxy_manager.py +223 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies for OpenCV and build tools
6
+ RUN apt-get update && apt-get install -y \
7
+ libgl1-mesa-glx \
8
+ libglib2.0-0 \
9
+ build-essential \
10
+ git \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ COPY requirements.txt .
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ COPY . .
17
+
18
+ # Make port 7860 available to the world outside the container
19
+ EXPOSE 7860
20
+
21
+ # Command to run when the container starts
22
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import threading
3
+ import psutil
4
+ import fastapi
5
+ import gradio as gr
6
+ import uvicorn
7
+
8
+ from viser_proxy_manager import ViserProxyManager
9
+ from vis_st4rtrack import visualize_st4rtrack, load_trajectory_data, log_memory_usage
10
+
11
+ # Global cache for loaded data
12
+ global_data_cache = None
13
+
14
+ def check_ram_usage(threshold_percent=90):
15
+ """Check if RAM usage is above the threshold.
16
+
17
+ Args:
18
+ threshold_percent: Maximum RAM usage percentage allowed
19
+
20
+ Returns:
21
+ bool: True if RAM usage is below threshold, False otherwise
22
+ """
23
+ ram_percent = psutil.virtual_memory().percent
24
+ print(f"Current RAM usage: {ram_percent}%")
25
+ return ram_percent < threshold_percent
26
+
27
+
28
+ def main() -> None:
29
+ # Load data once at startup using the function from vis_st4rtrack.py
30
+ global global_data_cache
31
+ global_data_cache = load_trajectory_data(use_float16=True, max_frames=32)
32
+
33
+ app = fastapi.FastAPI()
34
+ viser_manager = ViserProxyManager(app)
35
+
36
+ # Create a Gradio interface with title, iframe, and buttons
37
+ with gr.Blocks(title="Viser Viewer") as demo:
38
+ # Add the iframe with a border
39
+ iframe_html = gr.HTML("")
40
+ status_text = gr.Markdown("") # Add status text component
41
+
42
+ @demo.load(outputs=[iframe_html, status_text])
43
+ def start_server(request: gr.Request):
44
+ assert request.session_hash is not None
45
+
46
+ # Check RAM usage before starting visualization
47
+ if not check_ram_usage(threshold_percent=100):
48
+ return """
49
+ <div style="text-align: center; padding: 20px; background-color: #ffeeee; border-radius: 5px;">
50
+ <h2>⚠️ Server is currently under high load</h2>
51
+ <p>Please try again later when resources are available.</p>
52
+ </div>
53
+ """, "**System Status:** High memory usage detected. Visualization not loaded to prevent server overload."
54
+
55
+ viser_manager.start_server(request.session_hash)
56
+
57
+ # Use the request's base URL if available
58
+ host = request.headers["host"]
59
+
60
+ # Determine protocol (use HTTPS for HuggingFace Spaces or other secure environments)
61
+ protocol = (
62
+ "https"
63
+ if request.headers.get("x-forwarded-proto") == "https"
64
+ else "http"
65
+ )
66
+
67
+ # Add visualization in a separate thread
68
+ server = viser_manager.get_server(request.session_hash)
69
+ threading.Thread(
70
+ target=visualize_st4rtrack,
71
+ kwargs={
72
+ "server": server,
73
+ "use_float16": True,
74
+ "preloaded_data": global_data_cache, # Pass the preloaded data
75
+ "color_code": "jet",
76
+ "blue_rgb": (0.0, 0.149, 0.463), # #002676
77
+ "red_rgb": (0.769, 0.510, 0.055), # #FDB515
78
+ "blend_ratio": 0.7,
79
+ "max_frames": 100,
80
+ "traj_path": "480p_bear",
81
+ "mask_folder": "bear",
82
+ },
83
+ daemon=True
84
+ ).start()
85
+
86
+ return f"""
87
+ <iframe
88
+ src="{protocol}://{host}/viser/{request.session_hash}/"
89
+ width="100%"
90
+ height="500px"
91
+ frameborder="0"
92
+ style="display: block;"
93
+ allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
94
+ loading="lazy"
95
+ ></iframe>
96
+ """, "**System Status:** Visualization loaded successfully."
97
+
98
+ @demo.unload
99
+ def stop(request: gr.Request):
100
+ assert request.session_hash is not None
101
+ viser_manager.stop_server(request.session_hash)
102
+
103
+ gr.mount_gradio_app(app, demo, "/")
104
+ uvicorn.run(app, host="0.0.0.0", port=7860)
105
+
106
+
107
+ if __name__ == "__main__":
108
+ main()
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # git+https://github.com/nerfstudio-project/viser.git
2
+ viser>=0.2.23
3
+ gradio==5.23.1
4
+ fastapi==0.115.11
5
+ uvicorn==0.34.0
6
+ httpx==0.27.2
7
+ websockets==15.0.1
8
+ tyro==0.4.1
9
+ numpy>=1.20.0
10
+ tqdm>=4.62.0
11
+ opencv-python>=4.5.0
12
+ imageio>=2.25.0
13
+ matplotlib>=3.5.0
14
+ pyliblzfse>=0.1.0
15
+ psutil>=5.9.0
vis_st4rtrack.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Record3D visualizer
2
+
3
+ Parse and stream record3d captures. To get the demo data, see `./assets/download_record3d_dance.sh`.
4
+ """
5
+
6
+ import time
7
+ from pathlib import Path
8
+
9
+ import numpy as onp
10
+ import tyro
11
+ import cv2
12
+ from tqdm.auto import tqdm
13
+
14
+ import viser
15
+ import viser.extras
16
+ import viser.transforms as tf
17
+
18
+ from glob import glob
19
+ import numpy as np
20
+ import imageio.v3 as iio
21
+ import matplotlib.pyplot as plt
22
+ import psutil
23
+
24
+ def log_memory_usage(message=""):
25
+ """Log current memory usage with an optional message."""
26
+ process = psutil.Process()
27
+ memory_info = process.memory_info()
28
+ memory_mb = memory_info.rss / (1024 * 1024) # Convert to MB
29
+ print(f"Memory usage {message}: {memory_mb:.2f} MB")
30
+
31
+ def load_trajectory_data(traj_path="results", use_float16=True, max_frames=None, mask_folder='./train'):
32
+ """Load trajectory data from files.
33
+
34
+ Args:
35
+ traj_path: Path to the directory containing trajectory data
36
+ use_float16: Whether to convert data to float16 to save memory
37
+ max_frames: Maximum number of frames to load (None for all)
38
+ mask_folder: Path to the directory containing mask images
39
+
40
+ Returns:
41
+ A dictionary containing loaded data
42
+ """
43
+ log_memory_usage("before loading data")
44
+
45
+ data_cache = {
46
+ 'traj_3d_head1': None,
47
+ 'traj_3d_head2': None,
48
+ 'conf_mask_head1': None,
49
+ 'conf_mask_head2': None,
50
+ 'masks': None,
51
+ 'raw_video': None,
52
+ 'loaded': False
53
+ }
54
+
55
+ # Load masks
56
+ masks_paths = sorted(glob(mask_folder + '/*.jpg'))
57
+ masks = None
58
+
59
+ if masks_paths:
60
+ masks = [iio.imread(p) for p in masks_paths]
61
+ masks = np.stack(masks, axis=0)
62
+ # Convert masks to binary (0 or 1)
63
+ masks = (masks < 1).astype(np.float32)
64
+ masks = masks.sum(axis=-1) > 2 # Combine all channels, True where any channel was 1
65
+ print(f"Original masks shape: {masks.shape}")
66
+ else:
67
+ print("No masks found. Will create default masks when needed.")
68
+
69
+ data_cache['masks'] = masks
70
+
71
+ if Path(traj_path).is_dir():
72
+ # Find all trajectory files
73
+ traj_3d_paths_head1 = sorted(glob(traj_path + '/pts3d1_p*.npy'),
74
+ key=lambda x: int(x.split('_p')[-1].split('.')[0]))
75
+ conf_paths_head1 = sorted(glob(traj_path + '/conf1_p*.npy'),
76
+ key=lambda x: int(x.split('_p')[-1].split('.')[0]))
77
+
78
+ traj_3d_paths_head2 = sorted(glob(traj_path + '/pts3d2_p*.npy'),
79
+ key=lambda x: int(x.split('_p')[-1].split('.')[0]))
80
+ conf_paths_head2 = sorted(glob(traj_path + '/conf2_p*.npy'),
81
+ key=lambda x: int(x.split('_p')[-1].split('.')[0]))
82
+
83
+ # Limit number of frames if specified
84
+ if max_frames is not None:
85
+ traj_3d_paths_head1 = traj_3d_paths_head1[:max_frames]
86
+ conf_paths_head1 = conf_paths_head1[:max_frames] if conf_paths_head1 else []
87
+ traj_3d_paths_head2 = traj_3d_paths_head2[:max_frames]
88
+ conf_paths_head2 = conf_paths_head2[:max_frames] if conf_paths_head2 else []
89
+
90
+ # Process head1
91
+ if traj_3d_paths_head1:
92
+ if use_float16:
93
+ traj_3d_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head1], axis=0)
94
+ else:
95
+ traj_3d_head1 = onp.stack([onp.load(p) for p in traj_3d_paths_head1], axis=0)
96
+
97
+ log_memory_usage("after loading head1 data")
98
+
99
+ h, w, _ = traj_3d_head1.shape[1:]
100
+ num_frames = traj_3d_head1.shape[0]
101
+
102
+ # If masks is None, create default masks (all ones)
103
+ if masks is None:
104
+ masks = np.ones((num_frames, h, w), dtype=bool)
105
+ print(f"Created default masks with shape: {masks.shape}")
106
+ data_cache['masks'] = masks
107
+ else:
108
+ # Resize masks to match trajectory dimensions using nearest neighbor interpolation
109
+ masks_resized = np.zeros((masks.shape[0], h, w), dtype=bool)
110
+ for i in range(masks.shape[0]):
111
+ masks_resized[i] = cv2.resize(
112
+ masks[i].astype(np.uint8),
113
+ (w, h),
114
+ interpolation=cv2.INTER_NEAREST
115
+ ).astype(bool)
116
+
117
+ print(f"Resized masks shape: {masks_resized.shape}")
118
+ data_cache['masks'] = masks_resized
119
+
120
+ # Reshape trajectory data
121
+ traj_3d_head1 = traj_3d_head1.reshape(traj_3d_head1.shape[0], -1, 6)
122
+ data_cache['traj_3d_head1'] = traj_3d_head1
123
+
124
+ if conf_paths_head1:
125
+ conf_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head1], axis=0)
126
+ conf_head1 = conf_head1.reshape(conf_head1.shape[0], -1)
127
+ conf_head1 = conf_head1.mean(axis=0)
128
+ # repeat the conf_head1 to match the number of frames in the dimension 0
129
+ conf_head1 = np.tile(conf_head1, (num_frames, 1))
130
+ # Convert to float32 before calculating percentile to avoid overflow
131
+ conf_thre = np.percentile(conf_head1.astype(np.float32), 1) # Default percentile
132
+ conf_mask_head1 = conf_head1 > conf_thre
133
+ data_cache['conf_mask_head1'] = conf_mask_head1
134
+
135
+ # Process head2
136
+ if traj_3d_paths_head2:
137
+ if use_float16:
138
+ traj_3d_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head2], axis=0)
139
+ else:
140
+ traj_3d_head2 = onp.stack([onp.load(p) for p in traj_3d_paths_head2], axis=0)
141
+
142
+ log_memory_usage("after loading head2 data")
143
+
144
+ # Store raw video data
145
+ raw_video = traj_3d_head2[:, :, :, 3:6] # [num_frames, h, w, 3]
146
+ data_cache['raw_video'] = raw_video
147
+
148
+ traj_3d_head2 = traj_3d_head2.reshape(traj_3d_head2.shape[0], -1, 6)
149
+ data_cache['traj_3d_head2'] = traj_3d_head2
150
+
151
+ if conf_paths_head2:
152
+ conf_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head2], axis=0)
153
+ conf_head2 = conf_head2.reshape(conf_head2.shape[0], -1)
154
+ # set conf thre to be 1 percentile of the conf_head2, for each frame
155
+ conf_thre = np.percentile(conf_head2.astype(np.float32), 1, axis=1)
156
+ conf_mask_head2 = conf_head2 > conf_thre[:, None]
157
+ data_cache['conf_mask_head2'] = conf_mask_head2
158
+
159
+ data_cache['loaded'] = True
160
+ log_memory_usage("after loading all data")
161
+ return data_cache
162
+
163
+ def visualize_st4rtrack(
164
+ traj_path: str = "results",
165
+ up_dir: str = "-z", # should be +z or -z
166
+ max_frames: int = 32,
167
+ share: bool = False,
168
+ point_size: float = 0.0015,
169
+ downsample_factor: int = 3,
170
+ num_traj_points: int = 100,
171
+ conf_thre_percentile: float = 1,
172
+ traj_end_frame: int = 100,
173
+ traj_start_frame: int = 0,
174
+ traj_line_width: float = 3.,
175
+ fixed_length_traj: int = 15,
176
+ server: viser.ViserServer = None,
177
+ use_float16: bool = True,
178
+ preloaded_data: dict = None, # Add this parameter to accept preloaded data
179
+ color_code: str = "jet",
180
+ # Updated hex colors: #002676 for blue and #FDB515 for red/gold
181
+ blue_rgb: tuple[float, float, float] = (0.0, 0.149, 0.463), # #002676
182
+ red_rgb: tuple[float, float, float] = (0.769, 0.510, 0.055), # #FDB515
183
+ blend_ratio: float = 0.7,
184
+ mask_folder: str = None,
185
+ mid_anchor: bool = False,
186
+ video_width: int = 320, # Video display width
187
+ video_height: int = 180, # Video display height
188
+ ) -> None:
189
+ log_memory_usage("at start of visualization")
190
+
191
+ if server is None:
192
+ server = viser.ViserServer()
193
+ if share:
194
+ server.request_share_url()
195
+
196
+ @server.on_client_connect
197
+ def _(client: viser.ClientHandle) -> None:
198
+ client.camera.position = (1e-3, 0.6, -0.08)
199
+ client.camera.look_at = (0, 0, 0)
200
+
201
+ # Configure the GUI panel size and layout
202
+ server.gui.configure_theme(
203
+ control_layout="collapsible",
204
+ control_width="small",
205
+ dark_mode=False,
206
+ show_logo=False,
207
+ show_share_button=True
208
+ )
209
+
210
+ # Add video preview to the GUI panel - placed at the top
211
+ video_preview = server.gui.add_image(
212
+ np.zeros((video_height, video_width, 3), dtype=np.uint8), # Initial blank image
213
+ format="jpeg"
214
+ )
215
+
216
+ # Use preloaded data if available
217
+ if preloaded_data and preloaded_data.get('loaded', False):
218
+ traj_3d_head1 = preloaded_data.get('traj_3d_head1')
219
+ traj_3d_head2 = preloaded_data.get('traj_3d_head2')
220
+ conf_mask_head1 = preloaded_data.get('conf_mask_head1')
221
+ conf_mask_head2 = preloaded_data.get('conf_mask_head2')
222
+ masks = preloaded_data.get('masks')
223
+ raw_video = preloaded_data.get('raw_video')
224
+ print("Using preloaded data!")
225
+ else:
226
+ # Load data using the shared function
227
+ print("No preloaded data available, loading from files...")
228
+ data = load_trajectory_data(traj_path, use_float16, max_frames, mask_folder)
229
+ traj_3d_head1 = data.get('traj_3d_head1')
230
+ traj_3d_head2 = data.get('traj_3d_head2')
231
+ conf_mask_head1 = data.get('conf_mask_head1')
232
+ conf_mask_head2 = data.get('conf_mask_head2')
233
+ masks = data.get('masks')
234
+ raw_video = data.get('raw_video')
235
+
236
+ def process_video_frame(frame_idx):
237
+ if raw_video is None:
238
+ return np.zeros((video_height, video_width, 3), dtype=np.uint8)
239
+
240
+ # Get the original frame
241
+ raw_frame = raw_video[frame_idx]
242
+
243
+ # Adjust value range to 0-255
244
+ if raw_frame.max() <= 1.0:
245
+ frame = (raw_frame * 255).astype(np.uint8)
246
+ else:
247
+ frame = raw_frame.astype(np.uint8)
248
+
249
+ # Resize to fit the preview window
250
+ h, w = frame.shape[:2]
251
+ # Calculate size while maintaining aspect ratio
252
+ if h/w > video_height/video_width: # Height limited
253
+ new_h = video_height
254
+ new_w = int(w * (new_h / h))
255
+ else: # Width limited
256
+ new_w = video_width
257
+ new_h = int(h * (new_w / w))
258
+
259
+ # Resize
260
+ resized_frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
261
+
262
+ # Create a black background
263
+ display_frame = np.zeros((video_height, video_width, 3), dtype=np.uint8)
264
+
265
+ # Place the resized frame in the center
266
+ y_offset = (video_height - new_h) // 2
267
+ x_offset = (video_width - new_w) // 2
268
+ display_frame[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized_frame
269
+
270
+ return display_frame
271
+
272
+ server.scene.set_up_direction(up_dir)
273
+ print("Setting up visualization!")
274
+
275
+ # Add visualization controls
276
+ with server.gui.add_folder("Visualization"):
277
+ gui_show_head1 = server.gui.add_checkbox("Tracking Points", True)
278
+ gui_show_head2 = server.gui.add_checkbox("Recon Points", True)
279
+ gui_show_trajectories = server.gui.add_checkbox("Trajectories", True)
280
+ gui_use_color_tint = server.gui.add_checkbox("Use Color Tint", True)
281
+
282
+ # Process and center point clouds
283
+ center_point = None
284
+ if traj_3d_head1 is not None:
285
+ xyz_head1 = traj_3d_head1[:, :, :3]
286
+ rgb_head1 = traj_3d_head1[:, :, 3:6]
287
+ if center_point is None:
288
+ center_point = onp.mean(xyz_head1, axis=(0, 1), keepdims=True)
289
+ xyz_head1 -= center_point
290
+ if rgb_head1.sum(axis=(-1)).max() > 125:
291
+ rgb_head1 /= 255.0
292
+
293
+ if traj_3d_head2 is not None:
294
+ xyz_head2 = traj_3d_head2[:, :, :3]
295
+ rgb_head2 = traj_3d_head2[:, :, 3:6]
296
+ if center_point is None:
297
+ center_point = onp.mean(xyz_head2, axis=(0, 1), keepdims=True)
298
+ xyz_head2 -= center_point
299
+ if rgb_head2.sum(axis=(-1)).max() > 125:
300
+ rgb_head2 /= 255.0
301
+
302
+ # Determine number of frames
303
+ F = max(
304
+ traj_3d_head1.shape[0] if traj_3d_head1 is not None else 0,
305
+ traj_3d_head2.shape[0] if traj_3d_head2 is not None else 0
306
+ )
307
+ num_frames = min(max_frames, F)
308
+ traj_end_frame = min(traj_end_frame, num_frames)
309
+ print(f"Number of frames: {num_frames}")
310
+ xyz_head1 = xyz_head1[:num_frames]
311
+ xyz_head2 = xyz_head2[:num_frames]
312
+ rgb_head1 = rgb_head1[:num_frames]
313
+ rgb_head2 = rgb_head2[:num_frames]
314
+
315
+ # Add playback UI.
316
+ with server.gui.add_folder("Playback"):
317
+ gui_timestep = server.gui.add_slider(
318
+ "Timestep",
319
+ min=0,
320
+ max=num_frames - 1,
321
+ step=1,
322
+ initial_value=0,
323
+ disabled=True,
324
+ )
325
+ gui_next_frame = server.gui.add_button("Next Frame", disabled=True)
326
+ gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True)
327
+ gui_playing = server.gui.add_checkbox("Playing", True)
328
+ gui_framerate = server.gui.add_slider(
329
+ "FPS", min=1, max=60, step=0.1, initial_value=20
330
+ )
331
+ gui_framerate_options = server.gui.add_button_group(
332
+ "FPS options", ("10", "20", "30")
333
+ )
334
+ gui_show_all_frames = server.gui.add_checkbox("Show all frames", False)
335
+ gui_stride = server.gui.add_slider(
336
+ "Stride",
337
+ min=1,
338
+ max=num_frames,
339
+ step=1,
340
+ initial_value=5,
341
+ disabled=True, # Initially disabled
342
+ )
343
+
344
+ # Frame step buttons.
345
+ @gui_next_frame.on_click
346
+ def _(_) -> None:
347
+ gui_timestep.value = (gui_timestep.value + 1) % num_frames
348
+
349
+ @gui_prev_frame.on_click
350
+ def _(_) -> None:
351
+ gui_timestep.value = (gui_timestep.value - 1) % num_frames
352
+
353
+ # Disable frame controls when we're playing.
354
+ @gui_playing.on_update
355
+ def _(_) -> None:
356
+ gui_timestep.disabled = gui_playing.value or gui_show_all_frames.value
357
+ gui_next_frame.disabled = gui_playing.value or gui_show_all_frames.value
358
+ gui_prev_frame.disabled = gui_playing.value or gui_show_all_frames.value
359
+
360
+ # Set the framerate when we click one of the options.
361
+ @gui_framerate_options.on_click
362
+ def _(_) -> None:
363
+ gui_framerate.value = int(gui_framerate_options.value)
364
+
365
+ prev_timestep = gui_timestep.value
366
+
367
+ # Toggle frame visibility when the timestep slider changes.
368
+ @gui_timestep.on_update
369
+ def _(_) -> None:
370
+ nonlocal prev_timestep
371
+ current_timestep = gui_timestep.value
372
+ if not gui_show_all_frames.value:
373
+ with server.atomic():
374
+ if gui_show_head1.value:
375
+ frame_nodes_head1[current_timestep].visible = True
376
+ frame_nodes_head1[prev_timestep].visible = False
377
+ if gui_show_head2.value:
378
+ frame_nodes_head2[current_timestep].visible = True
379
+ frame_nodes_head2[prev_timestep].visible = False
380
+ prev_timestep = current_timestep
381
+ server.flush() # Optional!
382
+
383
+ # Show or hide all frames based on the checkbox.
384
+ @gui_show_all_frames.on_update
385
+ def _(_) -> None:
386
+ gui_stride.disabled = not gui_show_all_frames.value # Enable/disable stride slider
387
+ if gui_show_all_frames.value:
388
+ # Show frames with stride
389
+ stride = gui_stride.value
390
+ with server.atomic():
391
+ for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
392
+ node1.visible = gui_show_head1.value and (i % stride == 0)
393
+ node2.visible = gui_show_head2.value and (i % stride == 0)
394
+ # Disable playback controls
395
+ gui_playing.disabled = True
396
+ gui_timestep.disabled = True
397
+ gui_next_frame.disabled = True
398
+ gui_prev_frame.disabled = True
399
+ else:
400
+ # Show only the current frame
401
+ current_timestep = gui_timestep.value
402
+ with server.atomic():
403
+ for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
404
+ node1.visible = gui_show_head1.value and (i == current_timestep)
405
+ node2.visible = gui_show_head2.value and (i == current_timestep)
406
+ # Re-enable playback controls
407
+ gui_playing.disabled = False
408
+ gui_timestep.disabled = gui_playing.value
409
+ gui_next_frame.disabled = gui_playing.value
410
+ gui_prev_frame.disabled = gui_playing.value
411
+
412
+ # Update frame visibility when the stride changes.
413
+ @gui_stride.on_update
414
+ def _(_) -> None:
415
+ if gui_show_all_frames.value:
416
+ # Update frame visibility based on new stride
417
+ stride = gui_stride.value
418
+ with server.atomic():
419
+ for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
420
+ node1.visible = gui_show_head1.value and (i % stride == 0)
421
+ node2.visible = gui_show_head2.value and (i % stride == 0)
422
+
423
+ # Load in frames.
424
+ server.scene.add_frame(
425
+ "/frames",
426
+ wxyz=tf.SO3.exp(onp.array([onp.pi / 2.0, 0.0, 0.0])).wxyz,
427
+ position=(0, 0, 0),
428
+ show_axes=False,
429
+ )
430
+ frame_nodes_head1: list[viser.FrameHandle] = []
431
+ frame_nodes_head2: list[viser.FrameHandle] = []
432
+
433
+ # Extract RGB components for tinting
434
+ blue_r, blue_g, blue_b = blue_rgb
435
+ red_r, red_g, red_b = red_rgb
436
+
437
+ # Create frames for each timestep
438
+ frame_nodes_head1 = []
439
+ frame_nodes_head2 = []
440
+ for i in tqdm(range(num_frames)):
441
+ # Process head1
442
+ if traj_3d_head1 is not None:
443
+ frame_nodes_head1.append(server.scene.add_frame(f"/frames/t{i}/head1", show_axes=False))
444
+ position = xyz_head1[i]
445
+ color = rgb_head1[i]
446
+ if conf_mask_head1 is not None:
447
+ position = position[conf_mask_head1[i]]
448
+ color = color[conf_mask_head1[i]]
449
+
450
+ # Add point cloud for head1 with optional blue tint
451
+ color_head1 = color.copy()
452
+ if gui_use_color_tint.value:
453
+ color_head1 *= blend_ratio
454
+ color_head1[:, 0] = onp.clip(color_head1[:, 0] + blue_r * (1 - blend_ratio), 0, 1) # R
455
+ color_head1[:, 1] = onp.clip(color_head1[:, 1] + blue_g * (1 - blend_ratio), 0, 1) # G
456
+ color_head1[:, 2] = onp.clip(color_head1[:, 2] + blue_b * (1 - blend_ratio), 0, 1) # B
457
+
458
+ server.scene.add_point_cloud(
459
+ name=f"/frames/t{i}/head1/point_cloud",
460
+ points=position[::downsample_factor],
461
+ colors=color_head1[::downsample_factor],
462
+ point_size=point_size,
463
+ point_shape="rounded",
464
+ )
465
+
466
+ # Process head2
467
+ if traj_3d_head2 is not None:
468
+ frame_nodes_head2.append(server.scene.add_frame(f"/frames/t{i}/head2", show_axes=False))
469
+ position = xyz_head2[i]
470
+ color = rgb_head2[i]
471
+ if conf_mask_head2 is not None:
472
+ position = position[conf_mask_head2[i]]
473
+ color = color[conf_mask_head2[i]]
474
+
475
+ # Add point cloud for head2 with optional red tint
476
+ color_head2 = color.copy()
477
+ if gui_use_color_tint.value:
478
+ color_head2 *= blend_ratio
479
+ color_head2[:, 0] = onp.clip(color_head2[:, 0] + red_r * (1 - blend_ratio), 0, 1) # R
480
+ color_head2[:, 1] = onp.clip(color_head2[:, 1] + red_g * (1 - blend_ratio), 0, 1) # G
481
+ color_head2[:, 2] = onp.clip(color_head2[:, 2] + red_b * (1 - blend_ratio), 0, 1) # B
482
+
483
+ server.scene.add_point_cloud(
484
+ name=f"/frames/t{i}/head2/point_cloud",
485
+ points=position[::downsample_factor],
486
+ colors=color_head2[::downsample_factor],
487
+ point_size=point_size,
488
+ point_shape="rounded",
489
+ )
490
+
491
+ # Update visibility based on checkboxes
492
+ @gui_show_head1.on_update
493
+ def _(_) -> None:
494
+ with server.atomic():
495
+ for frame_node in frame_nodes_head1:
496
+ frame_node.visible = gui_show_head1.value and (
497
+ gui_show_all_frames.value
498
+ or (not gui_show_all_frames.value )
499
+ )
500
+
501
+ @gui_show_head2.on_update
502
+ def _(_) -> None:
503
+ with server.atomic():
504
+ for frame_node in frame_nodes_head2:
505
+ frame_node.visible = gui_show_head2.value and (
506
+ gui_show_all_frames.value
507
+ or (not gui_show_all_frames.value )
508
+ )
509
+
510
+ # Initial visibility
511
+ for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
512
+ if gui_show_all_frames.value:
513
+ node1.visible = gui_show_head1.value and (i % gui_stride.value == 0)
514
+ node2.visible = gui_show_head2.value and (i % gui_stride.value == 0)
515
+ else:
516
+ node1.visible = gui_show_head1.value and (i == gui_timestep.value)
517
+ node2.visible = gui_show_head2.value and (i == gui_timestep.value)
518
+
519
+ # Process and visualize trajectories for head1
520
+ if traj_3d_head1 is not None:
521
+ # Get points over time
522
+ xyz_head1_centered = xyz_head1.copy()
523
+
524
+ # Select points to visualize
525
+ num_points = xyz_head1.shape[1]
526
+ points_to_visualize = min(num_points, num_traj_points)
527
+
528
+ # Get the mask for the first frame and reshape it to match point cloud dimensions
529
+ if mid_anchor:
530
+ first_frame_mask = masks[num_frames//2].reshape(-1)
531
+ else:
532
+ first_frame_mask = masks[0].reshape(-1) #[#points, h]
533
+
534
+ # Calculate trajectory lengths for each point
535
+ trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame] # Shape: (num_frames, num_points, 3)
536
+ traj_diffs = np.diff(trajectories, axis=0) # Differences between consecutive frames
537
+ traj_lengths = np.sum(np.sqrt(np.sum(traj_diffs**2, axis=-1)), axis=0) # Sum of distances for each point
538
+
539
+ # Get points that are within the mask
540
+ valid_indices = np.where(first_frame_mask)[0]
541
+
542
+ if len(valid_indices) > 0:
543
+ # Calculate average trajectory length for masked points
544
+ masked_traj_lengths = traj_lengths[valid_indices]
545
+ avg_traj_length = np.mean(masked_traj_lengths)
546
+
547
+ if mask_folder is not None:
548
+ # do not filter points by trajectory length
549
+ long_traj_indices = valid_indices
550
+ else:
551
+ # Filter points by trajectory length
552
+ long_traj_indices = valid_indices[masked_traj_lengths >= avg_traj_length]
553
+
554
+ # Randomly sample from the filtered points
555
+ if len(long_traj_indices) > 0:
556
+ # Random sampling without replacement
557
+ selected_indices = np.random.choice(
558
+ len(long_traj_indices),
559
+ min(points_to_visualize, len(long_traj_indices)),
560
+ replace=False
561
+ )
562
+ # Get the actual indices in their original order
563
+ valid_point_indices = long_traj_indices[np.sort(selected_indices)]
564
+ else:
565
+ valid_point_indices = np.array([])
566
+ else:
567
+ valid_point_indices = np.array([])
568
+
569
+ if len(valid_point_indices) > 0:
570
+ # Get trajectories for all valid points
571
+ trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame, valid_point_indices]
572
+ N_point = trajectories.shape[1]
573
+ if color_code == "rainbow":
574
+ point_colors = plt.cm.rainbow(np.linspace(0, 1, N_point))[:, :3]
575
+ elif color_code == "jet":
576
+ point_colors = plt.cm.jet(np.linspace(0, 1, N_point))[:, :3]
577
+ # Modify the loop to handle frames less than fixed_length_traj
578
+ for i in range(traj_end_frame - traj_start_frame):
579
+ # Calculate the actual trajectory length for this frame
580
+ actual_length = min(fixed_length_traj, i + 1)
581
+
582
+ if actual_length > 1: # Need at least 2 points to form a line
583
+ # Get the appropriate slice of trajectory data
584
+ start_idx = max(0, i - actual_length + 1)
585
+ end_idx = i + 1
586
+
587
+ # Create line segments between consecutive frames
588
+ traj_slice = trajectories[start_idx:end_idx]
589
+ line_points = np.stack([traj_slice[:-1], traj_slice[1:]], axis=2)
590
+ line_points = line_points.reshape(-1, 2, 3)
591
+
592
+ # Create corresponding colors
593
+ line_colors = np.tile(point_colors, (actual_length-1, 1))
594
+ line_colors = np.stack([line_colors, line_colors], axis=1)
595
+
596
+ # Add line segments
597
+ server.scene.add_line_segments(
598
+ name=f"/frames/t{i+traj_start_frame}/head1/trajectory",
599
+ points=line_points,
600
+ colors=line_colors,
601
+ line_width=traj_line_width,
602
+ visible=gui_show_trajectories.value
603
+ )
604
+
605
+ # Add trajectory controls functionality
606
+ @gui_show_trajectories.on_update
607
+ def _(_) -> None:
608
+ with server.atomic():
609
+ # Remove all existing trajectories
610
+ for i in range(num_frames):
611
+ try:
612
+ server.scene.remove_by_name(f"/frames/t{i}/head1/trajectory")
613
+ except KeyError:
614
+ pass
615
+
616
+ # Create new trajectories if enabled
617
+ if gui_show_trajectories.value and traj_3d_head1 is not None:
618
+ # Get the mask for the last frame and reshape it
619
+ last_frame_mask = masks[traj_end_frame-1].reshape(-1)
620
+
621
+ # Calculate trajectory lengths
622
+ trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame]
623
+ traj_diffs = np.diff(trajectories, axis=0)
624
+ traj_lengths = np.sum(np.sqrt(np.sum(traj_diffs**2, axis=-1)), axis=0)
625
+
626
+ # Get points that are within the mask
627
+ valid_indices = np.where(last_frame_mask)[0]
628
+
629
+ if len(valid_indices) > 0:
630
+ # Filter by trajectory length
631
+ masked_traj_lengths = traj_lengths[valid_indices]
632
+ avg_traj_length = np.mean(masked_traj_lengths)
633
+ long_traj_indices = valid_indices[masked_traj_lengths >= avg_traj_length]
634
+
635
+ # Randomly sample from the filtered points
636
+ if len(long_traj_indices) > 0:
637
+ # Random sampling without replacement
638
+ selected_indices = np.random.choice(
639
+ len(long_traj_indices),
640
+ min(points_to_visualize, len(long_traj_indices)),
641
+ replace=False
642
+ )
643
+ # Get the actual indices in their original order
644
+ valid_point_indices = long_traj_indices[np.sort(selected_indices)]
645
+ else:
646
+ valid_point_indices = np.array([])
647
+ else:
648
+ valid_point_indices = np.array([])
649
+
650
+ if len(valid_point_indices) > 0:
651
+ # Get trajectories for all valid points
652
+ trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame, valid_point_indices]
653
+ N_point = trajectories.shape[1]
654
+
655
+ if color_code == "rainbow":
656
+ point_colors = plt.cm.rainbow(np.linspace(0, 1, N_point))[:, :3]
657
+ elif color_code == "jet":
658
+ point_colors = plt.cm.jet(np.linspace(0, 1, N_point))[:, :3]
659
+
660
+ # Modify the loop to handle frames less than fixed_length_traj
661
+ for i in range(traj_end_frame - traj_start_frame):
662
+ # Calculate the actual trajectory length for this frame
663
+ actual_length = min(fixed_length_traj, i + 1)
664
+
665
+ if actual_length > 1: # Need at least 2 points to form a line
666
+ # Get the appropriate slice of trajectory data
667
+ start_idx = max(0, i - actual_length + 1)
668
+ end_idx = i + 1
669
+
670
+ # Create line segments between consecutive frames
671
+ traj_slice = trajectories[start_idx:end_idx]
672
+ line_points = np.stack([traj_slice[:-1], traj_slice[1:]], axis=2)
673
+ line_points = line_points.reshape(-1, 2, 3)
674
+
675
+ # Create corresponding colors
676
+ line_colors = np.tile(point_colors, (actual_length-1, 1))
677
+ line_colors = np.stack([line_colors, line_colors], axis=1)
678
+
679
+ # Add line segments
680
+ server.scene.add_line_segments(
681
+ name=f"/frames/t{i+traj_start_frame}/head1/trajectory",
682
+ points=line_points,
683
+ colors=line_colors,
684
+ line_width=traj_line_width,
685
+ visible=True
686
+ )
687
+
688
+ # Update color tinting when the checkbox changes
689
+ @gui_use_color_tint.on_update
690
+ def _(_) -> None:
691
+ with server.atomic():
692
+ for i in range(num_frames):
693
+ # Update head1 point cloud
694
+ if traj_3d_head1 is not None:
695
+ position = xyz_head1[i]
696
+ color = rgb_head1[i]
697
+ if conf_mask_head1 is not None:
698
+ position = position[conf_mask_head1[i]]
699
+ color = color[conf_mask_head1[i]]
700
+
701
+ color_head1 = color.copy()
702
+ if gui_use_color_tint.value:
703
+ color_head1 *= blend_ratio
704
+ color_head1[:, 0] = onp.clip(color_head1[:, 0] + blue_r * (1 - blend_ratio), 0, 1) # R
705
+ color_head1[:, 1] = onp.clip(color_head1[:, 1] + blue_g * (1 - blend_ratio), 0, 1) # G
706
+ color_head1[:, 2] = onp.clip(color_head1[:, 2] + blue_b * (1 - blend_ratio), 0, 1) # B
707
+
708
+ server.scene.remove_by_name(f"/frames/t{i}/head1/point_cloud")
709
+ server.scene.add_point_cloud(
710
+ name=f"/frames/t{i}/head1/point_cloud",
711
+ points=position[::downsample_factor],
712
+ colors=color_head1[::downsample_factor],
713
+ point_size=point_size,
714
+ point_shape="rounded",
715
+ )
716
+
717
+ # Update head2 point cloud
718
+ if traj_3d_head2 is not None:
719
+ position = xyz_head2[i]
720
+ color = rgb_head2[i]
721
+ if conf_mask_head2 is not None:
722
+ position = position[conf_mask_head2[i]]
723
+ color = color[conf_mask_head2[i]]
724
+
725
+ color_head2 = color.copy()
726
+ if gui_use_color_tint.value:
727
+ color_head2 *= blend_ratio
728
+ color_head2[:, 0] = onp.clip(color_head2[:, 0] + red_r * (1 - blend_ratio), 0, 1) # R
729
+ color_head2[:, 1] = onp.clip(color_head2[:, 1] + red_g * (1 - blend_ratio), 0, 1) # G
730
+ color_head2[:, 2] = onp.clip(color_head2[:, 2] + red_b * (1 - blend_ratio), 0, 1) # B
731
+
732
+ server.scene.remove_by_name(f"/frames/t{i}/head2/point_cloud")
733
+ server.scene.add_point_cloud(
734
+ name=f"/frames/t{i}/head2/point_cloud",
735
+ points=position[::downsample_factor],
736
+ colors=color_head2[::downsample_factor],
737
+ point_size=point_size,
738
+ point_shape="rounded",
739
+ )
740
+
741
+ # Initialize video preview
742
+ if raw_video is not None:
743
+ video_preview.image = process_video_frame(0)
744
+
745
+ # Update video preview when timestep changes
746
+ @gui_timestep.on_update
747
+ def _(_) -> None:
748
+ current_timestep = gui_timestep.value
749
+ if raw_video is not None:
750
+ video_preview.image = process_video_frame(current_timestep)
751
+
752
+ # Playback update loop.
753
+ log_memory_usage("before starting playback loop")
754
+
755
+ prev_timestep = gui_timestep.value
756
+ while True:
757
+ current_timestep = gui_timestep.value
758
+
759
+ # If timestep changes, update frame visibility
760
+ if current_timestep != prev_timestep:
761
+ with server.atomic():
762
+ # ... existing code ...
763
+
764
+ # Update video preview
765
+ if raw_video is not None:
766
+ video_preview.image = process_video_frame(current_timestep)
767
+
768
+ # Update in playback mode
769
+ if gui_playing.value and not gui_show_all_frames.value:
770
+ gui_timestep.value = (gui_timestep.value + 1) % num_frames
771
+
772
+ # Update video preview in playback mode
773
+ if raw_video is not None:
774
+ video_preview.image = process_video_frame(gui_timestep.value)
775
+
776
+ time.sleep(1.0 / gui_framerate.value)
777
+
778
+
779
+ if __name__ == "__main__":
780
+ tyro.cli(visualize_st4rtrack)
viser_proxy_manager.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ import httpx
4
+ import viser
5
+ import websockets
6
+ from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
7
+ from fastapi.responses import Response
8
+
9
+
10
+ class ViserProxyManager:
11
+ """Manages Viser server instances for Gradio applications.
12
+
13
+ This class handles the creation, retrieval, and cleanup of Viser server instances,
14
+ as well as proxying HTTP and WebSocket requests to the appropriate Viser server.
15
+
16
+ Args:
17
+ app: The FastAPI application to which the proxy routes will be added.
18
+ min_local_port: Minimum local port number to use for Viser servers. Defaults to 8000.
19
+ These ports are used only for internal communication and don't need to be publicly exposed.
20
+ max_local_port: Maximum local port number to use for Viser servers. Defaults to 9000.
21
+ These ports are used only for internal communication and don't need to be publicly exposed.
22
+ max_message_size: Maximum WebSocket message size in bytes. Defaults to 100MB.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ app: FastAPI,
28
+ min_local_port: int = 8000,
29
+ max_local_port: int = 9000,
30
+ max_message_size: int = 300 * 1024 * 1024, # 300MB default
31
+ ) -> None:
32
+ self._min_port = min_local_port
33
+ self._max_port = max_local_port
34
+ self._max_message_size = max_message_size
35
+ self._server_from_session_hash: dict[str, viser.ViserServer] = {}
36
+ self._last_port = self._min_port - 1 # Track last port tried
37
+
38
+ @app.get("/viser/{server_id}/{proxy_path:path}")
39
+ async def proxy(request: Request, server_id: str, proxy_path: str):
40
+ """Proxy HTTP requests to the appropriate Viser server."""
41
+ # Get the local port for this server ID
42
+ server = self._server_from_session_hash.get(server_id)
43
+ if server is None:
44
+ return Response(content="Server not found", status_code=404)
45
+
46
+ # Build target URL
47
+ if proxy_path:
48
+ path_suffix = f"/{proxy_path}"
49
+ else:
50
+ path_suffix = "/"
51
+
52
+ target_url = f"http://127.0.0.1:{server.get_port()}{path_suffix}"
53
+ if request.url.query:
54
+ target_url += f"?{request.url.query}"
55
+
56
+ # Forward request
57
+ async with httpx.AsyncClient() as client:
58
+ # Forward the original headers, but remove any problematic ones
59
+ headers = dict(request.headers)
60
+ headers.pop("host", None) # Remove host header to avoid conflicts
61
+ headers["accept-encoding"] = "identity" # Disable compression
62
+
63
+ proxied_req = client.build_request(
64
+ method=request.method,
65
+ url=target_url,
66
+ headers=headers,
67
+ content=await request.body(),
68
+ )
69
+ proxied_resp = await client.send(proxied_req, stream=True)
70
+
71
+ # Get response headers
72
+ response_headers = dict(proxied_resp.headers)
73
+
74
+ # Check if this is an HTML response
75
+ content = await proxied_resp.aread()
76
+ return Response(
77
+ content=content,
78
+ status_code=proxied_resp.status_code,
79
+ headers=response_headers,
80
+ )
81
+
82
+ # WebSocket Proxy
83
+ @app.websocket("/viser/{server_id}")
84
+ async def websocket_proxy(websocket: WebSocket, server_id: str):
85
+ """Proxy WebSocket connections to the appropriate Viser server."""
86
+ try:
87
+ await websocket.accept()
88
+
89
+ server = self._server_from_session_hash.get(server_id)
90
+ if server is None:
91
+ await websocket.close(code=1008, reason="Not Found")
92
+ return
93
+
94
+ # Determine target WebSocket URL
95
+ target_ws_url = f"ws://127.0.0.1:{server.get_port()}"
96
+
97
+ if not target_ws_url:
98
+ await websocket.close(code=1008, reason="Not Found")
99
+ return
100
+
101
+ try:
102
+ # Connect to the target WebSocket with increased message size and timeout
103
+ async with websockets.connect(
104
+ target_ws_url,
105
+ max_size=self._max_message_size,
106
+ ping_interval=30, # Send ping every 30 seconds
107
+ ping_timeout=10, # Wait 10 seconds for pong response
108
+ close_timeout=5, # Wait 5 seconds for close handshake
109
+ ) as ws_target:
110
+ # Create tasks for bidirectional communication
111
+ async def forward_to_target():
112
+ """Forward messages from the client to the target WebSocket."""
113
+ try:
114
+ while True:
115
+ data = await websocket.receive_bytes()
116
+ await ws_target.send(data, text=False)
117
+ except WebSocketDisconnect:
118
+ try:
119
+ await ws_target.close()
120
+ except RuntimeError:
121
+ pass
122
+
123
+ async def forward_from_target():
124
+ """Forward messages from the target WebSocket to the client."""
125
+ try:
126
+ while True:
127
+ data = await ws_target.recv(decode=False)
128
+ await websocket.send_bytes(data)
129
+ except websockets.exceptions.ConnectionClosed:
130
+ try:
131
+ await websocket.close()
132
+ except RuntimeError:
133
+ pass
134
+
135
+ # Run both forwarding tasks concurrently
136
+ forward_task = asyncio.create_task(forward_to_target())
137
+ backward_task = asyncio.create_task(forward_from_target())
138
+
139
+ # Wait for either task to complete (which means a connection was closed)
140
+ done, pending = await asyncio.wait(
141
+ [forward_task, backward_task],
142
+ return_when=asyncio.FIRST_COMPLETED,
143
+ )
144
+
145
+ # Cancel the remaining task
146
+ for task in pending:
147
+ task.cancel()
148
+
149
+ except websockets.exceptions.ConnectionClosedError as e:
150
+ print(f"WebSocket connection closed with error: {e}")
151
+ await websocket.close(code=1011, reason="Connection to target closed")
152
+
153
+ except Exception as e:
154
+ print(f"WebSocket proxy error: {e}")
155
+ try:
156
+ await websocket.close(code=1011, reason=str(e)[:120]) # Limit reason length
157
+ except:
158
+ pass # Already closed
159
+
160
+ def start_server(self, server_id: str) -> viser.ViserServer:
161
+ """Start a new Viser server and associate it with the given server ID.
162
+
163
+ Finds an available port within the configured min_local_port and max_local_port range.
164
+ These ports are used only for internal communication and don't need to be publicly exposed.
165
+
166
+ Args:
167
+ server_id: The unique identifier to associate with the new server.
168
+
169
+ Returns:
170
+ The newly created Viser server instance.
171
+
172
+ Raises:
173
+ RuntimeError: If no free ports are available in the configured range.
174
+ """
175
+ import socket
176
+
177
+ # Start searching from the last port + 1 (with wraparound)
178
+ port_range_size = self._max_port - self._min_port + 1
179
+ start_port = (
180
+ (self._last_port + 1 - self._min_port) % port_range_size
181
+ ) + self._min_port
182
+
183
+ # Try each port once
184
+ for offset in range(port_range_size):
185
+ port = (
186
+ (start_port - self._min_port + offset) % port_range_size
187
+ ) + self._min_port
188
+ try:
189
+ # Check if port is available by attempting to bind to it
190
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
191
+ s.bind(("127.0.0.1", port))
192
+ # Port is available, create server with this port
193
+ server = viser.ViserServer(port=port)
194
+ self._server_from_session_hash[server_id] = server
195
+ self._last_port = port
196
+ return server
197
+ except OSError:
198
+ # Port is in use, try the next one
199
+ continue
200
+
201
+ # If we get here, no ports were available
202
+ raise RuntimeError(
203
+ f"No available local ports in range {self._min_port}-{self._max_port}"
204
+ )
205
+
206
+ def get_server(self, server_id: str) -> viser.ViserServer:
207
+ """Retrieve a Viser server instance by its ID.
208
+
209
+ Args:
210
+ server_id: The unique identifier of the server to retrieve.
211
+
212
+ Returns:
213
+ The Viser server instance associated with the given ID.
214
+ """
215
+ return self._server_from_session_hash[server_id]
216
+
217
+ def stop_server(self, server_id: str) -> None:
218
+ """Stop a Viser server and remove it from the manager.
219
+
220
+ Args:
221
+ server_id: The unique identifier of the server to stop.
222
+ """
223
+ self._server_from_session_hash.pop(server_id).stop()