Speaker-Diarization / inference.py
Saiyaswanth007's picture
changing from /
7c64918
raw
history blame
4.34 kB
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from shared import RealtimeSpeakerDiarization
import numpy as np
import uvicorn
import logging
import asyncio
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI()
# Add CORS middleware for browser compatibility
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize the diarization system
logger.info("Initializing diarization system...")
diart = RealtimeSpeakerDiarization()
success = diart.initialize_models()
logger.info(f"Models initialized: {success}")
if success:
diart.start_recording()
# Track active WebSocket connections
active_connections = set()
# Periodic status update function
async def send_conversation_updates():
"""Periodically send conversation updates to all connected clients"""
while True:
if active_connections:
try:
# Get current conversation HTML
conversation_html = diart.get_formatted_conversation()
# Send to all active connections
for ws in active_connections.copy():
try:
await ws.send_text(conversation_html)
except Exception as e:
logger.error(f"Error sending to WebSocket: {e}")
active_connections.discard(ws)
except Exception as e:
logger.error(f"Error in conversation update: {e}")
# Wait before sending next update
await asyncio.sleep(0.5) # 500ms update interval
@app.on_event("startup")
async def startup_event():
"""Start background tasks when the app starts"""
asyncio.create_task(send_conversation_updates())
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"system_running": diart.is_running,
"active_connections": len(active_connections)
}
@app.websocket("/ws_inference")
async def ws_inference(ws: WebSocket):
"""WebSocket endpoint for real-time audio processing"""
await ws.accept()
active_connections.add(ws)
logger.info(f"WebSocket connection established. Total connections: {len(active_connections)}")
try:
# Send initial conversation state
conversation_html = diart.get_formatted_conversation()
await ws.send_text(conversation_html)
# Process incoming audio chunks
async for chunk in ws.iter_bytes():
try:
# Process raw audio bytes
if chunk:
# Process audio data - this updates the internal conversation state
diart.process_audio_chunk(chunk)
except Exception as e:
logger.error(f"Error processing audio chunk: {e}")
except WebSocketDisconnect:
logger.info("WebSocket disconnected")
except Exception as e:
logger.error(f"WebSocket error: {e}")
finally:
active_connections.discard(ws)
logger.info(f"WebSocket connection closed. Remaining connections: {len(active_connections)}")
@app.get("/conversation")
async def get_conversation():
"""Get the current conversation as HTML"""
return {"conversation": diart.get_formatted_conversation()}
@app.get("/status")
async def get_status():
"""Get system status information"""
return {"status": diart.get_status_info()}
@app.post("/settings")
async def update_settings(threshold: float, max_speakers: int):
"""Update speaker detection settings"""
result = diart.update_settings(threshold, max_speakers)
return {"result": result}
@app.post("/clear")
async def clear_conversation():
"""Clear the conversation"""
result = diart.clear_conversation()
return {"result": result}
# Import UI module to mount the Gradio app
try:
import ui
ui.mount_ui(app)
logger.info("Gradio UI mounted successfully")
except ImportError:
logger.warning("UI module not found, running in API-only mode")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)