Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from inference_server.models import list_supported_policies | |
| from inference_server.session_manager import SessionManager | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Global session manager | |
| session_manager = SessionManager() | |
| async def lifespan(app: FastAPI): | |
| """Handle app startup and shutdown.""" | |
| logger.info("🚀 Inference Server starting up...") | |
| yield | |
| logger.info("🔄 Inference Server shutting down...") | |
| await session_manager.cleanup_all_sessions() | |
| logger.info("✅ Inference Server shutdown complete") | |
| # FastAPI app | |
| app = FastAPI( | |
| title="Inference Server", | |
| description="Multi-Policy Model Inference Server for Real-time Robot Control", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify actual origins | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Request/Response models | |
| class CreateSessionRequest(BaseModel): | |
| session_id: str | |
| policy_path: str | |
| camera_names: list[str] = ["front"] # Support multiple cameras | |
| arena_server_url: str = "http://localhost:8000" | |
| workspace_id: str | None = None # Optional workspace ID | |
| policy_type: str = "act" # Policy type: act, pi0, pi0fast, smolvla, diffusion | |
| language_instruction: str | None = None # For vision-language policies | |
| class CreateSessionResponse(BaseModel): | |
| workspace_id: str | |
| camera_room_ids: dict[str, str] # {camera_name: room_id} | |
| joint_input_room_id: str | |
| joint_output_room_id: str | |
| class SessionStatusResponse(BaseModel): | |
| session_id: str | |
| status: str | |
| policy_path: str | |
| policy_type: str | |
| camera_names: list[str] # Multiple camera names | |
| workspace_id: str | |
| rooms: dict | |
| stats: dict | |
| inference_stats: dict | None = None | |
| error_message: str | None = None | |
| # Health check | |
| async def root(): | |
| """Health check endpoint.""" | |
| return {"message": "Inference Server is running", "status": "healthy"} | |
| async def health_check(): | |
| """Detailed health check.""" | |
| return { | |
| "status": "healthy", | |
| "active_sessions": len(session_manager.sessions), | |
| "session_ids": list(session_manager.sessions.keys()), | |
| } | |
| async def list_policies(): | |
| """List supported policy types.""" | |
| return { | |
| "supported_policies": list_supported_policies(), | |
| "description": "Available policy types for inference", | |
| } | |
| # Session management endpoints | |
| async def create_session(request: CreateSessionRequest): | |
| """ | |
| Create a new inference session. | |
| If workspace_id is provided, all rooms will be created in that workspace. | |
| If workspace_id is not provided, a new workspace will be generated automatically. | |
| All rooms for a session (cameras + joints) are always created in the same workspace. | |
| """ | |
| try: | |
| room_ids = await session_manager.create_session( | |
| session_id=request.session_id, | |
| policy_path=request.policy_path, | |
| camera_names=request.camera_names, | |
| arena_server_url=request.arena_server_url, | |
| workspace_id=request.workspace_id, | |
| policy_type=request.policy_type, | |
| language_instruction=request.language_instruction, | |
| ) | |
| return CreateSessionResponse(**room_ids) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| logger.exception(f"Failed to create session {request.session_id}") | |
| raise HTTPException(status_code=500, detail=f"Failed to create session: {e!s}") | |
| async def list_sessions(): | |
| """List all sessions.""" | |
| sessions = await session_manager.list_sessions() | |
| return [SessionStatusResponse(**session) for session in sessions] | |
| async def get_session_status(session_id: str): | |
| """Get status of a specific session.""" | |
| try: | |
| status = await session_manager.get_session_status(session_id) | |
| return SessionStatusResponse(**status) | |
| except KeyError: | |
| raise HTTPException(status_code=404, detail=f"Session {session_id} not found") | |
| async def start_inference(session_id: str): | |
| """Start inference for a session.""" | |
| try: | |
| await session_manager.start_inference(session_id) | |
| return {"message": f"Inference started for session {session_id}"} | |
| except KeyError: | |
| raise HTTPException(status_code=404, detail=f"Session {session_id} not found") | |
| except Exception as e: | |
| logger.exception(f"Failed to start inference for session {session_id}") | |
| raise HTTPException(status_code=500, detail=f"Failed to start inference: {e!s}") | |
| async def stop_inference(session_id: str): | |
| """Stop inference for a session.""" | |
| try: | |
| await session_manager.stop_inference(session_id) | |
| return {"message": f"Inference stopped for session {session_id}"} | |
| except KeyError: | |
| raise HTTPException(status_code=404, detail=f"Session {session_id} not found") | |
| async def restart_inference(session_id: str): | |
| """Restart inference for a session.""" | |
| try: | |
| await session_manager.restart_inference(session_id) | |
| return {"message": f"Inference restarted for session {session_id}"} | |
| except KeyError: | |
| raise HTTPException(status_code=404, detail=f"Session {session_id} not found") | |
| except Exception as e: | |
| logger.exception(f"Failed to restart inference for session {session_id}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Failed to restart inference: {e!s}" | |
| ) | |
| async def delete_session(session_id: str): | |
| """Delete a session.""" | |
| try: | |
| await session_manager.delete_session(session_id) | |
| return {"message": f"Session {session_id} deleted"} | |
| except KeyError: | |
| raise HTTPException(status_code=404, detail=f"Session {session_id} not found") | |
| # Debug endpoints for enhanced monitoring | |
| async def get_system_info(): | |
| """Get system information for debugging.""" | |
| import psutil | |
| import torch | |
| try: | |
| # System info | |
| system_info = { | |
| "cpu_percent": psutil.cpu_percent(interval=1), | |
| "memory": { | |
| "total": psutil.virtual_memory().total, | |
| "available": psutil.virtual_memory().available, | |
| "percent": psutil.virtual_memory().percent, | |
| }, | |
| "disk": { | |
| "total": psutil.disk_usage("/").total, | |
| "used": psutil.disk_usage("/").used, | |
| "percent": psutil.disk_usage("/").percent, | |
| }, | |
| } | |
| # GPU info if available | |
| if torch.cuda.is_available(): | |
| system_info["gpu"] = { | |
| "device_count": torch.cuda.device_count(), | |
| "current_device": torch.cuda.current_device(), | |
| "device_name": torch.cuda.get_device_name(), | |
| "memory_allocated": torch.cuda.memory_allocated(), | |
| "memory_cached": torch.cuda.memory_reserved(), | |
| } | |
| return system_info | |
| except Exception as e: | |
| return {"error": f"Failed to get system info: {e}"} | |
| async def get_recent_logs(): | |
| """Get recent log entries for debugging.""" | |
| try: | |
| # This is a simple implementation - in production you might want to read from actual log files | |
| return { | |
| "message": "Log endpoint available", | |
| "note": "Implement actual log reading if needed", | |
| "active_sessions": len(session_manager.sessions), | |
| } | |
| except Exception as e: | |
| return {"error": f"Failed to get logs: {e}"} | |
| async def debug_reset_session(session_id: str): | |
| """Reset a session's internal state for debugging.""" | |
| try: | |
| if session_id not in session_manager.sessions: | |
| raise HTTPException( | |
| status_code=404, detail=f"Session {session_id} not found" | |
| ) | |
| session = session_manager.sessions[session_id] | |
| # Reset inference engine if available | |
| if session.inference_engine: | |
| session.inference_engine.reset() | |
| # Clear action queue | |
| session.action_queue.clear() | |
| # Reset flags | |
| for camera_name in session.camera_names: | |
| session.images_updated[camera_name] = False | |
| session.joints_updated = False | |
| return {"message": f"Session {session_id} state reset successfully"} | |
| except Exception as e: | |
| logger.exception(f"Failed to reset session {session_id}") | |
| raise HTTPException(status_code=500, detail=f"Failed to reset session: {e}") | |
| async def get_session_queue_info(session_id: str): | |
| """Get detailed information about a session's action queue.""" | |
| try: | |
| if session_id not in session_manager.sessions: | |
| raise HTTPException( | |
| status_code=404, detail=f"Session {session_id} not found" | |
| ) | |
| session = session_manager.sessions[session_id] | |
| return { | |
| "session_id": session_id, | |
| "queue_length": len(session.action_queue), | |
| "queue_maxlen": session.action_queue.maxlen, | |
| "n_action_steps": session.n_action_steps, | |
| "control_frequency_hz": session.control_frequency_hz, | |
| "inference_frequency_hz": session.inference_frequency_hz, | |
| "last_queue_cleanup": session.last_queue_cleanup, | |
| "data_status": { | |
| "has_joint_data": session.latest_joint_positions is not None, | |
| "images_status": { | |
| camera: camera in session.latest_images | |
| for camera in session.camera_names | |
| }, | |
| "images_updated": session.images_updated.copy(), | |
| "joints_updated": session.joints_updated, | |
| }, | |
| } | |
| except Exception as e: | |
| logger.exception(f"Failed to get queue info for session {session_id}") | |
| raise HTTPException(status_code=500, detail=f"Failed to get queue info: {e}") | |
| # Main entry point | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 8001)) | |
| uvicorn.run( | |
| "inference_server.main:app", | |
| host="0.0.0.0", | |
| port=port, | |
| reload=False, | |
| log_level="info", | |
| ) | |