Saiyaswanth007's picture
experiment 2
17cb251
raw
history blame
42.1 kB
import gradio as gr
from fastapi import FastAPI
from shared import DEFAULT_CHANGE_THRESHOLD, DEFAULT_MAX_SPEAKERS, ABSOLUTE_MAX_SPEAKERS, FINAL_TRANSCRIPTION_MODEL, REALTIME_TRANSCRIPTION_MODEL
import os
# Connection configuration (separate signaling server from model server)
# These will be replaced with environment variables or defaults
RENDER_SIGNALING_URL = os.environ.get("RENDER_SIGNALING_URL", "wss://render-signal-audio.onrender.com/stream")
HF_SPACE_URL = os.environ.get("HF_SPACE_URL", "https://androidguy-speaker-diarization.hf.space")
def build_ui():
"""Build Gradio UI for speaker diarization with improved reliability"""
with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as demo:
# Add configuration variables to page using custom component
gr.HTML(
f"""
<!-- Configuration parameters -->
<script>
window.RENDER_SIGNALING_URL = "{RENDER_SIGNALING_URL}";
window.HF_SPACE_URL = "{HF_SPACE_URL}";
window.FINAL_TRANSCRIPTION_MODEL = "{FINAL_TRANSCRIPTION_MODEL}";
window.REALTIME_TRANSCRIPTION_MODEL = "{REALTIME_TRANSCRIPTION_MODEL}";
</script>
"""
)
# Header and description
gr.Markdown("# 🎀 Live Speaker Diarization")
gr.Markdown("Real-time speech recognition with automatic speaker identification")
# Add transcription model info
gr.Markdown(f"**Using Models:** Final: {FINAL_TRANSCRIPTION_MODEL}, Realtime: {REALTIME_TRANSCRIPTION_MODEL}")
# Status indicator
connection_status = gr.HTML(
"""<div class="status-indicator">
<span id="status-text" style="color:#888;">Waiting to connect...</span>
<span id="status-icon" style="width:10px; height:10px; display:inline-block;
background-color:#888; border-radius:50%; margin-left:5px;"></span>
</div>""",
elem_id="connection-status"
)
with gr.Row():
with gr.Column(scale=2):
# Conversation display with embedded JavaScript for WebRTC and audio handling
conversation_display = gr.HTML(
"""
<div class='output' id="conversation" style='padding:20px; background:#111; border-radius:10px;
min-height:400px; font-family:Arial; font-size:16px; line-height:1.5; overflow-y:auto;'>
<i>Click 'Start Listening' to begin...</i>
</div>
<script>
// Global variables
let rtcConnection;
let mediaStream;
let wsConnection;
let statusUpdateInterval;
let isOfflineMode = false;
// Check connection to HF space with timeout
async function checkHfConnection() {
try {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 5000);
const response = await fetch(`${window.HF_SPACE_URL}/health`, {
signal: controller.signal
});
clearTimeout(timeoutId);
return response.ok;
} catch (err) {
console.warn("HF Space connection failed:", err);
return false;
}
}
// Start the connection and audio streaming with robust error handling
async function startStreaming() {
try {
// Update status
updateStatus('connecting');
// First check backend connectivity
const backendAvailable = await checkHfConnection();
isOfflineMode = !backendAvailable;
// Request microphone access - this works even offline
try {
mediaStream = await navigator.mediaDevices.getUserMedia({audio: {
echoCancellation: true,
noiseSuppression: true,
autoGainControl: true
}});
} catch (micErr) {
console.error('Microphone access error:', micErr);
updateStatus('error', 'Microphone access denied: ' + micErr.message);
return;
}
if (backendAvailable) {
// Try WebRTC connection
try {
await setupWebRTC();
} catch (rtcErr) {
console.error("WebRTC setup failed:", rtcErr);
// Continue even if WebRTC fails
}
// Try WebSocket connection
try {
setupWebSocket();
} catch (wsErr) {
console.error("WebSocket setup failed:", wsErr);
// Continue even if WebSocket fails
}
updateStatus('connected');
document.getElementById("conversation").innerHTML = "<i>Connected! Start speaking...</i>";
} else {
updateStatus('warning', 'Running in offline mode - limited functionality');
document.getElementById("conversation").innerHTML =
"<i>Backend connection failed. Microphone active but transcription unavailable.</i>";
}
// Start status update interval regardless
statusUpdateInterval = setInterval(updateConnectionInfo, 5000);
} catch (err) {
console.error('Error starting stream:', err);
updateStatus('error', err.message);
}
}
// Set up WebRTC connection to Render signaling server
async function setupWebRTC() {
try {
if (rtcConnection) {
rtcConnection.close();
}
// Use FastRTC's connection approach
const pc = new RTCPeerConnection({
iceServers: [
{ urls: 'stun:stun.l.google.com:19302' },
{ urls: 'stun:stun1.l.google.com:19302' }
]
});
// Add audio track
mediaStream.getAudioTracks().forEach(track => {
pc.addTrack(track, mediaStream);
});
// Connect to FastRTC signaling via WebSocket with timeout
const signalWs = new WebSocket(window.RENDER_SIGNALING_URL);
// Set connection timeout
const connectionTimeout = setTimeout(() => {
if (signalWs.readyState !== WebSocket.OPEN) {
signalWs.close();
throw new Error("WebRTC signaling connection timeout");
}
}, 10000);
// Wait for connection to open
await new Promise((resolve, reject) => {
signalWs.onopen = () => {
clearTimeout(connectionTimeout);
resolve();
};
signalWs.onerror = (err) => {
clearTimeout(connectionTimeout);
reject(new Error("WebRTC signaling connection failed"));
};
});
// Handle signaling messages
signalWs.onmessage = async (event) => {
try {
const message = JSON.parse(event.data);
if (message.type === 'offer') {
await pc.setRemoteDescription(new RTCSessionDescription(message));
const answer = await pc.createAnswer();
await pc.setLocalDescription(answer);
signalWs.send(JSON.stringify(pc.localDescription));
} else if (message.type === 'candidate') {
if (message.candidate) {
await pc.addIceCandidate(new RTCIceCandidate(message));
}
}
} catch (err) {
console.error("Error processing signaling message:", err);
}
};
// Send ICE candidates
pc.onicecandidate = (event) => {
if (event.candidate) {
signalWs.send(JSON.stringify({
type: 'candidate',
candidate: event.candidate
}));
}
};
// Keep connection reference
rtcConnection = pc;
// Wait for connection to be established with timeout
await new Promise((resolve, reject) => {
const timeout = setTimeout(() => reject(new Error("WebRTC connection timeout")), 15000);
pc.onconnectionstatechange = () => {
console.log("WebRTC connection state:", pc.connectionState);
if (pc.connectionState === 'connected') {
clearTimeout(timeout);
resolve();
} else if (pc.connectionState === 'failed' || pc.connectionState === 'disconnected' || pc.connectionState === 'closed') {
clearTimeout(timeout);
reject(new Error("WebRTC connection failed"));
}
};
// Also check ice connection state as fallback
pc.oniceconnectionstatechange = () => {
console.log("ICE connection state:", pc.iceConnectionState);
if (pc.iceConnectionState === 'connected' || pc.iceConnectionState === 'completed') {
clearTimeout(timeout);
resolve();
} else if (pc.iceConnectionState === 'failed' || pc.iceConnectionState === 'disconnected' || pc.iceConnectionState === 'closed') {
clearTimeout(timeout);
reject(new Error("ICE connection failed"));
}
};
});
updateStatus('connected');
console.log("WebRTC connection established successfully");
} catch (err) {
console.error('WebRTC setup error:', err);
updateStatus('warning', 'WebRTC setup issue: ' + err.message);
throw err;
}
}
// Set up WebSocket connection to HF Space for conversation updates
function setupWebSocket() {
try {
// Close existing connection if any
if (wsConnection) {
wsConnection.close();
}
const wsUrl = window.RENDER_SIGNALING_URL.replace('stream', 'ws_relay');
wsConnection = new WebSocket(wsUrl);
// Set connection timeout
const connectionTimeout = setTimeout(() => {
if (wsConnection.readyState !== WebSocket.OPEN) {
wsConnection.close();
throw new Error("WebSocket connection timeout");
}
}, 10000);
wsConnection.onopen = () => {
clearTimeout(connectionTimeout);
console.log('WebSocket connection established');
};
wsConnection.onmessage = (event) => {
try {
// Parse the JSON message
const message = JSON.parse(event.data);
// Process different message types
switch(message.type) {
case 'transcription':
// Handle transcription data
if (message && message.data && typeof message.data === 'object') {
document.getElementById("conversation").innerHTML = message.data.conversation_html ||
JSON.stringify(message.data);
}
break;
case 'processing_result':
// Handle individual audio chunk processing result
console.log('Processing result:', message.data);
// Update status info if needed
if (message.data && message.data.status === "processed") {
const statusElem = document.getElementById('status-text');
if (statusElem) {
const speakerId = message.data.speaker_id !== undefined ?
`Speaker ${message.data.speaker_id + 1}` : '';
if (speakerId) {
statusElem.textContent = `Connected - ${speakerId} active`;
}
}
} else if (message.data && message.data.status === "error") {
updateStatus('error', message.data.message || 'Processing error');
}
break;
case 'connection':
console.log('Connection status:', message.status);
updateStatus(message.status === 'connected' ? 'connected' : 'warning');
break;
case 'connection_established':
console.log('Connection established:', message);
updateStatus('connected');
// If initial conversation is provided, display it
if (message.conversation) {
document.getElementById("conversation").innerHTML = message.conversation;
}
break;
case 'conversation_update':
if (message.conversation_html) {
document.getElementById("conversation").innerHTML = message.conversation_html;
}
break;
case 'conversation_cleared':
document.getElementById("conversation").innerHTML =
"<i>Conversation cleared. Start speaking again...</i>";
break;
case 'error':
console.error('Error message from server:', message.message);
updateStatus('warning', message.message);
break;
default:
// If it's just HTML content without proper JSON structure (legacy format)
document.getElementById("conversation").innerHTML = event.data;
}
// Auto-scroll to bottom
const container = document.getElementById("conversation");
container.scrollTop = container.scrollHeight;
} catch (e) {
// Fallback for non-JSON messages (legacy format)
document.getElementById("conversation").innerHTML = event.data;
// Auto-scroll to bottom
const container = document.getElementById("conversation");
container.scrollTop = container.scrollHeight;
}
};
wsConnection.onerror = (error) => {
clearTimeout(connectionTimeout);
console.error('WebSocket error:', error);
updateStatus('warning', 'WebSocket error');
};
wsConnection.onclose = () => {
console.log('WebSocket connection closed');
// Try to reconnect after a delay if not in offline mode
if (!isOfflineMode) {
setTimeout(() => {
try {
setupWebSocket();
} catch (e) {
console.error("Failed to reconnect WebSocket:", e);
}
}, 3000);
}
};
} catch (err) {
console.error("WebSocket setup error:", err);
throw err;
}
}
// Update connection info in the UI with better error handling
async function updateConnectionInfo() {
try {
const hfConnected = await checkHfConnection();
if (!hfConnected) {
// If we were online but now offline, update mode
if (!isOfflineMode) {
isOfflineMode = true;
updateStatus('warning', 'Backend unavailable - limited functionality');
}
} else {
// If we were offline but now online, update mode
if (isOfflineMode) {
isOfflineMode = false;
// Try to reconnect services
try {
if (!rtcConnection || rtcConnection.connectionState !== 'connected') {
await setupWebRTC();
}
if (!wsConnection || wsConnection.readyState !== WebSocket.OPEN) {
setupWebSocket();
}
updateStatus('connected');
} catch (e) {
console.warn("Failed to reconnect services:", e);
updateStatus('warning', 'Connection partially restored');
}
} else if (rtcConnection?.connectionState === 'connected' ||
rtcConnection?.iceConnectionState === 'connected') {
updateStatus('connected');
} else {
updateStatus('warning', 'Connection unstable');
// Try to reconnect if needed
if (!rtcConnection ||
rtcConnection.connectionState === 'failed' ||
rtcConnection.connectionState === 'disconnected') {
try {
await setupWebRTC();
} catch (e) {
console.warn("Failed to reconnect WebRTC:", e);
}
}
if (!wsConnection || wsConnection.readyState !== WebSocket.OPEN) {
try {
setupWebSocket();
} catch (e) {
console.warn("Failed to reconnect WebSocket:", e);
}
}
}
}
} catch (err) {
console.error('Error updating connection info:', err);
// Don't update status here to avoid flickering
}
}
// Update status indicator
function updateStatus(status, message = '') {
const statusText = document.getElementById('status-text');
const statusIcon = document.getElementById('status-icon');
if (!statusText || !statusIcon) return;
switch(status) {
case 'connected':
statusText.textContent = 'Connected';
statusIcon.style.backgroundColor = '#4CAF50';
break;
case 'connecting':
statusText.textContent = 'Connecting...';
statusIcon.style.backgroundColor = '#FFC107';
break;
case 'disconnected':
statusText.textContent = 'Disconnected';
statusIcon.style.backgroundColor = '#9E9E9E';
break;
case 'error':
statusText.textContent = 'Error: ' + message;
statusIcon.style.backgroundColor = '#F44336';
break;
case 'warning':
statusText.textContent = 'Warning: ' + message;
statusIcon.style.backgroundColor = '#FF9800';
break;
default:
statusText.textContent = 'Unknown';
statusIcon.style.backgroundColor = '#9E9E9E';
}
}
// Stop streaming and clean up
function stopStreaming() {
// Close WebRTC connection
if (rtcConnection) {
rtcConnection.close();
rtcConnection = null;
}
// Close WebSocket
if (wsConnection) {
wsConnection.close();
wsConnection = null;
}
// Stop all tracks in media stream
if (mediaStream) {
mediaStream.getTracks().forEach(track => track.stop());
mediaStream = null;
}
// Clear interval
if (statusUpdateInterval) {
clearInterval(statusUpdateInterval);
statusUpdateInterval = null;
}
// Update status
updateStatus('disconnected');
}
// Clear conversation with better error handling and offline mode support
function clearConversation() {
// First update the UI immediately regardless of backend availability
document.getElementById("conversation").innerHTML =
"<i>Conversation cleared. Start speaking again...</i>";
// Then try to update on the backend if available
if (!isOfflineMode) {
checkHfConnection().then(isConnected => {
if (isConnected) {
return fetch(`${window.HF_SPACE_URL}/clear`, {
method: 'POST'
});
} else {
throw new Error("Backend unavailable");
}
})
.then(resp => resp.json())
.then(data => {
console.log("Backend conversation cleared successfully");
})
.catch(err => {
console.warn("Backend clear API failed:", err);
// No need to update UI again as we already did it above
});
}
}
// Update settings with better error handling and offline mode support
function updateSettings() {
const threshold = document.querySelector('input[data-testid="threshold-slider"]')?.value ||
document.getElementById('threshold-slider')?.value;
const maxSpeakers = document.querySelector('input[data-testid="speakers-slider"]')?.value ||
document.getElementById('speakers-slider')?.value;
if (!threshold || !maxSpeakers) {
console.error("Could not find slider values");
return;
}
// First update the UI immediately regardless of API success
const statusOutput = document.getElementById('status-output');
if (statusOutput) {
statusOutput.innerHTML = `
<h2>System Status</h2>
<p>Settings updated:</p>
<ul>
<li>Threshold: ${threshold}</li>
<li>Max Speakers: ${maxSpeakers}</li>
</ul>
<p>Transcription Models:</p>
<ul>
<li>Final: ${window.FINAL_TRANSCRIPTION_MODEL || "distil-large-v3"}</li>
<li>Realtime: ${window.REALTIME_TRANSCRIPTION_MODEL || "distil-small.en"}</li>
</ul>
`;
}
// Then try to update on the backend if available and not in offline mode
if (!isOfflineMode) {
checkHfConnection().then(isConnected => {
if (isConnected) {
return fetch(`${window.HF_SPACE_URL}/settings?threshold=${threshold}&max_speakers=${maxSpeakers}`, {
method: 'POST'
});
} else {
throw new Error("Backend unavailable");
}
})
.then(resp => resp.json())
.then(data => {
console.log("Backend settings updated successfully:", data);
})
.catch(err => {
console.warn("Backend settings update failed:", err);
// No need to update UI again as we already did it above
});
}
}
// Set up event listeners when the DOM is loaded
document.addEventListener('DOMContentLoaded', () => {
updateStatus('disconnected');
// Function to find and bind buttons with retries
function findAndBindButtons() {
// Try to find buttons by ID first (most reliable)
let startBtn = document.getElementById('btn-start');
let stopBtn = document.getElementById('btn-stop');
let clearBtn = document.getElementById('btn-clear');
let updateBtn = document.getElementById('btn-update');
// Fallback to aria-label if IDs aren't found
if (!startBtn) startBtn = document.querySelector('button[aria-label="Start Listening"]');
if (!stopBtn) stopBtn = document.querySelector('button[aria-label="Stop"]');
if (!clearBtn) stopBtn = document.querySelector('button[aria-label="Clear"]');
if (!updateBtn) updateBtn = document.querySelector('button[aria-label="Update Settings"]');
// Fallback to text content as last resort
if (!startBtn) startBtn = Array.from(document.querySelectorAll('button')).find(btn => btn.textContent.includes('Start'));
if (!stopBtn) stopBtn = Array.from(document.querySelectorAll('button')).find(btn => btn.textContent.includes('Stop'));
if (!clearBtn) clearBtn = Array.from(document.querySelectorAll('button')).find(btn => btn.textContent.includes('Clear'));
if (!updateBtn) updateBtn = Array.from(document.querySelectorAll('button')).find(btn => btn.textContent.includes('Update'));
// Check if all buttons are found
const buttonsFound = startBtn && stopBtn && clearBtn && updateBtn;
if (buttonsFound) {
console.log("All buttons found, binding events");
// Bind event handlers
startBtn.onclick = () => startStreaming();
stopBtn.onclick = () => stopStreaming();
clearBtn.onclick = () => clearConversation();
updateBtn.onclick = () => updateSettings();
// Add data attributes to make it clear these are bound
startBtn.setAttribute('data-bound', 'true');
stopBtn.setAttribute('data-bound', 'true');
clearBtn.setAttribute('data-bound', 'true');
updateBtn.setAttribute('data-bound', 'true');
return true;
} else {
console.log("Not all buttons found, will retry");
return false;
}
}
// Try to bind immediately
if (!findAndBindButtons()) {
// If not successful, set up a retry mechanism
let retryCount = 0;
const maxRetries = 20; // More retries, longer interval
const retryInterval = 300; // 300ms between retries
const retryBinding = setInterval(() => {
if (findAndBindButtons() || ++retryCount >= maxRetries) {
clearInterval(retryBinding);
if (retryCount >= maxRetries) {
console.warn("Failed to find all buttons after maximum retries");
}
}
}, retryInterval);
}
});
</script>
""",
label="Live Conversation"
)
# Control buttons with elem_id for reliable selection
with gr.Row():
start_btn = gr.Button("▢️ Start Listening", variant="primary", size="lg", elem_id="btn-start")
stop_btn = gr.Button("⏹️ Stop", variant="stop", size="lg", elem_id="btn-stop")
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", size="lg", elem_id="btn-clear")
# Status display with elem_id for reliable selection
status_output = gr.Markdown(
"""
## System Status
Waiting to connect...
*Click Start Listening to begin*
""",
label="Status Information",
elem_id="status-output"
)
with gr.Column(scale=1):
# Settings
gr.Markdown("## βš™οΈ Settings")
threshold_slider = gr.Slider(
minimum=0.3,
maximum=0.9,
step=0.05,
value=DEFAULT_CHANGE_THRESHOLD,
label="Speaker Change Sensitivity",
info="Lower = more sensitive (more speaker changes)",
elem_id="threshold-slider"
)
max_speakers_slider = gr.Slider(
minimum=2,
maximum=ABSOLUTE_MAX_SPEAKERS,
step=1,
value=DEFAULT_MAX_SPEAKERS,
label="Maximum Speakers",
elem_id="speakers-slider"
)
update_btn = gr.Button("Update Settings", variant="secondary", elem_id="btn-update")
# Instructions
gr.Markdown("""
## πŸ“‹ Instructions
1. **Start Listening** - allows browser to access microphone
2. **Speak** - system will transcribe and identify speakers
3. **Stop** when finished
4. **Clear** to reset conversation
## 🎨 Speaker Colors
- πŸ”΄ Speaker 1 (Red)
- 🟒 Speaker 2 (Teal)
- πŸ”΅ Speaker 3 (Blue)
- 🟑 Speaker 4 (Green)
- ⭐ Speaker 5 (Yellow)
- 🟣 Speaker 6 (Plum)
- 🟀 Speaker 7 (Mint)
- 🟠 Speaker 8 (Gold)
""")
# Function to get backend status (for periodic updates)
def get_status():
"""API call to get system status - called periodically"""
import requests
try:
# Use a short timeout to prevent UI hanging
resp = requests.get(f"{HF_SPACE_URL}/status", timeout=2)
if resp.status_code == 200:
return resp.json().get('formatted_text', 'No status information')
return "Error getting status"
except Exception as e:
return f"Status update unavailable: Backend may be offline"
# Set up periodic status updates with shorter interval and error handling
status_timer = gr.Timer(10) # 10 seconds between updates
status_timer.tick(fn=get_status, outputs=status_output)
return demo
# Create Gradio interface
demo = build_ui()
def mount_ui(app: FastAPI):
"""Mount Gradio app to FastAPI"""
app.mount("/ui", demo.app)
# For standalone testing
if __name__ == "__main__":
demo.launch()