Spaces:
Paused
Paused
| """ | |
| Event Bus Implementation for Flare | |
| ================================== | |
| Provides async event publishing and subscription mechanism | |
| """ | |
| import asyncio | |
| from typing import Dict, List, Callable, Any, Optional | |
| from enum import Enum | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| import traceback | |
| from collections import defaultdict | |
| import sys | |
| from utils.logger import log_info, log_error, log_debug, log_warning | |
| class EventType(Enum): | |
| """All event types in the system""" | |
| # Lifecycle events | |
| SESSION_STARTED = "session_started" | |
| SESSION_ENDED = "session_ended" | |
| # STT events | |
| STT_STARTED = "stt_started" | |
| STT_STOPPED = "stt_stopped" | |
| STT_RESULT = "stt_result" | |
| STT_ERROR = "stt_error" | |
| STT_READY = "stt_ready" | |
| # TTS events | |
| TTS_STARTED = "tts_started" | |
| TTS_CHUNK_READY = "tts_chunk_ready" | |
| TTS_COMPLETED = "tts_completed" | |
| TTS_ERROR = "tts_error" | |
| # Audio events | |
| AUDIO_PLAYBACK_STARTED = "audio_playback_started" | |
| AUDIO_PLAYBACK_COMPLETED = "audio_playback_completed" | |
| AUDIO_BUFFER_LOW = "audio_buffer_low" | |
| AUDIO_CHUNK_RECEIVED = "audio_chunk_received" | |
| # LLM events | |
| LLM_PROCESSING_STARTED = "llm_processing_started" | |
| LLM_RESPONSE_READY = "llm_response_ready" | |
| LLM_ERROR = "llm_error" | |
| # Error events | |
| CRITICAL_ERROR = "critical_error" | |
| RECOVERABLE_ERROR = "recoverable_error" | |
| # State events | |
| STATE_TRANSITION = "state_transition" | |
| STATE_ROLLBACK = "state_rollback" | |
| # WebSocket events | |
| WEBSOCKET_CONNECTED = "websocket_connected" | |
| WEBSOCKET_DISCONNECTED = "websocket_disconnected" | |
| WEBSOCKET_MESSAGE = "websocket_message" | |
| WEBSOCKET_ERROR = "websocket_error" | |
| class Event: | |
| """Event data structure""" | |
| type: EventType | |
| session_id: str | |
| data: Dict[str, Any] | |
| timestamp: datetime = None | |
| priority: int = 0 # Higher priority = processed first | |
| def __post_init__(self): | |
| if self.timestamp is None: | |
| self.timestamp = datetime.utcnow() | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert to dictionary for serialization""" | |
| return { | |
| "type": self.type.value, | |
| "session_id": self.session_id, | |
| "data": self.data, | |
| "timestamp": self.timestamp.isoformat(), | |
| "priority": self.priority | |
| } | |
| class EventBus: | |
| """Central event bus for component communication with session isolation""" | |
| def __init__(self): | |
| self._subscribers: Dict[EventType, List[Callable]] = defaultdict(list) | |
| self._session_handlers: Dict[str, Dict[EventType, List[Callable]]] = defaultdict(lambda: defaultdict(list)) | |
| # Session-specific queues for parallel processing | |
| self._session_queues: Dict[str, asyncio.PriorityQueue] = {} | |
| self._session_processors: Dict[str, asyncio.Task] = {} | |
| # Global queue for non-session events | |
| self._global_queue: asyncio.PriorityQueue = asyncio.PriorityQueue() | |
| self._global_processor: Optional[asyncio.Task] = None | |
| self._running = False | |
| self._event_history: List[Event] = [] | |
| self._max_history_size = 1000 | |
| async def start(self): | |
| """Start the event processor""" | |
| if self._running: | |
| log_warning("EventBus already running") | |
| return | |
| self._running = True | |
| # Start global processor | |
| self._global_processor = asyncio.create_task(self._process_global_events()) | |
| log_info("β EventBus started") | |
| async def stop(self): | |
| """Stop the event processor""" | |
| self._running = False | |
| # Stop all session processors | |
| for session_id, task in list(self._session_processors.items()): | |
| task.cancel() | |
| try: | |
| await asyncio.wait_for(task, timeout=2.0) | |
| except (asyncio.TimeoutError, asyncio.CancelledError): | |
| pass | |
| # Stop global processor | |
| if self._global_processor: | |
| await self._global_queue.put((999, None)) # Sentinel | |
| try: | |
| await asyncio.wait_for(self._global_processor, timeout=5.0) | |
| except asyncio.TimeoutError: | |
| log_warning("EventBus global processor timeout, cancelling") | |
| self._global_processor.cancel() | |
| log_info("β EventBus stopped") | |
| async def publish(self, event: Event): | |
| """Publish an event to the bus""" | |
| if not self._running: | |
| log_error("EventBus not running, cannot publish event", event_type=event.type.value) | |
| return | |
| # Add to history | |
| self._event_history.append(event) | |
| if len(self._event_history) > self._max_history_size: | |
| self._event_history.pop(0) | |
| # Route to appropriate queue | |
| if event.session_id: | |
| # Ensure session queue exists | |
| if event.session_id not in self._session_queues: | |
| await self._create_session_processor(event.session_id) | |
| # Add to session queue | |
| queue = self._session_queues[event.session_id] | |
| await queue.put((-event.priority, event)) | |
| else: | |
| # Add to global queue | |
| await self._global_queue.put((-event.priority, event)) | |
| log_debug( | |
| f"π€ Event published", | |
| event_type=event.type.value, | |
| session_id=event.session_id, | |
| priority=event.priority | |
| ) | |
| async def _create_session_processor(self, session_id: str): | |
| """Create a processor for session-specific events""" | |
| if session_id in self._session_processors: | |
| return | |
| # Create queue | |
| self._session_queues[session_id] = asyncio.PriorityQueue() | |
| # Create processor task | |
| task = asyncio.create_task(self._process_session_events(session_id)) | |
| self._session_processors[session_id] = task | |
| log_debug(f"π Created session processor", session_id=session_id) | |
| async def _process_session_events(self, session_id: str): | |
| """Process events for a specific session""" | |
| queue = self._session_queues[session_id] | |
| log_info(f"π Session event processor started", session_id=session_id) | |
| while self._running: | |
| try: | |
| # Wait for event with timeout | |
| priority, event = await asyncio.wait_for( | |
| queue.get(), | |
| timeout=60.0 # Longer timeout for sessions | |
| ) | |
| # Check for session cleanup | |
| if event is None: | |
| break | |
| # Process the event | |
| await self._dispatch_event(event) | |
| except asyncio.TimeoutError: | |
| # Check if session is still active | |
| if session_id not in self._session_handlers: | |
| log_info(f"Session inactive, stopping processor", session_id=session_id) | |
| break | |
| continue | |
| except Exception as e: | |
| log_error( | |
| f"β Error processing session event", | |
| session_id=session_id, | |
| error=str(e), | |
| traceback=traceback.format_exc() | |
| ) | |
| # Cleanup | |
| self._session_queues.pop(session_id, None) | |
| self._session_processors.pop(session_id, None) | |
| log_info(f"π Session event processor stopped", session_id=session_id) | |
| async def _process_global_events(self): | |
| """Process global events (no session_id)""" | |
| log_info("π Global event processor started") | |
| while self._running: | |
| try: | |
| priority, event = await asyncio.wait_for( | |
| self._global_queue.get(), | |
| timeout=1.0 | |
| ) | |
| if event is None: # Sentinel | |
| break | |
| await self._dispatch_event(event) | |
| except asyncio.TimeoutError: | |
| continue | |
| except Exception as e: | |
| log_error( | |
| "β Error processing global event", | |
| error=str(e), | |
| traceback=traceback.format_exc() | |
| ) | |
| log_info("π Global event processor stopped") | |
| def subscribe(self, event_type: EventType, handler: Callable): | |
| """Subscribe to an event type globally""" | |
| self._subscribers[event_type].append(handler) | |
| log_debug(f"π Global subscription added", event_type=event_type.value) | |
| def subscribe_session(self, session_id: str, event_type: EventType, handler: Callable): | |
| """Subscribe to an event type for a specific session""" | |
| self._session_handlers[session_id][event_type].append(handler) | |
| log_debug( | |
| f"π Session subscription added", | |
| event_type=event_type.value, | |
| session_id=session_id | |
| ) | |
| def unsubscribe(self, event_type: EventType, handler: Callable): | |
| """Unsubscribe from an event type""" | |
| if handler in self._subscribers[event_type]: | |
| self._subscribers[event_type].remove(handler) | |
| log_debug(f"π Global subscription removed", event_type=event_type.value) | |
| def unsubscribe_session(self, session_id: str, event_type: EventType = None): | |
| """Unsubscribe session handlers""" | |
| if event_type: | |
| # Remove specific event type for session | |
| if session_id in self._session_handlers and event_type in self._session_handlers[session_id]: | |
| del self._session_handlers[session_id][event_type] | |
| else: | |
| # Remove all handlers for session | |
| if session_id in self._session_handlers: | |
| del self._session_handlers[session_id] | |
| log_debug(f"π All session subscriptions removed", session_id=session_id) | |
| async def _dispatch_event(self, event: Event): | |
| """Dispatch event to all subscribers""" | |
| try: | |
| handlers = [] | |
| # Get global handlers | |
| if event.type in self._subscribers: | |
| handlers.extend(self._subscribers[event.type]) | |
| # Get session-specific handlers | |
| if event.session_id in self._session_handlers: | |
| if event.type in self._session_handlers[event.session_id]: | |
| handlers.extend(self._session_handlers[event.session_id][event.type]) | |
| if not handlers: | |
| log_debug( | |
| f"π No handlers for event", | |
| event_type=event.type.value, | |
| session_id=event.session_id | |
| ) | |
| return | |
| log_debug( | |
| f"π¨ Dispatching event to {len(handlers)} handlers", | |
| event_type=event.type.value, | |
| session_id=event.session_id | |
| ) | |
| # Call all handlers concurrently | |
| tasks = [] | |
| for handler in handlers: | |
| if asyncio.iscoroutinefunction(handler): | |
| task = asyncio.create_task(handler(event)) | |
| else: | |
| # Wrap sync handler in async | |
| task = asyncio.create_task(asyncio.to_thread(handler, event)) | |
| tasks.append(task) | |
| # Wait for all handlers to complete | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| # Log any exceptions | |
| for i, result in enumerate(results): | |
| if isinstance(result, Exception): | |
| log_error( | |
| f"β Handler error", | |
| handler=handlers[i].__name__, | |
| event_type=event.type.value, | |
| error=str(result), | |
| traceback=traceback.format_exception(type(result), result, result.__traceback__) | |
| ) | |
| except Exception as e: | |
| log_error( | |
| f"β Error dispatching event", | |
| event_type=event.type.value, | |
| error=str(e), | |
| traceback=traceback.format_exc() | |
| ) | |
| def get_event_history(self, session_id: Optional[str] = None, event_type: Optional[EventType] = None) -> List[Event]: | |
| """Get event history with optional filters""" | |
| history = self._event_history | |
| if session_id: | |
| history = [e for e in history if e.session_id == session_id] | |
| if event_type: | |
| history = [e for e in history if e.type == event_type] | |
| return history | |
| def clear_session_data(self, session_id: str): | |
| """Clear all session-related data and stop processor""" | |
| # Remove session handlers | |
| self.unsubscribe_session(session_id) | |
| # Stop session processor | |
| if session_id in self._session_processors: | |
| task = self._session_processors[session_id] | |
| task.cancel() | |
| # Clear queues | |
| self._session_queues.pop(session_id, None) | |
| self._session_processors.pop(session_id, None) | |
| # Remove session events from history | |
| self._event_history = [e for e in self._event_history if e.session_id != session_id] | |
| log_debug(f"π§Ή Session data cleared", session_id=session_id) | |
| # Global event bus instance | |
| event_bus = EventBus() | |
| # Helper functions for common event publishing patterns | |
| async def publish_error(session_id: str, error_type: str, error_message: str, details: Dict[str, Any] = None): | |
| """Helper to publish error events""" | |
| event = Event( | |
| type=EventType.RECOVERABLE_ERROR if error_type != "critical" else EventType.CRITICAL_ERROR, | |
| session_id=session_id, | |
| data={ | |
| "error_type": error_type, | |
| "message": error_message, | |
| "details": details or {} | |
| }, | |
| priority=10 # High priority for errors | |
| ) | |
| await event_bus.publish(event) | |
| async def publish_state_transition(session_id: str, from_state: str, to_state: str, reason: str = None): | |
| """Helper to publish state transition events""" | |
| event = Event( | |
| type=EventType.STATE_TRANSITION, | |
| session_id=session_id, | |
| data={ | |
| "from_state": from_state, | |
| "to_state": to_state, | |
| "reason": reason | |
| }, | |
| priority=5 # Medium priority for state changes | |
| ) | |
| await event_bus.publish(event) |