Spaces:
Sleeping
Sleeping
Commit
·
10b8972
1
Parent(s):
cf7649e
changing from /
Browse files- inference.py +29 -38
- ui.py +2 -2
inference.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
from shared import RealtimeSpeakerDiarization
|
4 |
-
import
|
5 |
import uvicorn
|
6 |
import logging
|
7 |
import asyncio
|
@@ -13,6 +13,12 @@ logger = logging.getLogger(__name__)
|
|
13 |
# Initialize FastAPI app
|
14 |
app = FastAPI()
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
# Add CORS middleware for browser compatibility
|
17 |
app.add_middleware(
|
18 |
CORSMiddleware,
|
@@ -38,21 +44,13 @@ async def send_conversation_updates():
|
|
38 |
"""Periodically send conversation updates to all connected clients"""
|
39 |
while True:
|
40 |
if active_connections:
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
await ws.send_text(conversation_html)
|
49 |
-
except Exception as e:
|
50 |
-
logger.error(f"Error sending to WebSocket: {e}")
|
51 |
-
active_connections.discard(ws)
|
52 |
-
except Exception as e:
|
53 |
-
logger.error(f"Error in conversation update: {e}")
|
54 |
-
|
55 |
-
# Wait before sending next update
|
56 |
await asyncio.sleep(0.5) # 500ms update interval
|
57 |
|
58 |
@app.on_event("startup")
|
@@ -74,30 +72,24 @@ async def ws_inference(ws: WebSocket):
|
|
74 |
"""WebSocket endpoint for real-time audio processing"""
|
75 |
await ws.accept()
|
76 |
active_connections.add(ws)
|
77 |
-
logger.info(f"WebSocket
|
78 |
-
|
79 |
try:
|
80 |
# Send initial conversation state
|
81 |
-
|
82 |
-
await ws.send_text(conversation_html)
|
83 |
-
|
84 |
# Process incoming audio chunks
|
85 |
async for chunk in ws.iter_bytes():
|
86 |
-
|
87 |
-
|
88 |
-
if chunk:
|
89 |
-
# Process audio data - this updates the internal conversation state
|
90 |
diart.process_audio_chunk(chunk)
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
except WebSocketDisconnect:
|
95 |
logger.info("WebSocket disconnected")
|
96 |
except Exception as e:
|
97 |
logger.error(f"WebSocket error: {e}")
|
98 |
finally:
|
99 |
active_connections.discard(ws)
|
100 |
-
logger.info(f"WebSocket
|
101 |
|
102 |
@app.get("/conversation")
|
103 |
async def get_conversation():
|
@@ -112,22 +104,21 @@ async def get_status():
|
|
112 |
@app.post("/settings")
|
113 |
async def update_settings(threshold: float, max_speakers: int):
|
114 |
"""Update speaker detection settings"""
|
115 |
-
result
|
116 |
-
return {"result": result}
|
117 |
|
118 |
@app.post("/clear")
|
119 |
async def clear_conversation():
|
120 |
"""Clear the conversation"""
|
121 |
-
result
|
122 |
-
return {"result": result}
|
123 |
|
124 |
-
#
|
125 |
try:
|
126 |
import ui
|
127 |
-
ui.mount_ui(app)
|
128 |
-
logger.info("Gradio UI mounted
|
129 |
except ImportError:
|
130 |
logger.warning("UI module not found, running in API-only mode")
|
131 |
|
132 |
if __name__ == "__main__":
|
133 |
-
|
|
|
|
1 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Response
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
from shared import RealtimeSpeakerDiarization
|
4 |
+
import os
|
5 |
import uvicorn
|
6 |
import logging
|
7 |
import asyncio
|
|
|
13 |
# Initialize FastAPI app
|
14 |
app = FastAPI()
|
15 |
|
16 |
+
# Respond to HEAD / with a 200 so port scanners don’t see a 405
|
17 |
+
@app.head("/", include_in_schema=False)
|
18 |
+
@app.get("/")
|
19 |
+
async def root():
|
20 |
+
return {"message": "Speaker Diarization Signaling Server"}
|
21 |
+
|
22 |
# Add CORS middleware for browser compatibility
|
23 |
app.add_middleware(
|
24 |
CORSMiddleware,
|
|
|
44 |
"""Periodically send conversation updates to all connected clients"""
|
45 |
while True:
|
46 |
if active_connections:
|
47 |
+
conversation_html = diart.get_formatted_conversation()
|
48 |
+
for ws in list(active_connections):
|
49 |
+
try:
|
50 |
+
await ws.send_text(conversation_html)
|
51 |
+
except Exception as e:
|
52 |
+
logger.error(f"Error sending to WebSocket: {e}")
|
53 |
+
active_connections.discard(ws)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
await asyncio.sleep(0.5) # 500ms update interval
|
55 |
|
56 |
@app.on_event("startup")
|
|
|
72 |
"""WebSocket endpoint for real-time audio processing"""
|
73 |
await ws.accept()
|
74 |
active_connections.add(ws)
|
75 |
+
logger.info(f"WebSocket connected. Total: {len(active_connections)}")
|
|
|
76 |
try:
|
77 |
# Send initial conversation state
|
78 |
+
await ws.send_text(diart.get_formatted_conversation())
|
|
|
|
|
79 |
# Process incoming audio chunks
|
80 |
async for chunk in ws.iter_bytes():
|
81 |
+
if chunk:
|
82 |
+
try:
|
|
|
|
|
83 |
diart.process_audio_chunk(chunk)
|
84 |
+
except Exception as e:
|
85 |
+
logger.error(f"Error processing chunk: {e}")
|
|
|
86 |
except WebSocketDisconnect:
|
87 |
logger.info("WebSocket disconnected")
|
88 |
except Exception as e:
|
89 |
logger.error(f"WebSocket error: {e}")
|
90 |
finally:
|
91 |
active_connections.discard(ws)
|
92 |
+
logger.info(f"WebSocket closed. Remaining: {len(active_connections)}")
|
93 |
|
94 |
@app.get("/conversation")
|
95 |
async def get_conversation():
|
|
|
104 |
@app.post("/settings")
|
105 |
async def update_settings(threshold: float, max_speakers: int):
|
106 |
"""Update speaker detection settings"""
|
107 |
+
return {"result": diart.update_settings(threshold, max_speakers)}
|
|
|
108 |
|
109 |
@app.post("/clear")
|
110 |
async def clear_conversation():
|
111 |
"""Clear the conversation"""
|
112 |
+
return {"result": diart.clear_conversation()}
|
|
|
113 |
|
114 |
+
# Mount Gradio UI at /ui so it doesn't override API/WebSocket routes
|
115 |
try:
|
116 |
import ui
|
117 |
+
ui.mount_ui(app, path="/ui")
|
118 |
+
logger.info("Gradio UI mounted at /ui")
|
119 |
except ImportError:
|
120 |
logger.warning("UI module not found, running in API-only mode")
|
121 |
|
122 |
if __name__ == "__main__":
|
123 |
+
port = int(os.getenv("PORT", 10000))
|
124 |
+
uvicorn.run("backend:app", host="0.0.0.0", port=port)
|
ui.py
CHANGED
@@ -203,7 +203,7 @@ def build_ui():
|
|
203 |
} else {
|
204 |
updateStatus('warning', 'Connection unstable');
|
205 |
}
|
206 |
-
|
207 |
console.error('Error updating connection info:', err);
|
208 |
}
|
209 |
}
|
@@ -406,7 +406,7 @@ def build_ui():
|
|
406 |
status_timer = gr.Timer(5)
|
407 |
status_timer.tick(fn=get_status, outputs=status_output)
|
408 |
|
409 |
-
|
410 |
return demo
|
411 |
|
412 |
# Create Gradio interface
|
|
|
203 |
} else {
|
204 |
updateStatus('warning', 'Connection unstable');
|
205 |
}
|
206 |
+
} catch (err) {
|
207 |
console.error('Error updating connection info:', err);
|
208 |
}
|
209 |
}
|
|
|
406 |
status_timer = gr.Timer(5)
|
407 |
status_timer.tick(fn=get_status, outputs=status_output)
|
408 |
|
409 |
+
|
410 |
return demo
|
411 |
|
412 |
# Create Gradio interface
|