Saiyaswanth007's picture
Experiment 3
a905808
raw
history blame
42.3 kB
import gradio as gr
from fastapi import FastAPI
import requests
import json
from shared import DEFAULT_CHANGE_THRESHOLD, DEFAULT_MAX_SPEAKERS, ABSOLUTE_MAX_SPEAKERS, FINAL_TRANSCRIPTION_MODEL, REALTIME_TRANSCRIPTION_MODEL
print(gr.__version__)
# Connection configuration (separate signaling server from model server)
# These will be replaced at deployment time with the correct URLs
RENDER_SIGNALING_URL = "wss://render-signal-audio.onrender.com/stream"
HF_SPACE_URL = "https://androidguy-speaker-diarization.hf.space"
class ResourceManager:
"""Manages cleanup of resources"""
def __init__(self):
self.timers = []
self.cleanup_callbacks = []
def add_timer(self, timer):
self.timers.append(timer)
def add_cleanup_callback(self, callback):
self.cleanup_callbacks.append(callback)
def cleanup(self):
for timer in self.timers:
try:
timer.stop()
except:
pass
for callback in self.cleanup_callbacks:
try:
callback()
except:
pass
# Global resource manager
resource_manager = ResourceManager()
def build_ui():
"""Build Gradio UI for speaker diarization"""
with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as demo:
# Add configuration variables to page using custom component
gr.HTML(
f"""
<!-- Configuration parameters -->
<script>
window.RENDER_SIGNALING_URL = "{RENDER_SIGNALING_URL}";
window.HF_SPACE_URL = "{HF_SPACE_URL}";
window.FINAL_TRANSCRIPTION_MODEL = "{FINAL_TRANSCRIPTION_MODEL}";
window.REALTIME_TRANSCRIPTION_MODEL = "{REALTIME_TRANSCRIPTION_MODEL}";
</script>
"""
)
# Header and description
gr.Markdown("# 🎀 Live Speaker Diarization")
gr.Markdown(f"Real-time speech recognition with automatic speaker identification")
# Add transcription model info
gr.Markdown(f"**Using Models:** Final: {FINAL_TRANSCRIPTION_MODEL}, Realtime: {REALTIME_TRANSCRIPTION_MODEL}")
# Status indicator
connection_status = gr.HTML(
"""<div class="status-indicator">
<span id="status-text" style="color:#888;">Waiting to connect...</span>
<span id="status-icon" style="width:10px; height:10px; display:inline-block;
background-color:#888; border-radius:50%; margin-left:5px;"></span>
</div>"""
)
with gr.Row():
with gr.Column(scale=2):
# Conversation display with embedded JavaScript for WebRTC and audio handling
conversation_display = gr.HTML(
"""
<div class='output' id="conversation" style='padding:20px; background:#111; border-radius:10px;
min-height:400px; font-family:Arial; font-size:16px; line-height:1.5; overflow-y:auto;'>
<i>Click 'Start Listening' to begin...</i>
</div>
<script>
// Global state management
const AppState = {
rtcConnection: null,
mediaStream: null,
wsConnection: null,
statusUpdateInterval: null,
wsReconnectAttempts: 0,
maxReconnectAttempts: 5,
isConnecting: false,
isCleaningUp: false
};
// Utility functions
const Utils = {
// Check connection to HF space with timeout
async checkHfConnection(timeout = 5000) {
try {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), timeout);
const response = await fetch(`${window.HF_SPACE_URL}/health`, {
signal: controller.signal,
method: 'GET',
cache: 'no-cache'
});
clearTimeout(timeoutId);
return response.ok;
} catch (err) {
console.warn("HF Space connection failed:", err);
return false;
}
},
// Safe JSON parse
safeJsonParse(str) {
try {
return JSON.parse(str);
} catch (e) {
return null;
}
},
// Debounce function
debounce(func, wait) {
let timeout;
return function executedFunction(...args) {
const later = () => {
clearTimeout(timeout);
func(...args);
};
clearTimeout(timeout);
timeout = setTimeout(later, wait);
};
}
};
// Main streaming control
const StreamController = {
async start() {
if (AppState.isConnecting || AppState.isCleaningUp) {
console.log("Already connecting or cleaning up, ignoring start request");
return;
}
AppState.isConnecting = true;
try {
// Update status
StatusManager.update('connecting');
// Request microphone access with proper error handling
await this.setupMediaStream();
// Check backend availability
const backendAvailable = await Utils.checkHfConnection();
// Setup connections
if (backendAvailable) {
await Promise.allSettled([
this.setupWebRTC(),
this.setupWebSocket()
]);
StatusManager.update('connected');
document.getElementById("conversation").innerHTML = "<i>Connected! Start speaking...</i>";
} else {
StatusManager.update('warning', 'Backend unavailable - limited functionality');
document.getElementById("conversation").innerHTML =
"<i>Backend connection failed. Microphone active but transcription unavailable.</i>";
}
// Start status monitoring
AppState.statusUpdateInterval = setInterval(() => {
ConnectionMonitor.updateConnectionInfo();
}, 5000);
} catch (err) {
console.error('Error starting stream:', err);
StatusManager.update('error', err.message);
this.cleanup();
} finally {
AppState.isConnecting = false;
}
},
async setupMediaStream() {
try {
AppState.mediaStream = await navigator.mediaDevices.getUserMedia({
audio: {
echoCancellation: true,
noiseSuppression: true,
autoGainControl: true,
sampleRate: 16000 // Specify sample rate for consistency
}
});
} catch (err) {
let errorMessage;
switch (err.name) {
case 'NotAllowedError':
errorMessage = 'Microphone access denied. Please allow microphone access and try again.';
break;
case 'NotFoundError':
errorMessage = 'No microphone found. Please connect a microphone and try again.';
break;
case 'NotReadableError':
errorMessage = 'Microphone is being used by another application.';
break;
case 'OverconstrainedError':
errorMessage = 'Microphone constraints cannot be satisfied.';
break;
default:
errorMessage = `Microphone error: ${err.message}`;
}
throw new Error(errorMessage);
}
},
async setupWebRTC() {
try {
// Close existing connection
if (AppState.rtcConnection) {
AppState.rtcConnection.close();
}
const pc = new RTCPeerConnection({
iceServers: [
{ urls: 'stun:stun.l.google.com:19302' },
{ urls: 'stun:stun1.l.google.com:19302' }
]
});
// Add audio track
if (AppState.mediaStream) {
AppState.mediaStream.getAudioTracks().forEach(track => {
pc.addTrack(track, AppState.mediaStream);
});
}
// Connect to signaling server
const signalWs = new WebSocket(window.RENDER_SIGNALING_URL);
// Handle signaling messages
signalWs.onmessage = async (event) => {
const message = Utils.safeJsonParse(event.data);
if (!message) return;
try {
if (message.type === 'offer') {
await pc.setRemoteDescription(new RTCSessionDescription(message));
const answer = await pc.createAnswer();
await pc.setLocalDescription(answer);
signalWs.send(JSON.stringify(pc.localDescription));
} else if (message.type === 'candidate' && message.candidate) {
await pc.addIceCandidate(new RTCIceCandidate(message));
}
} catch (err) {
console.error('Error handling signaling message:', err);
}
};
// Send ICE candidates
pc.onicecandidate = (event) => {
if (event.candidate && signalWs.readyState === WebSocket.OPEN) {
signalWs.send(JSON.stringify({
type: 'candidate',
candidate: event.candidate
}));
}
};
// Handle connection state changes
pc.onconnectionstatechange = () => {
console.log('WebRTC connection state:', pc.connectionState);
if (pc.connectionState === 'failed' || pc.connectionState === 'disconnected') {
StatusManager.update('warning', 'WebRTC connection lost');
}
};
AppState.rtcConnection = pc;
// Wait for connection with timeout
await new Promise((resolve, reject) => {
const timeout = setTimeout(() => {
reject(new Error("WebRTC connection timeout (15s)"));
}, 15000);
pc.onconnectionstatechange = () => {
if (pc.connectionState === 'connected') {
clearTimeout(timeout);
resolve();
} else if (pc.connectionState === 'failed') {
clearTimeout(timeout);
reject(new Error("WebRTC connection failed"));
}
};
});
} catch (err) {
console.error('WebRTC setup error:', err);
throw new Error(`WebRTC setup failed: ${err.message}`);
}
},
setupWebSocket() {
try {
// Close existing connection
if (AppState.wsConnection) {
AppState.wsConnection.close();
}
const wsUrl = window.RENDER_SIGNALING_URL.replace('/stream', '/ws_relay');
AppState.wsConnection = new WebSocket(wsUrl);
AppState.wsConnection.onopen = () => {
console.log('WebSocket connection established');
AppState.wsReconnectAttempts = 0; // Reset on successful connection
};
AppState.wsConnection.onmessage = (event) => {
MessageHandler.process(event.data);
};
AppState.wsConnection.onerror = (error) => {
console.error('WebSocket error:', error);
StatusManager.update('warning', 'WebSocket error');
};
AppState.wsConnection.onclose = () => {
console.log('WebSocket connection closed');
// Only attempt reconnection if not cleaning up and under limit
if (!AppState.isCleaningUp &&
AppState.wsReconnectAttempts < AppState.maxReconnectAttempts) {
AppState.wsReconnectAttempts++;
const delay = Math.min(3000 * AppState.wsReconnectAttempts, 30000); // Max 30s delay
console.log(`Attempting WebSocket reconnection ${AppState.wsReconnectAttempts}/${AppState.maxReconnectAttempts} in ${delay}ms`);
setTimeout(() => {
if (!AppState.isCleaningUp) {
this.setupWebSocket();
}
}, delay);
} else if (AppState.wsReconnectAttempts >= AppState.maxReconnectAttempts) {
StatusManager.update('error', 'Max WebSocket reconnection attempts reached');
}
};
} catch (err) {
console.error('WebSocket setup error:', err);
throw new Error(`WebSocket setup failed: ${err.message}`);
}
},
stop() {
AppState.isCleaningUp = true;
this.cleanup();
StatusManager.update('disconnected');
AppState.isCleaningUp = false;
},
cleanup() {
// Close WebRTC connection
if (AppState.rtcConnection) {
AppState.rtcConnection.close();
AppState.rtcConnection = null;
}
// Close WebSocket
if (AppState.wsConnection) {
AppState.wsConnection.close();
AppState.wsConnection = null;
}
// Stop media stream
if (AppState.mediaStream) {
AppState.mediaStream.getTracks().forEach(track => track.stop());
AppState.mediaStream = null;
}
// Clear intervals
if (AppState.statusUpdateInterval) {
clearInterval(AppState.statusUpdateInterval);
AppState.statusUpdateInterval = null;
}
// Reset reconnection attempts
AppState.wsReconnectAttempts = 0;
}
};
// Message handling
const MessageHandler = {
process(data) {
try {
const message = Utils.safeJsonParse(data);
if (message) {
this.handleStructuredMessage(message);
} else {
// Fallback for plain HTML content
this.updateConversationDisplay(data);
}
this.autoScroll();
} catch (e) {
console.error('Error processing message:', e);
this.updateConversationDisplay(data);
this.autoScroll();
}
},
handleStructuredMessage(message) {
switch(message.type) {
case 'transcription':
if (message.data && message.data.conversation_html) {
this.updateConversationDisplay(message.data.conversation_html);
}
break;
case 'processing_result':
this.handleProcessingResult(message.data);
break;
case 'connection':
StatusManager.update(message.status === 'connected' ? 'connected' : 'warning');
break;
case 'connection_established':
StatusManager.update('connected');
if (message.conversation) {
this.updateConversationDisplay(message.conversation);
}
break;
case 'conversation_update':
if (message.conversation_html) {
this.updateConversationDisplay(message.conversation_html);
}
break;
case 'conversation_cleared':
this.updateConversationDisplay("<i>Conversation cleared. Start speaking again...</i>");
break;
case 'error':
console.error('Server error:', message.message);
StatusManager.update('warning', message.message);
break;
default:
console.log('Unknown message type:', message.type);
}
},
handleProcessingResult(data) {
if (!data) return;
if (data.status === "processed" && data.speaker_id !== undefined) {
const statusElem = document.getElementById('status-text');
if (statusElem) {
const speakerId = `Speaker ${data.speaker_id + 1}`;
statusElem.textContent = `Connected - ${speakerId} active`;
}
} else if (data.status === "error") {
StatusManager.update('error', data.message || 'Processing error');
}
},
updateConversationDisplay(content) {
const container = document.getElementById("conversation");
if (container) {
container.innerHTML = content;
}
},
autoScroll() {
const container = document.getElementById("conversation");
if (container) {
container.scrollTop = container.scrollHeight;
}
}
};
// Connection monitoring
const ConnectionMonitor = {
async updateConnectionInfo() {
try {
const hfConnected = await Utils.checkHfConnection(3000);
if (!hfConnected) {
StatusManager.update('warning', 'Backend unavailable');
// Try to reconnect WebSocket only if not already trying
if (!AppState.wsConnection || AppState.wsConnection.readyState !== WebSocket.OPEN) {
if (AppState.wsReconnectAttempts < AppState.maxReconnectAttempts) {
StreamController.setupWebSocket();
}
}
} else if (AppState.rtcConnection?.connectionState === 'connected' ||
AppState.rtcConnection?.iceConnectionState === 'connected') {
StatusManager.update('connected');
} else {
StatusManager.update('warning', 'Connection unstable');
}
} catch (err) {
console.error('Error updating connection info:', err);
}
}
};
// Status management
const StatusManager = {
update(status, message = '') {
const statusText = document.getElementById('status-text');
const statusIcon = document.getElementById('status-icon');
if (!statusText || !statusIcon) return;
switch(status) {
case 'connected':
statusText.textContent = message || 'Connected';
statusIcon.style.backgroundColor = '#4CAF50';
break;
case 'connecting':
statusText.textContent = 'Connecting...';
statusIcon.style.backgroundColor = '#FFC107';
break;
case 'disconnected':
statusText.textContent = 'Disconnected';
statusIcon.style.backgroundColor = '#9E9E9E';
break;
case 'error':
statusText.textContent = `Error: ${message}`;
statusIcon.style.backgroundColor = '#F44336';
break;
case 'warning':
statusText.textContent = `Warning: ${message}`;
statusIcon.style.backgroundColor = '#FF9800';
break;
default:
statusText.textContent = 'Unknown';
statusIcon.style.backgroundColor = '#9E9E9E';
}
}
};
// API functions
const ApiManager = {
async clearConversation() {
// Update UI immediately
document.getElementById("conversation").innerHTML =
"<i>Conversation cleared. Start speaking again...</i>";
// Try backend API
try {
const isConnected = await Utils.checkHfConnection();
if (isConnected) {
const response = await fetch(`${window.HF_SPACE_URL}/clear`, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
}
});
if (!response.ok) {
throw new Error(`HTTP ${response.status}`);
}
console.log("Backend conversation cleared successfully");
}
} catch (err) {
console.warn("Backend clear API failed:", err);
}
},
async updateSettings() {
const threshold = document.querySelector('input[data-testid="threshold-slider"]')?.value || 0.7;
const maxSpeakers = document.querySelector('input[data-testid="speakers-slider"]')?.value || 4;
// Update UI immediately
const statusOutput = document.getElementById('status-output');
if (statusOutput) {
statusOutput.innerHTML = `
<h2>System Status</h2>
<p>Settings updated:</p>
<ul>
<li>Threshold: ${threshold}</li>
<li>Max Speakers: ${maxSpeakers}</li>
</ul>
<p>Transcription Models:</p>
<ul>
<li>Final: ${window.FINAL_TRANSCRIPTION_MODEL || "distil-large-v3"}</li>
<li>Realtime: ${window.REALTIME_TRANSCRIPTION_MODEL || "distil-small.en"}</li>
</ul>
`;
}
// Try backend API
try {
const isConnected = await Utils.checkHfConnection();
if (isConnected) {
const response = await fetch(
`${window.HF_SPACE_URL}/settings?threshold=${threshold}&max_speakers=${maxSpeakers}`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json'
}
}
);
if (!response.ok) {
throw new Error(`HTTP ${response.status}`);
}
console.log("Backend settings updated successfully");
}
} catch (err) {
console.warn("Backend settings update failed:", err);
}
}
};
// DOM initialization
const DOMManager = {
init() {
StatusManager.update('disconnected');
// Use MutationObserver for reliable button detection
const observer = new MutationObserver(Utils.debounce(() => {
if (this.bindButtons()) {
observer.disconnect();
}
}, 100));
observer.observe(document.body, {
childList: true,
subtree: true
});
// Fallback: try binding immediately
this.bindButtons();
// Cleanup on page unload
window.addEventListener('beforeunload', () => {
StreamController.cleanup();
});
},
bindButtons() {
const buttons = {
start: document.getElementById('btn-start') ||
document.querySelector('button[aria-label="Start Listening"]'),
stop: document.getElementById('btn-stop') ||
document.querySelector('button[aria-label="Stop"]'),
clear: document.getElementById('btn-clear') ||
document.querySelector('button[aria-label="Clear"]'),
update: document.getElementById('btn-update') ||
document.querySelector('button[aria-label="Update Settings"]')
};
const allFound = Object.values(buttons).every(btn => btn !== null);
if (allFound) {
// Remove existing listeners to prevent duplicates
Object.values(buttons).forEach(btn => {
if (btn.dataset.bound !== 'true') {
btn.onclick = null;
}
});
// Bind new listeners
buttons.start.onclick = () => StreamController.start();
buttons.stop.onclick = () => StreamController.stop();
buttons.clear.onclick = () => ApiManager.clearConversation();
buttons.update.onclick = () => ApiManager.updateSettings();
// Mark as bound
Object.values(buttons).forEach(btn => {
btn.dataset.bound = 'true';
});
console.log("All buttons bound successfully");
return true;
}
return false;
}
};
// Initialize when DOM is ready
if (document.readyState === 'loading') {
document.addEventListener('DOMContentLoaded', () => DOMManager.init());
} else {
DOMManager.init();
}
</script>
""",
label="Live Conversation"
)
# Control buttons with elem_id for reliable selection
with gr.Row():
start_btn = gr.Button("▢️ Start Listening", variant="primary", size="lg", elem_id="btn-start")
stop_btn = gr.Button("⏹️ Stop", variant="stop", size="lg", elem_id="btn-stop")
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", size="lg", elem_id="btn-clear")
# Status display with elem_id for reliable selection
status_output = gr.Markdown(
"""
## System Status
Waiting to connect...
*Click Start Listening to begin*
""",
label="Status Information",
elem_id="status-output"
)
with gr.Column(scale=1):
# Settings
gr.Markdown("## βš™οΈ Settings")
threshold_slider = gr.Slider(
minimum=0.3,
maximum=0.9,
step=0.05,
value=DEFAULT_CHANGE_THRESHOLD,
label="Speaker Change Sensitivity",
info="Lower = more sensitive (more speaker changes)",
elem_id="threshold-slider"
)
max_speakers_slider = gr.Slider(
minimum=2,
maximum=ABSOLUTE_MAX_SPEAKERS,
step=1,
value=DEFAULT_MAX_SPEAKERS,
label="Maximum Speakers",
elem_id="speakers-slider"
)
update_btn = gr.Button("Update Settings", variant="secondary", elem_id="btn-update")
# Instructions
gr.Markdown("""
## πŸ“‹ Instructions
1. **Start Listening** - allows browser to access microphone
2. **Speak** - system will transcribe and identify speakers
3. **Stop** when finished
4. **Clear** to reset conversation
## 🎨 Speaker Colors
- πŸ”΄ Speaker 1 (Red)
- 🟒 Speaker 2 (Teal)
- πŸ”΅ Speaker 3 (Blue)
- 🟑 Speaker 4 (Green)
- ⭐ Speaker 5 (Yellow)
- 🟣 Speaker 6 (Plum)
- 🟀 Speaker 7 (Mint)
- 🟠 Speaker 8 (Gold)
""")
# Set up periodic status updates with proper error handling
def get_status():
"""API call to get system status - called periodically"""
try:
resp = requests.get(f"{HF_SPACE_URL}/status", timeout=5)
if resp.status_code == 200:
data = resp.json()
return data.get('status', 'No status information')
return f"HTTP {resp.status_code}"
except requests.exceptions.Timeout:
return "Connection timeout"
except requests.exceptions.ConnectionError:
return "Connection error - backend unavailable"
except Exception as e:
return f"Error: {str(e)}"
# Create timer and add to resource manager
status_timer = gr.Timer(5)
status_timer.tick(fn=get_status, outputs=status_output)
resource_manager.add_timer(status_timer)
return demo
# Create Gradio interface
demo = build_ui()
def mount_ui(app: FastAPI):
"""Mount Gradio app to FastAPI"""
app.mount("/ui", demo.app)
def cleanup_resources():
"""Cleanup function to be called on app shutdown"""
resource_manager.cleanup()
# For standalone testing
if __name__ == "__main__":
try:
demo.launch(
share=False,
debug=False,
show_error=True,
server_name="0.0.0.0",
server_port=7860
)
except KeyboardInterrupt:
print("\nShutting down...")
cleanup_resources()
except Exception as e:
print(f"Error launching demo: {e}")
cleanup_resources()