Spaces:
Sleeping
Sleeping
code
Browse files- Dockerfile +22 -0
- app.py +108 -0
- requirements.txt +15 -0
- vis_st4rtrack.py +780 -0
- viser_proxy_manager.py +223 -0
Dockerfile
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies for OpenCV and build tools
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
libgl1-mesa-glx \
|
| 8 |
+
libglib2.0-0 \
|
| 9 |
+
build-essential \
|
| 10 |
+
git \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
|
| 13 |
+
COPY requirements.txt .
|
| 14 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 15 |
+
|
| 16 |
+
COPY . .
|
| 17 |
+
|
| 18 |
+
# Make port 7860 available to the world outside the container
|
| 19 |
+
EXPOSE 7860
|
| 20 |
+
|
| 21 |
+
# Command to run when the container starts
|
| 22 |
+
CMD ["python", "app.py"]
|
app.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import threading
|
| 3 |
+
import psutil
|
| 4 |
+
import fastapi
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import uvicorn
|
| 7 |
+
|
| 8 |
+
from viser_proxy_manager import ViserProxyManager
|
| 9 |
+
from vis_st4rtrack import visualize_st4rtrack, load_trajectory_data, log_memory_usage
|
| 10 |
+
|
| 11 |
+
# Global cache for loaded data
|
| 12 |
+
global_data_cache = None
|
| 13 |
+
|
| 14 |
+
def check_ram_usage(threshold_percent=90):
|
| 15 |
+
"""Check if RAM usage is above the threshold.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
threshold_percent: Maximum RAM usage percentage allowed
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
bool: True if RAM usage is below threshold, False otherwise
|
| 22 |
+
"""
|
| 23 |
+
ram_percent = psutil.virtual_memory().percent
|
| 24 |
+
print(f"Current RAM usage: {ram_percent}%")
|
| 25 |
+
return ram_percent < threshold_percent
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def main() -> None:
|
| 29 |
+
# Load data once at startup using the function from vis_st4rtrack.py
|
| 30 |
+
global global_data_cache
|
| 31 |
+
global_data_cache = load_trajectory_data(use_float16=True, max_frames=32)
|
| 32 |
+
|
| 33 |
+
app = fastapi.FastAPI()
|
| 34 |
+
viser_manager = ViserProxyManager(app)
|
| 35 |
+
|
| 36 |
+
# Create a Gradio interface with title, iframe, and buttons
|
| 37 |
+
with gr.Blocks(title="Viser Viewer") as demo:
|
| 38 |
+
# Add the iframe with a border
|
| 39 |
+
iframe_html = gr.HTML("")
|
| 40 |
+
status_text = gr.Markdown("") # Add status text component
|
| 41 |
+
|
| 42 |
+
@demo.load(outputs=[iframe_html, status_text])
|
| 43 |
+
def start_server(request: gr.Request):
|
| 44 |
+
assert request.session_hash is not None
|
| 45 |
+
|
| 46 |
+
# Check RAM usage before starting visualization
|
| 47 |
+
if not check_ram_usage(threshold_percent=100):
|
| 48 |
+
return """
|
| 49 |
+
<div style="text-align: center; padding: 20px; background-color: #ffeeee; border-radius: 5px;">
|
| 50 |
+
<h2>⚠️ Server is currently under high load</h2>
|
| 51 |
+
<p>Please try again later when resources are available.</p>
|
| 52 |
+
</div>
|
| 53 |
+
""", "**System Status:** High memory usage detected. Visualization not loaded to prevent server overload."
|
| 54 |
+
|
| 55 |
+
viser_manager.start_server(request.session_hash)
|
| 56 |
+
|
| 57 |
+
# Use the request's base URL if available
|
| 58 |
+
host = request.headers["host"]
|
| 59 |
+
|
| 60 |
+
# Determine protocol (use HTTPS for HuggingFace Spaces or other secure environments)
|
| 61 |
+
protocol = (
|
| 62 |
+
"https"
|
| 63 |
+
if request.headers.get("x-forwarded-proto") == "https"
|
| 64 |
+
else "http"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Add visualization in a separate thread
|
| 68 |
+
server = viser_manager.get_server(request.session_hash)
|
| 69 |
+
threading.Thread(
|
| 70 |
+
target=visualize_st4rtrack,
|
| 71 |
+
kwargs={
|
| 72 |
+
"server": server,
|
| 73 |
+
"use_float16": True,
|
| 74 |
+
"preloaded_data": global_data_cache, # Pass the preloaded data
|
| 75 |
+
"color_code": "jet",
|
| 76 |
+
"blue_rgb": (0.0, 0.149, 0.463), # #002676
|
| 77 |
+
"red_rgb": (0.769, 0.510, 0.055), # #FDB515
|
| 78 |
+
"blend_ratio": 0.7,
|
| 79 |
+
"max_frames": 100,
|
| 80 |
+
"traj_path": "480p_bear",
|
| 81 |
+
"mask_folder": "bear",
|
| 82 |
+
},
|
| 83 |
+
daemon=True
|
| 84 |
+
).start()
|
| 85 |
+
|
| 86 |
+
return f"""
|
| 87 |
+
<iframe
|
| 88 |
+
src="{protocol}://{host}/viser/{request.session_hash}/"
|
| 89 |
+
width="100%"
|
| 90 |
+
height="500px"
|
| 91 |
+
frameborder="0"
|
| 92 |
+
style="display: block;"
|
| 93 |
+
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
|
| 94 |
+
loading="lazy"
|
| 95 |
+
></iframe>
|
| 96 |
+
""", "**System Status:** Visualization loaded successfully."
|
| 97 |
+
|
| 98 |
+
@demo.unload
|
| 99 |
+
def stop(request: gr.Request):
|
| 100 |
+
assert request.session_hash is not None
|
| 101 |
+
viser_manager.stop_server(request.session_hash)
|
| 102 |
+
|
| 103 |
+
gr.mount_gradio_app(app, demo, "/")
|
| 104 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if __name__ == "__main__":
|
| 108 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# git+https://github.com/nerfstudio-project/viser.git
|
| 2 |
+
viser>=0.2.23
|
| 3 |
+
gradio==5.23.1
|
| 4 |
+
fastapi==0.115.11
|
| 5 |
+
uvicorn==0.34.0
|
| 6 |
+
httpx==0.27.2
|
| 7 |
+
websockets==15.0.1
|
| 8 |
+
tyro==0.4.1
|
| 9 |
+
numpy>=1.20.0
|
| 10 |
+
tqdm>=4.62.0
|
| 11 |
+
opencv-python>=4.5.0
|
| 12 |
+
imageio>=2.25.0
|
| 13 |
+
matplotlib>=3.5.0
|
| 14 |
+
pyliblzfse>=0.1.0
|
| 15 |
+
psutil>=5.9.0
|
vis_st4rtrack.py
ADDED
|
@@ -0,0 +1,780 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Record3D visualizer
|
| 2 |
+
|
| 3 |
+
Parse and stream record3d captures. To get the demo data, see `./assets/download_record3d_dance.sh`.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import numpy as onp
|
| 10 |
+
import tyro
|
| 11 |
+
import cv2
|
| 12 |
+
from tqdm.auto import tqdm
|
| 13 |
+
|
| 14 |
+
import viser
|
| 15 |
+
import viser.extras
|
| 16 |
+
import viser.transforms as tf
|
| 17 |
+
|
| 18 |
+
from glob import glob
|
| 19 |
+
import numpy as np
|
| 20 |
+
import imageio.v3 as iio
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
import psutil
|
| 23 |
+
|
| 24 |
+
def log_memory_usage(message=""):
|
| 25 |
+
"""Log current memory usage with an optional message."""
|
| 26 |
+
process = psutil.Process()
|
| 27 |
+
memory_info = process.memory_info()
|
| 28 |
+
memory_mb = memory_info.rss / (1024 * 1024) # Convert to MB
|
| 29 |
+
print(f"Memory usage {message}: {memory_mb:.2f} MB")
|
| 30 |
+
|
| 31 |
+
def load_trajectory_data(traj_path="results", use_float16=True, max_frames=None, mask_folder='./train'):
|
| 32 |
+
"""Load trajectory data from files.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
traj_path: Path to the directory containing trajectory data
|
| 36 |
+
use_float16: Whether to convert data to float16 to save memory
|
| 37 |
+
max_frames: Maximum number of frames to load (None for all)
|
| 38 |
+
mask_folder: Path to the directory containing mask images
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
A dictionary containing loaded data
|
| 42 |
+
"""
|
| 43 |
+
log_memory_usage("before loading data")
|
| 44 |
+
|
| 45 |
+
data_cache = {
|
| 46 |
+
'traj_3d_head1': None,
|
| 47 |
+
'traj_3d_head2': None,
|
| 48 |
+
'conf_mask_head1': None,
|
| 49 |
+
'conf_mask_head2': None,
|
| 50 |
+
'masks': None,
|
| 51 |
+
'raw_video': None,
|
| 52 |
+
'loaded': False
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
# Load masks
|
| 56 |
+
masks_paths = sorted(glob(mask_folder + '/*.jpg'))
|
| 57 |
+
masks = None
|
| 58 |
+
|
| 59 |
+
if masks_paths:
|
| 60 |
+
masks = [iio.imread(p) for p in masks_paths]
|
| 61 |
+
masks = np.stack(masks, axis=0)
|
| 62 |
+
# Convert masks to binary (0 or 1)
|
| 63 |
+
masks = (masks < 1).astype(np.float32)
|
| 64 |
+
masks = masks.sum(axis=-1) > 2 # Combine all channels, True where any channel was 1
|
| 65 |
+
print(f"Original masks shape: {masks.shape}")
|
| 66 |
+
else:
|
| 67 |
+
print("No masks found. Will create default masks when needed.")
|
| 68 |
+
|
| 69 |
+
data_cache['masks'] = masks
|
| 70 |
+
|
| 71 |
+
if Path(traj_path).is_dir():
|
| 72 |
+
# Find all trajectory files
|
| 73 |
+
traj_3d_paths_head1 = sorted(glob(traj_path + '/pts3d1_p*.npy'),
|
| 74 |
+
key=lambda x: int(x.split('_p')[-1].split('.')[0]))
|
| 75 |
+
conf_paths_head1 = sorted(glob(traj_path + '/conf1_p*.npy'),
|
| 76 |
+
key=lambda x: int(x.split('_p')[-1].split('.')[0]))
|
| 77 |
+
|
| 78 |
+
traj_3d_paths_head2 = sorted(glob(traj_path + '/pts3d2_p*.npy'),
|
| 79 |
+
key=lambda x: int(x.split('_p')[-1].split('.')[0]))
|
| 80 |
+
conf_paths_head2 = sorted(glob(traj_path + '/conf2_p*.npy'),
|
| 81 |
+
key=lambda x: int(x.split('_p')[-1].split('.')[0]))
|
| 82 |
+
|
| 83 |
+
# Limit number of frames if specified
|
| 84 |
+
if max_frames is not None:
|
| 85 |
+
traj_3d_paths_head1 = traj_3d_paths_head1[:max_frames]
|
| 86 |
+
conf_paths_head1 = conf_paths_head1[:max_frames] if conf_paths_head1 else []
|
| 87 |
+
traj_3d_paths_head2 = traj_3d_paths_head2[:max_frames]
|
| 88 |
+
conf_paths_head2 = conf_paths_head2[:max_frames] if conf_paths_head2 else []
|
| 89 |
+
|
| 90 |
+
# Process head1
|
| 91 |
+
if traj_3d_paths_head1:
|
| 92 |
+
if use_float16:
|
| 93 |
+
traj_3d_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head1], axis=0)
|
| 94 |
+
else:
|
| 95 |
+
traj_3d_head1 = onp.stack([onp.load(p) for p in traj_3d_paths_head1], axis=0)
|
| 96 |
+
|
| 97 |
+
log_memory_usage("after loading head1 data")
|
| 98 |
+
|
| 99 |
+
h, w, _ = traj_3d_head1.shape[1:]
|
| 100 |
+
num_frames = traj_3d_head1.shape[0]
|
| 101 |
+
|
| 102 |
+
# If masks is None, create default masks (all ones)
|
| 103 |
+
if masks is None:
|
| 104 |
+
masks = np.ones((num_frames, h, w), dtype=bool)
|
| 105 |
+
print(f"Created default masks with shape: {masks.shape}")
|
| 106 |
+
data_cache['masks'] = masks
|
| 107 |
+
else:
|
| 108 |
+
# Resize masks to match trajectory dimensions using nearest neighbor interpolation
|
| 109 |
+
masks_resized = np.zeros((masks.shape[0], h, w), dtype=bool)
|
| 110 |
+
for i in range(masks.shape[0]):
|
| 111 |
+
masks_resized[i] = cv2.resize(
|
| 112 |
+
masks[i].astype(np.uint8),
|
| 113 |
+
(w, h),
|
| 114 |
+
interpolation=cv2.INTER_NEAREST
|
| 115 |
+
).astype(bool)
|
| 116 |
+
|
| 117 |
+
print(f"Resized masks shape: {masks_resized.shape}")
|
| 118 |
+
data_cache['masks'] = masks_resized
|
| 119 |
+
|
| 120 |
+
# Reshape trajectory data
|
| 121 |
+
traj_3d_head1 = traj_3d_head1.reshape(traj_3d_head1.shape[0], -1, 6)
|
| 122 |
+
data_cache['traj_3d_head1'] = traj_3d_head1
|
| 123 |
+
|
| 124 |
+
if conf_paths_head1:
|
| 125 |
+
conf_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head1], axis=0)
|
| 126 |
+
conf_head1 = conf_head1.reshape(conf_head1.shape[0], -1)
|
| 127 |
+
conf_head1 = conf_head1.mean(axis=0)
|
| 128 |
+
# repeat the conf_head1 to match the number of frames in the dimension 0
|
| 129 |
+
conf_head1 = np.tile(conf_head1, (num_frames, 1))
|
| 130 |
+
# Convert to float32 before calculating percentile to avoid overflow
|
| 131 |
+
conf_thre = np.percentile(conf_head1.astype(np.float32), 1) # Default percentile
|
| 132 |
+
conf_mask_head1 = conf_head1 > conf_thre
|
| 133 |
+
data_cache['conf_mask_head1'] = conf_mask_head1
|
| 134 |
+
|
| 135 |
+
# Process head2
|
| 136 |
+
if traj_3d_paths_head2:
|
| 137 |
+
if use_float16:
|
| 138 |
+
traj_3d_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head2], axis=0)
|
| 139 |
+
else:
|
| 140 |
+
traj_3d_head2 = onp.stack([onp.load(p) for p in traj_3d_paths_head2], axis=0)
|
| 141 |
+
|
| 142 |
+
log_memory_usage("after loading head2 data")
|
| 143 |
+
|
| 144 |
+
# Store raw video data
|
| 145 |
+
raw_video = traj_3d_head2[:, :, :, 3:6] # [num_frames, h, w, 3]
|
| 146 |
+
data_cache['raw_video'] = raw_video
|
| 147 |
+
|
| 148 |
+
traj_3d_head2 = traj_3d_head2.reshape(traj_3d_head2.shape[0], -1, 6)
|
| 149 |
+
data_cache['traj_3d_head2'] = traj_3d_head2
|
| 150 |
+
|
| 151 |
+
if conf_paths_head2:
|
| 152 |
+
conf_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head2], axis=0)
|
| 153 |
+
conf_head2 = conf_head2.reshape(conf_head2.shape[0], -1)
|
| 154 |
+
# set conf thre to be 1 percentile of the conf_head2, for each frame
|
| 155 |
+
conf_thre = np.percentile(conf_head2.astype(np.float32), 1, axis=1)
|
| 156 |
+
conf_mask_head2 = conf_head2 > conf_thre[:, None]
|
| 157 |
+
data_cache['conf_mask_head2'] = conf_mask_head2
|
| 158 |
+
|
| 159 |
+
data_cache['loaded'] = True
|
| 160 |
+
log_memory_usage("after loading all data")
|
| 161 |
+
return data_cache
|
| 162 |
+
|
| 163 |
+
def visualize_st4rtrack(
|
| 164 |
+
traj_path: str = "results",
|
| 165 |
+
up_dir: str = "-z", # should be +z or -z
|
| 166 |
+
max_frames: int = 32,
|
| 167 |
+
share: bool = False,
|
| 168 |
+
point_size: float = 0.0015,
|
| 169 |
+
downsample_factor: int = 3,
|
| 170 |
+
num_traj_points: int = 100,
|
| 171 |
+
conf_thre_percentile: float = 1,
|
| 172 |
+
traj_end_frame: int = 100,
|
| 173 |
+
traj_start_frame: int = 0,
|
| 174 |
+
traj_line_width: float = 3.,
|
| 175 |
+
fixed_length_traj: int = 15,
|
| 176 |
+
server: viser.ViserServer = None,
|
| 177 |
+
use_float16: bool = True,
|
| 178 |
+
preloaded_data: dict = None, # Add this parameter to accept preloaded data
|
| 179 |
+
color_code: str = "jet",
|
| 180 |
+
# Updated hex colors: #002676 for blue and #FDB515 for red/gold
|
| 181 |
+
blue_rgb: tuple[float, float, float] = (0.0, 0.149, 0.463), # #002676
|
| 182 |
+
red_rgb: tuple[float, float, float] = (0.769, 0.510, 0.055), # #FDB515
|
| 183 |
+
blend_ratio: float = 0.7,
|
| 184 |
+
mask_folder: str = None,
|
| 185 |
+
mid_anchor: bool = False,
|
| 186 |
+
video_width: int = 320, # Video display width
|
| 187 |
+
video_height: int = 180, # Video display height
|
| 188 |
+
) -> None:
|
| 189 |
+
log_memory_usage("at start of visualization")
|
| 190 |
+
|
| 191 |
+
if server is None:
|
| 192 |
+
server = viser.ViserServer()
|
| 193 |
+
if share:
|
| 194 |
+
server.request_share_url()
|
| 195 |
+
|
| 196 |
+
@server.on_client_connect
|
| 197 |
+
def _(client: viser.ClientHandle) -> None:
|
| 198 |
+
client.camera.position = (1e-3, 0.6, -0.08)
|
| 199 |
+
client.camera.look_at = (0, 0, 0)
|
| 200 |
+
|
| 201 |
+
# Configure the GUI panel size and layout
|
| 202 |
+
server.gui.configure_theme(
|
| 203 |
+
control_layout="collapsible",
|
| 204 |
+
control_width="small",
|
| 205 |
+
dark_mode=False,
|
| 206 |
+
show_logo=False,
|
| 207 |
+
show_share_button=True
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Add video preview to the GUI panel - placed at the top
|
| 211 |
+
video_preview = server.gui.add_image(
|
| 212 |
+
np.zeros((video_height, video_width, 3), dtype=np.uint8), # Initial blank image
|
| 213 |
+
format="jpeg"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Use preloaded data if available
|
| 217 |
+
if preloaded_data and preloaded_data.get('loaded', False):
|
| 218 |
+
traj_3d_head1 = preloaded_data.get('traj_3d_head1')
|
| 219 |
+
traj_3d_head2 = preloaded_data.get('traj_3d_head2')
|
| 220 |
+
conf_mask_head1 = preloaded_data.get('conf_mask_head1')
|
| 221 |
+
conf_mask_head2 = preloaded_data.get('conf_mask_head2')
|
| 222 |
+
masks = preloaded_data.get('masks')
|
| 223 |
+
raw_video = preloaded_data.get('raw_video')
|
| 224 |
+
print("Using preloaded data!")
|
| 225 |
+
else:
|
| 226 |
+
# Load data using the shared function
|
| 227 |
+
print("No preloaded data available, loading from files...")
|
| 228 |
+
data = load_trajectory_data(traj_path, use_float16, max_frames, mask_folder)
|
| 229 |
+
traj_3d_head1 = data.get('traj_3d_head1')
|
| 230 |
+
traj_3d_head2 = data.get('traj_3d_head2')
|
| 231 |
+
conf_mask_head1 = data.get('conf_mask_head1')
|
| 232 |
+
conf_mask_head2 = data.get('conf_mask_head2')
|
| 233 |
+
masks = data.get('masks')
|
| 234 |
+
raw_video = data.get('raw_video')
|
| 235 |
+
|
| 236 |
+
def process_video_frame(frame_idx):
|
| 237 |
+
if raw_video is None:
|
| 238 |
+
return np.zeros((video_height, video_width, 3), dtype=np.uint8)
|
| 239 |
+
|
| 240 |
+
# Get the original frame
|
| 241 |
+
raw_frame = raw_video[frame_idx]
|
| 242 |
+
|
| 243 |
+
# Adjust value range to 0-255
|
| 244 |
+
if raw_frame.max() <= 1.0:
|
| 245 |
+
frame = (raw_frame * 255).astype(np.uint8)
|
| 246 |
+
else:
|
| 247 |
+
frame = raw_frame.astype(np.uint8)
|
| 248 |
+
|
| 249 |
+
# Resize to fit the preview window
|
| 250 |
+
h, w = frame.shape[:2]
|
| 251 |
+
# Calculate size while maintaining aspect ratio
|
| 252 |
+
if h/w > video_height/video_width: # Height limited
|
| 253 |
+
new_h = video_height
|
| 254 |
+
new_w = int(w * (new_h / h))
|
| 255 |
+
else: # Width limited
|
| 256 |
+
new_w = video_width
|
| 257 |
+
new_h = int(h * (new_w / w))
|
| 258 |
+
|
| 259 |
+
# Resize
|
| 260 |
+
resized_frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
| 261 |
+
|
| 262 |
+
# Create a black background
|
| 263 |
+
display_frame = np.zeros((video_height, video_width, 3), dtype=np.uint8)
|
| 264 |
+
|
| 265 |
+
# Place the resized frame in the center
|
| 266 |
+
y_offset = (video_height - new_h) // 2
|
| 267 |
+
x_offset = (video_width - new_w) // 2
|
| 268 |
+
display_frame[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized_frame
|
| 269 |
+
|
| 270 |
+
return display_frame
|
| 271 |
+
|
| 272 |
+
server.scene.set_up_direction(up_dir)
|
| 273 |
+
print("Setting up visualization!")
|
| 274 |
+
|
| 275 |
+
# Add visualization controls
|
| 276 |
+
with server.gui.add_folder("Visualization"):
|
| 277 |
+
gui_show_head1 = server.gui.add_checkbox("Tracking Points", True)
|
| 278 |
+
gui_show_head2 = server.gui.add_checkbox("Recon Points", True)
|
| 279 |
+
gui_show_trajectories = server.gui.add_checkbox("Trajectories", True)
|
| 280 |
+
gui_use_color_tint = server.gui.add_checkbox("Use Color Tint", True)
|
| 281 |
+
|
| 282 |
+
# Process and center point clouds
|
| 283 |
+
center_point = None
|
| 284 |
+
if traj_3d_head1 is not None:
|
| 285 |
+
xyz_head1 = traj_3d_head1[:, :, :3]
|
| 286 |
+
rgb_head1 = traj_3d_head1[:, :, 3:6]
|
| 287 |
+
if center_point is None:
|
| 288 |
+
center_point = onp.mean(xyz_head1, axis=(0, 1), keepdims=True)
|
| 289 |
+
xyz_head1 -= center_point
|
| 290 |
+
if rgb_head1.sum(axis=(-1)).max() > 125:
|
| 291 |
+
rgb_head1 /= 255.0
|
| 292 |
+
|
| 293 |
+
if traj_3d_head2 is not None:
|
| 294 |
+
xyz_head2 = traj_3d_head2[:, :, :3]
|
| 295 |
+
rgb_head2 = traj_3d_head2[:, :, 3:6]
|
| 296 |
+
if center_point is None:
|
| 297 |
+
center_point = onp.mean(xyz_head2, axis=(0, 1), keepdims=True)
|
| 298 |
+
xyz_head2 -= center_point
|
| 299 |
+
if rgb_head2.sum(axis=(-1)).max() > 125:
|
| 300 |
+
rgb_head2 /= 255.0
|
| 301 |
+
|
| 302 |
+
# Determine number of frames
|
| 303 |
+
F = max(
|
| 304 |
+
traj_3d_head1.shape[0] if traj_3d_head1 is not None else 0,
|
| 305 |
+
traj_3d_head2.shape[0] if traj_3d_head2 is not None else 0
|
| 306 |
+
)
|
| 307 |
+
num_frames = min(max_frames, F)
|
| 308 |
+
traj_end_frame = min(traj_end_frame, num_frames)
|
| 309 |
+
print(f"Number of frames: {num_frames}")
|
| 310 |
+
xyz_head1 = xyz_head1[:num_frames]
|
| 311 |
+
xyz_head2 = xyz_head2[:num_frames]
|
| 312 |
+
rgb_head1 = rgb_head1[:num_frames]
|
| 313 |
+
rgb_head2 = rgb_head2[:num_frames]
|
| 314 |
+
|
| 315 |
+
# Add playback UI.
|
| 316 |
+
with server.gui.add_folder("Playback"):
|
| 317 |
+
gui_timestep = server.gui.add_slider(
|
| 318 |
+
"Timestep",
|
| 319 |
+
min=0,
|
| 320 |
+
max=num_frames - 1,
|
| 321 |
+
step=1,
|
| 322 |
+
initial_value=0,
|
| 323 |
+
disabled=True,
|
| 324 |
+
)
|
| 325 |
+
gui_next_frame = server.gui.add_button("Next Frame", disabled=True)
|
| 326 |
+
gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True)
|
| 327 |
+
gui_playing = server.gui.add_checkbox("Playing", True)
|
| 328 |
+
gui_framerate = server.gui.add_slider(
|
| 329 |
+
"FPS", min=1, max=60, step=0.1, initial_value=20
|
| 330 |
+
)
|
| 331 |
+
gui_framerate_options = server.gui.add_button_group(
|
| 332 |
+
"FPS options", ("10", "20", "30")
|
| 333 |
+
)
|
| 334 |
+
gui_show_all_frames = server.gui.add_checkbox("Show all frames", False)
|
| 335 |
+
gui_stride = server.gui.add_slider(
|
| 336 |
+
"Stride",
|
| 337 |
+
min=1,
|
| 338 |
+
max=num_frames,
|
| 339 |
+
step=1,
|
| 340 |
+
initial_value=5,
|
| 341 |
+
disabled=True, # Initially disabled
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Frame step buttons.
|
| 345 |
+
@gui_next_frame.on_click
|
| 346 |
+
def _(_) -> None:
|
| 347 |
+
gui_timestep.value = (gui_timestep.value + 1) % num_frames
|
| 348 |
+
|
| 349 |
+
@gui_prev_frame.on_click
|
| 350 |
+
def _(_) -> None:
|
| 351 |
+
gui_timestep.value = (gui_timestep.value - 1) % num_frames
|
| 352 |
+
|
| 353 |
+
# Disable frame controls when we're playing.
|
| 354 |
+
@gui_playing.on_update
|
| 355 |
+
def _(_) -> None:
|
| 356 |
+
gui_timestep.disabled = gui_playing.value or gui_show_all_frames.value
|
| 357 |
+
gui_next_frame.disabled = gui_playing.value or gui_show_all_frames.value
|
| 358 |
+
gui_prev_frame.disabled = gui_playing.value or gui_show_all_frames.value
|
| 359 |
+
|
| 360 |
+
# Set the framerate when we click one of the options.
|
| 361 |
+
@gui_framerate_options.on_click
|
| 362 |
+
def _(_) -> None:
|
| 363 |
+
gui_framerate.value = int(gui_framerate_options.value)
|
| 364 |
+
|
| 365 |
+
prev_timestep = gui_timestep.value
|
| 366 |
+
|
| 367 |
+
# Toggle frame visibility when the timestep slider changes.
|
| 368 |
+
@gui_timestep.on_update
|
| 369 |
+
def _(_) -> None:
|
| 370 |
+
nonlocal prev_timestep
|
| 371 |
+
current_timestep = gui_timestep.value
|
| 372 |
+
if not gui_show_all_frames.value:
|
| 373 |
+
with server.atomic():
|
| 374 |
+
if gui_show_head1.value:
|
| 375 |
+
frame_nodes_head1[current_timestep].visible = True
|
| 376 |
+
frame_nodes_head1[prev_timestep].visible = False
|
| 377 |
+
if gui_show_head2.value:
|
| 378 |
+
frame_nodes_head2[current_timestep].visible = True
|
| 379 |
+
frame_nodes_head2[prev_timestep].visible = False
|
| 380 |
+
prev_timestep = current_timestep
|
| 381 |
+
server.flush() # Optional!
|
| 382 |
+
|
| 383 |
+
# Show or hide all frames based on the checkbox.
|
| 384 |
+
@gui_show_all_frames.on_update
|
| 385 |
+
def _(_) -> None:
|
| 386 |
+
gui_stride.disabled = not gui_show_all_frames.value # Enable/disable stride slider
|
| 387 |
+
if gui_show_all_frames.value:
|
| 388 |
+
# Show frames with stride
|
| 389 |
+
stride = gui_stride.value
|
| 390 |
+
with server.atomic():
|
| 391 |
+
for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
|
| 392 |
+
node1.visible = gui_show_head1.value and (i % stride == 0)
|
| 393 |
+
node2.visible = gui_show_head2.value and (i % stride == 0)
|
| 394 |
+
# Disable playback controls
|
| 395 |
+
gui_playing.disabled = True
|
| 396 |
+
gui_timestep.disabled = True
|
| 397 |
+
gui_next_frame.disabled = True
|
| 398 |
+
gui_prev_frame.disabled = True
|
| 399 |
+
else:
|
| 400 |
+
# Show only the current frame
|
| 401 |
+
current_timestep = gui_timestep.value
|
| 402 |
+
with server.atomic():
|
| 403 |
+
for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
|
| 404 |
+
node1.visible = gui_show_head1.value and (i == current_timestep)
|
| 405 |
+
node2.visible = gui_show_head2.value and (i == current_timestep)
|
| 406 |
+
# Re-enable playback controls
|
| 407 |
+
gui_playing.disabled = False
|
| 408 |
+
gui_timestep.disabled = gui_playing.value
|
| 409 |
+
gui_next_frame.disabled = gui_playing.value
|
| 410 |
+
gui_prev_frame.disabled = gui_playing.value
|
| 411 |
+
|
| 412 |
+
# Update frame visibility when the stride changes.
|
| 413 |
+
@gui_stride.on_update
|
| 414 |
+
def _(_) -> None:
|
| 415 |
+
if gui_show_all_frames.value:
|
| 416 |
+
# Update frame visibility based on new stride
|
| 417 |
+
stride = gui_stride.value
|
| 418 |
+
with server.atomic():
|
| 419 |
+
for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
|
| 420 |
+
node1.visible = gui_show_head1.value and (i % stride == 0)
|
| 421 |
+
node2.visible = gui_show_head2.value and (i % stride == 0)
|
| 422 |
+
|
| 423 |
+
# Load in frames.
|
| 424 |
+
server.scene.add_frame(
|
| 425 |
+
"/frames",
|
| 426 |
+
wxyz=tf.SO3.exp(onp.array([onp.pi / 2.0, 0.0, 0.0])).wxyz,
|
| 427 |
+
position=(0, 0, 0),
|
| 428 |
+
show_axes=False,
|
| 429 |
+
)
|
| 430 |
+
frame_nodes_head1: list[viser.FrameHandle] = []
|
| 431 |
+
frame_nodes_head2: list[viser.FrameHandle] = []
|
| 432 |
+
|
| 433 |
+
# Extract RGB components for tinting
|
| 434 |
+
blue_r, blue_g, blue_b = blue_rgb
|
| 435 |
+
red_r, red_g, red_b = red_rgb
|
| 436 |
+
|
| 437 |
+
# Create frames for each timestep
|
| 438 |
+
frame_nodes_head1 = []
|
| 439 |
+
frame_nodes_head2 = []
|
| 440 |
+
for i in tqdm(range(num_frames)):
|
| 441 |
+
# Process head1
|
| 442 |
+
if traj_3d_head1 is not None:
|
| 443 |
+
frame_nodes_head1.append(server.scene.add_frame(f"/frames/t{i}/head1", show_axes=False))
|
| 444 |
+
position = xyz_head1[i]
|
| 445 |
+
color = rgb_head1[i]
|
| 446 |
+
if conf_mask_head1 is not None:
|
| 447 |
+
position = position[conf_mask_head1[i]]
|
| 448 |
+
color = color[conf_mask_head1[i]]
|
| 449 |
+
|
| 450 |
+
# Add point cloud for head1 with optional blue tint
|
| 451 |
+
color_head1 = color.copy()
|
| 452 |
+
if gui_use_color_tint.value:
|
| 453 |
+
color_head1 *= blend_ratio
|
| 454 |
+
color_head1[:, 0] = onp.clip(color_head1[:, 0] + blue_r * (1 - blend_ratio), 0, 1) # R
|
| 455 |
+
color_head1[:, 1] = onp.clip(color_head1[:, 1] + blue_g * (1 - blend_ratio), 0, 1) # G
|
| 456 |
+
color_head1[:, 2] = onp.clip(color_head1[:, 2] + blue_b * (1 - blend_ratio), 0, 1) # B
|
| 457 |
+
|
| 458 |
+
server.scene.add_point_cloud(
|
| 459 |
+
name=f"/frames/t{i}/head1/point_cloud",
|
| 460 |
+
points=position[::downsample_factor],
|
| 461 |
+
colors=color_head1[::downsample_factor],
|
| 462 |
+
point_size=point_size,
|
| 463 |
+
point_shape="rounded",
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# Process head2
|
| 467 |
+
if traj_3d_head2 is not None:
|
| 468 |
+
frame_nodes_head2.append(server.scene.add_frame(f"/frames/t{i}/head2", show_axes=False))
|
| 469 |
+
position = xyz_head2[i]
|
| 470 |
+
color = rgb_head2[i]
|
| 471 |
+
if conf_mask_head2 is not None:
|
| 472 |
+
position = position[conf_mask_head2[i]]
|
| 473 |
+
color = color[conf_mask_head2[i]]
|
| 474 |
+
|
| 475 |
+
# Add point cloud for head2 with optional red tint
|
| 476 |
+
color_head2 = color.copy()
|
| 477 |
+
if gui_use_color_tint.value:
|
| 478 |
+
color_head2 *= blend_ratio
|
| 479 |
+
color_head2[:, 0] = onp.clip(color_head2[:, 0] + red_r * (1 - blend_ratio), 0, 1) # R
|
| 480 |
+
color_head2[:, 1] = onp.clip(color_head2[:, 1] + red_g * (1 - blend_ratio), 0, 1) # G
|
| 481 |
+
color_head2[:, 2] = onp.clip(color_head2[:, 2] + red_b * (1 - blend_ratio), 0, 1) # B
|
| 482 |
+
|
| 483 |
+
server.scene.add_point_cloud(
|
| 484 |
+
name=f"/frames/t{i}/head2/point_cloud",
|
| 485 |
+
points=position[::downsample_factor],
|
| 486 |
+
colors=color_head2[::downsample_factor],
|
| 487 |
+
point_size=point_size,
|
| 488 |
+
point_shape="rounded",
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
# Update visibility based on checkboxes
|
| 492 |
+
@gui_show_head1.on_update
|
| 493 |
+
def _(_) -> None:
|
| 494 |
+
with server.atomic():
|
| 495 |
+
for frame_node in frame_nodes_head1:
|
| 496 |
+
frame_node.visible = gui_show_head1.value and (
|
| 497 |
+
gui_show_all_frames.value
|
| 498 |
+
or (not gui_show_all_frames.value )
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
@gui_show_head2.on_update
|
| 502 |
+
def _(_) -> None:
|
| 503 |
+
with server.atomic():
|
| 504 |
+
for frame_node in frame_nodes_head2:
|
| 505 |
+
frame_node.visible = gui_show_head2.value and (
|
| 506 |
+
gui_show_all_frames.value
|
| 507 |
+
or (not gui_show_all_frames.value )
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
# Initial visibility
|
| 511 |
+
for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
|
| 512 |
+
if gui_show_all_frames.value:
|
| 513 |
+
node1.visible = gui_show_head1.value and (i % gui_stride.value == 0)
|
| 514 |
+
node2.visible = gui_show_head2.value and (i % gui_stride.value == 0)
|
| 515 |
+
else:
|
| 516 |
+
node1.visible = gui_show_head1.value and (i == gui_timestep.value)
|
| 517 |
+
node2.visible = gui_show_head2.value and (i == gui_timestep.value)
|
| 518 |
+
|
| 519 |
+
# Process and visualize trajectories for head1
|
| 520 |
+
if traj_3d_head1 is not None:
|
| 521 |
+
# Get points over time
|
| 522 |
+
xyz_head1_centered = xyz_head1.copy()
|
| 523 |
+
|
| 524 |
+
# Select points to visualize
|
| 525 |
+
num_points = xyz_head1.shape[1]
|
| 526 |
+
points_to_visualize = min(num_points, num_traj_points)
|
| 527 |
+
|
| 528 |
+
# Get the mask for the first frame and reshape it to match point cloud dimensions
|
| 529 |
+
if mid_anchor:
|
| 530 |
+
first_frame_mask = masks[num_frames//2].reshape(-1)
|
| 531 |
+
else:
|
| 532 |
+
first_frame_mask = masks[0].reshape(-1) #[#points, h]
|
| 533 |
+
|
| 534 |
+
# Calculate trajectory lengths for each point
|
| 535 |
+
trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame] # Shape: (num_frames, num_points, 3)
|
| 536 |
+
traj_diffs = np.diff(trajectories, axis=0) # Differences between consecutive frames
|
| 537 |
+
traj_lengths = np.sum(np.sqrt(np.sum(traj_diffs**2, axis=-1)), axis=0) # Sum of distances for each point
|
| 538 |
+
|
| 539 |
+
# Get points that are within the mask
|
| 540 |
+
valid_indices = np.where(first_frame_mask)[0]
|
| 541 |
+
|
| 542 |
+
if len(valid_indices) > 0:
|
| 543 |
+
# Calculate average trajectory length for masked points
|
| 544 |
+
masked_traj_lengths = traj_lengths[valid_indices]
|
| 545 |
+
avg_traj_length = np.mean(masked_traj_lengths)
|
| 546 |
+
|
| 547 |
+
if mask_folder is not None:
|
| 548 |
+
# do not filter points by trajectory length
|
| 549 |
+
long_traj_indices = valid_indices
|
| 550 |
+
else:
|
| 551 |
+
# Filter points by trajectory length
|
| 552 |
+
long_traj_indices = valid_indices[masked_traj_lengths >= avg_traj_length]
|
| 553 |
+
|
| 554 |
+
# Randomly sample from the filtered points
|
| 555 |
+
if len(long_traj_indices) > 0:
|
| 556 |
+
# Random sampling without replacement
|
| 557 |
+
selected_indices = np.random.choice(
|
| 558 |
+
len(long_traj_indices),
|
| 559 |
+
min(points_to_visualize, len(long_traj_indices)),
|
| 560 |
+
replace=False
|
| 561 |
+
)
|
| 562 |
+
# Get the actual indices in their original order
|
| 563 |
+
valid_point_indices = long_traj_indices[np.sort(selected_indices)]
|
| 564 |
+
else:
|
| 565 |
+
valid_point_indices = np.array([])
|
| 566 |
+
else:
|
| 567 |
+
valid_point_indices = np.array([])
|
| 568 |
+
|
| 569 |
+
if len(valid_point_indices) > 0:
|
| 570 |
+
# Get trajectories for all valid points
|
| 571 |
+
trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame, valid_point_indices]
|
| 572 |
+
N_point = trajectories.shape[1]
|
| 573 |
+
if color_code == "rainbow":
|
| 574 |
+
point_colors = plt.cm.rainbow(np.linspace(0, 1, N_point))[:, :3]
|
| 575 |
+
elif color_code == "jet":
|
| 576 |
+
point_colors = plt.cm.jet(np.linspace(0, 1, N_point))[:, :3]
|
| 577 |
+
# Modify the loop to handle frames less than fixed_length_traj
|
| 578 |
+
for i in range(traj_end_frame - traj_start_frame):
|
| 579 |
+
# Calculate the actual trajectory length for this frame
|
| 580 |
+
actual_length = min(fixed_length_traj, i + 1)
|
| 581 |
+
|
| 582 |
+
if actual_length > 1: # Need at least 2 points to form a line
|
| 583 |
+
# Get the appropriate slice of trajectory data
|
| 584 |
+
start_idx = max(0, i - actual_length + 1)
|
| 585 |
+
end_idx = i + 1
|
| 586 |
+
|
| 587 |
+
# Create line segments between consecutive frames
|
| 588 |
+
traj_slice = trajectories[start_idx:end_idx]
|
| 589 |
+
line_points = np.stack([traj_slice[:-1], traj_slice[1:]], axis=2)
|
| 590 |
+
line_points = line_points.reshape(-1, 2, 3)
|
| 591 |
+
|
| 592 |
+
# Create corresponding colors
|
| 593 |
+
line_colors = np.tile(point_colors, (actual_length-1, 1))
|
| 594 |
+
line_colors = np.stack([line_colors, line_colors], axis=1)
|
| 595 |
+
|
| 596 |
+
# Add line segments
|
| 597 |
+
server.scene.add_line_segments(
|
| 598 |
+
name=f"/frames/t{i+traj_start_frame}/head1/trajectory",
|
| 599 |
+
points=line_points,
|
| 600 |
+
colors=line_colors,
|
| 601 |
+
line_width=traj_line_width,
|
| 602 |
+
visible=gui_show_trajectories.value
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
# Add trajectory controls functionality
|
| 606 |
+
@gui_show_trajectories.on_update
|
| 607 |
+
def _(_) -> None:
|
| 608 |
+
with server.atomic():
|
| 609 |
+
# Remove all existing trajectories
|
| 610 |
+
for i in range(num_frames):
|
| 611 |
+
try:
|
| 612 |
+
server.scene.remove_by_name(f"/frames/t{i}/head1/trajectory")
|
| 613 |
+
except KeyError:
|
| 614 |
+
pass
|
| 615 |
+
|
| 616 |
+
# Create new trajectories if enabled
|
| 617 |
+
if gui_show_trajectories.value and traj_3d_head1 is not None:
|
| 618 |
+
# Get the mask for the last frame and reshape it
|
| 619 |
+
last_frame_mask = masks[traj_end_frame-1].reshape(-1)
|
| 620 |
+
|
| 621 |
+
# Calculate trajectory lengths
|
| 622 |
+
trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame]
|
| 623 |
+
traj_diffs = np.diff(trajectories, axis=0)
|
| 624 |
+
traj_lengths = np.sum(np.sqrt(np.sum(traj_diffs**2, axis=-1)), axis=0)
|
| 625 |
+
|
| 626 |
+
# Get points that are within the mask
|
| 627 |
+
valid_indices = np.where(last_frame_mask)[0]
|
| 628 |
+
|
| 629 |
+
if len(valid_indices) > 0:
|
| 630 |
+
# Filter by trajectory length
|
| 631 |
+
masked_traj_lengths = traj_lengths[valid_indices]
|
| 632 |
+
avg_traj_length = np.mean(masked_traj_lengths)
|
| 633 |
+
long_traj_indices = valid_indices[masked_traj_lengths >= avg_traj_length]
|
| 634 |
+
|
| 635 |
+
# Randomly sample from the filtered points
|
| 636 |
+
if len(long_traj_indices) > 0:
|
| 637 |
+
# Random sampling without replacement
|
| 638 |
+
selected_indices = np.random.choice(
|
| 639 |
+
len(long_traj_indices),
|
| 640 |
+
min(points_to_visualize, len(long_traj_indices)),
|
| 641 |
+
replace=False
|
| 642 |
+
)
|
| 643 |
+
# Get the actual indices in their original order
|
| 644 |
+
valid_point_indices = long_traj_indices[np.sort(selected_indices)]
|
| 645 |
+
else:
|
| 646 |
+
valid_point_indices = np.array([])
|
| 647 |
+
else:
|
| 648 |
+
valid_point_indices = np.array([])
|
| 649 |
+
|
| 650 |
+
if len(valid_point_indices) > 0:
|
| 651 |
+
# Get trajectories for all valid points
|
| 652 |
+
trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame, valid_point_indices]
|
| 653 |
+
N_point = trajectories.shape[1]
|
| 654 |
+
|
| 655 |
+
if color_code == "rainbow":
|
| 656 |
+
point_colors = plt.cm.rainbow(np.linspace(0, 1, N_point))[:, :3]
|
| 657 |
+
elif color_code == "jet":
|
| 658 |
+
point_colors = plt.cm.jet(np.linspace(0, 1, N_point))[:, :3]
|
| 659 |
+
|
| 660 |
+
# Modify the loop to handle frames less than fixed_length_traj
|
| 661 |
+
for i in range(traj_end_frame - traj_start_frame):
|
| 662 |
+
# Calculate the actual trajectory length for this frame
|
| 663 |
+
actual_length = min(fixed_length_traj, i + 1)
|
| 664 |
+
|
| 665 |
+
if actual_length > 1: # Need at least 2 points to form a line
|
| 666 |
+
# Get the appropriate slice of trajectory data
|
| 667 |
+
start_idx = max(0, i - actual_length + 1)
|
| 668 |
+
end_idx = i + 1
|
| 669 |
+
|
| 670 |
+
# Create line segments between consecutive frames
|
| 671 |
+
traj_slice = trajectories[start_idx:end_idx]
|
| 672 |
+
line_points = np.stack([traj_slice[:-1], traj_slice[1:]], axis=2)
|
| 673 |
+
line_points = line_points.reshape(-1, 2, 3)
|
| 674 |
+
|
| 675 |
+
# Create corresponding colors
|
| 676 |
+
line_colors = np.tile(point_colors, (actual_length-1, 1))
|
| 677 |
+
line_colors = np.stack([line_colors, line_colors], axis=1)
|
| 678 |
+
|
| 679 |
+
# Add line segments
|
| 680 |
+
server.scene.add_line_segments(
|
| 681 |
+
name=f"/frames/t{i+traj_start_frame}/head1/trajectory",
|
| 682 |
+
points=line_points,
|
| 683 |
+
colors=line_colors,
|
| 684 |
+
line_width=traj_line_width,
|
| 685 |
+
visible=True
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
# Update color tinting when the checkbox changes
|
| 689 |
+
@gui_use_color_tint.on_update
|
| 690 |
+
def _(_) -> None:
|
| 691 |
+
with server.atomic():
|
| 692 |
+
for i in range(num_frames):
|
| 693 |
+
# Update head1 point cloud
|
| 694 |
+
if traj_3d_head1 is not None:
|
| 695 |
+
position = xyz_head1[i]
|
| 696 |
+
color = rgb_head1[i]
|
| 697 |
+
if conf_mask_head1 is not None:
|
| 698 |
+
position = position[conf_mask_head1[i]]
|
| 699 |
+
color = color[conf_mask_head1[i]]
|
| 700 |
+
|
| 701 |
+
color_head1 = color.copy()
|
| 702 |
+
if gui_use_color_tint.value:
|
| 703 |
+
color_head1 *= blend_ratio
|
| 704 |
+
color_head1[:, 0] = onp.clip(color_head1[:, 0] + blue_r * (1 - blend_ratio), 0, 1) # R
|
| 705 |
+
color_head1[:, 1] = onp.clip(color_head1[:, 1] + blue_g * (1 - blend_ratio), 0, 1) # G
|
| 706 |
+
color_head1[:, 2] = onp.clip(color_head1[:, 2] + blue_b * (1 - blend_ratio), 0, 1) # B
|
| 707 |
+
|
| 708 |
+
server.scene.remove_by_name(f"/frames/t{i}/head1/point_cloud")
|
| 709 |
+
server.scene.add_point_cloud(
|
| 710 |
+
name=f"/frames/t{i}/head1/point_cloud",
|
| 711 |
+
points=position[::downsample_factor],
|
| 712 |
+
colors=color_head1[::downsample_factor],
|
| 713 |
+
point_size=point_size,
|
| 714 |
+
point_shape="rounded",
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
# Update head2 point cloud
|
| 718 |
+
if traj_3d_head2 is not None:
|
| 719 |
+
position = xyz_head2[i]
|
| 720 |
+
color = rgb_head2[i]
|
| 721 |
+
if conf_mask_head2 is not None:
|
| 722 |
+
position = position[conf_mask_head2[i]]
|
| 723 |
+
color = color[conf_mask_head2[i]]
|
| 724 |
+
|
| 725 |
+
color_head2 = color.copy()
|
| 726 |
+
if gui_use_color_tint.value:
|
| 727 |
+
color_head2 *= blend_ratio
|
| 728 |
+
color_head2[:, 0] = onp.clip(color_head2[:, 0] + red_r * (1 - blend_ratio), 0, 1) # R
|
| 729 |
+
color_head2[:, 1] = onp.clip(color_head2[:, 1] + red_g * (1 - blend_ratio), 0, 1) # G
|
| 730 |
+
color_head2[:, 2] = onp.clip(color_head2[:, 2] + red_b * (1 - blend_ratio), 0, 1) # B
|
| 731 |
+
|
| 732 |
+
server.scene.remove_by_name(f"/frames/t{i}/head2/point_cloud")
|
| 733 |
+
server.scene.add_point_cloud(
|
| 734 |
+
name=f"/frames/t{i}/head2/point_cloud",
|
| 735 |
+
points=position[::downsample_factor],
|
| 736 |
+
colors=color_head2[::downsample_factor],
|
| 737 |
+
point_size=point_size,
|
| 738 |
+
point_shape="rounded",
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
# Initialize video preview
|
| 742 |
+
if raw_video is not None:
|
| 743 |
+
video_preview.image = process_video_frame(0)
|
| 744 |
+
|
| 745 |
+
# Update video preview when timestep changes
|
| 746 |
+
@gui_timestep.on_update
|
| 747 |
+
def _(_) -> None:
|
| 748 |
+
current_timestep = gui_timestep.value
|
| 749 |
+
if raw_video is not None:
|
| 750 |
+
video_preview.image = process_video_frame(current_timestep)
|
| 751 |
+
|
| 752 |
+
# Playback update loop.
|
| 753 |
+
log_memory_usage("before starting playback loop")
|
| 754 |
+
|
| 755 |
+
prev_timestep = gui_timestep.value
|
| 756 |
+
while True:
|
| 757 |
+
current_timestep = gui_timestep.value
|
| 758 |
+
|
| 759 |
+
# If timestep changes, update frame visibility
|
| 760 |
+
if current_timestep != prev_timestep:
|
| 761 |
+
with server.atomic():
|
| 762 |
+
# ... existing code ...
|
| 763 |
+
|
| 764 |
+
# Update video preview
|
| 765 |
+
if raw_video is not None:
|
| 766 |
+
video_preview.image = process_video_frame(current_timestep)
|
| 767 |
+
|
| 768 |
+
# Update in playback mode
|
| 769 |
+
if gui_playing.value and not gui_show_all_frames.value:
|
| 770 |
+
gui_timestep.value = (gui_timestep.value + 1) % num_frames
|
| 771 |
+
|
| 772 |
+
# Update video preview in playback mode
|
| 773 |
+
if raw_video is not None:
|
| 774 |
+
video_preview.image = process_video_frame(gui_timestep.value)
|
| 775 |
+
|
| 776 |
+
time.sleep(1.0 / gui_framerate.value)
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
if __name__ == "__main__":
|
| 780 |
+
tyro.cli(visualize_st4rtrack)
|
viser_proxy_manager.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
|
| 3 |
+
import httpx
|
| 4 |
+
import viser
|
| 5 |
+
import websockets
|
| 6 |
+
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
|
| 7 |
+
from fastapi.responses import Response
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ViserProxyManager:
|
| 11 |
+
"""Manages Viser server instances for Gradio applications.
|
| 12 |
+
|
| 13 |
+
This class handles the creation, retrieval, and cleanup of Viser server instances,
|
| 14 |
+
as well as proxying HTTP and WebSocket requests to the appropriate Viser server.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
app: The FastAPI application to which the proxy routes will be added.
|
| 18 |
+
min_local_port: Minimum local port number to use for Viser servers. Defaults to 8000.
|
| 19 |
+
These ports are used only for internal communication and don't need to be publicly exposed.
|
| 20 |
+
max_local_port: Maximum local port number to use for Viser servers. Defaults to 9000.
|
| 21 |
+
These ports are used only for internal communication and don't need to be publicly exposed.
|
| 22 |
+
max_message_size: Maximum WebSocket message size in bytes. Defaults to 100MB.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
app: FastAPI,
|
| 28 |
+
min_local_port: int = 8000,
|
| 29 |
+
max_local_port: int = 9000,
|
| 30 |
+
max_message_size: int = 300 * 1024 * 1024, # 300MB default
|
| 31 |
+
) -> None:
|
| 32 |
+
self._min_port = min_local_port
|
| 33 |
+
self._max_port = max_local_port
|
| 34 |
+
self._max_message_size = max_message_size
|
| 35 |
+
self._server_from_session_hash: dict[str, viser.ViserServer] = {}
|
| 36 |
+
self._last_port = self._min_port - 1 # Track last port tried
|
| 37 |
+
|
| 38 |
+
@app.get("/viser/{server_id}/{proxy_path:path}")
|
| 39 |
+
async def proxy(request: Request, server_id: str, proxy_path: str):
|
| 40 |
+
"""Proxy HTTP requests to the appropriate Viser server."""
|
| 41 |
+
# Get the local port for this server ID
|
| 42 |
+
server = self._server_from_session_hash.get(server_id)
|
| 43 |
+
if server is None:
|
| 44 |
+
return Response(content="Server not found", status_code=404)
|
| 45 |
+
|
| 46 |
+
# Build target URL
|
| 47 |
+
if proxy_path:
|
| 48 |
+
path_suffix = f"/{proxy_path}"
|
| 49 |
+
else:
|
| 50 |
+
path_suffix = "/"
|
| 51 |
+
|
| 52 |
+
target_url = f"http://127.0.0.1:{server.get_port()}{path_suffix}"
|
| 53 |
+
if request.url.query:
|
| 54 |
+
target_url += f"?{request.url.query}"
|
| 55 |
+
|
| 56 |
+
# Forward request
|
| 57 |
+
async with httpx.AsyncClient() as client:
|
| 58 |
+
# Forward the original headers, but remove any problematic ones
|
| 59 |
+
headers = dict(request.headers)
|
| 60 |
+
headers.pop("host", None) # Remove host header to avoid conflicts
|
| 61 |
+
headers["accept-encoding"] = "identity" # Disable compression
|
| 62 |
+
|
| 63 |
+
proxied_req = client.build_request(
|
| 64 |
+
method=request.method,
|
| 65 |
+
url=target_url,
|
| 66 |
+
headers=headers,
|
| 67 |
+
content=await request.body(),
|
| 68 |
+
)
|
| 69 |
+
proxied_resp = await client.send(proxied_req, stream=True)
|
| 70 |
+
|
| 71 |
+
# Get response headers
|
| 72 |
+
response_headers = dict(proxied_resp.headers)
|
| 73 |
+
|
| 74 |
+
# Check if this is an HTML response
|
| 75 |
+
content = await proxied_resp.aread()
|
| 76 |
+
return Response(
|
| 77 |
+
content=content,
|
| 78 |
+
status_code=proxied_resp.status_code,
|
| 79 |
+
headers=response_headers,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# WebSocket Proxy
|
| 83 |
+
@app.websocket("/viser/{server_id}")
|
| 84 |
+
async def websocket_proxy(websocket: WebSocket, server_id: str):
|
| 85 |
+
"""Proxy WebSocket connections to the appropriate Viser server."""
|
| 86 |
+
try:
|
| 87 |
+
await websocket.accept()
|
| 88 |
+
|
| 89 |
+
server = self._server_from_session_hash.get(server_id)
|
| 90 |
+
if server is None:
|
| 91 |
+
await websocket.close(code=1008, reason="Not Found")
|
| 92 |
+
return
|
| 93 |
+
|
| 94 |
+
# Determine target WebSocket URL
|
| 95 |
+
target_ws_url = f"ws://127.0.0.1:{server.get_port()}"
|
| 96 |
+
|
| 97 |
+
if not target_ws_url:
|
| 98 |
+
await websocket.close(code=1008, reason="Not Found")
|
| 99 |
+
return
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
# Connect to the target WebSocket with increased message size and timeout
|
| 103 |
+
async with websockets.connect(
|
| 104 |
+
target_ws_url,
|
| 105 |
+
max_size=self._max_message_size,
|
| 106 |
+
ping_interval=30, # Send ping every 30 seconds
|
| 107 |
+
ping_timeout=10, # Wait 10 seconds for pong response
|
| 108 |
+
close_timeout=5, # Wait 5 seconds for close handshake
|
| 109 |
+
) as ws_target:
|
| 110 |
+
# Create tasks for bidirectional communication
|
| 111 |
+
async def forward_to_target():
|
| 112 |
+
"""Forward messages from the client to the target WebSocket."""
|
| 113 |
+
try:
|
| 114 |
+
while True:
|
| 115 |
+
data = await websocket.receive_bytes()
|
| 116 |
+
await ws_target.send(data, text=False)
|
| 117 |
+
except WebSocketDisconnect:
|
| 118 |
+
try:
|
| 119 |
+
await ws_target.close()
|
| 120 |
+
except RuntimeError:
|
| 121 |
+
pass
|
| 122 |
+
|
| 123 |
+
async def forward_from_target():
|
| 124 |
+
"""Forward messages from the target WebSocket to the client."""
|
| 125 |
+
try:
|
| 126 |
+
while True:
|
| 127 |
+
data = await ws_target.recv(decode=False)
|
| 128 |
+
await websocket.send_bytes(data)
|
| 129 |
+
except websockets.exceptions.ConnectionClosed:
|
| 130 |
+
try:
|
| 131 |
+
await websocket.close()
|
| 132 |
+
except RuntimeError:
|
| 133 |
+
pass
|
| 134 |
+
|
| 135 |
+
# Run both forwarding tasks concurrently
|
| 136 |
+
forward_task = asyncio.create_task(forward_to_target())
|
| 137 |
+
backward_task = asyncio.create_task(forward_from_target())
|
| 138 |
+
|
| 139 |
+
# Wait for either task to complete (which means a connection was closed)
|
| 140 |
+
done, pending = await asyncio.wait(
|
| 141 |
+
[forward_task, backward_task],
|
| 142 |
+
return_when=asyncio.FIRST_COMPLETED,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Cancel the remaining task
|
| 146 |
+
for task in pending:
|
| 147 |
+
task.cancel()
|
| 148 |
+
|
| 149 |
+
except websockets.exceptions.ConnectionClosedError as e:
|
| 150 |
+
print(f"WebSocket connection closed with error: {e}")
|
| 151 |
+
await websocket.close(code=1011, reason="Connection to target closed")
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
print(f"WebSocket proxy error: {e}")
|
| 155 |
+
try:
|
| 156 |
+
await websocket.close(code=1011, reason=str(e)[:120]) # Limit reason length
|
| 157 |
+
except:
|
| 158 |
+
pass # Already closed
|
| 159 |
+
|
| 160 |
+
def start_server(self, server_id: str) -> viser.ViserServer:
|
| 161 |
+
"""Start a new Viser server and associate it with the given server ID.
|
| 162 |
+
|
| 163 |
+
Finds an available port within the configured min_local_port and max_local_port range.
|
| 164 |
+
These ports are used only for internal communication and don't need to be publicly exposed.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
server_id: The unique identifier to associate with the new server.
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
The newly created Viser server instance.
|
| 171 |
+
|
| 172 |
+
Raises:
|
| 173 |
+
RuntimeError: If no free ports are available in the configured range.
|
| 174 |
+
"""
|
| 175 |
+
import socket
|
| 176 |
+
|
| 177 |
+
# Start searching from the last port + 1 (with wraparound)
|
| 178 |
+
port_range_size = self._max_port - self._min_port + 1
|
| 179 |
+
start_port = (
|
| 180 |
+
(self._last_port + 1 - self._min_port) % port_range_size
|
| 181 |
+
) + self._min_port
|
| 182 |
+
|
| 183 |
+
# Try each port once
|
| 184 |
+
for offset in range(port_range_size):
|
| 185 |
+
port = (
|
| 186 |
+
(start_port - self._min_port + offset) % port_range_size
|
| 187 |
+
) + self._min_port
|
| 188 |
+
try:
|
| 189 |
+
# Check if port is available by attempting to bind to it
|
| 190 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 191 |
+
s.bind(("127.0.0.1", port))
|
| 192 |
+
# Port is available, create server with this port
|
| 193 |
+
server = viser.ViserServer(port=port)
|
| 194 |
+
self._server_from_session_hash[server_id] = server
|
| 195 |
+
self._last_port = port
|
| 196 |
+
return server
|
| 197 |
+
except OSError:
|
| 198 |
+
# Port is in use, try the next one
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
# If we get here, no ports were available
|
| 202 |
+
raise RuntimeError(
|
| 203 |
+
f"No available local ports in range {self._min_port}-{self._max_port}"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
def get_server(self, server_id: str) -> viser.ViserServer:
|
| 207 |
+
"""Retrieve a Viser server instance by its ID.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
server_id: The unique identifier of the server to retrieve.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
The Viser server instance associated with the given ID.
|
| 214 |
+
"""
|
| 215 |
+
return self._server_from_session_hash[server_id]
|
| 216 |
+
|
| 217 |
+
def stop_server(self, server_id: str) -> None:
|
| 218 |
+
"""Stop a Viser server and remove it from the manager.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
server_id: The unique identifier of the server to stop.
|
| 222 |
+
"""
|
| 223 |
+
self._server_from_session_hash.pop(server_id).stop()
|