import gradio as gr import json import time import os import asyncio import websockets import logging from fastrtc import RTCComponent import threading # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) RENDER_SIGNALING_URL = os.getenv("RENDER_SIGNALING_URL", "wss://render-signal-audio.onrender.com/stream") HF_SPACE_URL = os.getenv("HF_SPACE_URL", "androidguy-speaker-diarization.hf.space") WS_TRANSCRIPTION_URL = f"wss://{HF_SPACE_URL}/ws_transcription" class TranscriptionClient: """Client to handle WebSocket connection to transcription service""" def __init__(self, url, on_message, on_error=None, on_close=None): self.url = url self.on_message = on_message self.on_error = on_error or (lambda e: logger.error(f"WebSocket error: {e}")) self.on_close = on_close or (lambda: logger.info("WebSocket closed")) self.ws = None self.running = False self.connected = False self.reconnect_task = None self.thread = None async def connect_async(self): """Connect to WebSocket server asynchronously""" try: self.ws = await websockets.connect(self.url) self.connected = True logger.info(f"Connected to {self.url}") # Start listening for messages while self.running: try: message = await self.ws.recv() self.on_message(message) except websockets.exceptions.ConnectionClosed: logger.warning("Connection closed") self.connected = False break except Exception as e: self.on_error(e) break # Handle connection closed self.connected = False self.on_close() except Exception as e: logger.error(f"Connection error: {e}") self.connected = False self.on_error(e) def connect(self): """Start connection in a separate thread""" if self.running: return self.running = True def run_async_loop(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(self.connect_async()) loop.close() self.thread = threading.Thread(target=run_async_loop, daemon=True) self.thread.start() def disconnect(self): """Disconnect from WebSocket server""" self.running = False if self.ws: asyncio.run(self.ws.close()) self.ws = None self.connected = False class SpeakerDiarizationUI: """Main UI for speaker diarization""" def __init__(self): self.transcription_client = None self.conversation_html = "" self.system_status = {"status": "disconnected"} self.webrtc_state = "stopped" def handle_transcription_message(self, message): """Handle incoming transcription messages""" try: data = json.loads(message) message_type = data.get("type", "unknown") if message_type == "conversation_update": # Update the conversation display self.conversation_html = data.get("conversation_html", "") # Update system status if available if "status" in data: self.system_status = data.get("status", {}) elif message_type == "connection": # Handle connection status update status = data.get("status", "unknown") logger.info(f"Connection status: {status}") if "hf_space_status" in data: self.system_status["hf_space_status"] = data["hf_space_status"] elif message_type == "error": # Handle error message error_msg = data.get("message", "Unknown error") logger.error(f"Error from server: {error_msg}") except json.JSONDecodeError: logger.warning(f"Received invalid JSON: {message}") except Exception as e: logger.error(f"Error handling message: {e}") def handle_transcription_error(self, error): """Handle WebSocket errors""" logger.error(f"WebSocket error: {error}") def handle_transcription_close(self): """Handle WebSocket connection closure""" logger.info("WebSocket connection closed") self.system_status["status"] = "disconnected" def connect_to_transcription(self): """Connect to transcription WebSocket""" if self.transcription_client and self.transcription_client.connected: return self.transcription_client = TranscriptionClient( url=WS_TRANSCRIPTION_URL, on_message=self.handle_transcription_message, on_error=self.handle_transcription_error, on_close=self.handle_transcription_close ) self.transcription_client.connect() def disconnect_from_transcription(self): """Disconnect from transcription WebSocket""" if self.transcription_client: self.transcription_client.disconnect() self.transcription_client = None def start_listening(self): """Start listening to audio and connect to services""" self.connect_to_transcription() self.webrtc_state = "started" return { webrtc: gr.update(streaming=True), status_display: gr.update(value=f"Status: Connected and listening"), start_button: gr.update(visible=False), stop_button: gr.update(visible=True), clear_button: gr.update(visible=True) } def stop_listening(self): """Stop listening to audio and disconnect from services""" self.disconnect_from_transcription() self.webrtc_state = "stopped" return { webrtc: gr.update(streaming=False), status_display: gr.update(value=f"Status: Disconnected"), start_button: gr.update(visible=True), stop_button: gr.update(visible=False), clear_button: gr.update(visible=True) } def clear_conversation(self): """Clear the conversation display""" # Call API to clear conversation import requests try: response = requests.post(f"https://{HF_SPACE_URL}/clear") if response.status_code == 200: logger.info("Conversation cleared") else: logger.error(f"Failed to clear conversation: {response.status_code}") except Exception as e: logger.error(f"Error clearing conversation: {e}") # Clear local display self.conversation_html = "" return { conversation_display: gr.update(value="

