Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
import time | |
import os | |
import asyncio | |
import websockets | |
import logging | |
from fastrtc import RTCComponent | |
import threading | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
RENDER_SIGNALING_URL = os.getenv("RENDER_SIGNALING_URL", "wss://render-signal-audio.onrender.com/stream") | |
HF_SPACE_URL = os.getenv("HF_SPACE_URL", "androidguy-speaker-diarization.hf.space") | |
WS_TRANSCRIPTION_URL = f"wss://{HF_SPACE_URL}/ws_transcription" | |
class TranscriptionClient: | |
"""Client to handle WebSocket connection to transcription service""" | |
def __init__(self, url, on_message, on_error=None, on_close=None): | |
self.url = url | |
self.on_message = on_message | |
self.on_error = on_error or (lambda e: logger.error(f"WebSocket error: {e}")) | |
self.on_close = on_close or (lambda: logger.info("WebSocket closed")) | |
self.ws = None | |
self.running = False | |
self.connected = False | |
self.reconnect_task = None | |
self.thread = None | |
async def connect_async(self): | |
"""Connect to WebSocket server asynchronously""" | |
try: | |
self.ws = await websockets.connect(self.url) | |
self.connected = True | |
logger.info(f"Connected to {self.url}") | |
# Start listening for messages | |
while self.running: | |
try: | |
message = await self.ws.recv() | |
self.on_message(message) | |
except websockets.exceptions.ConnectionClosed: | |
logger.warning("Connection closed") | |
self.connected = False | |
break | |
except Exception as e: | |
self.on_error(e) | |
break | |
# Handle connection closed | |
self.connected = False | |
self.on_close() | |
except Exception as e: | |
logger.error(f"Connection error: {e}") | |
self.connected = False | |
self.on_error(e) | |
def connect(self): | |
"""Start connection in a separate thread""" | |
if self.running: | |
return | |
self.running = True | |
def run_async_loop(): | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
loop.run_until_complete(self.connect_async()) | |
loop.close() | |
self.thread = threading.Thread(target=run_async_loop, daemon=True) | |
self.thread.start() | |
def disconnect(self): | |
"""Disconnect from WebSocket server""" | |
self.running = False | |
if self.ws: | |
asyncio.run(self.ws.close()) | |
self.ws = None | |
self.connected = False | |
class SpeakerDiarizationUI: | |
"""Main UI for speaker diarization""" | |
def __init__(self): | |
self.transcription_client = None | |
self.conversation_html = "" | |
self.system_status = {"status": "disconnected"} | |
self.webrtc_state = "stopped" | |
def handle_transcription_message(self, message): | |
"""Handle incoming transcription messages""" | |
try: | |
data = json.loads(message) | |
message_type = data.get("type", "unknown") | |
if message_type == "conversation_update": | |
# Update the conversation display | |
self.conversation_html = data.get("conversation_html", "") | |
# Update system status if available | |
if "status" in data: | |
self.system_status = data.get("status", {}) | |
elif message_type == "connection": | |
# Handle connection status update | |
status = data.get("status", "unknown") | |
logger.info(f"Connection status: {status}") | |
if "hf_space_status" in data: | |
self.system_status["hf_space_status"] = data["hf_space_status"] | |
elif message_type == "error": | |
# Handle error message | |
error_msg = data.get("message", "Unknown error") | |
logger.error(f"Error from server: {error_msg}") | |
except json.JSONDecodeError: | |
logger.warning(f"Received invalid JSON: {message}") | |
except Exception as e: | |
logger.error(f"Error handling message: {e}") | |
def handle_transcription_error(self, error): | |
"""Handle WebSocket errors""" | |
logger.error(f"WebSocket error: {error}") | |
def handle_transcription_close(self): | |
"""Handle WebSocket connection closure""" | |
logger.info("WebSocket connection closed") | |
self.system_status["status"] = "disconnected" | |
def connect_to_transcription(self): | |
"""Connect to transcription WebSocket""" | |
if self.transcription_client and self.transcription_client.connected: | |
return | |
self.transcription_client = TranscriptionClient( | |
url=WS_TRANSCRIPTION_URL, | |
on_message=self.handle_transcription_message, | |
on_error=self.handle_transcription_error, | |
on_close=self.handle_transcription_close | |
) | |
self.transcription_client.connect() | |
def disconnect_from_transcription(self): | |
"""Disconnect from transcription WebSocket""" | |
if self.transcription_client: | |
self.transcription_client.disconnect() | |
self.transcription_client = None | |
def start_listening(self): | |
"""Start listening to audio and connect to services""" | |
self.connect_to_transcription() | |
self.webrtc_state = "started" | |
return { | |
webrtc: gr.update(streaming=True), | |
status_display: gr.update(value=f"Status: Connected and listening"), | |
start_button: gr.update(visible=False), | |
stop_button: gr.update(visible=True), | |
clear_button: gr.update(visible=True) | |
} | |
def stop_listening(self): | |
"""Stop listening to audio and disconnect from services""" | |
self.disconnect_from_transcription() | |
self.webrtc_state = "stopped" | |
return { | |
webrtc: gr.update(streaming=False), | |
status_display: gr.update(value=f"Status: Disconnected"), | |
start_button: gr.update(visible=True), | |
stop_button: gr.update(visible=False), | |
clear_button: gr.update(visible=True) | |
} | |
def clear_conversation(self): | |
"""Clear the conversation display""" | |
# Call API to clear conversation | |
import requests | |
try: | |
response = requests.post(f"https://{HF_SPACE_URL}/clear") | |
if response.status_code == 200: | |
logger.info("Conversation cleared") | |
else: | |
logger.error(f"Failed to clear conversation: {response.status_code}") | |
except Exception as e: | |
logger.error(f"Error clearing conversation: {e}") | |
# Clear local display | |
self.conversation_html = "" | |
return { | |
conversation_display: gr.update(value="<div class='conversation-container'><p>Conversation cleared</p></div>") | |
} | |
def update_display(self): | |
"""Update conversation display - called periodically""" | |
status_text = f"Status: " | |
if self.webrtc_state == "started": | |
status_text += "Connected and listening" | |
else: | |
status_text += "Disconnected" | |
if self.system_status.get("hf_space_status"): | |
status_text += f" | HF Space: {self.system_status['hf_space_status']}" | |
return { | |
conversation_display: gr.update(value=self.conversation_html if self.conversation_html else "<div class='conversation-container'><p>No conversation yet. Start speaking to begin transcription.</p></div>"), | |
status_display: gr.update(value=status_text) | |
} | |
# Create UI instance | |
ui = SpeakerDiarizationUI() | |
# Custom CSS for better styling | |
css = """ | |
.conversation-container { | |
border: 1px solid #ddd; | |
border-radius: 10px; | |
padding: 15px; | |
margin-bottom: 10px; | |
max-height: 500px; | |
overflow-y: auto; | |
background-color: white; | |
} | |
.speaker { | |
margin-bottom: 12px; | |
border-radius: 8px; | |
padding: 8px 12px; | |
} | |
.speaker-label { | |
font-weight: bold; | |
margin-bottom: 5px; | |
} | |
.status-display { | |
margin-top: 10px; | |
padding: 5px 10px; | |
background-color: #f0f0f0; | |
border-radius: 5px; | |
font-size: 0.9rem; | |
} | |
""" | |
# Create Gradio interface as a function to avoid clashing | |
def create_interface(): | |
with gr.Blocks(css=css) as interface: | |
gr.Markdown("# Real-Time Speaker Diarization") | |
gr.Markdown("This app performs real-time speaker diarization on your audio. It automatically transcribes speech and identifies different speakers.") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
conversation_display = gr.HTML("<div class='conversation-container'><p>No conversation yet. Start speaking to begin transcription.</p></div>") | |
with gr.Column(scale=1): | |
status_display = gr.Markdown("Status: Disconnected", elem_classes=["status-display"]) | |
webrtc = RTCComponent(url=RENDER_SIGNALING_URL, streaming=False, modality="audio", mode="send-receive") | |
with gr.Row(): | |
start_button = gr.Button("Start Listening", variant="primary") | |
stop_button = gr.Button("Stop Listening", variant="secondary", visible=False) | |
clear_button = gr.Button("Clear Conversation", visible=True) | |
with gr.Accordion("Advanced Settings", open=False): | |
speaker_threshold = gr.Slider(0.5, 0.9, value=0.65, label="Speaker Change Threshold") | |
max_speakers = gr.Slider(2, 8, value=4, step=1, label="Maximum Number of Speakers") | |
def update_settings(threshold, speakers): | |
import requests | |
try: | |
response = requests.post( | |
f"https://{HF_SPACE_URL}/settings", | |
params={"threshold": threshold, "max_speakers": speakers} | |
) | |
if response.status_code == 200: | |
return gr.update(value=f"Settings updated: Threshold={threshold}, Max Speakers={speakers}") | |
else: | |
return gr.update(value=f"Failed to update settings: {response.status_code}") | |
except Exception as e: | |
return gr.update(value=f"Error updating settings: {e}") | |
settings_button = gr.Button("Update Settings") | |
settings_status = gr.Markdown("", elem_classes=["status-display"]) | |
settings_button.click( | |
update_settings, | |
[speaker_threshold, max_speakers], | |
[settings_status] | |
) | |
# Set up event handlers | |
start_button.click( | |
ui.start_listening, | |
[], | |
[webrtc, status_display, start_button, stop_button, clear_button] | |
) | |
stop_button.click( | |
ui.stop_listening, | |
[], | |
[webrtc, status_display, start_button, stop_button, clear_button] | |
) | |
clear_button.click( | |
ui.clear_conversation, | |
[], | |
[conversation_display] | |
) | |
# Periodic update every 0.5 seconds | |
interface.load( | |
ui.update_display, | |
[], | |
[conversation_display, status_display], | |
every=0.5 | |
) | |
return interface | |
# Global interface instance | |
interface = None | |
# Launch the app | |
if __name__ == "__main__": | |
interface = create_interface() | |
interface.launch() | |
# Add mount_ui function for integration with FastAPI | |
def mount_ui(app): | |
"""Mount the Gradio interface at /ui path of the FastAPI app""" | |
global interface | |
# Create interface if it doesn't exist yet | |
if interface is None: | |
interface = create_interface() | |
# Mount Gradio app at /ui path | |
interface.mount_in_app(app, path="/ui") | |
return app | |