Spaces:
Paused
Paused
| """ | |
| State Orchestrator for Flare Realtime Chat | |
| ========================================== | |
| Central state machine and flow control | |
| """ | |
| import asyncio | |
| from typing import Dict, Optional, Set, Any | |
| from enum import Enum | |
| from datetime import datetime | |
| import traceback | |
| from dataclasses import dataclass, field | |
| from event_bus import EventBus, Event, EventType, publish_state_transition, publish_error | |
| from session import Session | |
| from utils.logger import log_info, log_error, log_debug, log_warning | |
| class ConversationState(Enum): | |
| """Conversation states""" | |
| IDLE = "idle" | |
| INITIALIZING = "initializing" | |
| PREPARING_WELCOME = "preparing_welcome" | |
| PLAYING_WELCOME = "playing_welcome" | |
| LISTENING = "listening" | |
| PROCESSING_SPEECH = "processing_speech" | |
| PREPARING_RESPONSE = "preparing_response" | |
| PLAYING_RESPONSE = "playing_response" | |
| ERROR = "error" | |
| ENDED = "ended" | |
| class SessionContext: | |
| """Context for a conversation session""" | |
| session_id: str | |
| session: Session | |
| state: ConversationState = ConversationState.IDLE | |
| stt_instance: Optional[Any] = None | |
| tts_instance: Optional[Any] = None | |
| llm_context: Optional[Any] = None | |
| audio_buffer: Optional[Any] = None | |
| websocket_connection: Optional[Any] = None | |
| created_at: datetime = field(default_factory=datetime.utcnow) | |
| last_activity: datetime = field(default_factory=datetime.utcnow) | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| def update_activity(self): | |
| """Update last activity timestamp""" | |
| self.last_activity = datetime.utcnow() | |
| async def cleanup(self): | |
| """Cleanup all session resources""" | |
| # Cleanup will be implemented by resource managers | |
| log_debug(f"🧹 Cleaning up session context", session_id=self.session_id) | |
| class StateOrchestrator: | |
| """Central state machine for conversation flow""" | |
| # Valid state transitions | |
| VALID_TRANSITIONS = { | |
| ConversationState.IDLE: {ConversationState.INITIALIZING}, | |
| ConversationState.INITIALIZING: {ConversationState.PREPARING_WELCOME, ConversationState.LISTENING}, | |
| ConversationState.PREPARING_WELCOME: {ConversationState.PLAYING_WELCOME, ConversationState.ERROR}, | |
| ConversationState.PLAYING_WELCOME: {ConversationState.LISTENING, ConversationState.ERROR}, | |
| ConversationState.LISTENING: {ConversationState.PROCESSING_SPEECH, ConversationState.ERROR, ConversationState.ENDED}, | |
| ConversationState.PROCESSING_SPEECH: {ConversationState.PREPARING_RESPONSE, ConversationState.ERROR}, | |
| ConversationState.PREPARING_RESPONSE: {ConversationState.PLAYING_RESPONSE, ConversationState.ERROR}, | |
| ConversationState.PLAYING_RESPONSE: {ConversationState.LISTENING, ConversationState.ERROR}, | |
| ConversationState.ERROR: {ConversationState.LISTENING, ConversationState.ENDED}, | |
| ConversationState.ENDED: set() # No transitions from ENDED | |
| } | |
| def __init__(self, event_bus: EventBus): | |
| self.event_bus = event_bus | |
| self.sessions: Dict[str, SessionContext] = {} | |
| self._setup_event_handlers() | |
| def _setup_event_handlers(self): | |
| """Subscribe to relevant events""" | |
| # Conversation events | |
| self.event_bus.subscribe(EventType.CONVERSATION_STARTED, self._handle_conversation_started) | |
| self.event_bus.subscribe(EventType.CONVERSATION_ENDED, self._handle_conversation_ended) | |
| # Session lifecycle | |
| self.event_bus.subscribe(EventType.SESSION_STARTED, self._handle_session_started) | |
| self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended) | |
| # STT events | |
| self.event_bus.subscribe(EventType.STT_READY, self._handle_stt_ready) | |
| self.event_bus.subscribe(EventType.STT_RESULT, self._handle_stt_result) | |
| self.event_bus.subscribe(EventType.STT_ERROR, self._handle_stt_error) | |
| # TTS events | |
| self.event_bus.subscribe(EventType.TTS_COMPLETED, self._handle_tts_completed) | |
| self.event_bus.subscribe(EventType.TTS_ERROR, self._handle_tts_error) | |
| # Audio events | |
| self.event_bus.subscribe(EventType.AUDIO_PLAYBACK_COMPLETED, self._handle_audio_playback_completed) | |
| # LLM events | |
| self.event_bus.subscribe(EventType.LLM_RESPONSE_READY, self._handle_llm_response_ready) | |
| self.event_bus.subscribe(EventType.LLM_ERROR, self._handle_llm_error) | |
| # Error events | |
| self.event_bus.subscribe(EventType.CRITICAL_ERROR, self._handle_critical_error) | |
| async def _handle_conversation_started(self, event: Event) -> None: | |
| """Handle conversation start within existing session""" | |
| session_id = event.session_id | |
| context = self.sessions.get(session_id) | |
| if not context: | |
| log_error(f"❌ Session not found for conversation start | session_id={session_id}") | |
| return | |
| log_info(f"🎤 Conversation started | session_id={session_id}") | |
| # İlk olarak IDLE'dan INITIALIZING'e geç | |
| await self.transition_to(session_id, ConversationState.INITIALIZING) | |
| # Welcome mesajı varsa | |
| if context.metadata.get("has_welcome") and context.metadata.get("welcome_text"): | |
| await self.transition_to(session_id, ConversationState.PREPARING_WELCOME) | |
| # Request TTS for welcome message | |
| await self.event_bus.publish(Event( | |
| type=EventType.TTS_STARTED, | |
| session_id=session_id, | |
| data={ | |
| "text": context.metadata.get("welcome_text", ""), | |
| "is_welcome": True | |
| } | |
| )) | |
| else: | |
| # Welcome yoksa direkt LISTENING'e geç | |
| await self.transition_to(session_id, ConversationState.LISTENING) | |
| # Start STT | |
| await self.event_bus.publish( | |
| Event( | |
| type=EventType.STT_STARTED, | |
| data={}, | |
| session_id=session_id | |
| ) | |
| ) | |
| async def _handle_conversation_ended(self, event: Event) -> None: | |
| """Handle conversation end - but keep session alive""" | |
| session_id = event.session_id | |
| context = self.sessions.get(session_id) | |
| if not context: | |
| log_warning(f"⚠️ Session not found for conversation end | session_id={session_id}") | |
| return | |
| log_info(f"🔚 Conversation ended | session_id={session_id}") | |
| # Stop STT if running | |
| await self.event_bus.publish(Event( | |
| type=EventType.STT_STOPPED, | |
| session_id=session_id, | |
| data={"reason": "conversation_ended"} | |
| )) | |
| # Stop any ongoing TTS | |
| await self.event_bus.publish(Event( | |
| type=EventType.TTS_STOPPED, | |
| session_id=session_id, | |
| data={"reason": "conversation_ended"} | |
| )) | |
| # Transition back to IDLE - session still alive! | |
| await self.transition_to(session_id, ConversationState.IDLE) | |
| log_info(f"💤 Session back to IDLE, ready for new conversation | session_id={session_id}") | |
| async def _handle_session_started(self, event: Event): | |
| """Handle session start""" | |
| session_id = event.session_id | |
| session_data = event.data | |
| log_info(f"🎬 Session started", session_id=session_id) | |
| # Create session context | |
| context = SessionContext( | |
| session_id=session_id, | |
| session=session_data.get("session"), | |
| metadata={ | |
| "has_welcome": session_data.get("has_welcome", False), | |
| "welcome_text": session_data.get("welcome_text", "") | |
| } | |
| ) | |
| self.sessions[session_id] = context | |
| # Session başladığında IDLE state'te kalmalı | |
| # Conversation başlayana kadar bekleyeceğiz | |
| # Zaten SessionContext default state'i IDLE | |
| log_info(f"📍 Session created in IDLE state | session_id={session_id}") | |
| async def _handle_session_ended(self, event: Event): | |
| """Handle session end - complete cleanup""" | |
| session_id = event.session_id | |
| log_info(f"🏁 Session ended | session_id={session_id}") | |
| # Get context for cleanup | |
| context = self.sessions.get(session_id) | |
| if context: | |
| # Try to transition to ENDED if possible | |
| try: | |
| await self.transition_to(session_id, ConversationState.ENDED) | |
| except Exception as e: | |
| log_warning(f"Could not transition to ENDED state: {e}") | |
| # Stop all components | |
| await self.event_bus.publish(Event( | |
| type=EventType.STT_STOPPED, | |
| session_id=session_id, | |
| data={"reason": "session_ended"} | |
| )) | |
| await self.event_bus.publish(Event( | |
| type=EventType.TTS_STOPPED, | |
| session_id=session_id, | |
| data={"reason": "session_ended"} | |
| )) | |
| # Cleanup session context | |
| await context.cleanup() | |
| # Remove session | |
| self.sessions.pop(session_id, None) | |
| # Clear event bus session data | |
| self.event_bus.clear_session_data(session_id) | |
| log_info(f"✅ Session fully cleaned up | session_id={session_id}") | |
| async def _handle_stt_ready(self, event: Event): | |
| """Handle STT ready signal""" | |
| session_id = event.session_id | |
| current_state = self.get_state(session_id) | |
| log_debug(f"🎤 STT ready", session_id=session_id, current_state=current_state) | |
| # Only process if we're expecting STT to be ready | |
| if current_state in [ConversationState.LISTENING, ConversationState.PLAYING_WELCOME]: | |
| # STT is ready, we're already in the right state | |
| pass | |
| async def _handle_stt_result(self, event: Event): | |
| """Handle STT transcription result""" | |
| session_id = event.session_id | |
| context = self.sessions.get(session_id) | |
| if not context: | |
| return | |
| current_state = context.state | |
| result_data = event.data | |
| is_final = result_data.get("is_final", False) | |
| # Interim result'ları websocket'e gönder ama state değiştirme | |
| if not is_final: | |
| # Sadece log, state değişikliği yok | |
| text = result_data.get("text", "").strip() | |
| if text: | |
| log_debug(f"📝 Interim transcription: '{text}'", session_id=session_id) | |
| return | |
| # Final result işleme | |
| text = result_data.get("text", "").strip() | |
| if not text: | |
| log_warning(f"⚠️ Empty final transcription", session_id=session_id) | |
| return | |
| if current_state != ConversationState.LISTENING: | |
| log_warning( | |
| f"⚠️ STT result in unexpected state", | |
| session_id=session_id, | |
| state=current_state.value | |
| ) | |
| return | |
| log_info(f"💬 Final transcription: '{text}'", session_id=session_id) | |
| # ✅ STT'yi otomatik durdur | |
| await self.event_bus.publish(Event( | |
| type=EventType.STT_STOPPED, | |
| session_id=session_id, | |
| data={"reason": "utterance_completed"} | |
| )) | |
| # Transition to processing | |
| await self.transition_to(session_id, ConversationState.PROCESSING_SPEECH) | |
| # Send to LLM | |
| await self.event_bus.publish(Event( | |
| type=EventType.LLM_PROCESSING_STARTED, | |
| session_id=session_id, | |
| data={"text": text} | |
| )) | |
| async def _handle_llm_response_ready(self, event: Event): | |
| """Handle LLM response""" | |
| session_id = event.session_id | |
| current_state = self.get_state(session_id) | |
| if current_state != ConversationState.PROCESSING_SPEECH: | |
| log_warning( | |
| f"⚠️ LLM response in unexpected state", | |
| session_id=session_id, | |
| state=current_state | |
| ) | |
| return | |
| response_text = event.data.get("text", "") | |
| log_info(f"🤖 LLM response ready", session_id=session_id, length=len(response_text)) | |
| # Transition to preparing response | |
| await self.transition_to(session_id, ConversationState.PREPARING_RESPONSE) | |
| # Request TTS | |
| await self.event_bus.publish(Event( | |
| type=EventType.TTS_STARTED, | |
| session_id=session_id, | |
| data={"text": response_text} | |
| )) | |
| async def _handle_tts_completed(self, event: Event): | |
| """Handle TTS completion""" | |
| session_id = event.session_id | |
| context = self.sessions.get(session_id) | |
| if not context: | |
| return | |
| current_state = context.state | |
| log_info(f"🔊 TTS completed", session_id=session_id, state=current_state.value) | |
| if current_state == ConversationState.PREPARING_WELCOME: | |
| await self.transition_to(session_id, ConversationState.PLAYING_WELCOME) | |
| # Welcome audio frontend'te çalınacak, biz sadece state'i güncelliyoruz | |
| # Frontend audio bitince bize audio_playback_completed gönderecek | |
| elif current_state == ConversationState.PREPARING_RESPONSE: | |
| await self.transition_to(session_id, ConversationState.PLAYING_RESPONSE) | |
| async def _handle_audio_playback_completed(self, event: Event): | |
| """Handle audio playback completion""" | |
| session_id = event.session_id | |
| context = self.sessions.get(session_id) | |
| if not context: | |
| return | |
| current_state = context.state | |
| log_info(f"🎵 Audio playback completed", session_id=session_id, state=current_state.value) | |
| if current_state in [ConversationState.PLAYING_WELCOME, ConversationState.PLAYING_RESPONSE]: | |
| # Transition to listening | |
| await self.transition_to(session_id, ConversationState.LISTENING) | |
| # ✅ STT'yi başlat - tek konuşma modunda | |
| locale = context.metadata.get("locale", "tr") | |
| await self.event_bus.publish(Event( | |
| type=EventType.STT_STARTED, | |
| session_id=session_id, | |
| data={ | |
| "locale": locale, | |
| "single_utterance": True, # ✅ Tek konuşma modu | |
| "interim_results": False, # ✅ Sadece final | |
| "speech_timeout_ms": 2000 # 2 saniye sessizlik | |
| } | |
| )) | |
| # Send STT ready signal to frontend | |
| await self.event_bus.publish(Event( | |
| type=EventType.STT_READY, | |
| session_id=session_id, | |
| data={} | |
| )) | |
| async def _handle_stt_error(self, event: Event): | |
| """Handle STT errors""" | |
| session_id = event.session_id | |
| error_data = event.data | |
| log_error( | |
| f"❌ STT error", | |
| session_id=session_id, | |
| error=error_data.get("message") | |
| ) | |
| # Try to recover by transitioning back to listening | |
| current_state = self.get_state(session_id) | |
| if current_state != ConversationState.ENDED: | |
| await self.transition_to(session_id, ConversationState.ERROR) | |
| # Try recovery after delay | |
| await asyncio.sleep(2.0) | |
| if self.get_state(session_id) == ConversationState.ERROR: | |
| await self.transition_to(session_id, ConversationState.LISTENING) | |
| # Restart STT | |
| await self.event_bus.publish(Event( | |
| type=EventType.STT_STARTED, | |
| session_id=session_id, | |
| data={"retry": True} | |
| )) | |
| async def _handle_tts_error(self, event: Event): | |
| """Handle TTS errors""" | |
| session_id = event.session_id | |
| error_data = event.data | |
| log_error( | |
| f"❌ TTS error", | |
| session_id=session_id, | |
| error=error_data.get("message") | |
| ) | |
| # Skip TTS and go to listening | |
| current_state = self.get_state(session_id) | |
| if current_state in [ConversationState.PREPARING_WELCOME, ConversationState.PREPARING_RESPONSE]: | |
| await self.transition_to(session_id, ConversationState.LISTENING) | |
| # Start STT | |
| await self.event_bus.publish(Event( | |
| type=EventType.STT_STARTED, | |
| session_id=session_id, | |
| data={} | |
| )) | |
| async def _handle_llm_error(self, event: Event): | |
| """Handle LLM errors""" | |
| session_id = event.session_id | |
| error_data = event.data | |
| log_error( | |
| f"❌ LLM error", | |
| session_id=session_id, | |
| error=error_data.get("message") | |
| ) | |
| # Go back to listening | |
| await self.transition_to(session_id, ConversationState.LISTENING) | |
| # Start STT | |
| await self.event_bus.publish(Event( | |
| type=EventType.STT_STARTED, | |
| session_id=session_id, | |
| data={} | |
| )) | |
| async def _handle_critical_error(self, event: Event): | |
| """Handle critical errors""" | |
| session_id = event.session_id | |
| error_data = event.data | |
| log_error( | |
| f"💥 Critical error", | |
| session_id=session_id, | |
| error=error_data.get("message") | |
| ) | |
| # End session | |
| await self.transition_to(session_id, ConversationState.ENDED) | |
| # Publish session end event | |
| await self.event_bus.publish(Event( | |
| type=EventType.SESSION_ENDED, | |
| session_id=session_id, | |
| data={"reason": "critical_error"} | |
| )) | |
| async def transition_to(self, session_id: str, new_state: ConversationState) -> bool: | |
| """ | |
| Transition to a new state with validation | |
| """ | |
| try: | |
| # Get session context | |
| context = self.sessions.get(session_id) | |
| if not context: | |
| log_info(f"❌ Session not found for state transition | session_id={session_id}") | |
| return False | |
| # Get current state from context | |
| current_state = context.state | |
| # Check if transition is valid | |
| if new_state not in self.VALID_TRANSITIONS.get(current_state, set()): | |
| log_info(f"❌ Invalid state transition | session_id={session_id}, current={current_state.value}, requested={new_state.value}") | |
| return False | |
| # Update state | |
| old_state = current_state | |
| context.state = new_state | |
| context.last_activity = datetime.utcnow() | |
| log_info(f"✅ State transition | session_id={session_id}, {old_state.value} → {new_state.value}") | |
| # Emit state transition event with correct field names | |
| await self.event_bus.publish( | |
| Event( | |
| type=EventType.STATE_TRANSITION, | |
| data={ | |
| "old_state": old_state.value, # Backend uses old_state/new_state | |
| "new_state": new_state.value, | |
| "timestamp": datetime.utcnow().isoformat() | |
| }, | |
| session_id=session_id | |
| ) | |
| ) | |
| return True | |
| except Exception as e: | |
| log_error(f"❌ State transition error | session_id={session_id}", e) | |
| return False | |
| def get_state(self, session_id: str) -> Optional[ConversationState]: | |
| """Get current state for a session""" | |
| return self.sessions.get(session_id) | |
| def get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]: | |
| """Get session data""" | |
| return self.session_data.get(session_id) | |
| async def handle_error_recovery(self, session_id: str, error_type: str): | |
| """Handle error recovery strategies""" | |
| context = self.sessions.get(session_id) | |
| if not context or context.state == ConversationState.ENDED: | |
| return | |
| log_info( | |
| f"🔧 Attempting error recovery", | |
| session_id=session_id, | |
| error_type=error_type, | |
| current_state=context.state.value | |
| ) | |
| # Update activity | |
| context.update_activity() | |
| # Define recovery strategies | |
| recovery_strategies = { | |
| "stt_error": self._recover_from_stt_error, | |
| "tts_error": self._recover_from_tts_error, | |
| "llm_error": self._recover_from_llm_error, | |
| "websocket_error": self._recover_from_websocket_error | |
| } | |
| strategy = recovery_strategies.get(error_type) | |
| if strategy: | |
| await strategy(session_id) | |
| else: | |
| # Default recovery: go to error state then back to listening | |
| await self.transition_to(session_id, ConversationState.ERROR) | |
| await asyncio.sleep(1.0) | |
| await self.transition_to(session_id, ConversationState.LISTENING) | |
| async def _recover_from_stt_error(self, session_id: str): | |
| """Recover from STT error""" | |
| # Stop STT, wait, restart | |
| await self.event_bus.publish(Event( | |
| type=EventType.STT_STOPPED, | |
| session_id=session_id, | |
| data={"reason": "error_recovery"} | |
| )) | |
| await asyncio.sleep(2.0) | |
| await self.transition_to(session_id, ConversationState.LISTENING) | |
| await self.event_bus.publish(Event( | |
| type=EventType.STT_STARTED, | |
| session_id=session_id, | |
| data={"retry": True} | |
| )) | |
| async def _recover_from_tts_error(self, session_id: str): | |
| """Recover from TTS error""" | |
| # Skip TTS, go directly to listening | |
| await self.transition_to(session_id, ConversationState.LISTENING) | |
| await self.event_bus.publish(Event( | |
| type=EventType.STT_STARTED, | |
| session_id=session_id, | |
| data={} | |
| )) | |
| async def _recover_from_llm_error(self, session_id: str): | |
| """Recover from LLM error""" | |
| # Go back to listening | |
| await self.transition_to(session_id, ConversationState.LISTENING) | |
| await self.event_bus.publish(Event( | |
| type=EventType.STT_STARTED, | |
| session_id=session_id, | |
| data={} | |
| )) | |
| async def _recover_from_websocket_error(self, session_id: str): | |
| """Recover from WebSocket error""" | |
| # End session cleanly | |
| await self.transition_to(session_id, ConversationState.ENDED) | |
| await self.event_bus.publish(Event( | |
| type=EventType.SESSION_ENDED, | |
| session_id=session_id, | |
| data={"reason": "websocket_error"} | |
| )) |