Conversation cleared

") } def update_display(self): """Update conversation display - called periodically""" status_text = f"Status: " if self.webrtc_state == "started": status_text += "Connected and listening" else: status_text += "Disconnected" if self.system_status.get("hf_space_status"): status_text += f" | HF Space: {self.system_status['hf_space_status']}" return { conversation_display: gr.update(value=self.conversation_html if self.conversation_html else "

No conversation yet. Start speaking to begin transcription.

"), status_display: gr.update(value=status_text) } # Create UI instance ui = SpeakerDiarizationUI() # Custom CSS for better styling css = """ .conversation-container { border: 1px solid #ddd; border-radius: 10px; padding: 15px; margin-bottom: 10px; max-height: 500px; overflow-y: auto; background-color: white; } .speaker { margin-bottom: 12px; border-radius: 8px; padding: 8px 12px; } .speaker-label { font-weight: bold; margin-bottom: 5px; } .status-display { margin-top: 10px; padding: 5px 10px; background-color: #f0f0f0; border-radius: 5px; font-size: 0.9rem; } """ # Create Gradio interface as a function to avoid clashing def create_interface(): with gr.Blocks(css=css) as interface: gr.Markdown("# Real-Time Speaker Diarization") gr.Markdown("This app performs real-time speaker diarization on your audio. It automatically transcribes speech and identifies different speakers.") with gr.Row(): with gr.Column(scale=2): conversation_display = gr.HTML("

No conversation yet. Start speaking to begin transcription.

") with gr.Column(scale=1): status_display = gr.Markdown("Status: Disconnected", elem_classes=["status-display"]) webrtc = RTCComponent(url=RENDER_SIGNALING_URL, streaming=False, modality="audio", mode="send-receive") with gr.Row(): start_button = gr.Button("Start Listening", variant="primary") stop_button = gr.Button("Stop Listening", variant="secondary", visible=False) clear_button = gr.Button("Clear Conversation", visible=True) with gr.Accordion("Advanced Settings", open=False): speaker_threshold = gr.Slider(0.5, 0.9, value=0.65, label="Speaker Change Threshold") max_speakers = gr.Slider(2, 8, value=4, step=1, label="Maximum Number of Speakers") def update_settings(threshold, speakers): import requests try: response = requests.post( f"https://{HF_SPACE_URL}/settings", params={"threshold": threshold, "max_speakers": speakers} ) if response.status_code == 200: return gr.update(value=f"Settings updated: Threshold={threshold}, Max Speakers={speakers}") else: return gr.update(value=f"Failed to update settings: {response.status_code}") except Exception as e: return gr.update(value=f"Error updating settings: {e}") settings_button = gr.Button("Update Settings") settings_status = gr.Markdown("", elem_classes=["status-display"]) settings_button.click( update_settings, [speaker_threshold, max_speakers], [settings_status] ) # Set up event handlers start_button.click( ui.start_listening, [], [webrtc, status_display, start_button, stop_button, clear_button] ) stop_button.click( ui.stop_listening, [], [webrtc, status_display, start_button, stop_button, clear_button] ) clear_button.click( ui.clear_conversation, [], [conversation_display] ) # Periodic update every 0.5 seconds interface.load( ui.update_display, [], [conversation_display, status_display], every=0.5 ) return interface # Global interface instance interface = None # Launch the app if __name__ == "__main__": interface = create_interface() interface.launch() # Add mount_ui function for integration with FastAPI def mount_ui(app): """Mount the Gradio interface at /ui path of the FastAPI app""" global interface # Create interface if it doesn't exist yet if interface is None: interface = create_interface() # Mount Gradio app at /ui path interface.mount_in_app(app, path="/ui") return app