Spaces:
Sleeping
Sleeping
import gradio as gr | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from fastapi.responses import JSONResponse, RedirectResponse | |
import asyncio | |
import json | |
import logging | |
from typing import Dict, List, Optional | |
import os | |
from datetime import datetime | |
import httpx | |
import websockets | |
from fastrtc import RTCComponent | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class Config: | |
def __init__(self): | |
# URLs should not include http/https prefix as we add it contextually | |
self.hf_space_url = os.getenv("HF_SPACE_URL", "androidguy-speaker-diarization.hf.space") | |
self.render_url = os.getenv("RENDER_URL", "render-signal-audio.onrender.com") | |
self.default_threshold = float(os.getenv("DEFAULT_THRESHOLD", "0.7")) | |
self.default_max_speakers = int(os.getenv("DEFAULT_MAX_SPEAKERS", "4")) | |
self.max_speakers_limit = int(os.getenv("MAX_SPEAKERS_LIMIT", "8")) | |
config = Config() | |
class ConnectionManager: | |
"""Manage WebSocket connections""" | |
def __init__(self): | |
self.active_connections: List[WebSocket] = [] | |
self.conversation_history: List[Dict] = [] | |
async def connect(self, websocket: WebSocket): | |
await websocket.accept() | |
self.active_connections.append(websocket) | |
logger.info(f"Client connected. Total connections: {len(self.active_connections)}") | |
def disconnect(self, websocket: WebSocket): | |
if websocket in self.active_connections: | |
self.active_connections.remove(websocket) | |
logger.info(f"Client disconnected. Total connections: {len(self.active_connections)}") | |
async def send_personal_message(self, message: str, websocket: WebSocket): | |
try: | |
await websocket.send_text(message) | |
except Exception as e: | |
logger.error(f"Error sending message: {e}") | |
self.disconnect(websocket) | |
async def broadcast(self, message: str): | |
"""Send message to all connected clients""" | |
disconnected = [] | |
for connection in self.active_connections: | |
try: | |
await connection.send_text(message) | |
except Exception as e: | |
logger.error(f"Error broadcasting to connection: {e}") | |
disconnected.append(connection) | |
# Clean up disconnected clients | |
for conn in disconnected: | |
self.disconnect(conn) | |
manager = ConnectionManager() | |
def create_gradio_app(): | |
"""Create the Gradio interface""" | |
def get_client_js(): | |
"""Return the client-side JavaScript""" | |
return f""" | |
<script> | |
class SpeakerDiarizationClient {{ | |
constructor() {{ | |
this.ws = null; | |
this.mediaStream = null; | |
this.mediaRecorder = null; | |
this.isRecording = false; | |
this.baseUrl = 'https://{config.hf_space_url}'; | |
this.wsUrl = 'wss://{config.hf_space_url}/ws'; | |
this.renderUrl = 'wss://{config.render_url}/stream'; | |
this.rtcComponentSynced = false; | |
}} | |
async startRecording() {{ | |
try {{ | |
this.isRecording = true; | |
this.updateStatus('connecting', 'Connecting to server...'); | |
// Connect to WebSocket for transcription updates | |
await this.connectWebSocket(); | |
// Let the RTCComponent handle the audio streaming | |
// This will be handled by Gradio's WebRTC component | |
this.updateStatus('connected', 'Connected and listening'); | |
}} catch (error) {{ | |
console.error('Error starting recording:', error); | |
this.updateStatus('error', `Failed to start: ${{error.message}}`); | |
}} | |
}} | |
async connectWebSocket() {{ | |
return new Promise((resolve, reject) => {{ | |
// Only connect to the conversation updates WebSocket | |
this.ws = new WebSocket('wss://{config.hf_space_url}/ws_transcription'); | |
this.ws.onopen = () => {{ | |
console.log('WebSocket connected for transcription updates'); | |
resolve(); | |
}}; | |
this.ws.onmessage = (event) => {{ | |
try {{ | |
const data = JSON.parse(event.data); | |
this.handleServerMessage(data); | |
}} catch (e) {{ | |
console.error('Error parsing message:', e); | |
}} | |
}}; | |
this.ws.onerror = (error) => {{ | |
console.error('WebSocket error:', error); | |
reject(error); | |
}}; | |
this.ws.onclose = () => {{ | |
console.log('WebSocket closed'); | |
if (this.isRecording) {{ | |
// Try to reconnect after a delay | |
setTimeout(() => this.connectWebSocket(), 3000); | |
}} | |
}}; | |
}}); | |
}} | |
handleServerMessage(data) {{ | |
switch(data.type) {{ | |
case 'transcription': | |
this.updateConversation(data.conversation_html); | |
break; | |
case 'speaker_update': | |
this.updateStatus('connected', `Active: ${{data.speaker}}`); | |
break; | |
case 'error': | |
this.updateStatus('error', data.message); | |
break; | |
case 'status': | |
this.updateStatus(data.status, data.message); | |
break; | |
}} | |
}} | |
stopRecording() {{ | |
this.isRecording = false; | |
if (this.mediaRecorder && this.mediaRecorder.state !== 'inactive') {{ | |
this.mediaRecorder.stop(); | |
}} | |
if (this.mediaStream) {{ | |
this.mediaStream.getTracks().forEach(track => track.stop()); | |
this.mediaStream = null; | |
}} | |
if (this.ws) {{ | |
this.ws.close(); | |
this.ws = null; | |
}} | |
this.updateStatus('disconnected', 'Recording stopped'); | |
}} | |
async clearConversation() {{ | |
try {{ | |
const response = await fetch(`${{this.baseUrl}}/clear`, {{ | |
method: 'POST' | |
}}); | |
if (response.ok) {{ | |
this.updateConversation('<i>Conversation cleared. Start speaking...</i>'); | |
}} | |
}} catch (error) {{ | |
console.error('Error clearing conversation:', error); | |
}} | |
}} | |
updateConversation(html) {{ | |
const elem = document.getElementById('conversation'); | |
if (elem) {{ | |
elem.innerHTML = html; | |
elem.scrollTop = elem.scrollHeight; | |
}} | |
}} | |
updateStatus(status, message = '') {{ | |
const statusText = document.getElementById('status-text'); | |
const statusIcon = document.getElementById('status-icon'); | |
if (!statusText || !statusIcon) return; | |
const colors = {{ | |
'connected': '#4CAF50', | |
'connecting': '#FFC107', | |
'disconnected': '#9E9E9E', | |
'error': '#F44336', | |
'warning': '#FF9800' | |
}}; | |
const labels = {{ | |
'connected': 'Connected', | |
'connecting': 'Connecting...', | |
'disconnected': 'Disconnected', | |
'error': 'Error', | |
'warning': 'Warning' | |
}}; | |
statusText.textContent = message ? `${{labels[status]}}: ${{message}}` : labels[status]; | |
statusIcon.style.backgroundColor = colors[status] || '#9E9E9E'; | |
}} | |
}} | |
// Global client instance | |
window.diarizationClient = new SpeakerDiarizationClient(); | |
// Button event handlers | |
function startListening() {{ | |
window.diarizationClient.startRecording(); | |
}} | |
function stopListening() {{ | |
window.diarizationClient.stopRecording(); | |
}} | |
function clearConversation() {{ | |
window.diarizationClient.clearConversation(); | |
}} | |
// Initialize on page load | |
document.addEventListener('DOMContentLoaded', () => {{ | |
window.diarizationClient.updateStatus('disconnected'); | |
}}); | |
</script> | |
""" | |
with gr.Blocks( | |
title="Real-time Speaker Diarization", | |
theme=gr.themes.Soft(), | |
css=""" | |
.status-indicator { margin: 10px 0; } | |
.conversation-display { | |
background: #f8f9fa; | |
border: 1px solid #dee2e6; | |
border-radius: 8px; | |
padding: 20px; | |
min-height: 400px; | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
overflow-y: auto; | |
} | |
""" | |
) as demo: | |
# Inject client-side JavaScript | |
gr.HTML(get_client_js()) | |
# Header | |
gr.Markdown("# 🎤 Real-time Speaker Diarization") | |
gr.Markdown("Advanced speech recognition with automatic speaker identification") | |
# Status indicator | |
gr.HTML(f""" | |
<div class="status-indicator"> | |
<span id="status-text" style="color:#666;">Ready to connect</span> | |
<span id="status-icon" style="width:12px; height:12px; display:inline-block; | |
background-color:#9E9E9E; border-radius:50%; margin-left:8px;"></span> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Conversation display | |
gr.HTML(f""" | |
<div id="conversation" class="conversation-display"> | |
<i>Click 'Start Listening' to begin real-time transcription...</i> | |
</div> | |
""") | |
# WebRTC component (hidden, but functional) | |
webrtc = RTCComponent( | |
url=f"wss://{config.render_url}/stream", | |
streaming=False, | |
modality="audio", | |
mode="send-receive", | |
audio_html_attrs="style='display:none;'", # Hide the audio element | |
visible=True, # Make component visible but hide audio element | |
elements=["video", "start", "stop"] # Don't include audio element | |
) | |
# Control buttons | |
with gr.Row(): | |
start_btn = gr.Button( | |
"▶️ Start Listening", | |
variant="primary", | |
size="lg", | |
elem_id="start-btn" | |
) | |
stop_btn = gr.Button( | |
"⏹️ Stop", | |
variant="stop", | |
size="lg", | |
elem_id="stop-btn" | |
) | |
clear_btn = gr.Button( | |
"🗑️ Clear", | |
variant="secondary", | |
size="lg", | |
elem_id="clear-btn" | |
) | |
# WebRTC control functions | |
def start_webrtc(): | |
return { | |
webrtc: gr.update(streaming=True) | |
} | |
def stop_webrtc(): | |
return { | |
webrtc: gr.update(streaming=False) | |
} | |
# Connect buttons to both WebRTC and JavaScript functions | |
start_btn.click(fn=start_webrtc, outputs=[webrtc], js="startListening()") | |
stop_btn.click(fn=stop_webrtc, outputs=[webrtc], js="stopListening()") | |
clear_btn.click(fn=None, js="clearConversation()") | |
with gr.Column(scale=1): | |
gr.Markdown("## ⚙️ Settings") | |
threshold_slider = gr.Slider( | |
minimum=0.3, | |
maximum=0.9, | |
step=0.05, | |
value=config.default_threshold, | |
label="Speaker Change Sensitivity", | |
info="Lower = more sensitive to speaker changes" | |
) | |
max_speakers_slider = gr.Slider( | |
minimum=2, | |
maximum=config.max_speakers_limit, | |
step=1, | |
value=config.default_max_speakers, | |
label="Maximum Speakers" | |
) | |
# Instructions | |
gr.Markdown(""" | |
## 📋 How to Use | |
1. **Start Listening** - Grant microphone access | |
2. **Speak** - System transcribes and identifies speakers | |
3. **Stop** when finished | |
4. **Clear** to reset conversation | |
## 🎨 Speaker Colors | |
- 🔴 Speaker 1 - 🟢 Speaker 2 - 🔵 Speaker 3 - 🟡 Speaker 4 | |
- ⭐ Speaker 5 - 🟣 Speaker 6 - 🟤 Speaker 7 - 🟠 Speaker 8 | |
""") | |
return demo | |
def create_fastapi_app(): | |
"""Create the FastAPI backend""" | |
app = FastAPI(title="Speaker Diarization API") | |
async def websocket_endpoint(websocket: WebSocket): | |
await manager.connect(websocket) | |
try: | |
while True: | |
# Receive audio data | |
data = await websocket.receive_bytes() | |
# Process audio data here | |
# This is where you'd integrate your actual speaker diarization model | |
result = await process_audio_chunk(data) | |
# Send result back to client | |
await manager.send_personal_message( | |
json.dumps(result), | |
websocket | |
) | |
except WebSocketDisconnect: | |
manager.disconnect(websocket) | |
except Exception as e: | |
logger.error(f"WebSocket error: {e}") | |
manager.disconnect(websocket) | |
async def clear_conversation(): | |
"""Clear the conversation history""" | |
manager.conversation_history.clear() | |
await manager.broadcast(json.dumps({ | |
"type": "conversation_cleared" | |
})) | |
return {"status": "cleared"} | |
async def health_check(): | |
"""Health check endpoint""" | |
return { | |
"status": "healthy", | |
"timestamp": datetime.now().isoformat(), | |
"active_connections": len(manager.active_connections) | |
} | |
async def get_status(): | |
"""Get system status""" | |
return { | |
"status": "online", | |
"connections": len(manager.active_connections), | |
"conversation_length": len(manager.conversation_history) | |
} | |
return app | |
async def process_audio_chunk(audio_data: bytes) -> dict: | |
""" | |
Process audio chunk by forwarding to the backend. | |
This function is only used for the direct WebSocket API, not for the WebRTC component. | |
Note: In production, you should primarily use the WebRTC component which has its own | |
audio processing flow through the Render backend. | |
""" | |
try: | |
# Connect to the Speaker Diarization backend via WebSocket | |
websocket_url = f"wss://{config.hf_space_url}/ws_inference" | |
logger.info(f"Forwarding audio to diarization backend at {websocket_url}") | |
async with websockets.connect(websocket_url) as websocket: | |
# Send audio data | |
await websocket.send(audio_data) | |
# Receive the response | |
response = await websocket.recv() | |
# Parse the response | |
try: | |
result = json.loads(response) | |
# Add to conversation history if it's a transcription | |
if result.get("type") == "transcription" or result.get("type") == "conversation_update": | |
if "conversation_html" in result: | |
manager.conversation_history.append({ | |
"timestamp": datetime.now().isoformat(), | |
"html": result["conversation_html"] | |
}) | |
return result | |
except json.JSONDecodeError: | |
logger.error(f"Invalid JSON response: {response}") | |
return { | |
"type": "error", | |
"error": "Invalid response from backend", | |
"timestamp": datetime.now().isoformat() | |
} | |
except Exception as e: | |
logger.exception(f"Error processing audio chunk: {e}") | |
return { | |
"type": "error", | |
"error": str(e), | |
"timestamp": datetime.now().isoformat() | |
} | |
# Create both apps | |
fastapi_app = create_fastapi_app() | |
gradio_app = create_gradio_app() | |
# Root redirect - keep this simple | |
def root(): | |
"""Redirect root to the Gradio UI""" | |
return RedirectResponse(url="/ui/") # Note the trailing slash is important | |
# Mount Gradio app to FastAPI - use correct mounting method for Gradio | |
try: | |
# For newer Gradio versions | |
fastapi_app.mount("/ui", gradio_app) | |
except Exception as e: | |
# Try alternative mounting method | |
try: | |
from gradio.routes import mount_gradio_app | |
app = mount_gradio_app(fastapi_app, gradio_app, path="/ui") | |
logger.info("Mounted Gradio app using mount_gradio_app") | |
except Exception as e2: | |
logger.error(f"Failed to mount Gradio app: {e2}") | |
# As a last resort, try the simplest mounting | |
fastapi_app.mount("/ui", gradio_app.app) | |
# Add diagnostic endpoints to check connections | |
async def check_backend(): | |
"""Check connection to the Render backend""" | |
try: | |
# Check if we can connect to the WebSocket endpoint on Render | |
websocket_url = f"wss://{config.render_url}/stream" | |
logger.info(f"Checking connection to Render backend at {websocket_url}") | |
# Don't actually connect, just return status | |
return { | |
"status": "configured", | |
"render_backend_url": websocket_url, | |
"hf_space_url": f"wss://{config.hf_space_url}/ws_inference", | |
"rtc_component_config": { | |
"url": f"wss://{config.render_url}/stream", | |
"modality": "audio", | |
"mode": "send-receive" | |
} | |
} | |
except Exception as e: | |
logger.error(f"Error checking backend: {e}") | |
return { | |
"status": "error", | |
"error": str(e) | |
} | |
# Log configuration on startup | |
async def log_configuration(): | |
logger.info(f"Starting UI with configuration:") | |
logger.info(f"- HF Space URL: {config.hf_space_url}") | |
logger.info(f"- Render URL: {config.render_url}") | |
logger.info(f"- WebRTC URL: wss://{config.render_url}/stream") | |
logger.info(f"- WebSocket URL: wss://{config.hf_space_url}/ws_inference") | |
logger.info("Note: Audio will be streamed through the Render backend using WebRTC") | |
# Test connection to Render backend | |
try: | |
async with websockets.connect(f"wss://{config.render_url}/stream", ping_interval=None, ping_timeout=None) as ws: | |
logger.info("Successfully connected to Render backend WebSocket") | |
except Exception as e: | |
logger.error(f"Failed to connect to Render backend: {e}") | |
# Test connection to HF Space backend | |
try: | |
async with websockets.connect(f"wss://{config.hf_space_url}/ws_inference", ping_interval=None, ping_timeout=None) as ws: | |
logger.info("Successfully connected to HF Space WebSocket") | |
except Exception as e: | |
logger.error(f"Failed to connect to HF Space: {e}") | |
if __name__ == "__main__": | |
import uvicorn | |
# Use the correct port for Hugging Face Spaces (7860) | |
port = int(os.environ.get("PORT", 7860)) | |
logger.info(f"Starting server on port {port}") | |
uvicorn.run(fastapi_app, host="0.0.0.0", port=port) |