Spaces:
Paused
Paused
| """ | |
| WebSocket Manager for Flare | |
| =========================== | |
| Manages WebSocket connections and message routing | |
| """ | |
| import base64 | |
| import struct | |
| import asyncio | |
| from typing import Dict, Optional, Set | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| import json | |
| from datetime import datetime | |
| import traceback | |
| from event_bus import EventBus, Event, EventType | |
| from utils.logger import log_info, log_error, log_debug, log_warning | |
| class WebSocketConnection: | |
| """Wrapper for WebSocket connection with metadata""" | |
| def __init__(self, websocket: WebSocket, session_id: str): | |
| self.websocket = websocket | |
| self.session_id = session_id | |
| self.connected_at = datetime.utcnow() | |
| self.last_activity = datetime.utcnow() | |
| self.is_active = True | |
| async def send_json(self, data: dict): | |
| """Send JSON data to client""" | |
| try: | |
| if self.is_active: | |
| await self.websocket.send_json(data) | |
| self.last_activity = datetime.utcnow() | |
| except Exception as e: | |
| log_error( | |
| f"❌ Failed to send message", | |
| session_id=self.session_id, | |
| error=str(e) | |
| ) | |
| self.is_active = False | |
| raise | |
| async def receive_json(self) -> dict: | |
| """Receive JSON data from client""" | |
| try: | |
| data = await self.websocket.receive_json() | |
| self.last_activity = datetime.utcnow() | |
| return data | |
| except WebSocketDisconnect: | |
| self.is_active = False | |
| raise | |
| except Exception as e: | |
| log_error( | |
| f"❌ Failed to receive message", | |
| session_id=self.session_id, | |
| error=str(e) | |
| ) | |
| self.is_active = False | |
| raise | |
| async def close(self): | |
| """Close the connection""" | |
| try: | |
| self.is_active = False | |
| await self.websocket.close() | |
| except: | |
| pass | |
| class WebSocketManager: | |
| """Manages WebSocket connections and routing""" | |
| def __init__(self, event_bus: EventBus): | |
| self.event_bus = event_bus | |
| self.connections: Dict[str, WebSocketConnection] = {} | |
| self.message_queues: Dict[str, asyncio.Queue] = {} | |
| self._setup_event_handlers() | |
| def _setup_event_handlers(self): | |
| """Subscribe to events that need to be sent to clients""" | |
| # State events | |
| self.event_bus.subscribe(EventType.STATE_TRANSITION, self._handle_state_transition) | |
| # STT events | |
| self.event_bus.subscribe(EventType.STT_READY, self._handle_stt_ready) | |
| self.event_bus.subscribe(EventType.STT_RESULT, self._handle_stt_result) | |
| # TTS events | |
| self.event_bus.subscribe(EventType.TTS_STARTED, self._handle_tts_started) | |
| self.event_bus.subscribe(EventType.TTS_CHUNK_READY, self._handle_tts_chunk) | |
| self.event_bus.subscribe(EventType.TTS_COMPLETED, self._handle_tts_completed) | |
| # LLM events | |
| self.event_bus.subscribe(EventType.LLM_RESPONSE_READY, self._handle_llm_response) | |
| # Error events | |
| self.event_bus.subscribe(EventType.RECOVERABLE_ERROR, self._handle_error) | |
| self.event_bus.subscribe(EventType.CRITICAL_ERROR, self._handle_error) | |
| async def connect(self, websocket: WebSocket, session_id: str): | |
| """Accept new WebSocket connection""" | |
| await websocket.accept() | |
| # Check for existing connection | |
| if session_id in self.connections: | |
| log_warning( | |
| f"⚠️ Existing connection for session, closing old one", | |
| session_id=session_id | |
| ) | |
| await self.disconnect(session_id) | |
| # Create connection wrapper | |
| connection = WebSocketConnection(websocket, session_id) | |
| self.connections[session_id] = connection | |
| # Create message queue | |
| self.message_queues[session_id] = asyncio.Queue() | |
| log_info( | |
| f"✅ WebSocket connected", | |
| session_id=session_id, | |
| total_connections=len(self.connections) | |
| ) | |
| # Publish connection event | |
| await self.event_bus.publish(Event( | |
| type=EventType.WEBSOCKET_CONNECTED, | |
| session_id=session_id, | |
| data={} | |
| )) | |
| async def disconnect(self, session_id: str): | |
| """Disconnect WebSocket connection""" | |
| connection = self.connections.get(session_id) | |
| if connection: | |
| await connection.close() | |
| del self.connections[session_id] | |
| # Remove message queue | |
| if session_id in self.message_queues: | |
| del self.message_queues[session_id] | |
| log_info( | |
| f"🔌 WebSocket disconnected", | |
| session_id=session_id, | |
| total_connections=len(self.connections) | |
| ) | |
| # Publish disconnection event | |
| await self.event_bus.publish(Event( | |
| type=EventType.WEBSOCKET_DISCONNECTED, | |
| session_id=session_id, | |
| data={} | |
| )) | |
| async def handle_connection(self, websocket: WebSocket, session_id: str): | |
| """Handle WebSocket connection lifecycle""" | |
| try: | |
| # Connect | |
| await self.connect(websocket, session_id) | |
| # Create tasks for bidirectional communication | |
| receive_task = asyncio.create_task(self._receive_messages(session_id)) | |
| send_task = asyncio.create_task(self._send_messages(session_id)) | |
| # Wait for either task to complete | |
| done, pending = await asyncio.wait( | |
| [receive_task, send_task], | |
| return_when=asyncio.FIRST_COMPLETED | |
| ) | |
| # Cancel pending tasks | |
| for task in pending: | |
| task.cancel() | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| pass | |
| except WebSocketDisconnect: | |
| log_info(f"WebSocket disconnected normally", session_id=session_id) | |
| except Exception as e: | |
| log_error( | |
| f"❌ WebSocket error", | |
| session_id=session_id, | |
| error=str(e), | |
| traceback=traceback.format_exc() | |
| ) | |
| # Publish error event | |
| await self.event_bus.publish(Event( | |
| type=EventType.WEBSOCKET_ERROR, | |
| session_id=session_id, | |
| data={ | |
| "error_type": "websocket_error", | |
| "message": str(e) | |
| } | |
| )) | |
| finally: | |
| # Ensure disconnection | |
| await self.disconnect(session_id) | |
| async def _receive_messages(self, session_id: str): | |
| """Receive messages from client""" | |
| connection = self.connections.get(session_id) | |
| if not connection: | |
| return | |
| try: | |
| while connection.is_active: | |
| # Receive message | |
| message = await connection.receive_json() | |
| log_debug( | |
| f"📨 Received message", | |
| session_id=session_id, | |
| message_type=message.get("type") | |
| ) | |
| # Route message based on type | |
| await self._route_client_message(session_id, message) | |
| except WebSocketDisconnect: | |
| log_info(f"Client disconnected", session_id=session_id) | |
| except Exception as e: | |
| log_error( | |
| f"❌ Error receiving messages", | |
| session_id=session_id, | |
| error=str(e) | |
| ) | |
| raise | |
| async def _send_messages(self, session_id: str): | |
| """Send queued messages to client""" | |
| connection = self.connections.get(session_id) | |
| queue = self.message_queues.get(session_id) | |
| if not connection or not queue: | |
| return | |
| try: | |
| while connection.is_active: | |
| # Wait for message with timeout | |
| try: | |
| message = await asyncio.wait_for(queue.get(), timeout=30.0) | |
| # Send to client | |
| await connection.send_json(message) | |
| log_debug( | |
| f"📤 Sent message", | |
| session_id=session_id, | |
| message_type=message.get("type") | |
| ) | |
| except asyncio.TimeoutError: | |
| # Send ping to keep connection alive | |
| await connection.send_json({"type": "ping"}) | |
| except Exception as e: | |
| log_error( | |
| f"❌ Error sending messages", | |
| session_id=session_id, | |
| error=str(e) | |
| ) | |
| raise | |
| async def _route_client_message(self, session_id: str, message: dict): | |
| """Route message from client to appropriate handler""" | |
| message_type = message.get("type") | |
| if message_type == "audio_chunk": | |
| # Audio data from client | |
| audio_data_base64 = message.get("data") | |
| if audio_data_base64: | |
| # Debug için audio analizi | |
| try: | |
| import base64 | |
| import struct | |
| # Base64'ten binary'ye çevir | |
| audio_data = base64.b64decode(audio_data_base64) | |
| # Session için debug counter | |
| if not hasattr(self, 'audio_debug_counters'): | |
| self.audio_debug_counters = {} | |
| if session_id not in self.audio_debug_counters: | |
| self.audio_debug_counters[session_id] = 0 | |
| # İlk 5 chunk için detaylı log | |
| if self.audio_debug_counters[session_id] < 5: | |
| log_info(f"🔊 Audio chunk analysis #{self.audio_debug_counters[session_id]}", | |
| session_id=session_id, | |
| size_bytes=len(audio_data), | |
| base64_size=len(audio_data_base64)) | |
| # İlk 20 byte'ı hex olarak göster | |
| if len(audio_data) >= 20: | |
| log_debug(f" First 20 bytes (hex): {audio_data[:20].hex()}") | |
| # Linear16 (little-endian int16) olarak yorumla | |
| samples = struct.unpack('<10h', audio_data[:20]) | |
| log_debug(f" First 10 samples: {samples}") | |
| log_debug(f" Max amplitude (first 10): {max(abs(s) for s in samples)}") | |
| # Tüm chunk'ı analiz et | |
| total_samples = len(audio_data) // 2 | |
| if total_samples > 0: | |
| all_samples = struct.unpack(f'<{total_samples}h', audio_data[:total_samples*2]) | |
| max_amp = max(abs(s) for s in all_samples) | |
| avg_amp = sum(abs(s) for s in all_samples) / total_samples | |
| # Sessizlik kontrolü | |
| silent = max_amp < 100 # Linear16 için düşük eşik | |
| log_info(f" Audio stats - Max: {max_amp}, Avg: {avg_amp:.1f}, Silent: {silent}") | |
| # Eğer çok sessizse uyar | |
| if max_amp < 50: | |
| log_warning(f"⚠️ Very low audio level detected! Max amplitude: {max_amp}") | |
| self.audio_debug_counters[session_id] += 1 | |
| except Exception as e: | |
| log_error(f"Error analyzing audio chunk: {e}") | |
| # Audio data from client | |
| await self.event_bus.publish(Event( | |
| type=EventType.AUDIO_CHUNK_RECEIVED, | |
| session_id=session_id, | |
| data={ | |
| "audio_data": message.get("data"), | |
| "timestamp": message.get("timestamp") | |
| } | |
| )) | |
| elif message_type == "control": | |
| # Control messages | |
| action = message.get("action") | |
| config = message.get("config", {}) | |
| if action == "start_conversation": | |
| # Yeni action: Mevcut session için conversation başlat | |
| log_info(f"🎤 Starting conversation for session | session_id={session_id}") | |
| await self.event_bus.publish(Event( | |
| type=EventType.CONVERSATION_STARTED, | |
| session_id=session_id, | |
| data={ | |
| "config": config, | |
| "continuous_listening": config.get("continuous_listening", True) | |
| } | |
| )) | |
| # Send confirmation to client | |
| await self.send_message(session_id, { | |
| "type": "conversation_started", | |
| "message": "Conversation started successfully" | |
| }) | |
| elif action == "stop_conversation": | |
| await self.event_bus.publish(Event( | |
| type=EventType.CONVERSATION_ENDED, | |
| session_id=session_id, | |
| data={"reason": "user_request"} | |
| )) | |
| elif action == "start_session": | |
| # Bu artık kullanılmamalı | |
| log_warning(f"⚠️ Deprecated start_session action received | session_id={session_id}") | |
| # Yine de işle ama conversation_started olarak | |
| await self.event_bus.publish(Event( | |
| type=EventType.CONVERSATION_STARTED, | |
| session_id=session_id, | |
| data=config | |
| )) | |
| elif action == "stop_session": | |
| await self.event_bus.publish(Event( | |
| type=EventType.CONVERSATION_ENDED, | |
| session_id=session_id, | |
| data={"reason": "user_request"} | |
| )) | |
| elif action == "end_session": | |
| await self.event_bus.publish(Event( | |
| type=EventType.SESSION_ENDED, | |
| session_id=session_id, | |
| data={"reason": "user_request"} | |
| )) | |
| elif action == "audio_ended": | |
| await self.event_bus.publish(Event( | |
| type=EventType.AUDIO_PLAYBACK_COMPLETED, | |
| session_id=session_id, | |
| data={} | |
| )) | |
| else: | |
| log_warning( | |
| f"⚠️ Unknown control action", | |
| session_id=session_id, | |
| action=action | |
| ) | |
| elif message_type == "ping": | |
| # Respond to ping | |
| await self.send_message(session_id, {"type": "pong"}) | |
| else: | |
| log_warning( | |
| f"⚠️ Unknown message type", | |
| session_id=session_id, | |
| message_type=message_type | |
| ) | |
| async def send_message(self, session_id: str, message: dict): | |
| """Queue message for sending to client""" | |
| queue = self.message_queues.get(session_id) | |
| if queue: | |
| await queue.put(message) | |
| else: | |
| log_warning( | |
| f"⚠️ No queue for session", | |
| session_id=session_id | |
| ) | |
| async def broadcast_to_session(self, session_id: str, message: dict): | |
| """Send message immediately (bypass queue)""" | |
| connection = self.connections.get(session_id) | |
| if connection and connection.is_active: | |
| await connection.send_json(message) | |
| # Event handlers for sending messages to clients | |
| async def _handle_state_transition(self, event: Event): | |
| """Send state transition to client""" | |
| await self.send_message(event.session_id, { | |
| "type": "state_change", | |
| "from": event.data.get("old_state"), | |
| "to": event.data.get("new_state") | |
| }) | |
| async def _handle_stt_ready(self, event: Event): | |
| """Send STT ready signal to client""" | |
| await self.send_message(event.session_id, { | |
| "type": "stt_ready", | |
| "message": "STT is ready to receive audio" | |
| }) | |
| async def _handle_stt_result(self, event: Event): | |
| """Send STT result to client""" | |
| # Her türlü result'ı (interim + final) frontend'e gönder | |
| await self.send_message(event.session_id, { | |
| "type": "transcription", | |
| "text": event.data.get("text", ""), | |
| "is_final": event.data.get("is_final", False), | |
| "confidence": event.data.get("confidence", 0.0) | |
| }) | |
| async def _handle_tts_started(self, event: Event): | |
| """Send assistant message when TTS starts""" | |
| if event.data.get("is_welcome"): | |
| # Send welcome message to client | |
| await self.send_message(event.session_id, { | |
| "type": "assistant_response", | |
| "text": event.data.get("text", ""), | |
| "is_welcome": True | |
| }) | |
| async def _handle_tts_chunk(self, event: Event): | |
| """Send TTS audio chunk to client""" | |
| await self.send_message(event.session_id, { | |
| "type": "tts_audio", | |
| "data": event.data.get("audio_data"), | |
| "chunk_index": event.data.get("chunk_index"), | |
| "total_chunks": event.data.get("total_chunks"), | |
| "is_last": event.data.get("is_last", False), | |
| "mime_type": event.data.get("mime_type", "audio/mpeg") | |
| }) | |
| async def _handle_tts_completed(self, event: Event): | |
| """Notify client that TTS is complete""" | |
| # Client knows from is_last flag in chunks | |
| pass | |
| async def _handle_llm_response(self, event: Event): | |
| """Send LLM response to client""" | |
| await self.send_message(event.session_id, { | |
| "type": "assistant_response", | |
| "text": event.data.get("text", ""), | |
| "is_welcome": event.data.get("is_welcome", False) | |
| }) | |
| async def _handle_error(self, event: Event): | |
| """Send error to client""" | |
| error_type = event.data.get("error_type", "unknown") | |
| message = event.data.get("message", "An error occurred") | |
| await self.send_message(event.session_id, { | |
| "type": "error", | |
| "error_type": error_type, | |
| "message": message, | |
| "details": event.data.get("details", {}) | |
| }) | |
| def get_connection_count(self) -> int: | |
| """Get number of active connections""" | |
| return len(self.connections) | |
| def get_session_connections(self) -> Set[str]: | |
| """Get all active session IDs""" | |
| return set(self.connections.keys()) | |
| async def close_all_connections(self): | |
| """Close all active connections""" | |
| session_ids = list(self.connections.keys()) | |
| for session_id in session_ids: | |
| await self.disconnect(session_id) |