Spaces:
Sleeping
Sleeping
import gradio as gr | |
import asyncio | |
import websockets | |
import json | |
import logging | |
import time | |
from typing import Dict, Any, Optional | |
import threading | |
from queue import Queue | |
import base64 | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class TranscriptionInterface: | |
"""Interface for real-time transcription and speaker diarization""" | |
def __init__(self): | |
self.connected_clients = set() | |
self.message_queue = Queue() | |
self.is_running = False | |
self.websocket_server = None | |
self.current_transcript = "" | |
self.conversation_history = [] | |
async def handle_client(self, websocket, path): | |
"""Handle WebSocket client connections""" | |
client_id = f"client_{int(time.time())}" | |
self.connected_clients.add(websocket) | |
logger.info(f"Client connected: {client_id}. Total clients: {len(self.connected_clients)}") | |
try: | |
# Send connection confirmation | |
await websocket.send(json.dumps({ | |
"type": "connection", | |
"status": "connected", | |
"timestamp": time.time(), | |
"client_id": client_id | |
})) | |
async for message in websocket: | |
try: | |
if isinstance(message, bytes): | |
# Handle binary audio data | |
await self.process_audio_chunk(message, websocket) | |
else: | |
# Handle text messages | |
data = json.loads(message) | |
await self.handle_message(data, websocket) | |
except json.JSONDecodeError: | |
logger.warning(f"Invalid JSON received from client: {message}") | |
except Exception as e: | |
logger.error(f"Error processing message: {e}") | |
except websockets.exceptions.ConnectionClosed: | |
logger.info(f"Client {client_id} disconnected") | |
except Exception as e: | |
logger.error(f"Client handler error: {e}") | |
finally: | |
self.connected_clients.discard(websocket) | |
logger.info(f"Client removed. Remaining clients: {len(self.connected_clients)}") | |
async def process_audio_chunk(self, audio_data: bytes, websocket): | |
"""Process incoming audio data""" | |
try: | |
# Import inference functions (assuming they exist in your setup) | |
from inference import process_audio_for_transcription | |
# Process the audio chunk | |
result = await process_audio_for_transcription(audio_data) | |
if result: | |
# Broadcast result to all clients | |
await self.broadcast_result({ | |
"type": "processing_result", | |
"timestamp": time.time(), | |
"data": result | |
}) | |
# Update conversation history | |
if "transcription" in result: | |
self.update_conversation(result) | |
except ImportError: | |
logger.warning("Inference module not found - audio processing disabled") | |
except Exception as e: | |
logger.error(f"Error processing audio chunk: {e}") | |
await websocket.send(json.dumps({ | |
"type": "error", | |
"message": f"Audio processing error: {str(e)}", | |
"timestamp": time.time() | |
})) | |
async def handle_message(self, data: Dict[str, Any], websocket): | |
"""Handle non-audio messages from clients""" | |
message_type = data.get("type", "unknown") | |
if message_type == "config": | |
# Handle configuration updates | |
logger.info(f"Configuration update: {data}") | |
elif message_type == "request_history": | |
# Send conversation history to client | |
await websocket.send(json.dumps({ | |
"type": "conversation_history", | |
"data": self.conversation_history, | |
"timestamp": time.time() | |
})) | |
elif message_type == "clear_history": | |
# Clear conversation history | |
self.conversation_history = [] | |
self.current_transcript = "" | |
await self.broadcast_result({ | |
"type": "conversation_update", | |
"action": "cleared", | |
"timestamp": time.time() | |
}) | |
else: | |
logger.warning(f"Unknown message type: {message_type}") | |
async def broadcast_result(self, result: Dict[str, Any]): | |
"""Broadcast results to all connected clients""" | |
if not self.connected_clients: | |
return | |
message = json.dumps(result) | |
disconnected = set() | |
for client in self.connected_clients.copy(): | |
try: | |
await client.send(message) | |
except Exception as e: | |
logger.warning(f"Failed to send to client: {e}") | |
disconnected.add(client) | |
# Clean up disconnected clients | |
for client in disconnected: | |
self.connected_clients.discard(client) | |
def update_conversation(self, result: Dict[str, Any]): | |
"""Update conversation history with new transcription results""" | |
if "transcription" in result: | |
transcript_data = { | |
"timestamp": time.time(), | |
"text": result["transcription"], | |
"speaker": result.get("speaker", "Unknown"), | |
"confidence": result.get("confidence", 0.0) | |
} | |
self.conversation_history.append(transcript_data) | |
# Keep only last 100 entries to prevent memory issues | |
if len(self.conversation_history) > 100: | |
self.conversation_history = self.conversation_history[-100:] | |
async def start_websocket_server(self, host="0.0.0.0", port=7860): | |
"""Start the WebSocket server""" | |
try: | |
self.websocket_server = await websockets.serve( | |
self.handle_client, | |
host, | |
port, | |
path="/ws_inference" | |
) | |
self.is_running = True | |
logger.info(f"WebSocket server started on {host}:{port}") | |
# Keep server running | |
await self.websocket_server.wait_closed() | |
except Exception as e: | |
logger.error(f"WebSocket server error: {e}") | |
self.is_running = False | |
def get_status(self): | |
"""Get current status information""" | |
return { | |
"connected_clients": len(self.connected_clients), | |
"is_running": self.is_running, | |
"conversation_entries": len(self.conversation_history), | |
"last_activity": time.time() | |
} | |
# Initialize the transcription interface | |
transcription_interface = TranscriptionInterface() | |
def create_gradio_interface(): | |
"""Create the Gradio interface""" | |
def get_server_status(): | |
"""Get server status for display""" | |
status = transcription_interface.get_status() | |
return f""" | |
**Server Status:** | |
- WebSocket Server: {'Running' if status['is_running'] else 'Stopped'} | |
- Connected Clients: {status['connected_clients']} | |
- Conversation Entries: {status['conversation_entries']} | |
- Last Activity: {time.ctime(status['last_activity'])} | |
""" | |
def get_conversation_history(): | |
"""Get formatted conversation history""" | |
if not transcription_interface.conversation_history: | |
return "No conversation history available." | |
formatted_history = [] | |
for entry in transcription_interface.conversation_history[-10:]: # Show last 10 entries | |
timestamp = time.ctime(entry['timestamp']) | |
speaker = entry.get('speaker', 'Unknown') | |
text = entry.get('text', '') | |
confidence = entry.get('confidence', 0.0) | |
formatted_history.append(f"**[{timestamp}] {speaker}** (confidence: {confidence:.2f})\n{text}\n") | |
return "\n".join(formatted_history) | |
def clear_conversation(): | |
"""Clear conversation history""" | |
transcription_interface.conversation_history = [] | |
transcription_interface.current_transcript = "" | |
return "Conversation history cleared." | |
# Create Gradio interface | |
with gr.Blocks(title="Real-time Audio Transcription & Speaker Diarization") as demo: | |
gr.Markdown("# Real-time Audio Transcription & Speaker Diarization") | |
gr.Markdown("This Hugging Face Space provides WebSocket endpoints for real-time audio processing.") | |
with gr.Tab("Server Status"): | |
status_display = gr.Markdown(get_server_status()) | |
refresh_btn = gr.Button("Refresh Status") | |
refresh_btn.click(get_server_status, outputs=status_display) | |
with gr.Tab("Live Transcription"): | |
gr.Markdown("### Live Conversation") | |
conversation_display = gr.Markdown(get_conversation_history()) | |
with gr.Row(): | |
refresh_conv_btn = gr.Button("Refresh Conversation") | |
clear_conv_btn = gr.Button("Clear History", variant="secondary") | |
refresh_conv_btn.click(get_conversation_history, outputs=conversation_display) | |
clear_conv_btn.click(clear_conversation, outputs=conversation_display) | |
with gr.Tab("WebSocket Info"): | |
gr.Markdown(""" | |
### WebSocket Endpoint | |
Connect to this Space's WebSocket endpoint for real-time audio processing: | |
**WebSocket URL:** `wss://your-space-name.hf.space/ws_inference` | |
### Message Format | |
**Audio Data:** Send raw audio bytes directly to the WebSocket | |
**Text Messages:** JSON format | |
```json | |
{ | |
"type": "config", | |
"settings": { | |
"language": "en", | |
"enable_diarization": true | |
} | |
} | |
``` | |
### Response Format | |
```json | |
{ | |
"type": "processing_result", | |
"timestamp": 1234567890.123, | |
"data": { | |
"transcription": "Hello world", | |
"speaker": "Speaker_1", | |
"confidence": 0.95 | |
} | |
} | |
``` | |
""") | |
with gr.Tab("API Documentation"): | |
gr.Markdown(""" | |
### Available Endpoints | |
- **WebSocket:** `/ws_inference` - Main endpoint for real-time audio processing | |
- **HTTP:** `/health` - Check server health status | |
- **HTTP:** `/stats` - Get detailed statistics | |
### Integration Example | |
```javascript | |
const ws = new WebSocket('wss://your-space-name.hf.space/ws_inference'); | |
ws.onopen = function() { | |
console.log('Connected to transcription service'); | |
}; | |
ws.onmessage = function(event) { | |
const data = JSON.parse(event.data); | |
if (data.type === 'processing_result') { | |
console.log('Transcription:', data.data.transcription); | |
console.log('Speaker:', data.data.speaker); | |
} | |
}; | |
// Send audio data | |
ws.send(audioBuffer); | |
``` | |
""") | |
return demo | |
def run_websocket_server(): | |
"""Run the WebSocket server in a separate thread""" | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
try: | |
loop.run_until_complete(transcription_interface.start_websocket_server()) | |
except Exception as e: | |
logger.error(f"WebSocket server thread error: {e}") | |
finally: | |
loop.close() | |
# Start WebSocket server in background thread | |
websocket_thread = threading.Thread(target=run_websocket_server, daemon=True) | |
websocket_thread.start() | |
# Create and launch Gradio interface | |
if __name__ == "__main__": | |
demo = create_gradio_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_error=True | |
) |