Saiyaswanth007 commited on
Commit
10b8972
·
1 Parent(s): cf7649e

changing from /

Browse files
Files changed (2) hide show
  1. inference.py +29 -38
  2. ui.py +2 -2
inference.py CHANGED
@@ -1,7 +1,7 @@
1
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from shared import RealtimeSpeakerDiarization
4
- import numpy as np
5
  import uvicorn
6
  import logging
7
  import asyncio
@@ -13,6 +13,12 @@ logger = logging.getLogger(__name__)
13
  # Initialize FastAPI app
14
  app = FastAPI()
15
 
 
 
 
 
 
 
16
  # Add CORS middleware for browser compatibility
17
  app.add_middleware(
18
  CORSMiddleware,
@@ -38,21 +44,13 @@ async def send_conversation_updates():
38
  """Periodically send conversation updates to all connected clients"""
39
  while True:
40
  if active_connections:
41
- try:
42
- # Get current conversation HTML
43
- conversation_html = diart.get_formatted_conversation()
44
-
45
- # Send to all active connections
46
- for ws in active_connections.copy():
47
- try:
48
- await ws.send_text(conversation_html)
49
- except Exception as e:
50
- logger.error(f"Error sending to WebSocket: {e}")
51
- active_connections.discard(ws)
52
- except Exception as e:
53
- logger.error(f"Error in conversation update: {e}")
54
-
55
- # Wait before sending next update
56
  await asyncio.sleep(0.5) # 500ms update interval
57
 
58
  @app.on_event("startup")
@@ -74,30 +72,24 @@ async def ws_inference(ws: WebSocket):
74
  """WebSocket endpoint for real-time audio processing"""
75
  await ws.accept()
76
  active_connections.add(ws)
77
- logger.info(f"WebSocket connection established. Total connections: {len(active_connections)}")
78
-
79
  try:
80
  # Send initial conversation state
81
- conversation_html = diart.get_formatted_conversation()
82
- await ws.send_text(conversation_html)
83
-
84
  # Process incoming audio chunks
85
  async for chunk in ws.iter_bytes():
86
- try:
87
- # Process raw audio bytes
88
- if chunk:
89
- # Process audio data - this updates the internal conversation state
90
  diart.process_audio_chunk(chunk)
91
- except Exception as e:
92
- logger.error(f"Error processing audio chunk: {e}")
93
-
94
  except WebSocketDisconnect:
95
  logger.info("WebSocket disconnected")
96
  except Exception as e:
97
  logger.error(f"WebSocket error: {e}")
98
  finally:
99
  active_connections.discard(ws)
100
- logger.info(f"WebSocket connection closed. Remaining connections: {len(active_connections)}")
101
 
102
  @app.get("/conversation")
103
  async def get_conversation():
@@ -112,22 +104,21 @@ async def get_status():
112
  @app.post("/settings")
113
  async def update_settings(threshold: float, max_speakers: int):
114
  """Update speaker detection settings"""
115
- result = diart.update_settings(threshold, max_speakers)
116
- return {"result": result}
117
 
118
  @app.post("/clear")
119
  async def clear_conversation():
120
  """Clear the conversation"""
121
- result = diart.clear_conversation()
122
- return {"result": result}
123
 
124
- # Import UI module to mount the Gradio app
125
  try:
126
  import ui
127
- ui.mount_ui(app)
128
- logger.info("Gradio UI mounted successfully")
129
  except ImportError:
130
  logger.warning("UI module not found, running in API-only mode")
131
 
132
  if __name__ == "__main__":
133
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Response
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from shared import RealtimeSpeakerDiarization
4
+ import os
5
  import uvicorn
6
  import logging
7
  import asyncio
 
13
  # Initialize FastAPI app
14
  app = FastAPI()
15
 
16
+ # Respond to HEAD / with a 200 so port scanners don’t see a 405
17
+ @app.head("/", include_in_schema=False)
18
+ @app.get("/")
19
+ async def root():
20
+ return {"message": "Speaker Diarization Signaling Server"}
21
+
22
  # Add CORS middleware for browser compatibility
23
  app.add_middleware(
24
  CORSMiddleware,
 
44
  """Periodically send conversation updates to all connected clients"""
45
  while True:
46
  if active_connections:
47
+ conversation_html = diart.get_formatted_conversation()
48
+ for ws in list(active_connections):
49
+ try:
50
+ await ws.send_text(conversation_html)
51
+ except Exception as e:
52
+ logger.error(f"Error sending to WebSocket: {e}")
53
+ active_connections.discard(ws)
 
 
 
 
 
 
 
 
54
  await asyncio.sleep(0.5) # 500ms update interval
55
 
56
  @app.on_event("startup")
 
72
  """WebSocket endpoint for real-time audio processing"""
73
  await ws.accept()
74
  active_connections.add(ws)
75
+ logger.info(f"WebSocket connected. Total: {len(active_connections)}")
 
76
  try:
77
  # Send initial conversation state
78
+ await ws.send_text(diart.get_formatted_conversation())
 
 
79
  # Process incoming audio chunks
80
  async for chunk in ws.iter_bytes():
81
+ if chunk:
82
+ try:
 
 
83
  diart.process_audio_chunk(chunk)
84
+ except Exception as e:
85
+ logger.error(f"Error processing chunk: {e}")
 
86
  except WebSocketDisconnect:
87
  logger.info("WebSocket disconnected")
88
  except Exception as e:
89
  logger.error(f"WebSocket error: {e}")
90
  finally:
91
  active_connections.discard(ws)
92
+ logger.info(f"WebSocket closed. Remaining: {len(active_connections)}")
93
 
94
  @app.get("/conversation")
95
  async def get_conversation():
 
104
  @app.post("/settings")
105
  async def update_settings(threshold: float, max_speakers: int):
106
  """Update speaker detection settings"""
107
+ return {"result": diart.update_settings(threshold, max_speakers)}
 
108
 
109
  @app.post("/clear")
110
  async def clear_conversation():
111
  """Clear the conversation"""
112
+ return {"result": diart.clear_conversation()}
 
113
 
114
+ # Mount Gradio UI at /ui so it doesn't override API/WebSocket routes
115
  try:
116
  import ui
117
+ ui.mount_ui(app, path="/ui")
118
+ logger.info("Gradio UI mounted at /ui")
119
  except ImportError:
120
  logger.warning("UI module not found, running in API-only mode")
121
 
122
  if __name__ == "__main__":
123
+ port = int(os.getenv("PORT", 10000))
124
+ uvicorn.run("backend:app", host="0.0.0.0", port=port)
ui.py CHANGED
@@ -203,7 +203,7 @@ def build_ui():
203
  } else {
204
  updateStatus('warning', 'Connection unstable');
205
  }
206
- } catch (err) {
207
  console.error('Error updating connection info:', err);
208
  }
209
  }
@@ -406,7 +406,7 @@ def build_ui():
406
  status_timer = gr.Timer(5)
407
  status_timer.tick(fn=get_status, outputs=status_output)
408
 
409
-
410
  return demo
411
 
412
  # Create Gradio interface
 
203
  } else {
204
  updateStatus('warning', 'Connection unstable');
205
  }
206
+ } catch (err) {
207
  console.error('Error updating connection info:', err);
208
  }
209
  }
 
406
  status_timer = gr.Timer(5)
407
  status_timer.tick(fn=get_status, outputs=status_output)
408
 
409
+
410
  return demo
411
 
412
  # Create Gradio interface