Spaces:
Paused
Paused
| """ | |
| WebSocket Handler for Real-time STT/TTS with Barge-in Support | |
| """ | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| from typing import Dict, Any, Optional | |
| import json | |
| import asyncio | |
| import base64 | |
| from datetime import datetime | |
| from collections import deque | |
| from enum import Enum | |
| import numpy as np | |
| import traceback | |
| from session import Session, session_store | |
| from config_provider import ConfigProvider | |
| from chat_handler import handle_new_message, handle_parameter_followup | |
| from stt_factory import STTFactory | |
| from tts_factory import TTSFactory | |
| from logger import log_info, log_error, log_debug, log_warning | |
| # ========================= CONSTANTS ========================= | |
| # Default values - will be overridden by config | |
| DEFAULT_SILENCE_THRESHOLD_MS = 2000 | |
| DEFAULT_AUDIO_CHUNK_SIZE = 4096 | |
| DEFAULT_ENERGY_THRESHOLD = 0.01 | |
| DEFAULT_AUDIO_BUFFER_MAX_SIZE = 1000 | |
| # ========================= ENUMS ========================= | |
| class ConversationState(Enum): | |
| IDLE = "idle" | |
| LISTENING = "listening" | |
| PROCESSING_STT = "processing_stt" | |
| PROCESSING_LLM = "processing_llm" | |
| PROCESSING_TTS = "processing_tts" | |
| PLAYING_AUDIO = "playing_audio" | |
| # ========================= CLASSES ========================= | |
| class AudioBuffer: | |
| """Thread-safe circular buffer for audio chunks""" | |
| def __init__(self, max_size: int = DEFAULT_AUDIO_BUFFER_MAX_SIZE): | |
| self.buffer = deque(maxlen=max_size) | |
| self.lock = asyncio.Lock() | |
| async def add_chunk(self, chunk_data: str): | |
| """Add base64 encoded audio chunk""" | |
| async with self.lock: | |
| decoded = base64.b64decode(chunk_data) | |
| self.buffer.append(decoded) | |
| async def get_all_audio(self) -> bytes: | |
| """Get all audio data concatenated""" | |
| async with self.lock: | |
| return b''.join(self.buffer) | |
| async def clear(self): | |
| """Clear buffer""" | |
| async with self.lock: | |
| self.buffer.clear() | |
| def size(self) -> int: | |
| """Get current buffer size""" | |
| return len(self.buffer) | |
| class SilenceDetector: | |
| """Detect silence in audio stream""" | |
| def __init__(self, threshold_ms: int = DEFAULT_SILENCE_THRESHOLD_MS, energy_threshold: float = DEFAULT_ENERGY_THRESHOLD): | |
| self.threshold_ms = threshold_ms | |
| self.energy_threshold = energy_threshold | |
| self.silence_start = None | |
| self.sample_rate = 16000 | |
| def update(self, audio_chunk: bytes) -> int: | |
| """Update with new audio chunk and return silence duration in ms""" | |
| if self.is_silence(audio_chunk): | |
| if self.silence_start is None: | |
| self.silence_start = datetime.now() | |
| silence_duration = (datetime.now() - self.silence_start).total_seconds() * 1000 | |
| return int(silence_duration) | |
| else: | |
| self.silence_start = None | |
| return 0 | |
| def is_silence(self, audio_chunk: bytes) -> bool: | |
| """Check if audio chunk is silence""" | |
| try: | |
| # Convert bytes to numpy array (assuming 16-bit PCM) | |
| audio_data = np.frombuffer(audio_chunk, dtype=np.int16) | |
| # Calculate RMS energy | |
| if len(audio_data) == 0: | |
| return True | |
| rms = np.sqrt(np.mean(audio_data.astype(float) ** 2)) | |
| normalized_rms = rms / 32768.0 # Normalize for 16-bit audio | |
| return normalized_rms < self.energy_threshold | |
| except Exception as e: | |
| log_warning(f"Silence detection error: {e}") | |
| return False | |
| def reset(self): | |
| """Reset silence detection""" | |
| self.silence_start = None | |
| class BargeInHandler: | |
| """Handle user interruptions during TTS playback""" | |
| def __init__(self): | |
| self.active_tts_task: Optional[asyncio.Task] = None | |
| self.is_interrupting = False | |
| self.lock = asyncio.Lock() | |
| async def start_tts_task(self, coro): | |
| """Start a cancellable TTS task""" | |
| async with self.lock: | |
| # Cancel any existing task | |
| if self.active_tts_task and not self.active_tts_task.done(): | |
| self.active_tts_task.cancel() | |
| try: | |
| await self.active_tts_task | |
| except asyncio.CancelledError: | |
| pass | |
| # Start new task | |
| self.active_tts_task = asyncio.create_task(coro) | |
| return self.active_tts_task | |
| async def handle_interruption(self, current_state: ConversationState): | |
| """Handle barge-in interruption""" | |
| async with self.lock: | |
| self.is_interrupting = True | |
| # Cancel TTS if active | |
| if self.active_tts_task and not self.active_tts_task.done(): | |
| log_info("Barge-in: Cancelling active TTS") | |
| self.active_tts_task.cancel() | |
| try: | |
| await self.active_tts_task | |
| except asyncio.CancelledError: | |
| pass | |
| # Reset flag after short delay | |
| await asyncio.sleep(0.5) | |
| self.is_interrupting = False | |
| class RealtimeSession: | |
| """Manage a real-time conversation session""" | |
| def __init__(self, session: Session): | |
| self.session = session | |
| self.state = ConversationState.IDLE | |
| # Get settings from config | |
| config = ConfigProvider.get().global_config.stt_provider.settings | |
| # Initialize with config values or defaults | |
| silence_threshold = config.get("speech_timeout_ms", DEFAULT_SILENCE_THRESHOLD_MS) | |
| energy_threshold = config.get("energy_threshold", DEFAULT_ENERGY_THRESHOLD) | |
| buffer_max_size = config.get("audio_buffer_max_size", DEFAULT_AUDIO_BUFFER_MAX_SIZE) | |
| self.audio_buffer = AudioBuffer(max_size=buffer_max_size) | |
| self.silence_detector = SilenceDetector( | |
| threshold_ms=silence_threshold, | |
| energy_threshold=energy_threshold | |
| ) | |
| self.barge_in_handler = BargeInHandler() | |
| self.stt_manager = None | |
| self.current_transcription = "" | |
| self.is_streaming = False | |
| self.lock = asyncio.Lock() | |
| # Store config for later use | |
| self.audio_chunk_size = config.get("audio_chunk_size", DEFAULT_AUDIO_CHUNK_SIZE) | |
| self.silence_threshold_ms = silence_threshold | |
| async def initialize_stt(self): | |
| """Initialize STT provider""" | |
| try: | |
| self.stt_manager = STTFactory.create_provider() | |
| if self.stt_manager: | |
| config = ConfigProvider.get().global_config.stt_provider.settings | |
| await self.stt_manager.start_streaming({ | |
| "language": config.get("language", "tr-TR"), | |
| "interim_results": config.get("interim_results", True), | |
| "single_utterance": False, | |
| "enable_punctuation": config.get("enable_punctuation", True), | |
| "sample_rate": 16000, | |
| "encoding": "WEBM_OPUS" | |
| }) | |
| log_info("STT manager initialized", session_id=self.session.session_id) | |
| return True | |
| except Exception as e: | |
| log_error(f"Failed to initialize STT", error=str(e), session_id=self.session.session_id) | |
| return False | |
| async def change_state(self, new_state: ConversationState): | |
| """Change conversation state""" | |
| async with self.lock: | |
| old_state = self.state | |
| self.state = new_state | |
| log_debug( | |
| f"State change: {old_state.value} → {new_state.value}", | |
| session_id=self.session.session_id | |
| ) | |
| async def handle_barge_in(self): | |
| """Handle user interruption""" | |
| await self.barge_in_handler.handle_interruption(self.state) | |
| await self.change_state(ConversationState.LISTENING) | |
| async def reset_for_new_utterance(self): | |
| """Reset for new user utterance""" | |
| await self.audio_buffer.clear() | |
| self.silence_detector.reset() | |
| self.current_transcription = "" | |
| async def cleanup(self): | |
| """Clean up resources""" | |
| try: | |
| if self.stt_manager: | |
| await self.stt_manager.stop_streaming() | |
| log_info(f"Cleaned up realtime session", session_id=self.session.session_id) | |
| except Exception as e: | |
| log_warning(f"Cleanup error", error=str(e), session_id=self.session.session_id) | |
| # ========================= MAIN HANDLER ========================= | |
| async def websocket_endpoint(websocket: WebSocket, session_id: str): | |
| """Main WebSocket endpoint for real-time conversation""" | |
| await websocket.accept() | |
| log_info(f"WebSocket connected", session_id=session_id) | |
| # Get session | |
| session = session_store.get_session(session_id) | |
| if not session: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "Session not found" | |
| }) | |
| await websocket.close() | |
| return | |
| # Mark as realtime session | |
| session.is_realtime_session = True | |
| session_store.update_session(session) | |
| # Initialize conversation | |
| realtime_session = RealtimeSession(session) | |
| # Initialize STT | |
| stt_initialized = await realtime_session.initialize_stt() | |
| if not stt_initialized: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "STT initialization failed" | |
| }) | |
| try: | |
| while True: | |
| # Receive message | |
| message = await websocket.receive_json() | |
| message_type = message.get("type") | |
| if message_type == "audio_chunk": | |
| await handle_audio_chunk(websocket, realtime_session, message) | |
| elif message_type == "control": | |
| await handle_control_message(websocket, realtime_session, message) | |
| elif message_type == "ping": | |
| # Keep-alive ping | |
| await websocket.send_json({"type": "pong"}) | |
| except WebSocketDisconnect: | |
| log_info(f"WebSocket disconnected", session_id=session_id) | |
| except Exception as e: | |
| log_error( | |
| f"WebSocket error", | |
| error=str(e), | |
| traceback=traceback.format_exc(), | |
| session_id=session_id | |
| ) | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": str(e) | |
| }) | |
| finally: | |
| await realtime_session.cleanup() | |
| # ========================= MESSAGE HANDLERS ========================= | |
| async def handle_audio_chunk(websocket: WebSocket, session: RealtimeSession, message: Dict[str, Any]): | |
| """Handle incoming audio chunk with barge-in support""" | |
| try: | |
| audio_data = message.get("data") | |
| if not audio_data: | |
| return | |
| # Check for barge-in during TTS/audio playback | |
| if session.state in [ConversationState.PLAYING_AUDIO, ConversationState.PROCESSING_TTS]: | |
| await session.handle_barge_in() | |
| await websocket.send_json({ | |
| "type": "control", | |
| "action": "stop_playback" | |
| }) | |
| log_info(f"Barge-in detected", session_id=session.session.session_id, state=session.state.value) | |
| # Change state to listening if idle | |
| if session.state == ConversationState.IDLE: | |
| await session.change_state(ConversationState.LISTENING) | |
| await websocket.send_json({ | |
| "type": "state_change", | |
| "from": "idle", | |
| "to": "listening" | |
| }) | |
| # Add to buffer - don't lose any audio | |
| await session.audio_buffer.add_chunk(audio_data) | |
| # Decode for processing | |
| decoded_audio = base64.b64decode(audio_data) | |
| # Check silence | |
| silence_duration = session.silence_detector.update(decoded_audio) | |
| # Stream to STT if available | |
| if session.stt_manager and session.state == ConversationState.LISTENING: | |
| async for result in session.stt_manager.stream_audio(decoded_audio): | |
| # Send transcription updates | |
| await websocket.send_json({ | |
| "type": "transcription", | |
| "text": result.text, | |
| "is_final": result.is_final, | |
| "confidence": result.confidence | |
| }) | |
| if result.is_final: | |
| session.current_transcription = result.text | |
| # Process if silence detected and we have transcription | |
| if silence_duration > session.silence_threshold_ms and session.current_transcription: | |
| log_info( | |
| f"User stopped speaking", | |
| session_id=session.session.session_id, | |
| silence_ms=silence_duration, | |
| text=session.current_transcription | |
| ) | |
| await process_user_input(websocket, session) | |
| except Exception as e: | |
| log_error( | |
| f"Audio chunk handling error", | |
| error=str(e), | |
| traceback=traceback.format_exc(), | |
| session_id=session.session.session_id | |
| ) | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Audio processing error: {str(e)}" | |
| }) | |
| async def handle_control_message(websocket: WebSocket, session: RealtimeSession, message: Dict[str, Any]): | |
| """Handle control messages""" | |
| action = message.get("action") | |
| config = message.get("config", {}) | |
| log_debug(f"Control message", action=action, session_id=session.session.session_id) | |
| if action == "start_session": | |
| # Session configuration | |
| await websocket.send_json({ | |
| "type": "session_started", | |
| "session_id": session.session.session_id, | |
| "config": { | |
| "silence_threshold_ms": session.silence_threshold_ms, | |
| "audio_chunk_size": session.audio_chunk_size, | |
| "supports_barge_in": True | |
| } | |
| }) | |
| elif action == "end_session": | |
| # Clean up and close | |
| await session.cleanup() | |
| await websocket.close() | |
| elif action == "interrupt": | |
| # Handle explicit interrupt | |
| await session.handle_barge_in() | |
| await websocket.send_json({ | |
| "type": "control", | |
| "action": "interrupt_acknowledged" | |
| }) | |
| elif action == "reset": | |
| # Reset conversation state | |
| await session.reset_for_new_utterance() | |
| await session.change_state(ConversationState.IDLE) | |
| await websocket.send_json({ | |
| "type": "state_change", | |
| "from": session.state.value, | |
| "to": "idle" | |
| }) | |
| elif action == "audio_ended": | |
| # Audio playback ended on client | |
| if session.state == ConversationState.PLAYING_AUDIO: | |
| await session.change_state(ConversationState.IDLE) | |
| await websocket.send_json({ | |
| "type": "state_change", | |
| "from": "playing_audio", | |
| "to": "idle" | |
| }) | |
| # ========================= PROCESSING FUNCTIONS ========================= | |
| async def process_user_input(websocket: WebSocket, session: RealtimeSession): | |
| """Process complete user input""" | |
| try: | |
| user_text = session.current_transcription | |
| if not user_text: | |
| await session.reset_for_new_utterance() | |
| await session.change_state(ConversationState.IDLE) | |
| return | |
| log_info(f"Processing user input", text=user_text, session_id=session.session.session_id) | |
| # State: STT Processing | |
| await session.change_state(ConversationState.PROCESSING_STT) | |
| await websocket.send_json({ | |
| "type": "state_change", | |
| "from": "listening", | |
| "to": "processing_stt" | |
| }) | |
| # Send final transcription | |
| await websocket.send_json({ | |
| "type": "transcription", | |
| "text": user_text, | |
| "is_final": True, | |
| "confidence": 0.95 | |
| }) | |
| # State: LLM Processing | |
| await session.change_state(ConversationState.PROCESSING_LLM) | |
| await websocket.send_json({ | |
| "type": "state_change", | |
| "from": "processing_stt", | |
| "to": "processing_llm" | |
| }) | |
| # Add to chat history | |
| session.session.add_message("user", user_text) | |
| # Get LLM response based on session state | |
| if session.session.state == "collect_params": | |
| response_text = await handle_parameter_followup(session.session, user_text) | |
| else: | |
| response_text = await handle_new_message(session.session, user_text) | |
| # Add response to history | |
| session.session.add_message("assistant", response_text) | |
| # Send text response | |
| await websocket.send_json({ | |
| "type": "assistant_response", | |
| "text": response_text | |
| }) | |
| # Generate TTS if enabled | |
| tts_provider = TTSFactory.create_provider() | |
| if tts_provider: | |
| await session.change_state(ConversationState.PROCESSING_TTS) | |
| await websocket.send_json({ | |
| "type": "state_change", | |
| "from": "processing_llm", | |
| "to": "processing_tts" | |
| }) | |
| # Generate TTS with barge-in support | |
| tts_task = session.barge_in_handler.start_tts_task( | |
| generate_and_stream_tts(websocket, session, tts_provider, response_text) | |
| ) | |
| try: | |
| await tts_task | |
| except asyncio.CancelledError: | |
| log_info("TTS cancelled due to barge-in", session_id=session.session.session_id) | |
| else: | |
| # No TTS, go back to idle | |
| await session.change_state(ConversationState.IDLE) | |
| await websocket.send_json({ | |
| "type": "state_change", | |
| "from": "processing_llm", | |
| "to": "idle" | |
| }) | |
| # Reset for next input | |
| await session.reset_for_new_utterance() | |
| except Exception as e: | |
| log_error( | |
| f"Error processing user input", | |
| error=str(e), | |
| traceback=traceback.format_exc(), | |
| session_id=session.session.session_id | |
| ) | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Processing error: {str(e)}" | |
| }) | |
| await session.reset_for_new_utterance() | |
| await session.change_state(ConversationState.IDLE) | |
| async def generate_and_stream_tts( | |
| websocket: WebSocket, | |
| session: RealtimeSession, | |
| tts_provider, | |
| text: str | |
| ): | |
| """Generate and stream TTS audio with cancellation support""" | |
| try: | |
| # Generate audio | |
| audio_data = await tts_provider.synthesize(text) | |
| # Change state to playing | |
| await session.change_state(ConversationState.PLAYING_AUDIO) | |
| await websocket.send_json({ | |
| "type": "state_change", | |
| "from": "processing_tts", | |
| "to": "playing_audio" | |
| }) | |
| # Stream audio in chunks | |
| chunk_size = session.audio_chunk_size | |
| total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size | |
| for i in range(0, len(audio_data), chunk_size): | |
| # Check for cancellation | |
| if asyncio.current_task().cancelled(): | |
| break | |
| chunk = audio_data[i:i + chunk_size] | |
| chunk_index = i // chunk_size | |
| await websocket.send_json({ | |
| "type": "tts_audio", | |
| "data": base64.b64encode(chunk).decode('utf-8'), | |
| "chunk_index": chunk_index, | |
| "total_chunks": total_chunks, | |
| "is_last": chunk_index == total_chunks - 1 | |
| }) | |
| # Small delay to prevent overwhelming the client | |
| await asyncio.sleep(0.01) | |
| log_info( | |
| f"TTS streaming completed", | |
| session_id=session.session.session_id, | |
| text_length=len(text), | |
| audio_size=len(audio_data) | |
| ) | |
| except asyncio.CancelledError: | |
| log_info("TTS streaming cancelled", session_id=session.session.session_id) | |
| raise | |
| except Exception as e: | |
| log_error( | |
| f"TTS generation error", | |
| error=str(e), | |
| session_id=session.session.session_id | |
| ) | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"TTS error: {str(e)}" | |
| }) |