blanchon's picture
Update
f1df768
raw
history blame
12.7 kB
import logging
import threading
import time
import gradio as gr
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
# Import our existing components
from inference_server.main import app as fastapi_app
from inference_server.main import session_manager
logger = logging.getLogger(__name__)
# Configuration
DEFAULT_PORT = 7860
DEFAULT_ARENA_SERVER_URL = "http://localhost:8000"
# Global server thread
server_thread = None
server_started = False
def start_api_server_thread(port: int = 8001):
"""Start the API server in a background thread."""
global server_thread, server_started
if server_thread and server_thread.is_alive():
return
def run_server():
global server_started
from inference_server.main import app
try:
uvicorn.run(app, host="localhost", port=port, log_level="warning")
except Exception as e:
logger.exception(f"API server error: {e}")
finally:
server_started = False
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Wait a moment for server to start
time.sleep(2)
server_started = True
class SimpleServerManager:
"""Direct access to session manager without HTTP calls."""
def __init__(self):
self.session_manager = session_manager
async def create_and_start_session(
self,
session_id: str,
policy_path: str,
camera_names: str,
arena_server_url: str,
) -> str:
"""Create and start a session directly."""
try:
cameras = [name.strip() for name in camera_names.split(",") if name.strip()]
if not cameras:
cameras = ["front"]
# Create session directly
room_info = await self.session_manager.create_session(
session_id=session_id,
policy_path=policy_path,
camera_names=cameras,
arena_server_url=arena_server_url,
)
# Start the session
await self.session_manager.start_inference(session_id)
# Format camera rooms more clearly
camera_info = []
for camera_name, room_id in room_info["camera_room_ids"].items():
camera_info.append(f" ๐Ÿ“น **{camera_name.title()}**: `{room_id}`")
return f"""## โœ… Session Created Successfully!
**Session ID**: `{session_id}`
**Status**: ๐ŸŸข **RUNNING** (inference active)
### ๐Ÿ“ก Arena Connection Details:
**Workspace ID**: `{room_info["workspace_id"]}`
**Camera Rooms**:
{chr(10).join(camera_info)}
**Joint Communication**:
๐Ÿ“ฅ **Input**: `{room_info["joint_input_room_id"]}`
๐Ÿ“ค **Output**: `{room_info["joint_output_room_id"]}`
---
## ๐Ÿค– Ready for Robot Control!
Your robot can now connect to these rooms to start AI-powered control."""
except Exception as e:
return f"โŒ Error: {e!s}"
async def control_session(self, session_id: str, action: str) -> str:
"""Control a session directly."""
if not session_id.strip():
return "โš ๏ธ No session ID provided"
try:
# Check current status
if session_id not in self.session_manager.sessions:
return f"โŒ Session '{session_id}' not found"
session = self.session_manager.sessions[session_id]
current_status = session.status
# Smart action handling
if action == "start" and current_status == "running":
return f"โ„น๏ธ Session '{session_id}' is already running"
if action == "stop" and current_status in {"stopped", "ready"}:
return f"โ„น๏ธ Session '{session_id}' is already stopped"
# Perform action
if action == "start":
await self.session_manager.start_inference(session_id)
return f"โœ… Inference started for session {session_id}"
if action == "stop":
await self.session_manager.stop_inference(session_id)
return f"โœ… Inference stopped for session {session_id}"
if action == "restart":
await self.session_manager.restart_inference(session_id)
return f"โœ… Inference restarted for session {session_id}"
return f"โŒ Unknown action: {action}"
except Exception as e:
return f"โŒ Error: {e!s}"
async def get_session_status(self, session_id: str) -> str:
"""Get session status directly."""
if not session_id.strip():
return "โš ๏ธ No session ID provided"
try:
if session_id not in self.session_manager.sessions:
return f"โŒ Session '{session_id}' not found"
session_data = await self.session_manager.get_session_status(session_id)
session = session_data
status_emoji = {
"running": "๐ŸŸข",
"ready": "๐ŸŸก",
"stopped": "๐Ÿ”ด",
"initializing": "๐ŸŸ ",
}.get(session["status"], "โšช")
# Smart suggestions
suggestions = ""
if session["status"] == "running":
suggestions = "\n\n### ๐Ÿ’ก Smart Suggestion:\n**Session is active!** Use the 'โธ๏ธ Stop' button to pause inference."
elif session["status"] in {"ready", "stopped"}:
suggestions = "\n\n### ๐Ÿ’ก Smart Suggestion:\n**Session is ready!** Use the 'โ–ถ๏ธ Start' button to begin inference."
# Format camera names nicely
camera_list = [f"**{cam.title()}**" for cam in session["camera_names"]]
return f"""## {status_emoji} Session Status
**Session ID**: `{session_id}`
**Status**: {status_emoji} **{session["status"].upper()}**
**Model**: `{session["policy_path"]}`
**Cameras**: {", ".join(camera_list)}
### ๐Ÿ“Š Performance Metrics:
| Metric | Value |
|--------|-------|
| ๐Ÿง  **Inferences** | {session["stats"]["inference_count"]} |
| ๐Ÿ“ค **Commands Sent** | {session["stats"]["commands_sent"]} |
| ๐Ÿ“‹ **Queue Length** | {session["stats"]["actions_in_queue"]} actions |
| โŒ **Errors** | {session["stats"]["errors"]} |
### ๐Ÿ”ง Data Flow:
- ๐Ÿ“น **Images Received**: {session["stats"]["images_received"]}
- ๐Ÿค– **Joint States Received**: {session["stats"]["joints_received"]}
{suggestions}"""
except Exception as e:
return f"โŒ Error: {e!s}"
def create_simple_gradio_interface() -> gr.Blocks:
"""Create a simple Gradio interface with direct session management."""
server_manager = SimpleServerManager()
with gr.Blocks(
title="๐Ÿค– Robot AI Control Center",
theme=gr.themes.Soft(),
css=".gradio-container { max-width: 1200px !important; }",
fill_height=True,
) as demo:
gr.Markdown("""
# ๐Ÿค– Robot AI Control Center
**Control your robot with AI using ACT models**
*Integrated mode - FastAPI available at `/api/docs` for direct API access*
""")
# Server status
with gr.Group():
gr.Markdown("## ๐Ÿ“ก API Status")
gr.Markdown(
value="""
โœ… **FastAPI Server**: Available at `/api`
๐Ÿ“– **API Documentation**: Available at `/api/docs`
๐Ÿ”„ **Health Check**: Available at `/api/health`
""",
show_copy_button=True,
)
# Setup
with gr.Group():
gr.Markdown("## ๐Ÿค– Set Up Robot AI")
with gr.Row():
with gr.Column():
session_id_input = gr.Textbox(
label="Session Name", value="my-robot-01"
)
policy_path_input = gr.Textbox(
label="AI Model Path", value="./checkpoints/act_so101_beyond"
)
camera_names_input = gr.Textbox(label="Camera Names", value="front")
arena_server_url_input = gr.Textbox(
label="Arena Server URL",
value=DEFAULT_ARENA_SERVER_URL,
placeholder="http://localhost:8000",
)
create_btn = gr.Button(
"๐ŸŽฏ Create & Start AI Control", variant="primary"
)
with gr.Column():
setup_result = gr.Markdown(
value="Ready to create your robot AI session...",
show_copy_button=True,
container=True,
height=300,
)
# Control
with gr.Group():
gr.Markdown("## ๐ŸŽฎ Control Session")
with gr.Row():
current_session_input = gr.Textbox(label="Session ID")
start_btn = gr.Button("โ–ถ๏ธ Start", variant="primary")
stop_btn = gr.Button("โธ๏ธ Stop", variant="secondary")
status_btn = gr.Button("๐Ÿ“Š Status", variant="secondary")
session_status_display = gr.Markdown(
value="Click '๐Ÿ“Š Status' to check session information...",
show_copy_button=True,
container=True,
height=400,
)
# Event handlers
async def create_session(
session_id, policy_path, camera_names, arena_server_url
):
result = await server_manager.create_and_start_session(
session_id, policy_path, camera_names, arena_server_url
)
return result, session_id
async def start_session(session_id):
return await server_manager.control_session(session_id, "start")
async def stop_session(session_id):
return await server_manager.control_session(session_id, "stop")
async def get_status(session_id):
return await server_manager.get_session_status(session_id)
# Connect events
create_btn.click(
create_session,
inputs=[
session_id_input,
policy_path_input,
camera_names_input,
arena_server_url_input,
],
outputs=[setup_result, current_session_input],
)
start_btn.click(
start_session,
inputs=[current_session_input],
outputs=[session_status_display],
)
stop_btn.click(
stop_session,
inputs=[current_session_input],
outputs=[session_status_display],
)
status_btn.click(
get_status, inputs=[current_session_input], outputs=[session_status_display]
)
gr.Markdown("""
### ๐Ÿ’ก Tips:
- ๐ŸŸข **RUNNING**: Session actively doing inference
- ๐ŸŸก **READY**: Session created but not started
- ๐Ÿ”ด **STOPPED**: Session paused
### ๐Ÿš€ API Access:
- **FastAPI Docs**: Visit `/api/docs` for interactive API documentation
- **Direct API**: Use `/api/sessions` endpoints for programmatic access
- **Health Check**: `/api/health` shows server status
- **OpenAPI Schema**: Available at `/api/openapi.json`
""")
return demo
def launch_simple_integrated_app(
host: str = "localhost", port: int = DEFAULT_PORT, share: bool = False
):
"""Launch the integrated application with both FastAPI and Gradio."""
print(f"๐Ÿš€ Starting integrated app on {host}:{port}")
print(f"๐ŸŽจ Gradio UI: http://{host}:{port}/")
print(f"๐Ÿ“– FastAPI Docs: http://{host}:{port}/api/docs")
print(f"๐Ÿ”„ Health Check: http://{host}:{port}/api/health")
print("๐Ÿ”ง Direct session management + API access!")
# Create Gradio demo
demo = create_simple_gradio_interface()
# Create custom FastAPI app from the Gradio demo's built-in app
app = demo.app
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount the FastAPI AI server under /api
app.mount("/api", fastapi_app)
# Launch using Gradio's queue and launch methods
demo.queue()
demo.launch(
server_name=host,
server_port=port,
share=share,
show_error=True,
show_api=False, # Hide Gradio's default API docs since we have our own
prevent_thread_lock=False,
)
if __name__ == "__main__":
launch_simple_integrated_app()