Saiyaswanth007's picture
Experiment 4
ed08f62
raw
history blame
12.5 kB
import gradio as gr
import json
import time
import os
import asyncio
import websockets
import logging
from fastrtc import RTCComponent
import threading
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
RENDER_SIGNALING_URL = os.getenv("RENDER_SIGNALING_URL", "wss://render-signal-audio.onrender.com/stream")
HF_SPACE_URL = os.getenv("HF_SPACE_URL", "androidguy-speaker-diarization.hf.space")
WS_TRANSCRIPTION_URL = f"wss://{HF_SPACE_URL}/ws_transcription"
class TranscriptionClient:
"""Client to handle WebSocket connection to transcription service"""
def __init__(self, url, on_message, on_error=None, on_close=None):
self.url = url
self.on_message = on_message
self.on_error = on_error or (lambda e: logger.error(f"WebSocket error: {e}"))
self.on_close = on_close or (lambda: logger.info("WebSocket closed"))
self.ws = None
self.running = False
self.connected = False
self.reconnect_task = None
self.thread = None
async def connect_async(self):
"""Connect to WebSocket server asynchronously"""
try:
self.ws = await websockets.connect(self.url)
self.connected = True
logger.info(f"Connected to {self.url}")
# Start listening for messages
while self.running:
try:
message = await self.ws.recv()
self.on_message(message)
except websockets.exceptions.ConnectionClosed:
logger.warning("Connection closed")
self.connected = False
break
except Exception as e:
self.on_error(e)
break
# Handle connection closed
self.connected = False
self.on_close()
except Exception as e:
logger.error(f"Connection error: {e}")
self.connected = False
self.on_error(e)
def connect(self):
"""Start connection in a separate thread"""
if self.running:
return
self.running = True
def run_async_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.connect_async())
loop.close()
self.thread = threading.Thread(target=run_async_loop, daemon=True)
self.thread.start()
def disconnect(self):
"""Disconnect from WebSocket server"""
self.running = False
if self.ws:
asyncio.run(self.ws.close())
self.ws = None
self.connected = False
class SpeakerDiarizationUI:
"""Main UI for speaker diarization"""
def __init__(self):
self.transcription_client = None
self.conversation_html = ""
self.system_status = {"status": "disconnected"}
self.webrtc_state = "stopped"
def handle_transcription_message(self, message):
"""Handle incoming transcription messages"""
try:
data = json.loads(message)
message_type = data.get("type", "unknown")
if message_type == "conversation_update":
# Update the conversation display
self.conversation_html = data.get("conversation_html", "")
# Update system status if available
if "status" in data:
self.system_status = data.get("status", {})
elif message_type == "connection":
# Handle connection status update
status = data.get("status", "unknown")
logger.info(f"Connection status: {status}")
if "hf_space_status" in data:
self.system_status["hf_space_status"] = data["hf_space_status"]
elif message_type == "error":
# Handle error message
error_msg = data.get("message", "Unknown error")
logger.error(f"Error from server: {error_msg}")
except json.JSONDecodeError:
logger.warning(f"Received invalid JSON: {message}")
except Exception as e:
logger.error(f"Error handling message: {e}")
def handle_transcription_error(self, error):
"""Handle WebSocket errors"""
logger.error(f"WebSocket error: {error}")
def handle_transcription_close(self):
"""Handle WebSocket connection closure"""
logger.info("WebSocket connection closed")
self.system_status["status"] = "disconnected"
def connect_to_transcription(self):
"""Connect to transcription WebSocket"""
if self.transcription_client and self.transcription_client.connected:
return
self.transcription_client = TranscriptionClient(
url=WS_TRANSCRIPTION_URL,
on_message=self.handle_transcription_message,
on_error=self.handle_transcription_error,
on_close=self.handle_transcription_close
)
self.transcription_client.connect()
def disconnect_from_transcription(self):
"""Disconnect from transcription WebSocket"""
if self.transcription_client:
self.transcription_client.disconnect()
self.transcription_client = None
def start_listening(self):
"""Start listening to audio and connect to services"""
self.connect_to_transcription()
self.webrtc_state = "started"
return {
webrtc: gr.update(streaming=True),
status_display: gr.update(value=f"Status: Connected and listening"),
start_button: gr.update(visible=False),
stop_button: gr.update(visible=True),
clear_button: gr.update(visible=True)
}
def stop_listening(self):
"""Stop listening to audio and disconnect from services"""
self.disconnect_from_transcription()
self.webrtc_state = "stopped"
return {
webrtc: gr.update(streaming=False),
status_display: gr.update(value=f"Status: Disconnected"),
start_button: gr.update(visible=True),
stop_button: gr.update(visible=False),
clear_button: gr.update(visible=True)
}
def clear_conversation(self):
"""Clear the conversation display"""
# Call API to clear conversation
import requests
try:
response = requests.post(f"https://{HF_SPACE_URL}/clear")
if response.status_code == 200:
logger.info("Conversation cleared")
else:
logger.error(f"Failed to clear conversation: {response.status_code}")
except Exception as e:
logger.error(f"Error clearing conversation: {e}")
# Clear local display
self.conversation_html = ""
return {
conversation_display: gr.update(value="<div class='conversation-container'><p>Conversation cleared</p></div>")
}
def update_display(self):
"""Update conversation display - called periodically"""
status_text = f"Status: "
if self.webrtc_state == "started":
status_text += "Connected and listening"
else:
status_text += "Disconnected"
if self.system_status.get("hf_space_status"):
status_text += f" | HF Space: {self.system_status['hf_space_status']}"
return {
conversation_display: gr.update(value=self.conversation_html if self.conversation_html else "<div class='conversation-container'><p>No conversation yet. Start speaking to begin transcription.</p></div>"),
status_display: gr.update(value=status_text)
}
# Create UI instance
ui = SpeakerDiarizationUI()
# Custom CSS for better styling
css = """
.conversation-container {
border: 1px solid #ddd;
border-radius: 10px;
padding: 15px;
margin-bottom: 10px;
max-height: 500px;
overflow-y: auto;
background-color: white;
}
.speaker {
margin-bottom: 12px;
border-radius: 8px;
padding: 8px 12px;
}
.speaker-label {
font-weight: bold;
margin-bottom: 5px;
}
.status-display {
margin-top: 10px;
padding: 5px 10px;
background-color: #f0f0f0;
border-radius: 5px;
font-size: 0.9rem;
}
"""
# Create Gradio interface as a function to avoid clashing
def create_interface():
with gr.Blocks(css=css) as interface:
gr.Markdown("# Real-Time Speaker Diarization")
gr.Markdown("This app performs real-time speaker diarization on your audio. It automatically transcribes speech and identifies different speakers.")
with gr.Row():
with gr.Column(scale=2):
conversation_display = gr.HTML("<div class='conversation-container'><p>No conversation yet. Start speaking to begin transcription.</p></div>")
with gr.Column(scale=1):
status_display = gr.Markdown("Status: Disconnected", elem_classes=["status-display"])
webrtc = RTCComponent(url=RENDER_SIGNALING_URL, streaming=False, modality="audio", mode="send-receive")
with gr.Row():
start_button = gr.Button("Start Listening", variant="primary")
stop_button = gr.Button("Stop Listening", variant="secondary", visible=False)
clear_button = gr.Button("Clear Conversation", visible=True)
with gr.Accordion("Advanced Settings", open=False):
speaker_threshold = gr.Slider(0.5, 0.9, value=0.65, label="Speaker Change Threshold")
max_speakers = gr.Slider(2, 8, value=4, step=1, label="Maximum Number of Speakers")
def update_settings(threshold, speakers):
import requests
try:
response = requests.post(
f"https://{HF_SPACE_URL}/settings",
params={"threshold": threshold, "max_speakers": speakers}
)
if response.status_code == 200:
return gr.update(value=f"Settings updated: Threshold={threshold}, Max Speakers={speakers}")
else:
return gr.update(value=f"Failed to update settings: {response.status_code}")
except Exception as e:
return gr.update(value=f"Error updating settings: {e}")
settings_button = gr.Button("Update Settings")
settings_status = gr.Markdown("", elem_classes=["status-display"])
settings_button.click(
update_settings,
[speaker_threshold, max_speakers],
[settings_status]
)
# Set up event handlers
start_button.click(
ui.start_listening,
[],
[webrtc, status_display, start_button, stop_button, clear_button]
)
stop_button.click(
ui.stop_listening,
[],
[webrtc, status_display, start_button, stop_button, clear_button]
)
clear_button.click(
ui.clear_conversation,
[],
[conversation_display]
)
# Periodic update every 0.5 seconds
interface.load(
ui.update_display,
[],
[conversation_display, status_display],
every=0.5
)
return interface
# Global interface instance
interface = None
# Launch the app
if __name__ == "__main__":
interface = create_interface()
interface.launch()
# Add mount_ui function for integration with FastAPI
def mount_ui(app):
"""Mount the Gradio interface at /ui path of the FastAPI app"""
global interface
# Create interface if it doesn't exist yet
if interface is None:
interface = create_interface()
# Mount Gradio app at /ui path
interface.mount_in_app(app, path="/ui")
return app