File size: 4,015 Bytes
10b8972
97a4ae5
 
10b8972
97a4ae5
 
4641c1c
97a4ae5
 
 
 
 
 
 
 
10b8972
 
 
 
 
 
97a4ae5
 
 
 
 
 
 
 
 
 
4641c1c
97a4ae5
 
 
4641c1c
 
 
 
 
 
 
 
 
 
 
10b8972
 
 
 
 
 
 
4641c1c
 
 
 
 
 
97a4ae5
 
 
4641c1c
 
 
 
 
 
97a4ae5
 
 
 
 
4641c1c
10b8972
97a4ae5
4641c1c
10b8972
4641c1c
97a4ae5
10b8972
 
4641c1c
10b8972
 
4641c1c
 
97a4ae5
 
 
4641c1c
10b8972
97a4ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
10b8972
97a4ae5
 
 
 
10b8972
97a4ae5
10b8972
97a4ae5
 
10b8972
 
97a4ae5
 
 
 
10b8972
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Response
from fastapi.middleware.cors import CORSMiddleware
from shared import RealtimeSpeakerDiarization
import os
import uvicorn
import logging
import asyncio

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI()

# Respond to HEAD / with a 200 so port scanners don’t see a 405
@app.head("/", include_in_schema=False)
@app.get("/")
async def root():
    return {"message": "Speaker Diarization Signaling Server"}

# Add CORS middleware for browser compatibility
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize the diarization system
logger.info("Initializing diarization system...")
diart = RealtimeSpeakerDiarization()
success = diart.initialize_models()
logger.info(f"Models initialized: {success}")
if success:
    diart.start_recording()

# Track active WebSocket connections
active_connections = set()

# Periodic status update function
async def send_conversation_updates():
    """Periodically send conversation updates to all connected clients"""
    while True:
        if active_connections:
            conversation_html = diart.get_formatted_conversation()
            for ws in list(active_connections):
                try:
                    await ws.send_text(conversation_html)
                except Exception as e:
                    logger.error(f"Error sending to WebSocket: {e}")
                    active_connections.discard(ws)
        await asyncio.sleep(0.5)  # 500ms update interval

@app.on_event("startup")
async def startup_event():
    """Start background tasks when the app starts"""
    asyncio.create_task(send_conversation_updates())

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {
        "status": "healthy", 
        "system_running": diart.is_running,
        "active_connections": len(active_connections)
    }

@app.websocket("/ws_inference")
async def ws_inference(ws: WebSocket):
    """WebSocket endpoint for real-time audio processing"""
    await ws.accept()
    active_connections.add(ws)
    logger.info(f"WebSocket connected. Total: {len(active_connections)}")
    try:
        # Send initial conversation state
        await ws.send_text(diart.get_formatted_conversation())
        # Process incoming audio chunks
        async for chunk in ws.iter_bytes():
            if chunk:
                try:
                    diart.process_audio_chunk(chunk)
                except Exception as e:
                    logger.error(f"Error processing chunk: {e}")
    except WebSocketDisconnect:
        logger.info("WebSocket disconnected")
    except Exception as e:
        logger.error(f"WebSocket error: {e}")
    finally:
        active_connections.discard(ws)
        logger.info(f"WebSocket closed. Remaining: {len(active_connections)}")

@app.get("/conversation")
async def get_conversation():
    """Get the current conversation as HTML"""
    return {"conversation": diart.get_formatted_conversation()}

@app.get("/status")
async def get_status():
    """Get system status information"""
    return {"status": diart.get_status_info()}

@app.post("/settings")
async def update_settings(threshold: float, max_speakers: int):
    """Update speaker detection settings"""
    return {"result": diart.update_settings(threshold, max_speakers)}

@app.post("/clear")
async def clear_conversation():
    """Clear the conversation"""
    return {"result": diart.clear_conversation()}

# Mount Gradio UI at /ui so it doesn't override API/WebSocket routes
try:
    import ui
    ui.mount_ui(app, path="/ui")
    logger.info("Gradio UI mounted at /ui")
except ImportError:
    logger.warning("UI module not found, running in API-only mode")

if __name__ == "__main__":
    port = int(os.getenv("PORT", 10000))
    uvicorn.run("backend:app", host="0.0.0.0", port=port)