Saiyaswanth007 commited on
Commit
7c64918
·
1 Parent(s): 10b8972

changing from /

Browse files
Files changed (1) hide show
  1. inference.py +38 -29
inference.py CHANGED
@@ -1,7 +1,7 @@
1
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Response
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from shared import RealtimeSpeakerDiarization
4
- import os
5
  import uvicorn
6
  import logging
7
  import asyncio
@@ -13,12 +13,6 @@ logger = logging.getLogger(__name__)
13
  # Initialize FastAPI app
14
  app = FastAPI()
15
 
16
- # Respond to HEAD / with a 200 so port scanners don’t see a 405
17
- @app.head("/", include_in_schema=False)
18
- @app.get("/")
19
- async def root():
20
- return {"message": "Speaker Diarization Signaling Server"}
21
-
22
  # Add CORS middleware for browser compatibility
23
  app.add_middleware(
24
  CORSMiddleware,
@@ -44,13 +38,21 @@ async def send_conversation_updates():
44
  """Periodically send conversation updates to all connected clients"""
45
  while True:
46
  if active_connections:
47
- conversation_html = diart.get_formatted_conversation()
48
- for ws in list(active_connections):
49
- try:
50
- await ws.send_text(conversation_html)
51
- except Exception as e:
52
- logger.error(f"Error sending to WebSocket: {e}")
53
- active_connections.discard(ws)
 
 
 
 
 
 
 
 
54
  await asyncio.sleep(0.5) # 500ms update interval
55
 
56
  @app.on_event("startup")
@@ -72,24 +74,30 @@ async def ws_inference(ws: WebSocket):
72
  """WebSocket endpoint for real-time audio processing"""
73
  await ws.accept()
74
  active_connections.add(ws)
75
- logger.info(f"WebSocket connected. Total: {len(active_connections)}")
 
76
  try:
77
  # Send initial conversation state
78
- await ws.send_text(diart.get_formatted_conversation())
 
 
79
  # Process incoming audio chunks
80
  async for chunk in ws.iter_bytes():
81
- if chunk:
82
- try:
 
 
83
  diart.process_audio_chunk(chunk)
84
- except Exception as e:
85
- logger.error(f"Error processing chunk: {e}")
 
86
  except WebSocketDisconnect:
87
  logger.info("WebSocket disconnected")
88
  except Exception as e:
89
  logger.error(f"WebSocket error: {e}")
90
  finally:
91
  active_connections.discard(ws)
92
- logger.info(f"WebSocket closed. Remaining: {len(active_connections)}")
93
 
94
  @app.get("/conversation")
95
  async def get_conversation():
@@ -104,21 +112,22 @@ async def get_status():
104
  @app.post("/settings")
105
  async def update_settings(threshold: float, max_speakers: int):
106
  """Update speaker detection settings"""
107
- return {"result": diart.update_settings(threshold, max_speakers)}
 
108
 
109
  @app.post("/clear")
110
  async def clear_conversation():
111
  """Clear the conversation"""
112
- return {"result": diart.clear_conversation()}
 
113
 
114
- # Mount Gradio UI at /ui so it doesn't override API/WebSocket routes
115
  try:
116
  import ui
117
- ui.mount_ui(app, path="/ui")
118
- logger.info("Gradio UI mounted at /ui")
119
  except ImportError:
120
  logger.warning("UI module not found, running in API-only mode")
121
 
122
  if __name__ == "__main__":
123
- port = int(os.getenv("PORT", 10000))
124
- uvicorn.run("backend:app", host="0.0.0.0", port=port)
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from shared import RealtimeSpeakerDiarization
4
+ import numpy as np
5
  import uvicorn
6
  import logging
7
  import asyncio
 
13
  # Initialize FastAPI app
14
  app = FastAPI()
15
 
 
 
 
 
 
 
16
  # Add CORS middleware for browser compatibility
17
  app.add_middleware(
18
  CORSMiddleware,
 
38
  """Periodically send conversation updates to all connected clients"""
39
  while True:
40
  if active_connections:
41
+ try:
42
+ # Get current conversation HTML
43
+ conversation_html = diart.get_formatted_conversation()
44
+
45
+ # Send to all active connections
46
+ for ws in active_connections.copy():
47
+ try:
48
+ await ws.send_text(conversation_html)
49
+ except Exception as e:
50
+ logger.error(f"Error sending to WebSocket: {e}")
51
+ active_connections.discard(ws)
52
+ except Exception as e:
53
+ logger.error(f"Error in conversation update: {e}")
54
+
55
+ # Wait before sending next update
56
  await asyncio.sleep(0.5) # 500ms update interval
57
 
58
  @app.on_event("startup")
 
74
  """WebSocket endpoint for real-time audio processing"""
75
  await ws.accept()
76
  active_connections.add(ws)
77
+ logger.info(f"WebSocket connection established. Total connections: {len(active_connections)}")
78
+
79
  try:
80
  # Send initial conversation state
81
+ conversation_html = diart.get_formatted_conversation()
82
+ await ws.send_text(conversation_html)
83
+
84
  # Process incoming audio chunks
85
  async for chunk in ws.iter_bytes():
86
+ try:
87
+ # Process raw audio bytes
88
+ if chunk:
89
+ # Process audio data - this updates the internal conversation state
90
  diart.process_audio_chunk(chunk)
91
+ except Exception as e:
92
+ logger.error(f"Error processing audio chunk: {e}")
93
+
94
  except WebSocketDisconnect:
95
  logger.info("WebSocket disconnected")
96
  except Exception as e:
97
  logger.error(f"WebSocket error: {e}")
98
  finally:
99
  active_connections.discard(ws)
100
+ logger.info(f"WebSocket connection closed. Remaining connections: {len(active_connections)}")
101
 
102
  @app.get("/conversation")
103
  async def get_conversation():
 
112
  @app.post("/settings")
113
  async def update_settings(threshold: float, max_speakers: int):
114
  """Update speaker detection settings"""
115
+ result = diart.update_settings(threshold, max_speakers)
116
+ return {"result": result}
117
 
118
  @app.post("/clear")
119
  async def clear_conversation():
120
  """Clear the conversation"""
121
+ result = diart.clear_conversation()
122
+ return {"result": result}
123
 
124
+ # Import UI module to mount the Gradio app
125
  try:
126
  import ui
127
+ ui.mount_ui(app)
128
+ logger.info("Gradio UI mounted successfully")
129
  except ImportError:
130
  logger.warning("UI module not found, running in API-only mode")
131
 
132
  if __name__ == "__main__":
133
+ uvicorn.run(app, host="0.0.0.0", port=7860)