from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Response from fastapi.middleware.cors import CORSMiddleware from shared import RealtimeSpeakerDiarization import os import uvicorn import logging import asyncio # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI() # Respond to HEAD / with a 200 so port scanners don’t see a 405 @app.head("/", include_in_schema=False) @app.get("/") async def root(): return {"message": "Speaker Diarization Signaling Server"} # Add CORS middleware for browser compatibility app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize the diarization system logger.info("Initializing diarization system...") diart = RealtimeSpeakerDiarization() success = diart.initialize_models() logger.info(f"Models initialized: {success}") if success: diart.start_recording() # Track active WebSocket connections active_connections = set() # Periodic status update function async def send_conversation_updates(): """Periodically send conversation updates to all connected clients""" while True: if active_connections: conversation_html = diart.get_formatted_conversation() for ws in list(active_connections): try: await ws.send_text(conversation_html) except Exception as e: logger.error(f"Error sending to WebSocket: {e}") active_connections.discard(ws) await asyncio.sleep(0.5) # 500ms update interval @app.on_event("startup") async def startup_event(): """Start background tasks when the app starts""" asyncio.create_task(send_conversation_updates()) @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "system_running": diart.is_running, "active_connections": len(active_connections) } @app.websocket("/ws_inference") async def ws_inference(ws: WebSocket): """WebSocket endpoint for real-time audio processing""" await ws.accept() active_connections.add(ws) logger.info(f"WebSocket connected. Total: {len(active_connections)}") try: # Send initial conversation state await ws.send_text(diart.get_formatted_conversation()) # Process incoming audio chunks async for chunk in ws.iter_bytes(): if chunk: try: diart.process_audio_chunk(chunk) except Exception as e: logger.error(f"Error processing chunk: {e}") except WebSocketDisconnect: logger.info("WebSocket disconnected") except Exception as e: logger.error(f"WebSocket error: {e}") finally: active_connections.discard(ws) logger.info(f"WebSocket closed. Remaining: {len(active_connections)}") @app.get("/conversation") async def get_conversation(): """Get the current conversation as HTML""" return {"conversation": diart.get_formatted_conversation()} @app.get("/status") async def get_status(): """Get system status information""" return {"status": diart.get_status_info()} @app.post("/settings") async def update_settings(threshold: float, max_speakers: int): """Update speaker detection settings""" return {"result": diart.update_settings(threshold, max_speakers)} @app.post("/clear") async def clear_conversation(): """Clear the conversation""" return {"result": diart.clear_conversation()} # Mount Gradio UI at /ui so it doesn't override API/WebSocket routes try: import ui ui.mount_ui(app, path="/ui") logger.info("Gradio UI mounted at /ui") except ImportError: logger.warning("UI module not found, running in API-only mode") if __name__ == "__main__": port = int(os.getenv("PORT", 10000)) uvicorn.run("backend:app", host="0.0.0.0", port=port)