Spaces:
Sleeping
Sleeping
File size: 4,015 Bytes
10b8972 97a4ae5 10b8972 97a4ae5 4641c1c 97a4ae5 10b8972 97a4ae5 4641c1c 97a4ae5 4641c1c 10b8972 4641c1c 97a4ae5 4641c1c 97a4ae5 4641c1c 10b8972 97a4ae5 4641c1c 10b8972 4641c1c 97a4ae5 10b8972 4641c1c 10b8972 4641c1c 97a4ae5 4641c1c 10b8972 97a4ae5 10b8972 97a4ae5 10b8972 97a4ae5 10b8972 97a4ae5 10b8972 97a4ae5 10b8972 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Response
from fastapi.middleware.cors import CORSMiddleware
from shared import RealtimeSpeakerDiarization
import os
import uvicorn
import logging
import asyncio
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI()
# Respond to HEAD / with a 200 so port scanners don’t see a 405
@app.head("/", include_in_schema=False)
@app.get("/")
async def root():
return {"message": "Speaker Diarization Signaling Server"}
# Add CORS middleware for browser compatibility
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize the diarization system
logger.info("Initializing diarization system...")
diart = RealtimeSpeakerDiarization()
success = diart.initialize_models()
logger.info(f"Models initialized: {success}")
if success:
diart.start_recording()
# Track active WebSocket connections
active_connections = set()
# Periodic status update function
async def send_conversation_updates():
"""Periodically send conversation updates to all connected clients"""
while True:
if active_connections:
conversation_html = diart.get_formatted_conversation()
for ws in list(active_connections):
try:
await ws.send_text(conversation_html)
except Exception as e:
logger.error(f"Error sending to WebSocket: {e}")
active_connections.discard(ws)
await asyncio.sleep(0.5) # 500ms update interval
@app.on_event("startup")
async def startup_event():
"""Start background tasks when the app starts"""
asyncio.create_task(send_conversation_updates())
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"system_running": diart.is_running,
"active_connections": len(active_connections)
}
@app.websocket("/ws_inference")
async def ws_inference(ws: WebSocket):
"""WebSocket endpoint for real-time audio processing"""
await ws.accept()
active_connections.add(ws)
logger.info(f"WebSocket connected. Total: {len(active_connections)}")
try:
# Send initial conversation state
await ws.send_text(diart.get_formatted_conversation())
# Process incoming audio chunks
async for chunk in ws.iter_bytes():
if chunk:
try:
diart.process_audio_chunk(chunk)
except Exception as e:
logger.error(f"Error processing chunk: {e}")
except WebSocketDisconnect:
logger.info("WebSocket disconnected")
except Exception as e:
logger.error(f"WebSocket error: {e}")
finally:
active_connections.discard(ws)
logger.info(f"WebSocket closed. Remaining: {len(active_connections)}")
@app.get("/conversation")
async def get_conversation():
"""Get the current conversation as HTML"""
return {"conversation": diart.get_formatted_conversation()}
@app.get("/status")
async def get_status():
"""Get system status information"""
return {"status": diart.get_status_info()}
@app.post("/settings")
async def update_settings(threshold: float, max_speakers: int):
"""Update speaker detection settings"""
return {"result": diart.update_settings(threshold, max_speakers)}
@app.post("/clear")
async def clear_conversation():
"""Clear the conversation"""
return {"result": diart.clear_conversation()}
# Mount Gradio UI at /ui so it doesn't override API/WebSocket routes
try:
import ui
ui.mount_ui(app, path="/ui")
logger.info("Gradio UI mounted at /ui")
except ImportError:
logger.warning("UI module not found, running in API-only mode")
if __name__ == "__main__":
port = int(os.getenv("PORT", 10000))
uvicorn.run("backend:app", host="0.0.0.0", port=port)
|