Saiyaswanth007 commited on
Commit
8acaa5d
·
1 Parent(s): 1e3a43d

Added logging in inference

Browse files
Files changed (1) hide show
  1. inference.py +313 -60
inference.py CHANGED
@@ -1,17 +1,24 @@
1
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from shared import RealtimeSpeakerDiarization
4
  import numpy as np
5
  import uvicorn
6
  import logging
7
  import asyncio
 
 
 
 
8
 
9
  # Set up logging
10
- logging.basicConfig(level=logging.INFO)
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
13
  # Initialize FastAPI app
14
- app = FastAPI()
15
 
16
  # Add CORS middleware for browser compatibility
17
  app.add_middleware(
@@ -22,115 +29,361 @@ app.add_middleware(
22
  allow_headers=["*"],
23
  )
24
 
25
- # Initialize the diarization system
26
- logger.info("Initializing diarization system...")
27
- diart = RealtimeSpeakerDiarization()
28
- success = diart.initialize_models()
29
- logger.info(f"Models initialized: {success}")
30
- if success:
31
- diart.start_recording()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Track active WebSocket connections
34
- active_connections = set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Periodic status update function
37
  async def send_conversation_updates():
38
  """Periodically send conversation updates to all connected clients"""
 
 
 
39
  while True:
40
- if active_connections:
41
- try:
42
- # Get current conversation HTML
43
  conversation_html = diart.get_formatted_conversation()
44
 
45
- # Send to all active connections
46
- for ws in active_connections.copy():
47
- try:
48
- await ws.send_text(conversation_html)
49
- except Exception as e:
50
- logger.error(f"Error sending to WebSocket: {e}")
51
- active_connections.discard(ws)
52
- except Exception as e:
53
- logger.error(f"Error in conversation update: {e}")
 
 
 
 
 
 
 
 
54
 
55
- # Wait before sending next update
56
- await asyncio.sleep(0.5) # 500ms update interval
57
 
58
  @app.on_event("startup")
59
  async def startup_event():
60
- """Start background tasks when the app starts"""
 
 
 
 
 
 
 
 
61
  asyncio.create_task(send_conversation_updates())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  @app.get("/health")
64
  @app.head("/health")
65
  async def health_check():
66
- """Health check endpoint"""
 
 
67
  return {
68
- "status": "healthy",
69
- "system_running": diart.is_running,
70
- "active_connections": len(active_connections)
 
 
71
  }
72
 
73
  @app.websocket("/ws_inference")
74
- async def ws_inference(ws: WebSocket):
75
  """WebSocket endpoint for real-time audio processing"""
76
- await ws.accept()
77
- active_connections.add(ws)
78
- logger.info(f"WebSocket connection established. Total connections: {len(active_connections)}")
79
 
80
  try:
81
- # Send initial conversation state
82
- conversation_html = diart.get_formatted_conversation()
83
- await ws.send_text(conversation_html)
 
 
 
 
 
 
 
84
 
85
- # Process incoming audio chunks
86
- async for chunk in ws.iter_bytes():
87
  try:
88
- # Process raw audio bytes
89
- if chunk:
90
- # Process audio data - this updates the internal conversation state
91
- diart.process_audio_chunk(chunk)
 
 
 
 
 
 
 
 
 
 
 
92
  except Exception as e:
93
  logger.error(f"Error processing audio chunk: {e}")
 
 
 
 
 
 
 
94
 
95
  except WebSocketDisconnect:
96
- logger.info("WebSocket disconnected")
97
  except Exception as e:
98
- logger.error(f"WebSocket error: {e}")
99
  finally:
100
- active_connections.discard(ws)
101
- logger.info(f"WebSocket connection closed. Remaining connections: {len(active_connections)}")
102
 
103
  @app.get("/conversation")
104
  @app.head("/conversation")
105
  async def get_conversation():
106
  """Get the current conversation as HTML"""
107
- return {"conversation": diart.get_formatted_conversation()}
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  @app.get("/status")
110
  @app.head("/status")
111
  async def get_status():
112
- """Get system status information"""
113
- return {"status": diart.get_status_info()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  @app.post("/settings")
116
- async def update_settings(threshold: float, max_speakers: int):
117
  """Update speaker detection settings"""
118
- result = diart.update_settings(threshold, max_speakers)
119
- return {"result": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  @app.post("/clear")
122
  async def clear_conversation():
123
- """Clear the conversation"""
124
- result = diart.clear_conversation()
125
- return {"result": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- # Import UI module to mount the Gradio app
128
  try:
129
  import ui
130
  ui.mount_ui(app)
131
  logger.info("Gradio UI mounted successfully")
132
  except ImportError:
133
  logger.warning("UI module not found, running in API-only mode")
 
 
 
 
 
134
 
135
  if __name__ == "__main__":
136
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from shared import RealtimeSpeakerDiarization
4
  import numpy as np
5
  import uvicorn
6
  import logging
7
  import asyncio
8
+ import json
9
+ import time
10
+ from typing import Set, Dict, Any
11
+ import traceback
12
 
13
  # Set up logging
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
17
+ )
18
  logger = logging.getLogger(__name__)
19
 
20
  # Initialize FastAPI app
21
+ app = FastAPI(title="Real-time Speaker Diarization API", version="1.0.0")
22
 
23
  # Add CORS middleware for browser compatibility
24
  app.add_middleware(
 
29
  allow_headers=["*"],
30
  )
31
 
32
+ # Global state management
33
+ diart = None
34
+ active_connections: Set[WebSocket] = set()
35
+ connection_stats: Dict[str, Any] = {
36
+ "total_connections": 0,
37
+ "current_connections": 0,
38
+ "last_audio_received": None,
39
+ "total_audio_chunks": 0
40
+ }
41
+
42
+ class ConnectionManager:
43
+ """Manages WebSocket connections and broadcasting"""
44
+
45
+ def __init__(self):
46
+ self.active_connections: Set[WebSocket] = set()
47
+ self.connection_metadata: Dict[WebSocket, Dict] = {}
48
+
49
+ async def connect(self, websocket: WebSocket, client_id: str = None):
50
+ """Add a new WebSocket connection"""
51
+ await websocket.accept()
52
+ self.active_connections.add(websocket)
53
+ self.connection_metadata[websocket] = {
54
+ "client_id": client_id or f"client_{int(time.time())}",
55
+ "connected_at": time.time(),
56
+ "messages_sent": 0
57
+ }
58
+
59
+ connection_stats["current_connections"] = len(self.active_connections)
60
+ connection_stats["total_connections"] += 1
61
+
62
+ logger.info(f"WebSocket connected: {self.connection_metadata[websocket]['client_id']}. "
63
+ f"Total connections: {len(self.active_connections)}")
64
+
65
+ def disconnect(self, websocket: WebSocket):
66
+ """Remove a WebSocket connection"""
67
+ if websocket in self.active_connections:
68
+ client_info = self.connection_metadata.get(websocket, {})
69
+ client_id = client_info.get("client_id", "unknown")
70
+
71
+ self.active_connections.discard(websocket)
72
+ self.connection_metadata.pop(websocket, None)
73
+
74
+ connection_stats["current_connections"] = len(self.active_connections)
75
+
76
+ logger.info(f"WebSocket disconnected: {client_id}. "
77
+ f"Remaining connections: {len(self.active_connections)}")
78
+
79
+ async def broadcast(self, message: str):
80
+ """Broadcast message to all active connections"""
81
+ if not self.active_connections:
82
+ return
83
+
84
+ disconnected = set()
85
+
86
+ for websocket in self.active_connections.copy():
87
+ try:
88
+ await websocket.send_text(message)
89
+ if websocket in self.connection_metadata:
90
+ self.connection_metadata[websocket]["messages_sent"] += 1
91
+ except Exception as e:
92
+ logger.warning(f"Failed to send message to client: {e}")
93
+ disconnected.add(websocket)
94
+
95
+ # Clean up disconnected clients
96
+ for ws in disconnected:
97
+ self.disconnect(ws)
98
+
99
+ def get_stats(self):
100
+ """Get connection statistics"""
101
+ return {
102
+ "active_connections": len(self.active_connections),
103
+ "connection_metadata": {
104
+ ws_id: meta for ws_id, (ws, meta) in
105
+ enumerate(self.connection_metadata.items())
106
+ }
107
+ }
108
+
109
+ # Initialize connection manager
110
+ manager = ConnectionManager()
111
 
112
+ async def initialize_diarization_system():
113
+ """Initialize the diarization system with proper error handling"""
114
+ global diart
115
+
116
+ try:
117
+ logger.info("Initializing diarization system...")
118
+ diart = RealtimeSpeakerDiarization()
119
+ success = diart.initialize_models()
120
+
121
+ if success:
122
+ logger.info("Models initialized successfully")
123
+ diart.start_recording()
124
+ logger.info("Recording started")
125
+ return True
126
+ else:
127
+ logger.error("Failed to initialize models")
128
+ return False
129
+
130
+ except Exception as e:
131
+ logger.error(f"Error initializing diarization system: {e}")
132
+ logger.error(traceback.format_exc())
133
+ return False
134
 
 
135
  async def send_conversation_updates():
136
  """Periodically send conversation updates to all connected clients"""
137
+ update_interval = 0.5 # 500ms update intervals
138
+ last_conversation_hash = None
139
+
140
  while True:
141
+ try:
142
+ if diart and diart.is_running and manager.active_connections:
143
+ # Get current conversation
144
  conversation_html = diart.get_formatted_conversation()
145
 
146
+ # Only send if conversation has changed (to reduce bandwidth)
147
+ conversation_hash = hash(conversation_html)
148
+ if conversation_hash != last_conversation_hash:
149
+
150
+ # Create structured message
151
+ update_message = json.dumps({
152
+ "type": "conversation_update",
153
+ "timestamp": time.time(),
154
+ "conversation_html": conversation_html,
155
+ "status": diart.get_status_info() if hasattr(diart, 'get_status_info') else {}
156
+ })
157
+
158
+ await manager.broadcast(update_message)
159
+ last_conversation_hash = conversation_hash
160
+
161
+ except Exception as e:
162
+ logger.error(f"Error in conversation update: {e}")
163
 
164
+ await asyncio.sleep(update_interval)
 
165
 
166
  @app.on_event("startup")
167
  async def startup_event():
168
+ """Initialize system on startup"""
169
+ logger.info("Starting Real-time Speaker Diarization Service")
170
+
171
+ # Initialize diarization system
172
+ success = await initialize_diarization_system()
173
+ if not success:
174
+ logger.error("Failed to initialize diarization system!")
175
+
176
+ # Start background update task
177
  asyncio.create_task(send_conversation_updates())
178
+ logger.info("Background tasks started")
179
+
180
+ @app.on_event("shutdown")
181
+ async def shutdown_event():
182
+ """Clean up on shutdown"""
183
+ logger.info("Shutting down...")
184
+ if diart:
185
+ try:
186
+ diart.stop_recording()
187
+ logger.info("Recording stopped")
188
+ except Exception as e:
189
+ logger.error(f"Error stopping recording: {e}")
190
+
191
+ @app.get("/")
192
+ async def root():
193
+ """Root endpoint with service information"""
194
+ return {
195
+ "service": "Real-time Speaker Diarization API",
196
+ "version": "1.0.0",
197
+ "status": "running" if diart and diart.is_running else "initializing",
198
+ "endpoints": {
199
+ "websocket": "/ws_inference",
200
+ "health": "/health",
201
+ "conversation": "/conversation",
202
+ "status": "/status"
203
+ }
204
+ }
205
 
206
  @app.get("/health")
207
  @app.head("/health")
208
  async def health_check():
209
+ """Comprehensive health check endpoint"""
210
+ system_healthy = diart and diart.is_running
211
+
212
  return {
213
+ "status": "healthy" if system_healthy else "unhealthy",
214
+ "system_running": system_healthy,
215
+ "active_connections": len(manager.active_connections),
216
+ "connection_stats": connection_stats,
217
+ "diarization_status": diart.get_status_info() if diart and hasattr(diart, 'get_status_info') else {}
218
  }
219
 
220
  @app.websocket("/ws_inference")
221
+ async def ws_inference(websocket: WebSocket):
222
  """WebSocket endpoint for real-time audio processing"""
223
+ client_id = f"client_{int(time.time())}"
 
 
224
 
225
  try:
226
+ await manager.connect(websocket, client_id)
227
+
228
+ # Send initial connection confirmation
229
+ initial_message = json.dumps({
230
+ "type": "connection_established",
231
+ "client_id": client_id,
232
+ "system_status": "ready" if diart and diart.is_running else "initializing",
233
+ "conversation": diart.get_formatted_conversation() if diart else ""
234
+ })
235
+ await websocket.send_text(initial_message)
236
 
237
+ # Process incoming audio data
238
+ async for data in websocket.iter_bytes():
239
  try:
240
+ if data and diart and diart.is_running:
241
+ # Update statistics
242
+ connection_stats["last_audio_received"] = time.time()
243
+ connection_stats["total_audio_chunks"] += 1
244
+
245
+ # Process audio chunk
246
+ result = diart.process_audio_chunk(data)
247
+
248
+ # Log processing result (optional)
249
+ if connection_stats["total_audio_chunks"] % 100 == 0: # Log every 100 chunks
250
+ logger.debug(f"Processed {connection_stats['total_audio_chunks']} audio chunks")
251
+
252
+ elif not diart:
253
+ logger.warning("Received audio data but diarization system not initialized")
254
+
255
  except Exception as e:
256
  logger.error(f"Error processing audio chunk: {e}")
257
+ # Send error message to client
258
+ error_message = json.dumps({
259
+ "type": "error",
260
+ "message": "Error processing audio",
261
+ "timestamp": time.time()
262
+ })
263
+ await websocket.send_text(error_message)
264
 
265
  except WebSocketDisconnect:
266
+ logger.info(f"WebSocket {client_id} disconnected normally")
267
  except Exception as e:
268
+ logger.error(f"WebSocket {client_id} error: {e}")
269
  finally:
270
+ manager.disconnect(websocket)
 
271
 
272
  @app.get("/conversation")
273
  @app.head("/conversation")
274
  async def get_conversation():
275
  """Get the current conversation as HTML"""
276
+ if not diart:
277
+ raise HTTPException(status_code=503, detail="Diarization system not initialized")
278
+
279
+ try:
280
+ conversation = diart.get_formatted_conversation()
281
+ return {
282
+ "conversation": conversation,
283
+ "timestamp": time.time(),
284
+ "system_status": diart.get_status_info() if hasattr(diart, 'get_status_info') else {}
285
+ }
286
+ except Exception as e:
287
+ logger.error(f"Error getting conversation: {e}")
288
+ raise HTTPException(status_code=500, detail="Error retrieving conversation")
289
 
290
  @app.get("/status")
291
  @app.head("/status")
292
  async def get_status():
293
+ """Get comprehensive system status information"""
294
+ if not diart:
295
+ return {"status": "system_not_initialized"}
296
+
297
+ try:
298
+ base_status = diart.get_status_info() if hasattr(diart, 'get_status_info') else {}
299
+
300
+ return {
301
+ **base_status,
302
+ "connection_stats": connection_stats,
303
+ "active_connections": len(manager.active_connections),
304
+ "system_uptime": time.time() - connection_stats.get("system_start_time", time.time())
305
+ }
306
+ except Exception as e:
307
+ logger.error(f"Error getting status: {e}")
308
+ return {"status": "error", "message": str(e)}
309
 
310
  @app.post("/settings")
311
+ async def update_settings(threshold: float = None, max_speakers: int = None):
312
  """Update speaker detection settings"""
313
+ if not diart:
314
+ raise HTTPException(status_code=503, detail="Diarization system not initialized")
315
+
316
+ try:
317
+ # Validate parameters
318
+ if threshold is not None and (threshold < 0 or threshold > 1):
319
+ raise HTTPException(status_code=400, detail="Threshold must be between 0 and 1")
320
+
321
+ if max_speakers is not None and (max_speakers < 1 or max_speakers > 20):
322
+ raise HTTPException(status_code=400, detail="Max speakers must be between 1 and 20")
323
+
324
+ result = diart.update_settings(threshold, max_speakers)
325
+ return {
326
+ "result": result,
327
+ "updated_settings": {
328
+ "threshold": threshold,
329
+ "max_speakers": max_speakers
330
+ }
331
+ }
332
+ except Exception as e:
333
+ logger.error(f"Error updating settings: {e}")
334
+ raise HTTPException(status_code=500, detail="Error updating settings")
335
 
336
  @app.post("/clear")
337
  async def clear_conversation():
338
+ """Clear the conversation history"""
339
+ if not diart:
340
+ raise HTTPException(status_code=503, detail="Diarization system not initialized")
341
+
342
+ try:
343
+ result = diart.clear_conversation()
344
+
345
+ # Notify all connected clients about the clear
346
+ clear_message = json.dumps({
347
+ "type": "conversation_cleared",
348
+ "timestamp": time.time()
349
+ })
350
+ await manager.broadcast(clear_message)
351
+
352
+ return {"result": result, "message": "Conversation cleared successfully"}
353
+ except Exception as e:
354
+ logger.error(f"Error clearing conversation: {e}")
355
+ raise HTTPException(status_code=500, detail="Error clearing conversation")
356
+
357
+ @app.get("/stats")
358
+ async def get_connection_stats():
359
+ """Get detailed connection statistics"""
360
+ return {
361
+ "connection_stats": connection_stats,
362
+ "manager_stats": manager.get_stats(),
363
+ "system_info": {
364
+ "diarization_running": diart.is_running if diart else False,
365
+ "total_active_connections": len(manager.active_connections)
366
+ }
367
+ }
368
 
369
+ # Mount UI if available
370
  try:
371
  import ui
372
  ui.mount_ui(app)
373
  logger.info("Gradio UI mounted successfully")
374
  except ImportError:
375
  logger.warning("UI module not found, running in API-only mode")
376
+ except Exception as e:
377
+ logger.error(f"Error mounting UI: {e}")
378
+
379
+ # Initialize system start time
380
+ connection_stats["system_start_time"] = time.time()
381
 
382
  if __name__ == "__main__":
383
+ uvicorn.run(
384
+ app,
385
+ host="0.0.0.0",
386
+ port=7860,
387
+ log_level="info",
388
+ access_log=True
389
+ )