Speaker-Diarization / inference.py
Saiyaswanth007's picture
Code splitting
97a4ae5
raw
history blame
2.42 kB
from fastapi import FastAPI, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from shared import RealtimeSpeakerDiarization
import uvicorn
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI()
# Add CORS middleware for browser compatibility
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize the diarization system
diart = RealtimeSpeakerDiarization()
success = diart.initialize_models()
logger.info(f"Models initialized: {success}")
diart.start_recording()
@app.get("/health")
async def health_check():
return {"status": "healthy", "system_running": diart.is_running}
@app.websocket("/ws_inference")
async def ws_inference(ws: WebSocket):
"""WebSocket endpoint for real-time audio processing"""
await ws.accept()
logger.info("WebSocket connection established")
try:
async for chunk in ws.iter_bytes():
# Process audio data
diart.process_audio_chunk(chunk, sample_rate=16000)
# Send back conversation results
result = diart.get_formatted_conversation()
await ws.send_text(result)
except Exception as e:
logger.error(f"WebSocket error: {e}")
finally:
logger.info("WebSocket connection closed")
@app.get("/conversation")
async def get_conversation():
"""Get the current conversation as HTML"""
return {"conversation": diart.get_formatted_conversation()}
@app.get("/status")
async def get_status():
"""Get system status information"""
return {"status": diart.get_status_info()}
@app.post("/settings")
async def update_settings(threshold: float, max_speakers: int):
"""Update speaker detection settings"""
result = diart.update_settings(threshold, max_speakers)
return {"result": result}
@app.post("/clear")
async def clear_conversation():
"""Clear the conversation"""
result = diart.clear_conversation()
return {"result": result}
# Import UI module to mount the Gradio app
try:
import ui
ui.mount_ui(app)
logger.info("Gradio UI mounted successfully")
except ImportError:
logger.warning("UI module not found, running in API-only mode")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)