Spaces:
Sleeping
Sleeping
import gradio as gr | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from fastapi.responses import JSONResponse | |
import asyncio | |
import json | |
import logging | |
from typing import Dict, List, Optional | |
import os | |
from datetime import datetime | |
import httpx | |
import websockets | |
from fastrtc import RTCComponent | |
class Config: | |
def __init__(self): | |
self.hf_space_url = os.getenv("HF_SPACE_URL", "androidguy-speaker-diarization.hf.space") | |
self.render_url = os.getenv("RENDER_URL", "render-signal-audio.onrender.com") | |
self.default_threshold = float(os.getenv("DEFAULT_THRESHOLD", "0.7")) | |
self.default_max_speakers = int(os.getenv("DEFAULT_MAX_SPEAKERS", "4")) | |
self.max_speakers_limit = int(os.getenv("MAX_SPEAKERS_LIMIT", "8")) | |
config = Config() | |
logger = logging.getLogger(__name__) | |
class ConnectionManager: | |
"""Manage WebSocket connections""" | |
def __init__(self): | |
self.active_connections: List[WebSocket] = [] | |
self.conversation_history: List[Dict] = [] | |
async def connect(self, websocket: WebSocket): | |
await websocket.accept() | |
self.active_connections.append(websocket) | |
logger.info(f"Client connected. Total connections: {len(self.active_connections)}") | |
def disconnect(self, websocket: WebSocket): | |
if websocket in self.active_connections: | |
self.active_connections.remove(websocket) | |
logger.info(f"Client disconnected. Total connections: {len(self.active_connections)}") | |
async def send_personal_message(self, message: str, websocket: WebSocket): | |
try: | |
await websocket.send_text(message) | |
except Exception as e: | |
logger.error(f"Error sending message: {e}") | |
self.disconnect(websocket) | |
async def broadcast(self, message: str): | |
"""Send message to all connected clients""" | |
disconnected = [] | |
for connection in self.active_connections: | |
try: | |
await connection.send_text(message) | |
except Exception as e: | |
logger.error(f"Error broadcasting to connection: {e}") | |
disconnected.append(connection) | |
# Clean up disconnected clients | |
for conn in disconnected: | |
self.disconnect(conn) | |
manager = ConnectionManager() | |
def create_gradio_app(): | |
"""Create the Gradio interface""" | |
def get_client_js(): | |
"""Return the client-side JavaScript""" | |
return f""" | |
<script> | |
class SpeakerDiarizationClient {{ | |
constructor() {{ | |
this.ws = null; | |
this.mediaStream = null; | |
this.mediaRecorder = null; | |
this.isRecording = false; | |
this.baseUrl = 'https://{config.hf_space_url}'; | |
this.wsUrl = 'wss://{config.hf_space_url}/ws'; | |
this.renderUrl = 'wss://{config.render_url}/stream'; | |
}} | |
async startRecording() {{ | |
try {{ | |
// Request microphone access | |
this.mediaStream = await navigator.mediaDevices.getUserMedia({{ | |
audio: {{ | |
echoCancellation: true, | |
noiseSuppression: true, | |
autoGainControl: true, | |
sampleRate: 16000 | |
}} | |
}}); | |
// Set up WebSocket connection | |
await this.connectWebSocket(); | |
// Set up MediaRecorder for audio chunks | |
this.mediaRecorder = new MediaRecorder(this.mediaStream, {{ | |
mimeType: 'audio/webm;codecs=opus' | |
}}); | |
this.mediaRecorder.ondataavailable = (event) => {{ | |
if (event.data.size > 0 && this.ws && this.ws.readyState === WebSocket.OPEN) {{ | |
// Send audio chunk to server | |
this.ws.send(event.data); | |
}} | |
}}; | |
// Start recording with chunks every 1 second | |
this.mediaRecorder.start(1000); | |
this.isRecording = true; | |
this.updateStatus('connected', 'Recording started'); | |
}} catch (error) {{ | |
console.error('Error starting recording:', error); | |
this.updateStatus('error', `Failed to start: ${{error.message}}`); | |
}} | |
}} | |
async connectWebSocket() {{ | |
return new Promise((resolve, reject) => {{ | |
this.ws = new WebSocket(this.wsUrl); | |
this.ws.onopen = () => {{ | |
console.log('WebSocket connected'); | |
resolve(); | |
}}; | |
this.ws.onmessage = (event) => {{ | |
try {{ | |
const data = JSON.parse(event.data); | |
this.handleServerMessage(data); | |
}} catch (e) {{ | |
console.error('Error parsing message:', e); | |
}} | |
}}; | |
this.ws.onerror = (error) => {{ | |
console.error('WebSocket error:', error); | |
reject(error); | |
}}; | |
this.ws.onclose = () => {{ | |
console.log('WebSocket closed'); | |
if (this.isRecording) {{ | |
// Try to reconnect after a delay | |
setTimeout(() => this.connectWebSocket(), 3000); | |
}} | |
}}; | |
}}); | |
}} | |
handleServerMessage(data) {{ | |
switch(data.type) {{ | |
case 'transcription': | |
this.updateConversation(data.conversation_html); | |
break; | |
case 'speaker_update': | |
this.updateStatus('connected', `Active: ${{data.speaker}}`); | |
break; | |
case 'error': | |
this.updateStatus('error', data.message); | |
break; | |
case 'status': | |
this.updateStatus(data.status, data.message); | |
break; | |
}} | |
}} | |
stopRecording() {{ | |
this.isRecording = false; | |
if (this.mediaRecorder && this.mediaRecorder.state !== 'inactive') {{ | |
this.mediaRecorder.stop(); | |
}} | |
if (this.mediaStream) {{ | |
this.mediaStream.getTracks().forEach(track => track.stop()); | |
this.mediaStream = null; | |
}} | |
if (this.ws) {{ | |
this.ws.close(); | |
this.ws = null; | |
}} | |
this.updateStatus('disconnected', 'Recording stopped'); | |
}} | |
async clearConversation() {{ | |
try {{ | |
const response = await fetch(`${{this.baseUrl}}/clear`, {{ | |
method: 'POST' | |
}}); | |
if (response.ok) {{ | |
this.updateConversation('<i>Conversation cleared. Start speaking...</i>'); | |
}} | |
}} catch (error) {{ | |
console.error('Error clearing conversation:', error); | |
}} | |
}} | |
updateConversation(html) {{ | |
const elem = document.getElementById('conversation'); | |
if (elem) {{ | |
elem.innerHTML = html; | |
elem.scrollTop = elem.scrollHeight; | |
}} | |
}} | |
updateStatus(status, message = '') {{ | |
const statusText = document.getElementById('status-text'); | |
const statusIcon = document.getElementById('status-icon'); | |
if (!statusText || !statusIcon) return; | |
const colors = {{ | |
'connected': '#4CAF50', | |
'connecting': '#FFC107', | |
'disconnected': '#9E9E9E', | |
'error': '#F44336', | |
'warning': '#FF9800' | |
}}; | |
const labels = {{ | |
'connected': 'Connected', | |
'connecting': 'Connecting...', | |
'disconnected': 'Disconnected', | |
'error': 'Error', | |
'warning': 'Warning' | |
}}; | |
statusText.textContent = message ? `${{labels[status]}}: ${{message}}` : labels[status]; | |
statusIcon.style.backgroundColor = colors[status] || '#9E9E9E'; | |
}} | |
}} | |
// Global client instance | |
window.diarizationClient = new SpeakerDiarizationClient(); | |
// Button event handlers | |
function startListening() {{ | |
window.diarizationClient.startRecording(); | |
}} | |
function stopListening() {{ | |
window.diarizationClient.stopRecording(); | |
}} | |
function clearConversation() {{ | |
window.diarizationClient.clearConversation(); | |
}} | |
// Initialize on page load | |
document.addEventListener('DOMContentLoaded', () => {{ | |
window.diarizationClient.updateStatus('disconnected'); | |
}}); | |
</script> | |
""" | |
with gr.Blocks( | |
title="Real-time Speaker Diarization", | |
theme=gr.themes.Soft(), | |
css=""" | |
.status-indicator { margin: 10px 0; } | |
.conversation-display { | |
background: #f8f9fa; | |
border: 1px solid #dee2e6; | |
border-radius: 8px; | |
padding: 20px; | |
min-height: 400px; | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
overflow-y: auto; | |
} | |
""" | |
) as demo: | |
# Inject client-side JavaScript | |
gr.HTML(get_client_js()) | |
# Header | |
gr.Markdown("# π€ Real-time Speaker Diarization") | |
gr.Markdown("Advanced speech recognition with automatic speaker identification") | |
# Status indicator | |
gr.HTML(f""" | |
<div class="status-indicator"> | |
<span id="status-text" style="color:#666;">Ready to connect</span> | |
<span id="status-icon" style="width:12px; height:12px; display:inline-block; | |
background-color:#9E9E9E; border-radius:50%; margin-left:8px;"></span> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Conversation display | |
gr.HTML(f""" | |
<div id="conversation" class="conversation-display"> | |
<i>Click 'Start Listening' to begin real-time transcription...</i> | |
</div> | |
""") | |
# WebRTC component (hidden, but functional) | |
webrtc = RTCComponent( | |
url=f"wss://{config.render_url}/stream", | |
streaming=False, | |
modality="audio", | |
mode="send-receive", | |
visible=False # Hidden but functional | |
) | |
# Control buttons | |
with gr.Row(): | |
start_btn = gr.Button( | |
"βΆοΈ Start Listening", | |
variant="primary", | |
size="lg", | |
elem_id="start-btn" | |
) | |
stop_btn = gr.Button( | |
"βΉοΈ Stop", | |
variant="stop", | |
size="lg", | |
elem_id="stop-btn" | |
) | |
clear_btn = gr.Button( | |
"ποΈ Clear", | |
variant="secondary", | |
size="lg", | |
elem_id="clear-btn" | |
) | |
# WebRTC control functions | |
def start_webrtc(): | |
return { | |
webrtc: gr.update(streaming=True) | |
} | |
def stop_webrtc(): | |
return { | |
webrtc: gr.update(streaming=False) | |
} | |
# Connect buttons to both WebRTC and JavaScript functions | |
start_btn.click(fn=start_webrtc, outputs=[webrtc], js="startListening()") | |
stop_btn.click(fn=stop_webrtc, outputs=[webrtc], js="stopListening()") | |
clear_btn.click(fn=None, js="clearConversation()") | |
with gr.Column(scale=1): | |
gr.Markdown("## βοΈ Settings") | |
threshold_slider = gr.Slider( | |
minimum=0.3, | |
maximum=0.9, | |
step=0.05, | |
value=config.default_threshold, | |
label="Speaker Change Sensitivity", | |
info="Lower = more sensitive to speaker changes" | |
) | |
max_speakers_slider = gr.Slider( | |
minimum=2, | |
maximum=config.max_speakers_limit, | |
step=1, | |
value=config.default_max_speakers, | |
label="Maximum Speakers" | |
) | |
# Instructions | |
gr.Markdown(""" | |
## π How to Use | |
1. **Start Listening** - Grant microphone access | |
2. **Speak** - System transcribes and identifies speakers | |
3. **Stop** when finished | |
4. **Clear** to reset conversation | |
## π¨ Speaker Colors | |
- π΄ Speaker 1 - π’ Speaker 2 - π΅ Speaker 3 - π‘ Speaker 4 | |
- β Speaker 5 - π£ Speaker 6 - π€ Speaker 7 - π Speaker 8 | |
""") | |
return demo | |
def create_fastapi_app(): | |
"""Create the FastAPI backend""" | |
app = FastAPI(title="Speaker Diarization API") | |
async def websocket_endpoint(websocket: WebSocket): | |
await manager.connect(websocket) | |
try: | |
while True: | |
# Receive audio data | |
data = await websocket.receive_bytes() | |
# Process audio data here | |
# This is where you'd integrate your actual speaker diarization model | |
result = await process_audio_chunk(data) | |
# Send result back to client | |
await manager.send_personal_message( | |
json.dumps(result), | |
websocket | |
) | |
except WebSocketDisconnect: | |
manager.disconnect(websocket) | |
except Exception as e: | |
logger.error(f"WebSocket error: {e}") | |
manager.disconnect(websocket) | |
async def clear_conversation(): | |
"""Clear the conversation history""" | |
manager.conversation_history.clear() | |
await manager.broadcast(json.dumps({ | |
"type": "conversation_cleared" | |
})) | |
return {"status": "cleared"} | |
async def health_check(): | |
"""Health check endpoint""" | |
return { | |
"status": "healthy", | |
"timestamp": datetime.now().isoformat(), | |
"active_connections": len(manager.active_connections) | |
} | |
async def get_status(): | |
"""Get system status""" | |
return { | |
"status": "online", | |
"connections": len(manager.active_connections), | |
"conversation_length": len(manager.conversation_history) | |
} | |
return app | |
async def process_audio_chunk(audio_data: bytes) -> dict: | |
""" | |
Process audio chunk by forwarding to the backend. | |
This function is only used for the direct WebSocket API, not for the WebRTC component. | |
Note: In production, you should primarily use the WebRTC component which has its own | |
audio processing flow through the Render backend. | |
""" | |
try: | |
# Connect to the Speaker Diarization backend via WebSocket | |
websocket_url = f"wss://{config.hf_space_url}/ws_inference" | |
logger.info(f"Forwarding audio to diarization backend at {websocket_url}") | |
async with websockets.connect(websocket_url) as websocket: | |
# Send audio data | |
await websocket.send(audio_data) | |
# Receive the response | |
response = await websocket.recv() | |
# Parse the response | |
try: | |
result = json.loads(response) | |
# Add to conversation history if it's a transcription | |
if result.get("type") == "transcription" or result.get("type") == "conversation_update": | |
if "conversation_html" in result: | |
manager.conversation_history.append({ | |
"timestamp": datetime.now().isoformat(), | |
"html": result["conversation_html"] | |
}) | |
return result | |
except json.JSONDecodeError: | |
logger.error(f"Invalid JSON response: {response}") | |
return { | |
"type": "error", | |
"error": "Invalid response from backend", | |
"timestamp": datetime.now().isoformat() | |
} | |
except Exception as e: | |
logger.exception(f"Error processing audio chunk: {e}") | |
return { | |
"type": "error", | |
"error": str(e), | |
"timestamp": datetime.now().isoformat() | |
} | |
# Create both apps | |
fastapi_app = create_fastapi_app() | |
gradio_app = create_gradio_app() | |
# Mount Gradio app to FastAPI | |
fastapi_app.mount("/", gradio_app.app) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(fastapi_app, host="0.0.0.0", port=7860) |