Speaker-Diarization / inference.py
Saiyaswanth007's picture
changing from /
10b8972
raw
history blame
4.02 kB
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Response
from fastapi.middleware.cors import CORSMiddleware
from shared import RealtimeSpeakerDiarization
import os
import uvicorn
import logging
import asyncio
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI()
# Respond to HEAD / with a 200 so port scanners don’t see a 405
@app.head("/", include_in_schema=False)
@app.get("/")
async def root():
return {"message": "Speaker Diarization Signaling Server"}
# Add CORS middleware for browser compatibility
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize the diarization system
logger.info("Initializing diarization system...")
diart = RealtimeSpeakerDiarization()
success = diart.initialize_models()
logger.info(f"Models initialized: {success}")
if success:
diart.start_recording()
# Track active WebSocket connections
active_connections = set()
# Periodic status update function
async def send_conversation_updates():
"""Periodically send conversation updates to all connected clients"""
while True:
if active_connections:
conversation_html = diart.get_formatted_conversation()
for ws in list(active_connections):
try:
await ws.send_text(conversation_html)
except Exception as e:
logger.error(f"Error sending to WebSocket: {e}")
active_connections.discard(ws)
await asyncio.sleep(0.5) # 500ms update interval
@app.on_event("startup")
async def startup_event():
"""Start background tasks when the app starts"""
asyncio.create_task(send_conversation_updates())
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"system_running": diart.is_running,
"active_connections": len(active_connections)
}
@app.websocket("/ws_inference")
async def ws_inference(ws: WebSocket):
"""WebSocket endpoint for real-time audio processing"""
await ws.accept()
active_connections.add(ws)
logger.info(f"WebSocket connected. Total: {len(active_connections)}")
try:
# Send initial conversation state
await ws.send_text(diart.get_formatted_conversation())
# Process incoming audio chunks
async for chunk in ws.iter_bytes():
if chunk:
try:
diart.process_audio_chunk(chunk)
except Exception as e:
logger.error(f"Error processing chunk: {e}")
except WebSocketDisconnect:
logger.info("WebSocket disconnected")
except Exception as e:
logger.error(f"WebSocket error: {e}")
finally:
active_connections.discard(ws)
logger.info(f"WebSocket closed. Remaining: {len(active_connections)}")
@app.get("/conversation")
async def get_conversation():
"""Get the current conversation as HTML"""
return {"conversation": diart.get_formatted_conversation()}
@app.get("/status")
async def get_status():
"""Get system status information"""
return {"status": diart.get_status_info()}
@app.post("/settings")
async def update_settings(threshold: float, max_speakers: int):
"""Update speaker detection settings"""
return {"result": diart.update_settings(threshold, max_speakers)}
@app.post("/clear")
async def clear_conversation():
"""Clear the conversation"""
return {"result": diart.clear_conversation()}
# Mount Gradio UI at /ui so it doesn't override API/WebSocket routes
try:
import ui
ui.mount_ui(app, path="/ui")
logger.info("Gradio UI mounted at /ui")
except ImportError:
logger.warning("UI module not found, running in API-only mode")
if __name__ == "__main__":
port = int(os.getenv("PORT", 10000))
uvicorn.run("backend:app", host="0.0.0.0", port=port)