Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Response | |
from fastapi.middleware.cors import CORSMiddleware | |
from shared import RealtimeSpeakerDiarization | |
import os | |
import uvicorn | |
import logging | |
import asyncio | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Respond to HEAD / with a 200 so port scanners don’t see a 405 | |
async def root(): | |
return {"message": "Speaker Diarization Signaling Server"} | |
# Add CORS middleware for browser compatibility | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize the diarization system | |
logger.info("Initializing diarization system...") | |
diart = RealtimeSpeakerDiarization() | |
success = diart.initialize_models() | |
logger.info(f"Models initialized: {success}") | |
if success: | |
diart.start_recording() | |
# Track active WebSocket connections | |
active_connections = set() | |
# Periodic status update function | |
async def send_conversation_updates(): | |
"""Periodically send conversation updates to all connected clients""" | |
while True: | |
if active_connections: | |
conversation_html = diart.get_formatted_conversation() | |
for ws in list(active_connections): | |
try: | |
await ws.send_text(conversation_html) | |
except Exception as e: | |
logger.error(f"Error sending to WebSocket: {e}") | |
active_connections.discard(ws) | |
await asyncio.sleep(0.5) # 500ms update interval | |
async def startup_event(): | |
"""Start background tasks when the app starts""" | |
asyncio.create_task(send_conversation_updates()) | |
async def health_check(): | |
"""Health check endpoint""" | |
return { | |
"status": "healthy", | |
"system_running": diart.is_running, | |
"active_connections": len(active_connections) | |
} | |
async def ws_inference(ws: WebSocket): | |
"""WebSocket endpoint for real-time audio processing""" | |
await ws.accept() | |
active_connections.add(ws) | |
logger.info(f"WebSocket connected. Total: {len(active_connections)}") | |
try: | |
# Send initial conversation state | |
await ws.send_text(diart.get_formatted_conversation()) | |
# Process incoming audio chunks | |
async for chunk in ws.iter_bytes(): | |
if chunk: | |
try: | |
diart.process_audio_chunk(chunk) | |
except Exception as e: | |
logger.error(f"Error processing chunk: {e}") | |
except WebSocketDisconnect: | |
logger.info("WebSocket disconnected") | |
except Exception as e: | |
logger.error(f"WebSocket error: {e}") | |
finally: | |
active_connections.discard(ws) | |
logger.info(f"WebSocket closed. Remaining: {len(active_connections)}") | |
async def get_conversation(): | |
"""Get the current conversation as HTML""" | |
return {"conversation": diart.get_formatted_conversation()} | |
async def get_status(): | |
"""Get system status information""" | |
return {"status": diart.get_status_info()} | |
async def update_settings(threshold: float, max_speakers: int): | |
"""Update speaker detection settings""" | |
return {"result": diart.update_settings(threshold, max_speakers)} | |
async def clear_conversation(): | |
"""Clear the conversation""" | |
return {"result": diart.clear_conversation()} | |
# Mount Gradio UI at /ui so it doesn't override API/WebSocket routes | |
try: | |
import ui | |
ui.mount_ui(app, path="/ui") | |
logger.info("Gradio UI mounted at /ui") | |
except ImportError: | |
logger.warning("UI module not found, running in API-only mode") | |
if __name__ == "__main__": | |
port = int(os.getenv("PORT", 10000)) | |
uvicorn.run("backend:app", host="0.0.0.0", port=port) | |