Spaces:
Running
Running
Upload 8 files
Browse files- audio_buffer_manager.py +296 -0
- event_bus.py +410 -0
- llm_manager.py +689 -0
- resource_manager.py +401 -0
- state_orchestrator.py +511 -0
- stt_lifecycle_manager.py +366 -0
- tts_lifecycle_manager.py +377 -0
- websocket_manager.py +408 -0
audio_buffer_manager.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Audio Buffer Manager for Flare
|
3 |
+
==============================
|
4 |
+
Manages audio buffering, silence detection, and chunk processing
|
5 |
+
"""
|
6 |
+
import asyncio
|
7 |
+
from typing import Dict, Optional, List, Tuple
|
8 |
+
from collections import deque
|
9 |
+
from datetime import datetime
|
10 |
+
import base64
|
11 |
+
import numpy as np
|
12 |
+
from dataclasses import dataclass
|
13 |
+
import traceback
|
14 |
+
|
15 |
+
from event_bus import EventBus, Event, EventType
|
16 |
+
from logger import log_info, log_error, log_debug, log_warning
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class AudioChunk:
|
21 |
+
"""Audio chunk with metadata"""
|
22 |
+
data: bytes
|
23 |
+
timestamp: datetime
|
24 |
+
chunk_index: int
|
25 |
+
is_speech: bool = True
|
26 |
+
energy_level: float = 0.0
|
27 |
+
|
28 |
+
|
29 |
+
class SilenceDetector:
|
30 |
+
"""Detect silence in audio stream"""
|
31 |
+
|
32 |
+
def __init__(self,
|
33 |
+
threshold_ms: int = 2000,
|
34 |
+
energy_threshold: float = 0.01,
|
35 |
+
sample_rate: int = 16000):
|
36 |
+
self.threshold_ms = threshold_ms
|
37 |
+
self.energy_threshold = energy_threshold
|
38 |
+
self.sample_rate = sample_rate
|
39 |
+
self.silence_start: Optional[datetime] = None
|
40 |
+
|
41 |
+
def detect_silence(self, audio_chunk: bytes) -> Tuple[bool, int]:
|
42 |
+
"""
|
43 |
+
Detect if chunk is silence and return duration
|
44 |
+
Returns: (is_silence, silence_duration_ms)
|
45 |
+
"""
|
46 |
+
try:
|
47 |
+
# Handle empty or invalid chunk
|
48 |
+
if not audio_chunk or len(audio_chunk) < 2:
|
49 |
+
return True, 0
|
50 |
+
|
51 |
+
# Ensure even number of bytes for 16-bit audio
|
52 |
+
if len(audio_chunk) % 2 != 0:
|
53 |
+
audio_chunk = audio_chunk[:-1]
|
54 |
+
|
55 |
+
# Convert to numpy array
|
56 |
+
audio_data = np.frombuffer(audio_chunk, dtype=np.int16)
|
57 |
+
|
58 |
+
if len(audio_data) == 0:
|
59 |
+
return True, 0
|
60 |
+
|
61 |
+
# Calculate RMS energy
|
62 |
+
rms = np.sqrt(np.mean(audio_data.astype(float) ** 2))
|
63 |
+
normalized_rms = rms / 32768.0 # Normalize for 16-bit audio
|
64 |
+
|
65 |
+
is_silence = normalized_rms < self.energy_threshold
|
66 |
+
|
67 |
+
# Track silence duration
|
68 |
+
now = datetime.utcnow()
|
69 |
+
if is_silence:
|
70 |
+
if self.silence_start is None:
|
71 |
+
self.silence_start = now
|
72 |
+
duration_ms = int((now - self.silence_start).total_seconds() * 1000)
|
73 |
+
else:
|
74 |
+
self.silence_start = None
|
75 |
+
duration_ms = 0
|
76 |
+
|
77 |
+
return is_silence, duration_ms
|
78 |
+
|
79 |
+
except Exception as e:
|
80 |
+
log_warning(f"Silence detection error: {e}")
|
81 |
+
return False, 0
|
82 |
+
|
83 |
+
def reset(self):
|
84 |
+
"""Reset silence detection state"""
|
85 |
+
self.silence_start = None
|
86 |
+
|
87 |
+
|
88 |
+
class AudioBuffer:
|
89 |
+
"""Manage audio chunks for a session"""
|
90 |
+
|
91 |
+
def __init__(self,
|
92 |
+
session_id: str,
|
93 |
+
max_chunks: int = 1000,
|
94 |
+
chunk_size_bytes: int = 4096):
|
95 |
+
self.session_id = session_id
|
96 |
+
self.max_chunks = max_chunks
|
97 |
+
self.chunk_size_bytes = chunk_size_bytes
|
98 |
+
self.chunks: deque[AudioChunk] = deque(maxlen=max_chunks)
|
99 |
+
self.chunk_counter = 0
|
100 |
+
self.total_bytes = 0
|
101 |
+
self.lock = asyncio.Lock()
|
102 |
+
|
103 |
+
async def add_chunk(self, audio_data: bytes, timestamp: Optional[datetime] = None) -> AudioChunk:
|
104 |
+
"""Add audio chunk to buffer"""
|
105 |
+
async with self.lock:
|
106 |
+
if timestamp is None:
|
107 |
+
timestamp = datetime.utcnow()
|
108 |
+
|
109 |
+
chunk = AudioChunk(
|
110 |
+
data=audio_data,
|
111 |
+
timestamp=timestamp,
|
112 |
+
chunk_index=self.chunk_counter
|
113 |
+
)
|
114 |
+
|
115 |
+
self.chunks.append(chunk)
|
116 |
+
self.chunk_counter += 1
|
117 |
+
self.total_bytes += len(audio_data)
|
118 |
+
|
119 |
+
return chunk
|
120 |
+
|
121 |
+
async def get_recent_audio(self, duration_ms: int = 5000) -> bytes:
|
122 |
+
"""Get recent audio data"""
|
123 |
+
async with self.lock:
|
124 |
+
cutoff_time = datetime.utcnow()
|
125 |
+
audio_parts = []
|
126 |
+
|
127 |
+
# Iterate backwards through chunks
|
128 |
+
for chunk in reversed(self.chunks):
|
129 |
+
time_diff = (cutoff_time - chunk.timestamp).total_seconds() * 1000
|
130 |
+
if time_diff > duration_ms:
|
131 |
+
break
|
132 |
+
audio_parts.append(chunk.data)
|
133 |
+
|
134 |
+
# Reverse to maintain chronological order
|
135 |
+
audio_parts.reverse()
|
136 |
+
return b''.join(audio_parts)
|
137 |
+
|
138 |
+
async def clear(self):
|
139 |
+
"""Clear buffer"""
|
140 |
+
async with self.lock:
|
141 |
+
self.chunks.clear()
|
142 |
+
self.chunk_counter = 0
|
143 |
+
self.total_bytes = 0
|
144 |
+
|
145 |
+
def get_stats(self) -> Dict[str, Any]:
|
146 |
+
"""Get buffer statistics"""
|
147 |
+
return {
|
148 |
+
"chunks": len(self.chunks),
|
149 |
+
"total_bytes": self.total_bytes,
|
150 |
+
"chunk_counter": self.chunk_counter,
|
151 |
+
"oldest_chunk": self.chunks[0].timestamp if self.chunks else None,
|
152 |
+
"newest_chunk": self.chunks[-1].timestamp if self.chunks else None
|
153 |
+
}
|
154 |
+
|
155 |
+
|
156 |
+
class AudioBufferManager:
|
157 |
+
"""Manage audio buffers for all sessions"""
|
158 |
+
|
159 |
+
def __init__(self, event_bus: EventBus):
|
160 |
+
self.event_bus = event_bus
|
161 |
+
self.session_buffers: Dict[str, AudioBuffer] = {}
|
162 |
+
self.silence_detectors: Dict[str, SilenceDetector] = {}
|
163 |
+
self._setup_event_handlers()
|
164 |
+
|
165 |
+
def _setup_event_handlers(self):
|
166 |
+
"""Subscribe to audio events"""
|
167 |
+
self.event_bus.subscribe(EventType.SESSION_STARTED, self._handle_session_started)
|
168 |
+
self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended)
|
169 |
+
self.event_bus.subscribe(EventType.AUDIO_CHUNK_RECEIVED, self._handle_audio_chunk)
|
170 |
+
|
171 |
+
async def _handle_session_started(self, event: Event):
|
172 |
+
"""Initialize buffer for new session"""
|
173 |
+
session_id = event.session_id
|
174 |
+
config = event.data
|
175 |
+
|
176 |
+
# Create audio buffer
|
177 |
+
self.session_buffers[session_id] = AudioBuffer(
|
178 |
+
session_id=session_id,
|
179 |
+
max_chunks=config.get("max_chunks", 1000),
|
180 |
+
chunk_size_bytes=config.get("chunk_size", 4096)
|
181 |
+
)
|
182 |
+
|
183 |
+
# Create silence detector
|
184 |
+
self.silence_detectors[session_id] = SilenceDetector(
|
185 |
+
threshold_ms=config.get("silence_threshold_ms", 2000),
|
186 |
+
energy_threshold=config.get("energy_threshold", 0.01),
|
187 |
+
sample_rate=config.get("sample_rate", 16000)
|
188 |
+
)
|
189 |
+
|
190 |
+
log_info(f"📦 Audio buffer initialized", session_id=session_id)
|
191 |
+
|
192 |
+
async def _handle_session_ended(self, event: Event):
|
193 |
+
"""Cleanup session buffers"""
|
194 |
+
session_id = event.session_id
|
195 |
+
|
196 |
+
# Clear and remove buffer
|
197 |
+
if session_id in self.session_buffers:
|
198 |
+
await self.session_buffers[session_id].clear()
|
199 |
+
del self.session_buffers[session_id]
|
200 |
+
|
201 |
+
# Remove silence detector
|
202 |
+
if session_id in self.silence_detectors:
|
203 |
+
del self.silence_detectors[session_id]
|
204 |
+
|
205 |
+
log_info(f"📦 Audio buffer cleaned up", session_id=session_id)
|
206 |
+
|
207 |
+
async def _handle_audio_chunk(self, event: Event):
|
208 |
+
"""Process incoming audio chunk"""
|
209 |
+
session_id = event.session_id
|
210 |
+
|
211 |
+
buffer = self.session_buffers.get(session_id)
|
212 |
+
detector = self.silence_detectors.get(session_id)
|
213 |
+
|
214 |
+
if not buffer or not detector:
|
215 |
+
log_warning(f"⚠️ No buffer for session", session_id=session_id)
|
216 |
+
return
|
217 |
+
|
218 |
+
try:
|
219 |
+
# Decode audio data
|
220 |
+
audio_data = base64.b64decode(event.data.get("audio_data", ""))
|
221 |
+
|
222 |
+
# Add to buffer
|
223 |
+
chunk = await buffer.add_chunk(audio_data)
|
224 |
+
|
225 |
+
# Detect silence
|
226 |
+
is_silence, silence_duration = detector.detect_silence(audio_data)
|
227 |
+
|
228 |
+
# Update chunk metadata
|
229 |
+
chunk.is_speech = not is_silence
|
230 |
+
chunk.energy_level = 1.0 - (silence_duration / detector.threshold_ms)
|
231 |
+
|
232 |
+
# Check for end of speech
|
233 |
+
if silence_duration > detector.threshold_ms:
|
234 |
+
log_info(
|
235 |
+
f"🔇 Speech ended (silence: {silence_duration}ms)",
|
236 |
+
session_id=session_id
|
237 |
+
)
|
238 |
+
|
239 |
+
# Get complete audio
|
240 |
+
complete_audio = await buffer.get_recent_audio()
|
241 |
+
|
242 |
+
# Publish speech ended event
|
243 |
+
await self.event_bus.publish(Event(
|
244 |
+
type=EventType.STT_RESULT,
|
245 |
+
session_id=session_id,
|
246 |
+
data={
|
247 |
+
"audio_data": base64.b64encode(complete_audio).decode(),
|
248 |
+
"is_final": True,
|
249 |
+
"silence_triggered": True
|
250 |
+
},
|
251 |
+
priority=5
|
252 |
+
))
|
253 |
+
|
254 |
+
# Reset for next utterance
|
255 |
+
await self.reset_buffer(session_id)
|
256 |
+
|
257 |
+
# Log periodically
|
258 |
+
if chunk.chunk_index % 100 == 0:
|
259 |
+
stats = buffer.get_stats()
|
260 |
+
log_debug(
|
261 |
+
f"📊 Buffer stats",
|
262 |
+
session_id=session_id,
|
263 |
+
**stats
|
264 |
+
)
|
265 |
+
|
266 |
+
except Exception as e:
|
267 |
+
log_error(
|
268 |
+
f"❌ Error processing audio chunk",
|
269 |
+
session_id=session_id,
|
270 |
+
error=str(e),
|
271 |
+
traceback=traceback.format_exc()
|
272 |
+
)
|
273 |
+
|
274 |
+
async def get_buffer(self, session_id: str) -> Optional[AudioBuffer]:
|
275 |
+
"""Get buffer for session"""
|
276 |
+
return self.session_buffers.get(session_id)
|
277 |
+
|
278 |
+
async def reset_buffer(self, session_id: str):
|
279 |
+
"""Reset buffer for new utterance"""
|
280 |
+
buffer = self.session_buffers.get(session_id)
|
281 |
+
detector = self.silence_detectors.get(session_id)
|
282 |
+
|
283 |
+
if buffer:
|
284 |
+
await buffer.clear()
|
285 |
+
|
286 |
+
if detector:
|
287 |
+
detector.reset()
|
288 |
+
|
289 |
+
log_debug(f"🔄 Audio buffer reset", session_id=session_id)
|
290 |
+
|
291 |
+
def get_all_stats(self) -> Dict[str, Dict[str, Any]]:
|
292 |
+
"""Get statistics for all buffers"""
|
293 |
+
stats = {}
|
294 |
+
for session_id, buffer in self.session_buffers.items():
|
295 |
+
stats[session_id] = buffer.get_stats()
|
296 |
+
return stats
|
event_bus.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Event Bus Implementation for Flare
|
3 |
+
==================================
|
4 |
+
Provides async event publishing and subscription mechanism
|
5 |
+
"""
|
6 |
+
import asyncio
|
7 |
+
from typing import Dict, List, Callable, Any, Optional
|
8 |
+
from enum import Enum
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from datetime import datetime
|
11 |
+
import traceback
|
12 |
+
from collections import defaultdict
|
13 |
+
import sys
|
14 |
+
|
15 |
+
from logger import log_info, log_error, log_debug, log_warning
|
16 |
+
|
17 |
+
|
18 |
+
class EventType(Enum):
|
19 |
+
"""All event types in the system"""
|
20 |
+
# Lifecycle events
|
21 |
+
SESSION_STARTED = "session_started"
|
22 |
+
SESSION_ENDED = "session_ended"
|
23 |
+
|
24 |
+
# STT events
|
25 |
+
STT_STARTED = "stt_started"
|
26 |
+
STT_STOPPED = "stt_stopped"
|
27 |
+
STT_RESULT = "stt_result"
|
28 |
+
STT_ERROR = "stt_error"
|
29 |
+
STT_READY = "stt_ready"
|
30 |
+
|
31 |
+
# TTS events
|
32 |
+
TTS_STARTED = "tts_started"
|
33 |
+
TTS_CHUNK_READY = "tts_chunk_ready"
|
34 |
+
TTS_COMPLETED = "tts_completed"
|
35 |
+
TTS_ERROR = "tts_error"
|
36 |
+
|
37 |
+
# Audio events
|
38 |
+
AUDIO_PLAYBACK_STARTED = "audio_playback_started"
|
39 |
+
AUDIO_PLAYBACK_COMPLETED = "audio_playback_completed"
|
40 |
+
AUDIO_BUFFER_LOW = "audio_buffer_low"
|
41 |
+
AUDIO_CHUNK_RECEIVED = "audio_chunk_received"
|
42 |
+
|
43 |
+
# LLM events
|
44 |
+
LLM_PROCESSING_STARTED = "llm_processing_started"
|
45 |
+
LLM_RESPONSE_READY = "llm_response_ready"
|
46 |
+
LLM_ERROR = "llm_error"
|
47 |
+
|
48 |
+
# Error events
|
49 |
+
CRITICAL_ERROR = "critical_error"
|
50 |
+
RECOVERABLE_ERROR = "recoverable_error"
|
51 |
+
|
52 |
+
# State events
|
53 |
+
STATE_TRANSITION = "state_transition"
|
54 |
+
STATE_ROLLBACK = "state_rollback"
|
55 |
+
|
56 |
+
# WebSocket events
|
57 |
+
WEBSOCKET_CONNECTED = "websocket_connected"
|
58 |
+
WEBSOCKET_DISCONNECTED = "websocket_disconnected"
|
59 |
+
WEBSOCKET_MESSAGE = "websocket_message"
|
60 |
+
WEBSOCKET_ERROR = "websocket_error"
|
61 |
+
|
62 |
+
|
63 |
+
@dataclass
|
64 |
+
class Event:
|
65 |
+
"""Event data structure"""
|
66 |
+
type: EventType
|
67 |
+
session_id: str
|
68 |
+
data: Dict[str, Any]
|
69 |
+
timestamp: datetime = None
|
70 |
+
priority: int = 0 # Higher priority = processed first
|
71 |
+
|
72 |
+
def __post_init__(self):
|
73 |
+
if self.timestamp is None:
|
74 |
+
self.timestamp = datetime.utcnow()
|
75 |
+
|
76 |
+
def to_dict(self) -> Dict[str, Any]:
|
77 |
+
"""Convert to dictionary for serialization"""
|
78 |
+
return {
|
79 |
+
"type": self.type.value,
|
80 |
+
"session_id": self.session_id,
|
81 |
+
"data": self.data,
|
82 |
+
"timestamp": self.timestamp.isoformat(),
|
83 |
+
"priority": self.priority
|
84 |
+
}
|
85 |
+
|
86 |
+
|
87 |
+
class EventBus:
|
88 |
+
"""Central event bus for component communication with session isolation"""
|
89 |
+
|
90 |
+
def __init__(self):
|
91 |
+
self._subscribers: Dict[EventType, List[Callable]] = defaultdict(list)
|
92 |
+
self._session_handlers: Dict[str, Dict[EventType, List[Callable]]] = defaultdict(lambda: defaultdict(list))
|
93 |
+
|
94 |
+
# Session-specific queues for parallel processing
|
95 |
+
self._session_queues: Dict[str, asyncio.PriorityQueue] = {}
|
96 |
+
self._session_processors: Dict[str, asyncio.Task] = {}
|
97 |
+
|
98 |
+
# Global queue for non-session events
|
99 |
+
self._global_queue: asyncio.PriorityQueue = asyncio.PriorityQueue()
|
100 |
+
self._global_processor: Optional[asyncio.Task] = None
|
101 |
+
|
102 |
+
self._running = False
|
103 |
+
self._event_history: List[Event] = []
|
104 |
+
self._max_history_size = 1000
|
105 |
+
|
106 |
+
async def start(self):
|
107 |
+
"""Start the event processor"""
|
108 |
+
if self._running:
|
109 |
+
log_warning("EventBus already running")
|
110 |
+
return
|
111 |
+
|
112 |
+
self._running = True
|
113 |
+
|
114 |
+
# Start global processor
|
115 |
+
self._global_processor = asyncio.create_task(self._process_global_events())
|
116 |
+
|
117 |
+
log_info("✅ EventBus started")
|
118 |
+
|
119 |
+
async def stop(self):
|
120 |
+
"""Stop the event processor"""
|
121 |
+
self._running = False
|
122 |
+
|
123 |
+
# Stop all session processors
|
124 |
+
for session_id, task in list(self._session_processors.items()):
|
125 |
+
task.cancel()
|
126 |
+
try:
|
127 |
+
await asyncio.wait_for(task, timeout=2.0)
|
128 |
+
except (asyncio.TimeoutError, asyncio.CancelledError):
|
129 |
+
pass
|
130 |
+
|
131 |
+
# Stop global processor
|
132 |
+
if self._global_processor:
|
133 |
+
await self._global_queue.put((999, None)) # Sentinel
|
134 |
+
try:
|
135 |
+
await asyncio.wait_for(self._global_processor, timeout=5.0)
|
136 |
+
except asyncio.TimeoutError:
|
137 |
+
log_warning("EventBus global processor timeout, cancelling")
|
138 |
+
self._global_processor.cancel()
|
139 |
+
|
140 |
+
log_info("✅ EventBus stopped")
|
141 |
+
|
142 |
+
async def publish(self, event: Event):
|
143 |
+
"""Publish an event to the bus"""
|
144 |
+
if not self._running:
|
145 |
+
log_error("EventBus not running, cannot publish event", event_type=event.type.value)
|
146 |
+
return
|
147 |
+
|
148 |
+
# Add to history
|
149 |
+
self._event_history.append(event)
|
150 |
+
if len(self._event_history) > self._max_history_size:
|
151 |
+
self._event_history.pop(0)
|
152 |
+
|
153 |
+
# Route to appropriate queue
|
154 |
+
if event.session_id:
|
155 |
+
# Ensure session queue exists
|
156 |
+
if event.session_id not in self._session_queues:
|
157 |
+
await self._create_session_processor(event.session_id)
|
158 |
+
|
159 |
+
# Add to session queue
|
160 |
+
queue = self._session_queues[event.session_id]
|
161 |
+
await queue.put((-event.priority, event))
|
162 |
+
else:
|
163 |
+
# Add to global queue
|
164 |
+
await self._global_queue.put((-event.priority, event))
|
165 |
+
|
166 |
+
log_debug(
|
167 |
+
f"📤 Event published",
|
168 |
+
event_type=event.type.value,
|
169 |
+
session_id=event.session_id,
|
170 |
+
priority=event.priority
|
171 |
+
)
|
172 |
+
|
173 |
+
async def _create_session_processor(self, session_id: str):
|
174 |
+
"""Create a processor for session-specific events"""
|
175 |
+
if session_id in self._session_processors:
|
176 |
+
return
|
177 |
+
|
178 |
+
# Create queue
|
179 |
+
self._session_queues[session_id] = asyncio.PriorityQueue()
|
180 |
+
|
181 |
+
# Create processor task
|
182 |
+
task = asyncio.create_task(self._process_session_events(session_id))
|
183 |
+
self._session_processors[session_id] = task
|
184 |
+
|
185 |
+
log_debug(f"📌 Created session processor", session_id=session_id)
|
186 |
+
|
187 |
+
async def _process_session_events(self, session_id: str):
|
188 |
+
"""Process events for a specific session"""
|
189 |
+
queue = self._session_queues[session_id]
|
190 |
+
log_info(f"🔄 Session event processor started", session_id=session_id)
|
191 |
+
|
192 |
+
while self._running:
|
193 |
+
try:
|
194 |
+
# Wait for event with timeout
|
195 |
+
priority, event = await asyncio.wait_for(
|
196 |
+
queue.get(),
|
197 |
+
timeout=60.0 # Longer timeout for sessions
|
198 |
+
)
|
199 |
+
|
200 |
+
# Check for session cleanup
|
201 |
+
if event is None:
|
202 |
+
break
|
203 |
+
|
204 |
+
# Process the event
|
205 |
+
await self._dispatch_event(event)
|
206 |
+
|
207 |
+
except asyncio.TimeoutError:
|
208 |
+
# Check if session is still active
|
209 |
+
if session_id not in self._session_handlers:
|
210 |
+
log_info(f"Session inactive, stopping processor", session_id=session_id)
|
211 |
+
break
|
212 |
+
continue
|
213 |
+
except Exception as e:
|
214 |
+
log_error(
|
215 |
+
f"❌ Error processing session event",
|
216 |
+
session_id=session_id,
|
217 |
+
error=str(e),
|
218 |
+
traceback=traceback.format_exc()
|
219 |
+
)
|
220 |
+
|
221 |
+
# Cleanup
|
222 |
+
self._session_queues.pop(session_id, None)
|
223 |
+
self._session_processors.pop(session_id, None)
|
224 |
+
log_info(f"🔄 Session event processor stopped", session_id=session_id)
|
225 |
+
|
226 |
+
async def _process_global_events(self):
|
227 |
+
"""Process global events (no session_id)"""
|
228 |
+
log_info("🔄 Global event processor started")
|
229 |
+
|
230 |
+
while self._running:
|
231 |
+
try:
|
232 |
+
priority, event = await asyncio.wait_for(
|
233 |
+
self._global_queue.get(),
|
234 |
+
timeout=1.0
|
235 |
+
)
|
236 |
+
|
237 |
+
if event is None: # Sentinel
|
238 |
+
break
|
239 |
+
|
240 |
+
await self._dispatch_event(event)
|
241 |
+
|
242 |
+
except asyncio.TimeoutError:
|
243 |
+
continue
|
244 |
+
except Exception as e:
|
245 |
+
log_error(
|
246 |
+
"❌ Error processing global event",
|
247 |
+
error=str(e),
|
248 |
+
traceback=traceback.format_exc()
|
249 |
+
)
|
250 |
+
|
251 |
+
log_info("🔄 Global event processor stopped")
|
252 |
+
|
253 |
+
def subscribe(self, event_type: EventType, handler: Callable):
|
254 |
+
"""Subscribe to an event type globally"""
|
255 |
+
self._subscribers[event_type].append(handler)
|
256 |
+
log_debug(f"📌 Global subscription added", event_type=event_type.value)
|
257 |
+
|
258 |
+
def subscribe_session(self, session_id: str, event_type: EventType, handler: Callable):
|
259 |
+
"""Subscribe to an event type for a specific session"""
|
260 |
+
self._session_handlers[session_id][event_type].append(handler)
|
261 |
+
log_debug(
|
262 |
+
f"📌 Session subscription added",
|
263 |
+
event_type=event_type.value,
|
264 |
+
session_id=session_id
|
265 |
+
)
|
266 |
+
|
267 |
+
def unsubscribe(self, event_type: EventType, handler: Callable):
|
268 |
+
"""Unsubscribe from an event type"""
|
269 |
+
if handler in self._subscribers[event_type]:
|
270 |
+
self._subscribers[event_type].remove(handler)
|
271 |
+
log_debug(f"📌 Global subscription removed", event_type=event_type.value)
|
272 |
+
|
273 |
+
def unsubscribe_session(self, session_id: str, event_type: EventType = None):
|
274 |
+
"""Unsubscribe session handlers"""
|
275 |
+
if event_type:
|
276 |
+
# Remove specific event type for session
|
277 |
+
if session_id in self._session_handlers and event_type in self._session_handlers[session_id]:
|
278 |
+
del self._session_handlers[session_id][event_type]
|
279 |
+
else:
|
280 |
+
# Remove all handlers for session
|
281 |
+
if session_id in self._session_handlers:
|
282 |
+
del self._session_handlers[session_id]
|
283 |
+
log_debug(f"📌 All session subscriptions removed", session_id=session_id)
|
284 |
+
|
285 |
+
|
286 |
+
async def _dispatch_event(self, event: Event):
|
287 |
+
"""Dispatch event to all subscribers"""
|
288 |
+
try:
|
289 |
+
handlers = []
|
290 |
+
|
291 |
+
# Get global handlers
|
292 |
+
if event.type in self._subscribers:
|
293 |
+
handlers.extend(self._subscribers[event.type])
|
294 |
+
|
295 |
+
# Get session-specific handlers
|
296 |
+
if event.session_id in self._session_handlers:
|
297 |
+
if event.type in self._session_handlers[event.session_id]:
|
298 |
+
handlers.extend(self._session_handlers[event.session_id][event.type])
|
299 |
+
|
300 |
+
if not handlers:
|
301 |
+
log_debug(
|
302 |
+
f"📭 No handlers for event",
|
303 |
+
event_type=event.type.value,
|
304 |
+
session_id=event.session_id
|
305 |
+
)
|
306 |
+
return
|
307 |
+
|
308 |
+
log_debug(
|
309 |
+
f"📨 Dispatching event to {len(handlers)} handlers",
|
310 |
+
event_type=event.type.value,
|
311 |
+
session_id=event.session_id
|
312 |
+
)
|
313 |
+
|
314 |
+
# Call all handlers concurrently
|
315 |
+
tasks = []
|
316 |
+
for handler in handlers:
|
317 |
+
if asyncio.iscoroutinefunction(handler):
|
318 |
+
task = asyncio.create_task(handler(event))
|
319 |
+
else:
|
320 |
+
# Wrap sync handler in async
|
321 |
+
task = asyncio.create_task(asyncio.to_thread(handler, event))
|
322 |
+
tasks.append(task)
|
323 |
+
|
324 |
+
# Wait for all handlers to complete
|
325 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
326 |
+
|
327 |
+
# Log any exceptions
|
328 |
+
for i, result in enumerate(results):
|
329 |
+
if isinstance(result, Exception):
|
330 |
+
log_error(
|
331 |
+
f"❌ Handler error",
|
332 |
+
handler=handlers[i].__name__,
|
333 |
+
event_type=event.type.value,
|
334 |
+
error=str(result),
|
335 |
+
traceback=traceback.format_exception(type(result), result, result.__traceback__)
|
336 |
+
)
|
337 |
+
|
338 |
+
except Exception as e:
|
339 |
+
log_error(
|
340 |
+
f"❌ Error dispatching event",
|
341 |
+
event_type=event.type.value,
|
342 |
+
error=str(e),
|
343 |
+
traceback=traceback.format_exc()
|
344 |
+
)
|
345 |
+
|
346 |
+
def get_event_history(self, session_id: Optional[str] = None, event_type: Optional[EventType] = None) -> List[Event]:
|
347 |
+
"""Get event history with optional filters"""
|
348 |
+
history = self._event_history
|
349 |
+
|
350 |
+
if session_id:
|
351 |
+
history = [e for e in history if e.session_id == session_id]
|
352 |
+
|
353 |
+
if event_type:
|
354 |
+
history = [e for e in history if e.type == event_type]
|
355 |
+
|
356 |
+
return history
|
357 |
+
|
358 |
+
def clear_session_data(self, session_id: str):
|
359 |
+
"""Clear all session-related data and stop processor"""
|
360 |
+
# Remove session handlers
|
361 |
+
self.unsubscribe_session(session_id)
|
362 |
+
|
363 |
+
# Stop session processor
|
364 |
+
if session_id in self._session_processors:
|
365 |
+
task = self._session_processors[session_id]
|
366 |
+
task.cancel()
|
367 |
+
|
368 |
+
# Clear queues
|
369 |
+
self._session_queues.pop(session_id, None)
|
370 |
+
self._session_processors.pop(session_id, None)
|
371 |
+
|
372 |
+
# Remove session events from history
|
373 |
+
self._event_history = [e for e in self._event_history if e.session_id != session_id]
|
374 |
+
|
375 |
+
log_debug(f"🧹 Session data cleared", session_id=session_id)
|
376 |
+
|
377 |
+
|
378 |
+
# Global event bus instance
|
379 |
+
event_bus = EventBus()
|
380 |
+
|
381 |
+
|
382 |
+
# Helper functions for common event publishing patterns
|
383 |
+
async def publish_error(session_id: str, error_type: str, error_message: str, details: Dict[str, Any] = None):
|
384 |
+
"""Helper to publish error events"""
|
385 |
+
event = Event(
|
386 |
+
type=EventType.RECOVERABLE_ERROR if error_type != "critical" else EventType.CRITICAL_ERROR,
|
387 |
+
session_id=session_id,
|
388 |
+
data={
|
389 |
+
"error_type": error_type,
|
390 |
+
"message": error_message,
|
391 |
+
"details": details or {}
|
392 |
+
},
|
393 |
+
priority=10 # High priority for errors
|
394 |
+
)
|
395 |
+
await event_bus.publish(event)
|
396 |
+
|
397 |
+
|
398 |
+
async def publish_state_transition(session_id: str, from_state: str, to_state: str, reason: str = None):
|
399 |
+
"""Helper to publish state transition events"""
|
400 |
+
event = Event(
|
401 |
+
type=EventType.STATE_TRANSITION,
|
402 |
+
session_id=session_id,
|
403 |
+
data={
|
404 |
+
"from_state": from_state,
|
405 |
+
"to_state": to_state,
|
406 |
+
"reason": reason
|
407 |
+
},
|
408 |
+
priority=5 # Medium priority for state changes
|
409 |
+
)
|
410 |
+
await event_bus.publish(event)
|
llm_manager.py
ADDED
@@ -0,0 +1,689 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LLM Manager for Flare
|
3 |
+
====================
|
4 |
+
Manages LLM interactions per session with stateless approach
|
5 |
+
"""
|
6 |
+
import asyncio
|
7 |
+
from typing import Dict, Optional, Any, List
|
8 |
+
from datetime import datetime
|
9 |
+
import traceback
|
10 |
+
from dataclasses import dataclass, field
|
11 |
+
import json
|
12 |
+
|
13 |
+
from event_bus import EventBus, Event, EventType, publish_error
|
14 |
+
from resource_manager import ResourceManager, ResourceType
|
15 |
+
from session import Session
|
16 |
+
from llm_factory import LLMFactory
|
17 |
+
from llm_interface import LLMInterface
|
18 |
+
from prompt_builder import build_intent_prompt, build_parameter_prompt
|
19 |
+
from logger import log_info, log_error, log_debug, log_warning
|
20 |
+
from config_provider import ConfigProvider
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class LLMJob:
|
25 |
+
"""LLM processing job"""
|
26 |
+
job_id: str
|
27 |
+
session_id: str
|
28 |
+
input_text: str
|
29 |
+
job_type: str # "intent_detection", "parameter_collection", "response_generation"
|
30 |
+
created_at: datetime = field(default_factory=datetime.utcnow)
|
31 |
+
completed_at: Optional[datetime] = None
|
32 |
+
response_text: Optional[str] = None
|
33 |
+
detected_intent: Optional[str] = None
|
34 |
+
error: Optional[str] = None
|
35 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
36 |
+
|
37 |
+
def complete(self, response_text: str, intent: Optional[str] = None):
|
38 |
+
"""Mark job as completed"""
|
39 |
+
self.response_text = response_text
|
40 |
+
self.detected_intent = intent
|
41 |
+
self.completed_at = datetime.utcnow()
|
42 |
+
|
43 |
+
def fail(self, error: str):
|
44 |
+
"""Mark job as failed"""
|
45 |
+
self.error = error
|
46 |
+
self.completed_at = datetime.utcnow()
|
47 |
+
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class LLMSession:
|
51 |
+
"""LLM session wrapper"""
|
52 |
+
session_id: str
|
53 |
+
session: Session
|
54 |
+
llm_instance: LLMInterface
|
55 |
+
active_job: Optional[LLMJob] = None
|
56 |
+
job_history: List[LLMJob] = field(default_factory=list)
|
57 |
+
created_at: datetime = field(default_factory=datetime.utcnow)
|
58 |
+
last_activity: datetime = field(default_factory=datetime.utcnow)
|
59 |
+
total_jobs = 0
|
60 |
+
total_tokens = 0
|
61 |
+
|
62 |
+
def update_activity(self):
|
63 |
+
"""Update last activity timestamp"""
|
64 |
+
self.last_activity = datetime.utcnow()
|
65 |
+
|
66 |
+
|
67 |
+
class LLMManager:
|
68 |
+
"""Manages LLM interactions with stateless approach"""
|
69 |
+
|
70 |
+
def __init__(self, event_bus: EventBus, resource_manager: ResourceManager):
|
71 |
+
self.event_bus = event_bus
|
72 |
+
self.resource_manager = resource_manager
|
73 |
+
self.llm_sessions: Dict[str, LLMSession] = {}
|
74 |
+
self.config = ConfigProvider.get()
|
75 |
+
self._setup_event_handlers()
|
76 |
+
self._setup_resource_pool()
|
77 |
+
|
78 |
+
def _setup_event_handlers(self):
|
79 |
+
"""Subscribe to LLM-related events"""
|
80 |
+
self.event_bus.subscribe(EventType.LLM_PROCESSING_STARTED, self._handle_llm_processing)
|
81 |
+
self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended)
|
82 |
+
|
83 |
+
def _setup_resource_pool(self):
|
84 |
+
"""Setup LLM instance pool"""
|
85 |
+
self.resource_manager.register_pool(
|
86 |
+
resource_type=ResourceType.LLM_CONTEXT,
|
87 |
+
factory=self._create_llm_instance,
|
88 |
+
max_idle=2, # Lower pool size for LLM
|
89 |
+
max_age_seconds=900 # 15 minutes
|
90 |
+
)
|
91 |
+
|
92 |
+
async def _create_llm_instance(self) -> LLMInterface:
|
93 |
+
"""Factory for creating LLM instances"""
|
94 |
+
try:
|
95 |
+
llm_instance = LLMFactory.create_provider()
|
96 |
+
if not llm_instance:
|
97 |
+
raise ValueError("Failed to create LLM instance")
|
98 |
+
|
99 |
+
log_debug("🤖 Created new LLM instance")
|
100 |
+
return llm_instance
|
101 |
+
|
102 |
+
except Exception as e:
|
103 |
+
log_error(f"❌ Failed to create LLM instance", error=str(e))
|
104 |
+
raise
|
105 |
+
|
106 |
+
async def _handle_llm_processing(self, event: Event):
|
107 |
+
"""Handle LLM processing request"""
|
108 |
+
session_id = event.session_id
|
109 |
+
input_text = event.data.get("text", "")
|
110 |
+
|
111 |
+
if not input_text:
|
112 |
+
log_warning(f"⚠️ Empty text for LLM", session_id=session_id)
|
113 |
+
return
|
114 |
+
|
115 |
+
try:
|
116 |
+
log_info(
|
117 |
+
f"🤖 Starting LLM processing",
|
118 |
+
session_id=session_id,
|
119 |
+
text_length=len(input_text)
|
120 |
+
)
|
121 |
+
|
122 |
+
# Get or create LLM session
|
123 |
+
llm_session = await self._get_or_create_session(session_id)
|
124 |
+
if not llm_session:
|
125 |
+
raise ValueError("Failed to create LLM session")
|
126 |
+
|
127 |
+
# Determine job type based on session state
|
128 |
+
job_type = self._determine_job_type(llm_session.session)
|
129 |
+
|
130 |
+
# Create job
|
131 |
+
job_id = f"{session_id}_{llm_session.total_jobs}"
|
132 |
+
job = LLMJob(
|
133 |
+
job_id=job_id,
|
134 |
+
session_id=session_id,
|
135 |
+
input_text=input_text,
|
136 |
+
job_type=job_type,
|
137 |
+
metadata={
|
138 |
+
"session_state": llm_session.session.state,
|
139 |
+
"current_intent": llm_session.session.current_intent
|
140 |
+
}
|
141 |
+
)
|
142 |
+
|
143 |
+
llm_session.active_job = job
|
144 |
+
llm_session.total_jobs += 1
|
145 |
+
llm_session.update_activity()
|
146 |
+
|
147 |
+
# Process based on job type
|
148 |
+
if job_type == "intent_detection":
|
149 |
+
await self._process_intent_detection(llm_session, job)
|
150 |
+
elif job_type == "parameter_collection":
|
151 |
+
await self._process_parameter_collection(llm_session, job)
|
152 |
+
else:
|
153 |
+
await self._process_response_generation(llm_session, job)
|
154 |
+
|
155 |
+
except Exception as e:
|
156 |
+
log_error(
|
157 |
+
f"❌ Failed to process LLM request",
|
158 |
+
session_id=session_id,
|
159 |
+
error=str(e),
|
160 |
+
traceback=traceback.format_exc()
|
161 |
+
)
|
162 |
+
|
163 |
+
# Publish error event
|
164 |
+
await publish_error(
|
165 |
+
session_id=session_id,
|
166 |
+
error_type="llm_error",
|
167 |
+
error_message=f"LLM processing failed: {str(e)}"
|
168 |
+
)
|
169 |
+
|
170 |
+
async def _get_or_create_session(self, session_id: str) -> Optional[LLMSession]:
|
171 |
+
"""Get or create LLM session"""
|
172 |
+
if session_id in self.llm_sessions:
|
173 |
+
return self.llm_sessions[session_id]
|
174 |
+
|
175 |
+
# Get session from store
|
176 |
+
from session import session_store
|
177 |
+
session = session_store.get_session(session_id)
|
178 |
+
if not session:
|
179 |
+
log_error(f"❌ Session not found", session_id=session_id)
|
180 |
+
return None
|
181 |
+
|
182 |
+
# Acquire LLM instance from pool
|
183 |
+
resource_id = f"llm_{session_id}"
|
184 |
+
llm_instance = await self.resource_manager.acquire(
|
185 |
+
resource_id=resource_id,
|
186 |
+
session_id=session_id,
|
187 |
+
resource_type=ResourceType.LLM_CONTEXT,
|
188 |
+
cleanup_callback=self._cleanup_llm_instance
|
189 |
+
)
|
190 |
+
|
191 |
+
# Create LLM session
|
192 |
+
llm_session = LLMSession(
|
193 |
+
session_id=session_id,
|
194 |
+
session=session,
|
195 |
+
llm_instance=llm_instance
|
196 |
+
)
|
197 |
+
|
198 |
+
self.llm_sessions[session_id] = llm_session
|
199 |
+
return llm_session
|
200 |
+
|
201 |
+
def _determine_job_type(self, session: Session) -> str:
|
202 |
+
"""Determine job type based on session state"""
|
203 |
+
if session.state == "idle":
|
204 |
+
return "intent_detection"
|
205 |
+
elif session.state == "collect_params":
|
206 |
+
return "parameter_collection"
|
207 |
+
else:
|
208 |
+
return "response_generation"
|
209 |
+
|
210 |
+
async def _process_intent_detection(self, llm_session: LLMSession, job: LLMJob):
|
211 |
+
"""Process intent detection"""
|
212 |
+
try:
|
213 |
+
session = llm_session.session
|
214 |
+
|
215 |
+
# Get project and version config
|
216 |
+
project = next((p for p in self.config.projects if p.name == session.project_name), None)
|
217 |
+
if not project:
|
218 |
+
raise ValueError(f"Project not found: {session.project_name}")
|
219 |
+
|
220 |
+
version = session.get_version_config()
|
221 |
+
if not version:
|
222 |
+
raise ValueError("Version config not found")
|
223 |
+
|
224 |
+
# Build intent detection prompt
|
225 |
+
prompt = build_intent_prompt(
|
226 |
+
version=version,
|
227 |
+
conversation=session.chat_history,
|
228 |
+
project_locale=project.default_locale
|
229 |
+
)
|
230 |
+
|
231 |
+
log_debug(
|
232 |
+
f"📝 Intent detection prompt built",
|
233 |
+
session_id=job.session_id,
|
234 |
+
prompt_length=len(prompt)
|
235 |
+
)
|
236 |
+
|
237 |
+
# Call LLM
|
238 |
+
response = await llm_session.llm_instance.generate(
|
239 |
+
system_prompt=prompt,
|
240 |
+
user_input=job.input_text,
|
241 |
+
context=session.chat_history[-10:] # Last 10 messages
|
242 |
+
)
|
243 |
+
|
244 |
+
# Parse intent
|
245 |
+
intent_name, response_text = self._parse_intent_response(response)
|
246 |
+
|
247 |
+
if intent_name:
|
248 |
+
# Find intent config
|
249 |
+
intent_config = next((i for i in version.intents if i.name == intent_name), None)
|
250 |
+
|
251 |
+
if intent_config:
|
252 |
+
# Update session
|
253 |
+
session.current_intent = intent_name
|
254 |
+
session.set_intent_config(intent_config)
|
255 |
+
session.state = "collect_params"
|
256 |
+
|
257 |
+
log_info(
|
258 |
+
f"🎯 Intent detected",
|
259 |
+
session_id=job.session_id,
|
260 |
+
intent=intent_name
|
261 |
+
)
|
262 |
+
|
263 |
+
# Check if we need to collect parameters
|
264 |
+
missing_params = [
|
265 |
+
p.name for p in intent_config.parameters
|
266 |
+
if p.required and p.variable_name not in session.variables
|
267 |
+
]
|
268 |
+
|
269 |
+
if not missing_params:
|
270 |
+
# All parameters ready, execute action
|
271 |
+
await self._execute_intent_action(llm_session, intent_config)
|
272 |
+
return
|
273 |
+
else:
|
274 |
+
# Need to collect parameters
|
275 |
+
await self._request_parameter_collection(llm_session, intent_config, missing_params)
|
276 |
+
return
|
277 |
+
|
278 |
+
# No intent detected, use response as is
|
279 |
+
response_text = self._clean_response(response)
|
280 |
+
job.complete(response_text, intent_name)
|
281 |
+
|
282 |
+
# Publish response
|
283 |
+
await self._publish_response(job)
|
284 |
+
|
285 |
+
except Exception as e:
|
286 |
+
job.fail(str(e))
|
287 |
+
raise
|
288 |
+
|
289 |
+
async def _process_parameter_collection(self, llm_session: LLMSession, job: LLMJob):
|
290 |
+
"""Process parameter collection"""
|
291 |
+
try:
|
292 |
+
session = llm_session.session
|
293 |
+
intent_config = session.get_intent_config()
|
294 |
+
|
295 |
+
if not intent_config:
|
296 |
+
raise ValueError("No intent config in session")
|
297 |
+
|
298 |
+
# Extract parameters from user input
|
299 |
+
extracted_params = await self._extract_parameters(
|
300 |
+
llm_session,
|
301 |
+
job.input_text,
|
302 |
+
intent_config,
|
303 |
+
session.variables
|
304 |
+
)
|
305 |
+
|
306 |
+
# Update session variables
|
307 |
+
for param_name, param_value in extracted_params.items():
|
308 |
+
param_config = next(
|
309 |
+
(p for p in intent_config.parameters if p.name == param_name),
|
310 |
+
None
|
311 |
+
)
|
312 |
+
if param_config:
|
313 |
+
session.variables[param_config.variable_name] = str(param_value)
|
314 |
+
|
315 |
+
# Check what parameters are still missing
|
316 |
+
missing_params = [
|
317 |
+
p.name for p in intent_config.parameters
|
318 |
+
if p.required and p.variable_name not in session.variables
|
319 |
+
]
|
320 |
+
|
321 |
+
if not missing_params:
|
322 |
+
# All parameters collected, execute action
|
323 |
+
await self._execute_intent_action(llm_session, intent_config)
|
324 |
+
else:
|
325 |
+
# Still need more parameters
|
326 |
+
await self._request_parameter_collection(llm_session, intent_config, missing_params)
|
327 |
+
|
328 |
+
except Exception as e:
|
329 |
+
job.fail(str(e))
|
330 |
+
raise
|
331 |
+
|
332 |
+
async def _process_response_generation(self, llm_session: LLMSession, job: LLMJob):
|
333 |
+
"""Process general response generation"""
|
334 |
+
try:
|
335 |
+
session = llm_session.session
|
336 |
+
|
337 |
+
# Get version config
|
338 |
+
version = session.get_version_config()
|
339 |
+
if not version:
|
340 |
+
raise ValueError("Version config not found")
|
341 |
+
|
342 |
+
# Use general prompt
|
343 |
+
prompt = version.general_prompt
|
344 |
+
|
345 |
+
# Generate response
|
346 |
+
response = await llm_session.llm_instance.generate(
|
347 |
+
system_prompt=prompt,
|
348 |
+
user_input=job.input_text,
|
349 |
+
context=session.chat_history[-10:]
|
350 |
+
)
|
351 |
+
|
352 |
+
response_text = self._clean_response(response)
|
353 |
+
job.complete(response_text)
|
354 |
+
|
355 |
+
# Publish response
|
356 |
+
await self._publish_response(job)
|
357 |
+
|
358 |
+
except Exception as e:
|
359 |
+
job.fail(str(e))
|
360 |
+
raise
|
361 |
+
|
362 |
+
async def _extract_parameters(self,
|
363 |
+
llm_session: LLMSession,
|
364 |
+
user_input: str,
|
365 |
+
intent_config: Any,
|
366 |
+
existing_params: Dict[str, str]) -> Dict[str, Any]:
|
367 |
+
"""Extract parameters from user input"""
|
368 |
+
# Build extraction prompt
|
369 |
+
param_info = []
|
370 |
+
for param in intent_config.parameters:
|
371 |
+
if param.variable_name not in existing_params:
|
372 |
+
param_info.append({
|
373 |
+
'name': param.name,
|
374 |
+
'type': param.type,
|
375 |
+
'required': param.required,
|
376 |
+
'extraction_prompt': param.extraction_prompt
|
377 |
+
})
|
378 |
+
|
379 |
+
prompt = f"""
|
380 |
+
Extract parameters from user message: "{user_input}"
|
381 |
+
|
382 |
+
Expected parameters:
|
383 |
+
{json.dumps(param_info, ensure_ascii=False)}
|
384 |
+
|
385 |
+
Return as JSON object with parameter names as keys.
|
386 |
+
"""
|
387 |
+
|
388 |
+
# Call LLM
|
389 |
+
response = await llm_session.llm_instance.generate(
|
390 |
+
system_prompt=prompt,
|
391 |
+
user_input=user_input,
|
392 |
+
context=[]
|
393 |
+
)
|
394 |
+
|
395 |
+
# Parse JSON response
|
396 |
+
try:
|
397 |
+
# Look for JSON block in response
|
398 |
+
import re
|
399 |
+
json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
|
400 |
+
if not json_match:
|
401 |
+
json_match = re.search(r'\{[^}]+\}', response)
|
402 |
+
|
403 |
+
if json_match:
|
404 |
+
json_str = json_match.group(1) if '```' in response else json_match.group(0)
|
405 |
+
return json.loads(json_str)
|
406 |
+
except:
|
407 |
+
pass
|
408 |
+
|
409 |
+
return {}
|
410 |
+
|
411 |
+
async def _request_parameter_collection(self,
|
412 |
+
llm_session: LLMSession,
|
413 |
+
intent_config: Any,
|
414 |
+
missing_params: List[str]):
|
415 |
+
"""Request parameter collection from user"""
|
416 |
+
session = llm_session.session
|
417 |
+
|
418 |
+
# Get project config
|
419 |
+
project = next((p for p in self.config.projects if p.name == session.project_name), None)
|
420 |
+
if not project:
|
421 |
+
return
|
422 |
+
|
423 |
+
version = session.get_version_config()
|
424 |
+
if not version:
|
425 |
+
return
|
426 |
+
|
427 |
+
# Get parameter collection config
|
428 |
+
collection_config = self.config.global_config.llm_provider.settings.get("parameter_collection_config", {})
|
429 |
+
max_params = collection_config.get("max_params_per_question", 2)
|
430 |
+
|
431 |
+
# Decide which parameters to ask
|
432 |
+
params_to_ask = missing_params[:max_params]
|
433 |
+
|
434 |
+
# Build parameter collection prompt
|
435 |
+
prompt = build_parameter_prompt(
|
436 |
+
version=version,
|
437 |
+
intent_config=intent_config,
|
438 |
+
chat_history=session.chat_history,
|
439 |
+
collected_params=session.variables,
|
440 |
+
missing_params=missing_params,
|
441 |
+
params_to_ask=params_to_ask,
|
442 |
+
max_params=max_params,
|
443 |
+
project_locale=project.default_locale,
|
444 |
+
unanswered_params=session.unanswered_parameters
|
445 |
+
)
|
446 |
+
|
447 |
+
# Generate question
|
448 |
+
response = await llm_session.llm_instance.generate(
|
449 |
+
system_prompt=prompt,
|
450 |
+
user_input="",
|
451 |
+
context=session.chat_history[-5:]
|
452 |
+
)
|
453 |
+
|
454 |
+
response_text = self._clean_response(response)
|
455 |
+
|
456 |
+
# Create a job for the response
|
457 |
+
job = LLMJob(
|
458 |
+
job_id=f"{session.session_id}_param_request",
|
459 |
+
session_id=session.session_id,
|
460 |
+
input_text="",
|
461 |
+
job_type="parameter_request",
|
462 |
+
response_text=response_text
|
463 |
+
)
|
464 |
+
|
465 |
+
await self._publish_response(job)
|
466 |
+
|
467 |
+
async def _execute_intent_action(self, llm_session: LLMSession, intent_config: Any):
|
468 |
+
"""Execute intent action (API call)"""
|
469 |
+
session = llm_session.session
|
470 |
+
|
471 |
+
try:
|
472 |
+
# Get API config
|
473 |
+
api_name = intent_config.action
|
474 |
+
api_config = self.config.get_api(api_name)
|
475 |
+
|
476 |
+
if not api_config:
|
477 |
+
raise ValueError(f"API config not found: {api_name}")
|
478 |
+
|
479 |
+
log_info(
|
480 |
+
f"📡 Executing intent action",
|
481 |
+
session_id=session.session_id,
|
482 |
+
api_name=api_name,
|
483 |
+
variables=session.variables
|
484 |
+
)
|
485 |
+
|
486 |
+
# Execute API call
|
487 |
+
from api_executor import call_api
|
488 |
+
response = call_api(api_config, session)
|
489 |
+
api_json = response.json()
|
490 |
+
|
491 |
+
log_info(f"✅ API response received", session_id=session.session_id)
|
492 |
+
|
493 |
+
# Humanize response if prompt exists
|
494 |
+
if api_config.response_prompt:
|
495 |
+
prompt = api_config.response_prompt.replace(
|
496 |
+
"{{api_response}}",
|
497 |
+
json.dumps(api_json, ensure_ascii=False)
|
498 |
+
)
|
499 |
+
|
500 |
+
human_response = await llm_session.llm_instance.generate(
|
501 |
+
system_prompt=prompt,
|
502 |
+
user_input=json.dumps(api_json),
|
503 |
+
context=[]
|
504 |
+
)
|
505 |
+
|
506 |
+
response_text = self._clean_response(human_response)
|
507 |
+
else:
|
508 |
+
response_text = f"İşlem tamamlandı: {api_json}"
|
509 |
+
|
510 |
+
# Reset session flow
|
511 |
+
session.reset_flow()
|
512 |
+
|
513 |
+
# Create job for response
|
514 |
+
job = LLMJob(
|
515 |
+
job_id=f"{session.session_id}_action_result",
|
516 |
+
session_id=session.session_id,
|
517 |
+
input_text="",
|
518 |
+
job_type="action_result",
|
519 |
+
response_text=response_text
|
520 |
+
)
|
521 |
+
|
522 |
+
await self._publish_response(job)
|
523 |
+
|
524 |
+
except Exception as e:
|
525 |
+
log_error(
|
526 |
+
f"❌ API execution failed",
|
527 |
+
session_id=session.session_id,
|
528 |
+
error=str(e)
|
529 |
+
)
|
530 |
+
|
531 |
+
# Reset flow
|
532 |
+
session.reset_flow()
|
533 |
+
|
534 |
+
# Send error response
|
535 |
+
error_response = self._get_user_friendly_error("api_error", {"api_name": api_name})
|
536 |
+
|
537 |
+
job = LLMJob(
|
538 |
+
job_id=f"{session.session_id}_error",
|
539 |
+
session_id=session.session_id,
|
540 |
+
input_text="",
|
541 |
+
job_type="error",
|
542 |
+
response_text=error_response
|
543 |
+
)
|
544 |
+
|
545 |
+
await self._publish_response(job)
|
546 |
+
|
547 |
+
async def _publish_response(self, job: LLMJob):
|
548 |
+
"""Publish LLM response"""
|
549 |
+
# Update job history
|
550 |
+
llm_session = self.llm_sessions.get(job.session_id)
|
551 |
+
if llm_session:
|
552 |
+
llm_session.job_history.append(job)
|
553 |
+
# Keep only last 20 jobs
|
554 |
+
if len(llm_session.job_history) > 20:
|
555 |
+
llm_session.job_history.pop(0)
|
556 |
+
|
557 |
+
# Publish event
|
558 |
+
await self.event_bus.publish(Event(
|
559 |
+
type=EventType.LLM_RESPONSE_READY,
|
560 |
+
session_id=job.session_id,
|
561 |
+
data={
|
562 |
+
"text": job.response_text,
|
563 |
+
"intent": job.detected_intent,
|
564 |
+
"job_type": job.job_type
|
565 |
+
}
|
566 |
+
))
|
567 |
+
|
568 |
+
log_info(
|
569 |
+
f"✅ LLM response published",
|
570 |
+
session_id=job.session_id,
|
571 |
+
response_length=len(job.response_text) if job.response_text else 0
|
572 |
+
)
|
573 |
+
|
574 |
+
def _parse_intent_response(self, response: str) -> tuple[str, str]:
|
575 |
+
"""Parse intent from LLM response"""
|
576 |
+
import re
|
577 |
+
|
578 |
+
# Look for intent pattern
|
579 |
+
match = re.search(r"#DETECTED_INTENT:\s*([A-Za-z0-9_-]+)", response)
|
580 |
+
if not match:
|
581 |
+
return "", response
|
582 |
+
|
583 |
+
intent_name = match.group(1)
|
584 |
+
|
585 |
+
# Remove 'assistant' suffix if exists
|
586 |
+
if intent_name.endswith("assistant"):
|
587 |
+
intent_name = intent_name[:-9]
|
588 |
+
|
589 |
+
# Get remaining text after intent
|
590 |
+
remaining_text = response[match.end():]
|
591 |
+
|
592 |
+
return intent_name, remaining_text
|
593 |
+
|
594 |
+
def _clean_response(self, response: str) -> str:
|
595 |
+
"""Clean LLM response"""
|
596 |
+
# Remove everything after the first logical assistant block or intent tag
|
597 |
+
for stop in ["#DETECTED_INTENT", "⚠️", "\nassistant", "assistant\n", "assistant"]:
|
598 |
+
idx = response.find(stop)
|
599 |
+
if idx != -1:
|
600 |
+
response = response[:idx]
|
601 |
+
|
602 |
+
# Normalize common greetings
|
603 |
+
import re
|
604 |
+
response = re.sub(r"Hoş[\s-]?geldin(iz)?", "Hoş geldiniz", response, flags=re.IGNORECASE)
|
605 |
+
|
606 |
+
return response.strip()
|
607 |
+
|
608 |
+
def _get_user_friendly_error(self, error_type: str, context: dict = None) -> str:
|
609 |
+
"""Get user-friendly error messages"""
|
610 |
+
error_messages = {
|
611 |
+
"session_not_found": "Oturumunuz bulunamadı. Lütfen yeni bir konuşma başlatın.",
|
612 |
+
"project_not_found": "Proje konfigürasyonu bulunamadı. Lütfen yönetici ile iletişime geçin.",
|
613 |
+
"version_not_found": "Proje versiyonu bulunamadı. Lütfen geçerli bir versiyon seçin.",
|
614 |
+
"intent_not_found": "Üzgünüm, ne yapmak istediğinizi anlayamadım. Lütfen daha açık bir şekilde belirtir misiniz?",
|
615 |
+
"api_timeout": "İşlem zaman aşımına uğradı. Lütfen tekrar deneyin.",
|
616 |
+
"api_error": "İşlem sırasında bir hata oluştu. Lütfen daha sonra tekrar deneyin.",
|
617 |
+
"parameter_validation": "Girdiğiniz bilgide bir hata var. Lütfen kontrol edip tekrar deneyin.",
|
618 |
+
"llm_error": "Sistem yanıt veremedi. Lütfen biraz sonra tekrar deneyin.",
|
619 |
+
"llm_timeout": "Sistem meşgul. Lütfen birkaç saniye bekleyip tekrar deneyin.",
|
620 |
+
"session_expired": "Oturumunuz zaman aşımına uğradı. Lütfen yeni bir konuşma başlatın.",
|
621 |
+
"rate_limit": "Çok fazla istek gönderdiniz. Lütfen biraz bekleyin.",
|
622 |
+
"internal_error": "Beklenmeyen bir hata oluştu. Lütfen yönetici ile iletişime geçin."
|
623 |
+
}
|
624 |
+
|
625 |
+
message = error_messages.get(error_type, error_messages["internal_error"])
|
626 |
+
|
627 |
+
# Add context if available
|
628 |
+
if context:
|
629 |
+
if error_type == "api_error" and "api_name" in context:
|
630 |
+
message = f"{context['api_name']} servisi için {message}"
|
631 |
+
|
632 |
+
return message
|
633 |
+
|
634 |
+
async def _handle_session_ended(self, event: Event):
|
635 |
+
"""Clean up LLM resources when session ends"""
|
636 |
+
session_id = event.session_id
|
637 |
+
await self._cleanup_session(session_id)
|
638 |
+
|
639 |
+
async def _cleanup_session(self, session_id: str):
|
640 |
+
"""Clean up LLM session"""
|
641 |
+
llm_session = self.llm_sessions.pop(session_id, None)
|
642 |
+
if not llm_session:
|
643 |
+
return
|
644 |
+
|
645 |
+
try:
|
646 |
+
# Release resource
|
647 |
+
resource_id = f"llm_{session_id}"
|
648 |
+
await self.resource_manager.release(resource_id, delay_seconds=180) # 3 minutes
|
649 |
+
|
650 |
+
log_info(
|
651 |
+
f"🧹 LLM session cleaned up",
|
652 |
+
session_id=session_id,
|
653 |
+
total_jobs=llm_session.total_jobs,
|
654 |
+
job_history_size=len(llm_session.job_history)
|
655 |
+
)
|
656 |
+
|
657 |
+
except Exception as e:
|
658 |
+
log_error(
|
659 |
+
f"❌ Error cleaning up LLM session",
|
660 |
+
session_id=session_id,
|
661 |
+
error=str(e)
|
662 |
+
)
|
663 |
+
|
664 |
+
async def _cleanup_llm_instance(self, llm_instance: LLMInterface):
|
665 |
+
"""Cleanup callback for LLM instance"""
|
666 |
+
try:
|
667 |
+
# LLM instances typically don't need special cleanup
|
668 |
+
log_debug("🧹 LLM instance cleaned up")
|
669 |
+
|
670 |
+
except Exception as e:
|
671 |
+
log_error(f"❌ Error cleaning up LLM instance", error=str(e))
|
672 |
+
|
673 |
+
def get_stats(self) -> Dict[str, Any]:
|
674 |
+
"""Get LLM manager statistics"""
|
675 |
+
session_stats = {}
|
676 |
+
for session_id, llm_session in self.llm_sessions.items():
|
677 |
+
session_stats[session_id] = {
|
678 |
+
"active_job": llm_session.active_job.job_id if llm_session.active_job else None,
|
679 |
+
"total_jobs": llm_session.total_jobs,
|
680 |
+
"job_history_size": len(llm_session.job_history),
|
681 |
+
"uptime_seconds": (datetime.utcnow() - llm_session.created_at).total_seconds(),
|
682 |
+
"last_activity": llm_session.last_activity.isoformat()
|
683 |
+
}
|
684 |
+
|
685 |
+
return {
|
686 |
+
"active_sessions": len(self.llm_sessions),
|
687 |
+
"total_active_jobs": sum(1 for s in self.llm_sessions.values() if s.active_job),
|
688 |
+
"sessions": session_stats
|
689 |
+
}
|
resource_manager.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Resource Manager for Flare
|
3 |
+
==========================
|
4 |
+
Manages lifecycle of all session resources
|
5 |
+
"""
|
6 |
+
import asyncio
|
7 |
+
from typing import Dict, Any, Optional, Callable, Set
|
8 |
+
from datetime import datetime, timedelta
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
import traceback
|
11 |
+
from enum import Enum
|
12 |
+
|
13 |
+
from event_bus import EventBus, Event, EventType
|
14 |
+
from logger import log_info, log_error, log_debug, log_warning
|
15 |
+
|
16 |
+
|
17 |
+
class ResourceType(Enum):
|
18 |
+
"""Types of resources managed"""
|
19 |
+
STT_INSTANCE = "stt_instance"
|
20 |
+
TTS_INSTANCE = "tts_instance"
|
21 |
+
LLM_CONTEXT = "llm_context"
|
22 |
+
AUDIO_BUFFER = "audio_buffer"
|
23 |
+
WEBSOCKET = "websocket"
|
24 |
+
GENERIC = "generic"
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class Resource:
|
29 |
+
"""Resource wrapper with metadata"""
|
30 |
+
resource_id: str
|
31 |
+
resource_type: ResourceType
|
32 |
+
session_id: str
|
33 |
+
instance: Any
|
34 |
+
created_at: datetime = field(default_factory=datetime.utcnow)
|
35 |
+
last_accessed: datetime = field(default_factory=datetime.utcnow)
|
36 |
+
disposal_task: Optional[asyncio.Task] = None
|
37 |
+
cleanup_callback: Optional[Callable] = None
|
38 |
+
|
39 |
+
def touch(self):
|
40 |
+
"""Update last accessed time"""
|
41 |
+
self.last_accessed = datetime.utcnow()
|
42 |
+
|
43 |
+
async def cleanup(self):
|
44 |
+
"""Cleanup the resource"""
|
45 |
+
try:
|
46 |
+
if self.cleanup_callback:
|
47 |
+
if asyncio.iscoroutinefunction(self.cleanup_callback):
|
48 |
+
await self.cleanup_callback(self.instance)
|
49 |
+
else:
|
50 |
+
await asyncio.to_thread(self.cleanup_callback, self.instance)
|
51 |
+
|
52 |
+
log_debug(
|
53 |
+
f"🧹 Resource cleaned up",
|
54 |
+
resource_id=self.resource_id,
|
55 |
+
resource_type=self.resource_type.value
|
56 |
+
)
|
57 |
+
except Exception as e:
|
58 |
+
log_error(
|
59 |
+
f"❌ Error cleaning up resource",
|
60 |
+
resource_id=self.resource_id,
|
61 |
+
error=str(e)
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
class ResourcePool:
|
66 |
+
"""Pool for reusable resources"""
|
67 |
+
|
68 |
+
def __init__(self,
|
69 |
+
resource_type: ResourceType,
|
70 |
+
factory: Callable,
|
71 |
+
max_idle: int = 10,
|
72 |
+
max_age_seconds: int = 300):
|
73 |
+
self.resource_type = resource_type
|
74 |
+
self.factory = factory
|
75 |
+
self.max_idle = max_idle
|
76 |
+
self.max_age_seconds = max_age_seconds
|
77 |
+
self.idle_resources: List[Resource] = []
|
78 |
+
self.lock = asyncio.Lock()
|
79 |
+
|
80 |
+
async def acquire(self, session_id: str) -> Any:
|
81 |
+
"""Get resource from pool or create new"""
|
82 |
+
async with self.lock:
|
83 |
+
# Try to get from pool
|
84 |
+
now = datetime.utcnow()
|
85 |
+
while self.idle_resources:
|
86 |
+
resource = self.idle_resources.pop(0)
|
87 |
+
age = (now - resource.created_at).total_seconds()
|
88 |
+
|
89 |
+
if age < self.max_age_seconds:
|
90 |
+
# Reuse this resource
|
91 |
+
resource.session_id = session_id
|
92 |
+
resource.touch()
|
93 |
+
log_debug(
|
94 |
+
f"♻️ Reused pooled resource",
|
95 |
+
resource_type=self.resource_type.value,
|
96 |
+
age_seconds=age
|
97 |
+
)
|
98 |
+
return resource.instance
|
99 |
+
else:
|
100 |
+
# Too old, cleanup
|
101 |
+
await resource.cleanup()
|
102 |
+
|
103 |
+
# Create new resource
|
104 |
+
if asyncio.iscoroutinefunction(self.factory):
|
105 |
+
instance = await self.factory()
|
106 |
+
else:
|
107 |
+
instance = await asyncio.to_thread(self.factory)
|
108 |
+
|
109 |
+
log_debug(
|
110 |
+
f"🏗️ Created new resource",
|
111 |
+
resource_type=self.resource_type.value
|
112 |
+
)
|
113 |
+
return instance
|
114 |
+
|
115 |
+
async def release(self, resource: Resource):
|
116 |
+
"""Return resource to pool"""
|
117 |
+
async with self.lock:
|
118 |
+
if len(self.idle_resources) < self.max_idle:
|
119 |
+
resource.session_id = "" # Clear session
|
120 |
+
self.idle_resources.append(resource)
|
121 |
+
log_debug(
|
122 |
+
f"📥 Resource returned to pool",
|
123 |
+
resource_type=self.resource_type.value,
|
124 |
+
pool_size=len(self.idle_resources)
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
# Pool full, cleanup
|
128 |
+
await resource.cleanup()
|
129 |
+
|
130 |
+
async def cleanup_old(self):
|
131 |
+
"""Cleanup old resources in pool"""
|
132 |
+
async with self.lock:
|
133 |
+
now = datetime.utcnow()
|
134 |
+
active_resources = []
|
135 |
+
|
136 |
+
for resource in self.idle_resources:
|
137 |
+
age = (now - resource.created_at).total_seconds()
|
138 |
+
if age < self.max_age_seconds:
|
139 |
+
active_resources.append(resource)
|
140 |
+
else:
|
141 |
+
await resource.cleanup()
|
142 |
+
|
143 |
+
self.idle_resources = active_resources
|
144 |
+
|
145 |
+
|
146 |
+
class ResourceManager:
|
147 |
+
"""Manages all resources with lifecycle and pooling"""
|
148 |
+
|
149 |
+
def __init__(self, event_bus: EventBus):
|
150 |
+
self.event_bus = event_bus
|
151 |
+
self.resources: Dict[str, Resource] = {}
|
152 |
+
self.session_resources: Dict[str, Set[str]] = {}
|
153 |
+
self.pools: Dict[ResourceType, ResourcePool] = {}
|
154 |
+
self.disposal_delay_seconds = 60 # Default disposal delay
|
155 |
+
self._cleanup_task: Optional[asyncio.Task] = None
|
156 |
+
self._running = False
|
157 |
+
self._setup_event_handlers()
|
158 |
+
|
159 |
+
def _setup_event_handlers(self):
|
160 |
+
"""Subscribe to lifecycle events"""
|
161 |
+
self.event_bus.subscribe(EventType.SESSION_STARTED, self._handle_session_started)
|
162 |
+
self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended)
|
163 |
+
|
164 |
+
async def start(self):
|
165 |
+
"""Start resource manager"""
|
166 |
+
if self._running:
|
167 |
+
return
|
168 |
+
|
169 |
+
self._running = True
|
170 |
+
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
171 |
+
log_info("✅ Resource manager started")
|
172 |
+
|
173 |
+
async def stop(self):
|
174 |
+
"""Stop resource manager"""
|
175 |
+
self._running = False
|
176 |
+
|
177 |
+
if self._cleanup_task:
|
178 |
+
self._cleanup_task.cancel()
|
179 |
+
try:
|
180 |
+
await self._cleanup_task
|
181 |
+
except asyncio.CancelledError:
|
182 |
+
pass
|
183 |
+
|
184 |
+
# Cleanup all resources
|
185 |
+
for resource_id in list(self.resources.keys()):
|
186 |
+
await self.release(resource_id, immediate=True)
|
187 |
+
|
188 |
+
log_info("✅ Resource manager stopped")
|
189 |
+
|
190 |
+
def register_pool(self,
|
191 |
+
resource_type: ResourceType,
|
192 |
+
factory: Callable,
|
193 |
+
max_idle: int = 10,
|
194 |
+
max_age_seconds: int = 300):
|
195 |
+
"""Register a resource pool"""
|
196 |
+
self.pools[resource_type] = ResourcePool(
|
197 |
+
resource_type=resource_type,
|
198 |
+
factory=factory,
|
199 |
+
max_idle=max_idle,
|
200 |
+
max_age_seconds=max_age_seconds
|
201 |
+
)
|
202 |
+
log_info(
|
203 |
+
f"📊 Resource pool registered",
|
204 |
+
resource_type=resource_type.value,
|
205 |
+
max_idle=max_idle
|
206 |
+
)
|
207 |
+
|
208 |
+
async def acquire(self,
|
209 |
+
resource_id: str,
|
210 |
+
session_id: str,
|
211 |
+
resource_type: ResourceType,
|
212 |
+
factory: Optional[Callable] = None,
|
213 |
+
cleanup_callback: Optional[Callable] = None) -> Any:
|
214 |
+
"""Acquire a resource"""
|
215 |
+
|
216 |
+
# Check if already exists
|
217 |
+
if resource_id in self.resources:
|
218 |
+
resource = self.resources[resource_id]
|
219 |
+
resource.touch()
|
220 |
+
|
221 |
+
# Cancel any pending disposal
|
222 |
+
if resource.disposal_task:
|
223 |
+
resource.disposal_task.cancel()
|
224 |
+
resource.disposal_task = None
|
225 |
+
|
226 |
+
return resource.instance
|
227 |
+
|
228 |
+
# Try to get from pool
|
229 |
+
instance = None
|
230 |
+
if resource_type in self.pools:
|
231 |
+
instance = await self.pools[resource_type].acquire(session_id)
|
232 |
+
elif factory:
|
233 |
+
# Create new resource
|
234 |
+
if asyncio.iscoroutinefunction(factory):
|
235 |
+
instance = await factory()
|
236 |
+
else:
|
237 |
+
instance = await asyncio.to_thread(factory)
|
238 |
+
else:
|
239 |
+
raise ValueError(f"No factory or pool for resource type: {resource_type}")
|
240 |
+
|
241 |
+
# Create resource wrapper
|
242 |
+
resource = Resource(
|
243 |
+
resource_id=resource_id,
|
244 |
+
resource_type=resource_type,
|
245 |
+
session_id=session_id,
|
246 |
+
instance=instance,
|
247 |
+
cleanup_callback=cleanup_callback
|
248 |
+
)
|
249 |
+
|
250 |
+
# Track resource
|
251 |
+
self.resources[resource_id] = resource
|
252 |
+
|
253 |
+
if session_id not in self.session_resources:
|
254 |
+
self.session_resources[session_id] = set()
|
255 |
+
self.session_resources[session_id].add(resource_id)
|
256 |
+
|
257 |
+
log_info(
|
258 |
+
f"📌 Resource acquired",
|
259 |
+
resource_id=resource_id,
|
260 |
+
resource_type=resource_type.value,
|
261 |
+
session_id=session_id
|
262 |
+
)
|
263 |
+
|
264 |
+
return instance
|
265 |
+
|
266 |
+
async def release(self,
|
267 |
+
resource_id: str,
|
268 |
+
delay_seconds: Optional[int] = None,
|
269 |
+
immediate: bool = False):
|
270 |
+
"""Release a resource with optional delay"""
|
271 |
+
|
272 |
+
resource = self.resources.get(resource_id)
|
273 |
+
if not resource:
|
274 |
+
return
|
275 |
+
|
276 |
+
if immediate:
|
277 |
+
# Immediate cleanup
|
278 |
+
await self._dispose_resource(resource_id)
|
279 |
+
else:
|
280 |
+
# Schedule disposal
|
281 |
+
delay = delay_seconds or self.disposal_delay_seconds
|
282 |
+
resource.disposal_task = asyncio.create_task(
|
283 |
+
self._delayed_disposal(resource_id, delay)
|
284 |
+
)
|
285 |
+
|
286 |
+
log_debug(
|
287 |
+
f"⏱️ Resource disposal scheduled",
|
288 |
+
resource_id=resource_id,
|
289 |
+
delay_seconds=delay
|
290 |
+
)
|
291 |
+
|
292 |
+
async def _delayed_disposal(self, resource_id: str, delay_seconds: int):
|
293 |
+
"""Dispose resource after delay"""
|
294 |
+
try:
|
295 |
+
await asyncio.sleep(delay_seconds)
|
296 |
+
await self._dispose_resource(resource_id)
|
297 |
+
except asyncio.CancelledError:
|
298 |
+
log_debug(f"🚫 Disposal cancelled", resource_id=resource_id)
|
299 |
+
|
300 |
+
async def _dispose_resource(self, resource_id: str):
|
301 |
+
"""Actually dispose of a resource"""
|
302 |
+
resource = self.resources.pop(resource_id, None)
|
303 |
+
if not resource:
|
304 |
+
return
|
305 |
+
|
306 |
+
# Remove from session tracking
|
307 |
+
if resource.session_id in self.session_resources:
|
308 |
+
self.session_resources[resource.session_id].discard(resource_id)
|
309 |
+
|
310 |
+
# Return to pool or cleanup
|
311 |
+
if resource.resource_type in self.pools:
|
312 |
+
await self.pools[resource.resource_type].release(resource)
|
313 |
+
else:
|
314 |
+
await resource.cleanup()
|
315 |
+
|
316 |
+
log_info(
|
317 |
+
f"♻️ Resource disposed",
|
318 |
+
resource_id=resource_id,
|
319 |
+
resource_type=resource.resource_type.value
|
320 |
+
)
|
321 |
+
|
322 |
+
async def release_session_resources(self, session_id: str):
|
323 |
+
"""Release all resources for a session"""
|
324 |
+
resource_ids = self.session_resources.get(session_id, set()).copy()
|
325 |
+
|
326 |
+
for resource_id in resource_ids:
|
327 |
+
await self.release(resource_id, immediate=True)
|
328 |
+
|
329 |
+
# Remove session tracking
|
330 |
+
self.session_resources.pop(session_id, None)
|
331 |
+
|
332 |
+
log_info(
|
333 |
+
f"🧹 Session resources released",
|
334 |
+
session_id=session_id,
|
335 |
+
count=len(resource_ids)
|
336 |
+
)
|
337 |
+
|
338 |
+
async def _handle_session_started(self, event: Event):
|
339 |
+
"""Initialize session resource tracking"""
|
340 |
+
session_id = event.session_id
|
341 |
+
self.session_resources[session_id] = set()
|
342 |
+
|
343 |
+
async def _handle_session_ended(self, event: Event):
|
344 |
+
"""Cleanup session resources"""
|
345 |
+
session_id = event.session_id
|
346 |
+
await self.release_session_resources(session_id)
|
347 |
+
|
348 |
+
async def _periodic_cleanup(self):
|
349 |
+
"""Periodic cleanup of old resources"""
|
350 |
+
while self._running:
|
351 |
+
try:
|
352 |
+
await asyncio.sleep(60) # Check every minute
|
353 |
+
|
354 |
+
# Cleanup old pooled resources
|
355 |
+
for pool in self.pools.values():
|
356 |
+
await pool.cleanup_old()
|
357 |
+
|
358 |
+
# Check for orphaned resources
|
359 |
+
now = datetime.utcnow()
|
360 |
+
for resource_id, resource in list(self.resources.items()):
|
361 |
+
age = (now - resource.last_accessed).total_seconds()
|
362 |
+
|
363 |
+
# If not accessed for 5 minutes and no disposal scheduled
|
364 |
+
if age > 300 and not resource.disposal_task:
|
365 |
+
log_warning(
|
366 |
+
f"⚠️ Orphaned resource detected",
|
367 |
+
resource_id=resource_id,
|
368 |
+
age_seconds=age
|
369 |
+
)
|
370 |
+
await self.release(resource_id, delay_seconds=30)
|
371 |
+
|
372 |
+
except Exception as e:
|
373 |
+
log_error(
|
374 |
+
f"❌ Error in periodic cleanup",
|
375 |
+
error=str(e),
|
376 |
+
traceback=traceback.format_exc()
|
377 |
+
)
|
378 |
+
|
379 |
+
def get_stats(self) -> Dict[str, Any]:
|
380 |
+
"""Get resource manager statistics"""
|
381 |
+
pool_stats = {}
|
382 |
+
for resource_type, pool in self.pools.items():
|
383 |
+
pool_stats[resource_type.value] = {
|
384 |
+
"idle_count": len(pool.idle_resources),
|
385 |
+
"max_idle": pool.max_idle
|
386 |
+
}
|
387 |
+
|
388 |
+
return {
|
389 |
+
"active_resources": len(self.resources),
|
390 |
+
"sessions": len(self.session_resources),
|
391 |
+
"pools": pool_stats,
|
392 |
+
"total_resources_by_type": self._count_by_type()
|
393 |
+
}
|
394 |
+
|
395 |
+
def _count_by_type(self) -> Dict[str, int]:
|
396 |
+
"""Count resources by type"""
|
397 |
+
counts = {}
|
398 |
+
for resource in self.resources.values():
|
399 |
+
type_name = resource.resource_type.value
|
400 |
+
counts[type_name] = counts.get(type_name, 0) + 1
|
401 |
+
return counts
|
state_orchestrator.py
ADDED
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
State Orchestrator for Flare Realtime Chat
|
3 |
+
==========================================
|
4 |
+
Central state machine and flow control
|
5 |
+
"""
|
6 |
+
import asyncio
|
7 |
+
from typing import Dict, Optional, Set, Any
|
8 |
+
from enum import Enum
|
9 |
+
from datetime import datetime
|
10 |
+
import traceback
|
11 |
+
from dataclasses import dataclass, field
|
12 |
+
|
13 |
+
from event_bus import EventBus, Event, EventType, publish_state_transition, publish_error
|
14 |
+
from session import Session
|
15 |
+
from logger import log_info, log_error, log_debug, log_warning
|
16 |
+
|
17 |
+
|
18 |
+
class ConversationState(Enum):
|
19 |
+
"""Conversation states"""
|
20 |
+
IDLE = "idle"
|
21 |
+
INITIALIZING = "initializing"
|
22 |
+
PREPARING_WELCOME = "preparing_welcome"
|
23 |
+
PLAYING_WELCOME = "playing_welcome"
|
24 |
+
LISTENING = "listening"
|
25 |
+
PROCESSING_SPEECH = "processing_speech"
|
26 |
+
PREPARING_RESPONSE = "preparing_response"
|
27 |
+
PLAYING_RESPONSE = "playing_response"
|
28 |
+
ERROR = "error"
|
29 |
+
ENDED = "ended"
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class SessionContext:
|
34 |
+
"""Context for a conversation session"""
|
35 |
+
session_id: str
|
36 |
+
session: Session
|
37 |
+
state: ConversationState = ConversationState.IDLE
|
38 |
+
stt_instance: Optional[Any] = None
|
39 |
+
tts_instance: Optional[Any] = None
|
40 |
+
llm_context: Optional[Any] = None
|
41 |
+
audio_buffer: Optional[Any] = None
|
42 |
+
websocket_connection: Optional[Any] = None
|
43 |
+
created_at: datetime = field(default_factory=datetime.utcnow)
|
44 |
+
last_activity: datetime = field(default_factory=datetime.utcnow)
|
45 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
46 |
+
|
47 |
+
def update_activity(self):
|
48 |
+
"""Update last activity timestamp"""
|
49 |
+
self.last_activity = datetime.utcnow()
|
50 |
+
|
51 |
+
async def cleanup(self):
|
52 |
+
"""Cleanup all session resources"""
|
53 |
+
# Cleanup will be implemented by resource managers
|
54 |
+
log_debug(f"🧹 Cleaning up session context", session_id=self.session_id)
|
55 |
+
|
56 |
+
|
57 |
+
class StateOrchestrator:
|
58 |
+
"""Central state machine for conversation flow"""
|
59 |
+
|
60 |
+
# Valid state transitions
|
61 |
+
VALID_TRANSITIONS = {
|
62 |
+
ConversationState.IDLE: {ConversationState.INITIALIZING},
|
63 |
+
ConversationState.INITIALIZING: {ConversationState.PREPARING_WELCOME, ConversationState.LISTENING},
|
64 |
+
ConversationState.PREPARING_WELCOME: {ConversationState.PLAYING_WELCOME, ConversationState.ERROR},
|
65 |
+
ConversationState.PLAYING_WELCOME: {ConversationState.LISTENING, ConversationState.ERROR},
|
66 |
+
ConversationState.LISTENING: {ConversationState.PROCESSING_SPEECH, ConversationState.ERROR, ConversationState.ENDED},
|
67 |
+
ConversationState.PROCESSING_SPEECH: {ConversationState.PREPARING_RESPONSE, ConversationState.ERROR},
|
68 |
+
ConversationState.PREPARING_RESPONSE: {ConversationState.PLAYING_RESPONSE, ConversationState.ERROR},
|
69 |
+
ConversationState.PLAYING_RESPONSE: {ConversationState.LISTENING, ConversationState.ERROR},
|
70 |
+
ConversationState.ERROR: {ConversationState.LISTENING, ConversationState.ENDED},
|
71 |
+
ConversationState.ENDED: set() # No transitions from ENDED
|
72 |
+
}
|
73 |
+
|
74 |
+
def __init__(self, event_bus: EventBus):
|
75 |
+
self.event_bus = event_bus
|
76 |
+
self.sessions: Dict[str, SessionContext] = {}
|
77 |
+
self._setup_event_handlers()
|
78 |
+
|
79 |
+
def _setup_event_handlers(self):
|
80 |
+
"""Subscribe to relevant events"""
|
81 |
+
# Session lifecycle
|
82 |
+
self.event_bus.subscribe(EventType.SESSION_STARTED, self._handle_session_started)
|
83 |
+
self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended)
|
84 |
+
|
85 |
+
# STT events
|
86 |
+
self.event_bus.subscribe(EventType.STT_READY, self._handle_stt_ready)
|
87 |
+
self.event_bus.subscribe(EventType.STT_RESULT, self._handle_stt_result)
|
88 |
+
self.event_bus.subscribe(EventType.STT_ERROR, self._handle_stt_error)
|
89 |
+
|
90 |
+
# TTS events
|
91 |
+
self.event_bus.subscribe(EventType.TTS_COMPLETED, self._handle_tts_completed)
|
92 |
+
self.event_bus.subscribe(EventType.TTS_ERROR, self._handle_tts_error)
|
93 |
+
|
94 |
+
# Audio events
|
95 |
+
self.event_bus.subscribe(EventType.AUDIO_PLAYBACK_COMPLETED, self._handle_audio_playback_completed)
|
96 |
+
|
97 |
+
# LLM events
|
98 |
+
self.event_bus.subscribe(EventType.LLM_RESPONSE_READY, self._handle_llm_response_ready)
|
99 |
+
self.event_bus.subscribe(EventType.LLM_ERROR, self._handle_llm_error)
|
100 |
+
|
101 |
+
# Error events
|
102 |
+
self.event_bus.subscribe(EventType.CRITICAL_ERROR, self._handle_critical_error)
|
103 |
+
|
104 |
+
async def _handle_session_started(self, event: Event):
|
105 |
+
"""Handle session start"""
|
106 |
+
session_id = event.session_id
|
107 |
+
session_data = event.data
|
108 |
+
|
109 |
+
log_info(f"🎬 Session started", session_id=session_id)
|
110 |
+
|
111 |
+
# Create session context
|
112 |
+
context = SessionContext(
|
113 |
+
session_id=session_id,
|
114 |
+
session=session_data.get("session"),
|
115 |
+
metadata={
|
116 |
+
"has_welcome": session_data.get("has_welcome", False),
|
117 |
+
"welcome_text": session_data.get("welcome_text", "")
|
118 |
+
}
|
119 |
+
)
|
120 |
+
|
121 |
+
self.sessions[session_id] = context
|
122 |
+
|
123 |
+
# Transition to INITIALIZING
|
124 |
+
await self.transition_to(session_id, ConversationState.INITIALIZING)
|
125 |
+
|
126 |
+
# Check if welcome prompt exists
|
127 |
+
if session_data.get("has_welcome"):
|
128 |
+
await self.transition_to(session_id, ConversationState.PREPARING_WELCOME)
|
129 |
+
|
130 |
+
# Request TTS for welcome message
|
131 |
+
await self.event_bus.publish(Event(
|
132 |
+
type=EventType.TTS_STARTED,
|
133 |
+
session_id=session_id,
|
134 |
+
data={
|
135 |
+
"text": session_data.get("welcome_text", ""),
|
136 |
+
"is_welcome": True
|
137 |
+
}
|
138 |
+
))
|
139 |
+
else:
|
140 |
+
# No welcome, go straight to listening
|
141 |
+
await self.transition_to(session_id, ConversationState.LISTENING)
|
142 |
+
|
143 |
+
# Request STT start
|
144 |
+
await self.event_bus.publish(Event(
|
145 |
+
type=EventType.STT_STARTED,
|
146 |
+
session_id=session_id,
|
147 |
+
data={}
|
148 |
+
))
|
149 |
+
|
150 |
+
async def _handle_session_ended(self, event: Event):
|
151 |
+
"""Handle session end"""
|
152 |
+
session_id = event.session_id
|
153 |
+
|
154 |
+
log_info(f"🏁 Session ended", session_id=session_id)
|
155 |
+
|
156 |
+
# Get context for cleanup
|
157 |
+
context = self.sessions.get(session_id)
|
158 |
+
|
159 |
+
# Transition to ended
|
160 |
+
await self.transition_to(session_id, ConversationState.ENDED)
|
161 |
+
|
162 |
+
# Stop all components
|
163 |
+
await self.event_bus.publish(Event(
|
164 |
+
type=EventType.STT_STOPPED,
|
165 |
+
session_id=session_id,
|
166 |
+
data={"reason": "session_ended"}
|
167 |
+
))
|
168 |
+
|
169 |
+
# Cleanup session context
|
170 |
+
if context:
|
171 |
+
await context.cleanup()
|
172 |
+
|
173 |
+
# Remove session
|
174 |
+
self.sessions.pop(session_id, None)
|
175 |
+
|
176 |
+
# Clear event bus session data
|
177 |
+
self.event_bus.clear_session_data(session_id)
|
178 |
+
|
179 |
+
async def _handle_stt_ready(self, event: Event):
|
180 |
+
"""Handle STT ready signal"""
|
181 |
+
session_id = event.session_id
|
182 |
+
current_state = self.get_state(session_id)
|
183 |
+
|
184 |
+
log_debug(f"🎤 STT ready", session_id=session_id, current_state=current_state)
|
185 |
+
|
186 |
+
# Only process if we're expecting STT to be ready
|
187 |
+
if current_state in [ConversationState.LISTENING, ConversationState.PLAYING_WELCOME]:
|
188 |
+
# STT is ready, we're already in the right state
|
189 |
+
pass
|
190 |
+
|
191 |
+
async def _handle_stt_result(self, event: Event):
|
192 |
+
"""Handle STT transcription result"""
|
193 |
+
session_id = event.session_id
|
194 |
+
current_state = self.get_state(session_id)
|
195 |
+
|
196 |
+
if current_state != ConversationState.LISTENING:
|
197 |
+
log_warning(
|
198 |
+
f"⚠️ STT result in unexpected state",
|
199 |
+
session_id=session_id,
|
200 |
+
state=current_state
|
201 |
+
)
|
202 |
+
return
|
203 |
+
|
204 |
+
result_data = event.data
|
205 |
+
is_final = result_data.get("is_final", False)
|
206 |
+
|
207 |
+
if is_final:
|
208 |
+
text = result_data.get("text", "")
|
209 |
+
log_info(f"💬 Final transcription: '{text}'", session_id=session_id)
|
210 |
+
|
211 |
+
# Stop STT
|
212 |
+
await self.event_bus.publish(Event(
|
213 |
+
type=EventType.STT_STOPPED,
|
214 |
+
session_id=session_id,
|
215 |
+
data={"reason": "final_result"}
|
216 |
+
))
|
217 |
+
|
218 |
+
# Transition to processing
|
219 |
+
await self.transition_to(session_id, ConversationState.PROCESSING_SPEECH)
|
220 |
+
|
221 |
+
# Send to LLM
|
222 |
+
await self.event_bus.publish(Event(
|
223 |
+
type=EventType.LLM_PROCESSING_STARTED,
|
224 |
+
session_id=session_id,
|
225 |
+
data={"text": text}
|
226 |
+
))
|
227 |
+
|
228 |
+
async def _handle_llm_response_ready(self, event: Event):
|
229 |
+
"""Handle LLM response"""
|
230 |
+
session_id = event.session_id
|
231 |
+
current_state = self.get_state(session_id)
|
232 |
+
|
233 |
+
if current_state != ConversationState.PROCESSING_SPEECH:
|
234 |
+
log_warning(
|
235 |
+
f"⚠️ LLM response in unexpected state",
|
236 |
+
session_id=session_id,
|
237 |
+
state=current_state
|
238 |
+
)
|
239 |
+
return
|
240 |
+
|
241 |
+
response_text = event.data.get("text", "")
|
242 |
+
log_info(f"🤖 LLM response ready", session_id=session_id, length=len(response_text))
|
243 |
+
|
244 |
+
# Transition to preparing response
|
245 |
+
await self.transition_to(session_id, ConversationState.PREPARING_RESPONSE)
|
246 |
+
|
247 |
+
# Request TTS
|
248 |
+
await self.event_bus.publish(Event(
|
249 |
+
type=EventType.TTS_STARTED,
|
250 |
+
session_id=session_id,
|
251 |
+
data={"text": response_text}
|
252 |
+
))
|
253 |
+
|
254 |
+
async def _handle_tts_completed(self, event: Event):
|
255 |
+
"""Handle TTS completion"""
|
256 |
+
session_id = event.session_id
|
257 |
+
current_state = self.get_state(session_id)
|
258 |
+
|
259 |
+
log_info(f"🔊 TTS completed", session_id=session_id, state=current_state)
|
260 |
+
|
261 |
+
if current_state == ConversationState.PREPARING_WELCOME:
|
262 |
+
await self.transition_to(session_id, ConversationState.PLAYING_WELCOME)
|
263 |
+
elif current_state == ConversationState.PREPARING_RESPONSE:
|
264 |
+
await self.transition_to(session_id, ConversationState.PLAYING_RESPONSE)
|
265 |
+
|
266 |
+
async def _handle_audio_playback_completed(self, event: Event):
|
267 |
+
"""Handle audio playback completion"""
|
268 |
+
session_id = event.session_id
|
269 |
+
current_state = self.get_state(session_id)
|
270 |
+
|
271 |
+
log_info(f"🎵 Audio playback completed", session_id=session_id, state=current_state)
|
272 |
+
|
273 |
+
if current_state in [ConversationState.PLAYING_WELCOME, ConversationState.PLAYING_RESPONSE]:
|
274 |
+
# Transition back to listening
|
275 |
+
await self.transition_to(session_id, ConversationState.LISTENING)
|
276 |
+
|
277 |
+
# Start STT
|
278 |
+
await self.event_bus.publish(Event(
|
279 |
+
type=EventType.STT_STARTED,
|
280 |
+
session_id=session_id,
|
281 |
+
data={}
|
282 |
+
))
|
283 |
+
|
284 |
+
async def _handle_stt_error(self, event: Event):
|
285 |
+
"""Handle STT errors"""
|
286 |
+
session_id = event.session_id
|
287 |
+
error_data = event.data
|
288 |
+
|
289 |
+
log_error(
|
290 |
+
f"❌ STT error",
|
291 |
+
session_id=session_id,
|
292 |
+
error=error_data.get("message")
|
293 |
+
)
|
294 |
+
|
295 |
+
# Try to recover by transitioning back to listening
|
296 |
+
current_state = self.get_state(session_id)
|
297 |
+
if current_state != ConversationState.ENDED:
|
298 |
+
await self.transition_to(session_id, ConversationState.ERROR)
|
299 |
+
|
300 |
+
# Try recovery after delay
|
301 |
+
await asyncio.sleep(2.0)
|
302 |
+
|
303 |
+
if self.get_state(session_id) == ConversationState.ERROR:
|
304 |
+
await self.transition_to(session_id, ConversationState.LISTENING)
|
305 |
+
|
306 |
+
# Restart STT
|
307 |
+
await self.event_bus.publish(Event(
|
308 |
+
type=EventType.STT_STARTED,
|
309 |
+
session_id=session_id,
|
310 |
+
data={"retry": True}
|
311 |
+
))
|
312 |
+
|
313 |
+
async def _handle_tts_error(self, event: Event):
|
314 |
+
"""Handle TTS errors"""
|
315 |
+
session_id = event.session_id
|
316 |
+
error_data = event.data
|
317 |
+
|
318 |
+
log_error(
|
319 |
+
f"❌ TTS error",
|
320 |
+
session_id=session_id,
|
321 |
+
error=error_data.get("message")
|
322 |
+
)
|
323 |
+
|
324 |
+
# Skip TTS and go to listening
|
325 |
+
current_state = self.get_state(session_id)
|
326 |
+
if current_state in [ConversationState.PREPARING_WELCOME, ConversationState.PREPARING_RESPONSE]:
|
327 |
+
await self.transition_to(session_id, ConversationState.LISTENING)
|
328 |
+
|
329 |
+
# Start STT
|
330 |
+
await self.event_bus.publish(Event(
|
331 |
+
type=EventType.STT_STARTED,
|
332 |
+
session_id=session_id,
|
333 |
+
data={}
|
334 |
+
))
|
335 |
+
|
336 |
+
async def _handle_llm_error(self, event: Event):
|
337 |
+
"""Handle LLM errors"""
|
338 |
+
session_id = event.session_id
|
339 |
+
error_data = event.data
|
340 |
+
|
341 |
+
log_error(
|
342 |
+
f"❌ LLM error",
|
343 |
+
session_id=session_id,
|
344 |
+
error=error_data.get("message")
|
345 |
+
)
|
346 |
+
|
347 |
+
# Go back to listening
|
348 |
+
await self.transition_to(session_id, ConversationState.LISTENING)
|
349 |
+
|
350 |
+
# Start STT
|
351 |
+
await self.event_bus.publish(Event(
|
352 |
+
type=EventType.STT_STARTED,
|
353 |
+
session_id=session_id,
|
354 |
+
data={}
|
355 |
+
))
|
356 |
+
|
357 |
+
async def _handle_critical_error(self, event: Event):
|
358 |
+
"""Handle critical errors"""
|
359 |
+
session_id = event.session_id
|
360 |
+
error_data = event.data
|
361 |
+
|
362 |
+
log_error(
|
363 |
+
f"💥 Critical error",
|
364 |
+
session_id=session_id,
|
365 |
+
error=error_data.get("message")
|
366 |
+
)
|
367 |
+
|
368 |
+
# End session
|
369 |
+
await self.transition_to(session_id, ConversationState.ENDED)
|
370 |
+
|
371 |
+
# Publish session end event
|
372 |
+
await self.event_bus.publish(Event(
|
373 |
+
type=EventType.SESSION_ENDED,
|
374 |
+
session_id=session_id,
|
375 |
+
data={"reason": "critical_error"}
|
376 |
+
))
|
377 |
+
|
378 |
+
async def transition_to(self, session_id: str, new_state: ConversationState):
|
379 |
+
"""Transition to a new state"""
|
380 |
+
current_state = self.get_state(session_id)
|
381 |
+
|
382 |
+
if current_state is None:
|
383 |
+
log_warning(f"⚠️ Session not found for transition", session_id=session_id)
|
384 |
+
return
|
385 |
+
|
386 |
+
# Check if transition is valid
|
387 |
+
if new_state not in self.VALID_TRANSITIONS.get(current_state, set()):
|
388 |
+
log_error(
|
389 |
+
f"❌ Invalid state transition",
|
390 |
+
session_id=session_id,
|
391 |
+
from_state=current_state.value,
|
392 |
+
to_state=new_state.value
|
393 |
+
)
|
394 |
+
|
395 |
+
await publish_error(
|
396 |
+
session_id=session_id,
|
397 |
+
error_type="invalid_transition",
|
398 |
+
error_message=f"Cannot transition from {current_state.value} to {new_state.value}"
|
399 |
+
)
|
400 |
+
return
|
401 |
+
|
402 |
+
# Update state
|
403 |
+
self.sessions[session_id] = new_state
|
404 |
+
|
405 |
+
log_info(
|
406 |
+
f"🔄 State transition",
|
407 |
+
session_id=session_id,
|
408 |
+
from_state=current_state.value,
|
409 |
+
to_state=new_state.value
|
410 |
+
)
|
411 |
+
|
412 |
+
# Publish state transition event
|
413 |
+
await publish_state_transition(
|
414 |
+
session_id=session_id,
|
415 |
+
from_state=current_state.value,
|
416 |
+
to_state=new_state.value
|
417 |
+
)
|
418 |
+
|
419 |
+
def get_state(self, session_id: str) -> Optional[ConversationState]:
|
420 |
+
"""Get current state for a session"""
|
421 |
+
return self.sessions.get(session_id)
|
422 |
+
|
423 |
+
def get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]:
|
424 |
+
"""Get session data"""
|
425 |
+
return self.session_data.get(session_id)
|
426 |
+
|
427 |
+
async def handle_error_recovery(self, session_id: str, error_type: str):
|
428 |
+
"""Handle error recovery strategies"""
|
429 |
+
context = self.sessions.get(session_id)
|
430 |
+
|
431 |
+
if not context or context.state == ConversationState.ENDED:
|
432 |
+
return
|
433 |
+
|
434 |
+
log_info(
|
435 |
+
f"🔧 Attempting error recovery",
|
436 |
+
session_id=session_id,
|
437 |
+
error_type=error_type,
|
438 |
+
current_state=context.state.value
|
439 |
+
)
|
440 |
+
|
441 |
+
# Update activity
|
442 |
+
context.update_activity()
|
443 |
+
|
444 |
+
# Define recovery strategies
|
445 |
+
recovery_strategies = {
|
446 |
+
"stt_error": self._recover_from_stt_error,
|
447 |
+
"tts_error": self._recover_from_tts_error,
|
448 |
+
"llm_error": self._recover_from_llm_error,
|
449 |
+
"websocket_error": self._recover_from_websocket_error
|
450 |
+
}
|
451 |
+
|
452 |
+
strategy = recovery_strategies.get(error_type)
|
453 |
+
if strategy:
|
454 |
+
await strategy(session_id)
|
455 |
+
else:
|
456 |
+
# Default recovery: go to error state then back to listening
|
457 |
+
await self.transition_to(session_id, ConversationState.ERROR)
|
458 |
+
await asyncio.sleep(1.0)
|
459 |
+
await self.transition_to(session_id, ConversationState.LISTENING)
|
460 |
+
|
461 |
+
async def _recover_from_stt_error(self, session_id: str):
|
462 |
+
"""Recover from STT error"""
|
463 |
+
# Stop STT, wait, restart
|
464 |
+
await self.event_bus.publish(Event(
|
465 |
+
type=EventType.STT_STOPPED,
|
466 |
+
session_id=session_id,
|
467 |
+
data={"reason": "error_recovery"}
|
468 |
+
))
|
469 |
+
|
470 |
+
await asyncio.sleep(2.0)
|
471 |
+
|
472 |
+
await self.transition_to(session_id, ConversationState.LISTENING)
|
473 |
+
|
474 |
+
await self.event_bus.publish(Event(
|
475 |
+
type=EventType.STT_STARTED,
|
476 |
+
session_id=session_id,
|
477 |
+
data={"retry": True}
|
478 |
+
))
|
479 |
+
|
480 |
+
async def _recover_from_tts_error(self, session_id: str):
|
481 |
+
"""Recover from TTS error"""
|
482 |
+
# Skip TTS, go directly to listening
|
483 |
+
await self.transition_to(session_id, ConversationState.LISTENING)
|
484 |
+
|
485 |
+
await self.event_bus.publish(Event(
|
486 |
+
type=EventType.STT_STARTED,
|
487 |
+
session_id=session_id,
|
488 |
+
data={}
|
489 |
+
))
|
490 |
+
|
491 |
+
async def _recover_from_llm_error(self, session_id: str):
|
492 |
+
"""Recover from LLM error"""
|
493 |
+
# Go back to listening
|
494 |
+
await self.transition_to(session_id, ConversationState.LISTENING)
|
495 |
+
|
496 |
+
await self.event_bus.publish(Event(
|
497 |
+
type=EventType.STT_STARTED,
|
498 |
+
session_id=session_id,
|
499 |
+
data={}
|
500 |
+
))
|
501 |
+
|
502 |
+
async def _recover_from_websocket_error(self, session_id: str):
|
503 |
+
"""Recover from WebSocket error"""
|
504 |
+
# End session cleanly
|
505 |
+
await self.transition_to(session_id, ConversationState.ENDED)
|
506 |
+
|
507 |
+
await self.event_bus.publish(Event(
|
508 |
+
type=EventType.SESSION_ENDED,
|
509 |
+
session_id=session_id,
|
510 |
+
data={"reason": "websocket_error"}
|
511 |
+
))
|
stt_lifecycle_manager.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
STT Lifecycle Manager for Flare
|
3 |
+
===============================
|
4 |
+
Manages STT instances lifecycle per session
|
5 |
+
"""
|
6 |
+
import asyncio
|
7 |
+
from typing import Dict, Optional, Any
|
8 |
+
from datetime import datetime
|
9 |
+
import traceback
|
10 |
+
import base64
|
11 |
+
|
12 |
+
from event_bus import EventBus, Event, EventType, publish_error
|
13 |
+
from resource_manager import ResourceManager, ResourceType
|
14 |
+
from stt_factory import STTFactory
|
15 |
+
from stt_interface import STTInterface, STTConfig, TranscriptionResult
|
16 |
+
from logger import log_info, log_error, log_debug, log_warning
|
17 |
+
|
18 |
+
|
19 |
+
class STTSession:
|
20 |
+
"""STT session wrapper"""
|
21 |
+
|
22 |
+
def __init__(self, session_id: str, stt_instance: STTInterface):
|
23 |
+
self.session_id = session_id
|
24 |
+
self.stt_instance = stt_instance
|
25 |
+
self.is_streaming = False
|
26 |
+
self.config: Optional[STTConfig] = None
|
27 |
+
self.created_at = datetime.utcnow()
|
28 |
+
self.last_activity = datetime.utcnow()
|
29 |
+
self.total_chunks = 0
|
30 |
+
self.total_bytes = 0
|
31 |
+
|
32 |
+
def update_activity(self):
|
33 |
+
"""Update last activity timestamp"""
|
34 |
+
self.last_activity = datetime.utcnow()
|
35 |
+
|
36 |
+
|
37 |
+
class STTLifecycleManager:
|
38 |
+
"""Manages STT instances lifecycle"""
|
39 |
+
|
40 |
+
def __init__(self, event_bus: EventBus, resource_manager: ResourceManager):
|
41 |
+
self.event_bus = event_bus
|
42 |
+
self.resource_manager = resource_manager
|
43 |
+
self.stt_sessions: Dict[str, STTSession] = {}
|
44 |
+
self._setup_event_handlers()
|
45 |
+
self._setup_resource_pool()
|
46 |
+
|
47 |
+
def _setup_event_handlers(self):
|
48 |
+
"""Subscribe to STT-related events"""
|
49 |
+
self.event_bus.subscribe(EventType.STT_STARTED, self._handle_stt_start)
|
50 |
+
self.event_bus.subscribe(EventType.STT_STOPPED, self._handle_stt_stop)
|
51 |
+
self.event_bus.subscribe(EventType.AUDIO_CHUNK_RECEIVED, self._handle_audio_chunk)
|
52 |
+
self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended)
|
53 |
+
|
54 |
+
def _setup_resource_pool(self):
|
55 |
+
"""Setup STT instance pool"""
|
56 |
+
self.resource_manager.register_pool(
|
57 |
+
resource_type=ResourceType.STT_INSTANCE,
|
58 |
+
factory=self._create_stt_instance,
|
59 |
+
max_idle=5,
|
60 |
+
max_age_seconds=300 # 5 minutes
|
61 |
+
)
|
62 |
+
|
63 |
+
async def _create_stt_instance(self) -> STTInterface:
|
64 |
+
"""Factory for creating STT instances"""
|
65 |
+
try:
|
66 |
+
stt_instance = STTFactory.create_provider()
|
67 |
+
if not stt_instance:
|
68 |
+
raise ValueError("Failed to create STT instance")
|
69 |
+
|
70 |
+
log_debug("🎤 Created new STT instance")
|
71 |
+
return stt_instance
|
72 |
+
|
73 |
+
except Exception as e:
|
74 |
+
log_error(f"❌ Failed to create STT instance", error=str(e))
|
75 |
+
raise
|
76 |
+
|
77 |
+
async def _handle_stt_start(self, event: Event):
|
78 |
+
"""Handle STT start request"""
|
79 |
+
session_id = event.session_id
|
80 |
+
config_data = event.data
|
81 |
+
|
82 |
+
try:
|
83 |
+
log_info(f"🎤 Starting STT", session_id=session_id)
|
84 |
+
|
85 |
+
# Check if already exists
|
86 |
+
if session_id in self.stt_sessions:
|
87 |
+
stt_session = self.stt_sessions[session_id]
|
88 |
+
if stt_session.is_streaming:
|
89 |
+
log_warning(f"⚠️ STT already streaming", session_id=session_id)
|
90 |
+
return
|
91 |
+
else:
|
92 |
+
# Acquire STT instance from pool
|
93 |
+
resource_id = f"stt_{session_id}"
|
94 |
+
stt_instance = await self.resource_manager.acquire(
|
95 |
+
resource_id=resource_id,
|
96 |
+
session_id=session_id,
|
97 |
+
resource_type=ResourceType.STT_INSTANCE,
|
98 |
+
cleanup_callback=self._cleanup_stt_instance
|
99 |
+
)
|
100 |
+
|
101 |
+
# Create session wrapper
|
102 |
+
stt_session = STTSession(session_id, stt_instance)
|
103 |
+
self.stt_sessions[session_id] = stt_session
|
104 |
+
|
105 |
+
# Get session locale from state orchestrator
|
106 |
+
locale = config_data.get("locale", "tr")
|
107 |
+
|
108 |
+
# Build STT config
|
109 |
+
stt_config = STTConfig(
|
110 |
+
language=self._get_language_code(locale),
|
111 |
+
sample_rate=config_data.get("sample_rate", 16000),
|
112 |
+
encoding=config_data.get("encoding", "WEBM_OPUS"),
|
113 |
+
enable_punctuation=config_data.get("enable_punctuation", True),
|
114 |
+
enable_word_timestamps=False,
|
115 |
+
model=config_data.get("model", "latest_long"),
|
116 |
+
use_enhanced=config_data.get("use_enhanced", True),
|
117 |
+
single_utterance=False, # Continuous listening
|
118 |
+
interim_results=config_data.get("interim_results", True),
|
119 |
+
vad_enabled=config_data.get("vad_enabled", True),
|
120 |
+
speech_timeout_ms=config_data.get("speech_timeout_ms", 2000),
|
121 |
+
noise_reduction_enabled=config_data.get("noise_reduction_enabled", True),
|
122 |
+
noise_reduction_level=config_data.get("noise_reduction_level", 2)
|
123 |
+
)
|
124 |
+
|
125 |
+
stt_session.config = stt_config
|
126 |
+
|
127 |
+
# Start streaming
|
128 |
+
await stt_session.stt_instance.start_streaming(stt_config)
|
129 |
+
stt_session.is_streaming = True
|
130 |
+
stt_session.update_activity()
|
131 |
+
|
132 |
+
log_info(f"✅ STT started", session_id=session_id, language=stt_config.language)
|
133 |
+
|
134 |
+
# Notify STT is ready
|
135 |
+
await self.event_bus.publish(Event(
|
136 |
+
type=EventType.STT_READY,
|
137 |
+
session_id=session_id,
|
138 |
+
data={"language": stt_config.language}
|
139 |
+
))
|
140 |
+
|
141 |
+
except Exception as e:
|
142 |
+
log_error(
|
143 |
+
f"❌ Failed to start STT",
|
144 |
+
session_id=session_id,
|
145 |
+
error=str(e),
|
146 |
+
traceback=traceback.format_exc()
|
147 |
+
)
|
148 |
+
|
149 |
+
# Clean up on error
|
150 |
+
if session_id in self.stt_sessions:
|
151 |
+
await self._cleanup_session(session_id)
|
152 |
+
|
153 |
+
# Publish error event
|
154 |
+
await publish_error(
|
155 |
+
session_id=session_id,
|
156 |
+
error_type="stt_error",
|
157 |
+
error_message=f"Failed to start STT: {str(e)}"
|
158 |
+
)
|
159 |
+
|
160 |
+
async def _handle_stt_stop(self, event: Event):
|
161 |
+
"""Handle STT stop request"""
|
162 |
+
session_id = event.session_id
|
163 |
+
reason = event.data.get("reason", "unknown")
|
164 |
+
|
165 |
+
log_info(f"🛑 Stopping STT", session_id=session_id, reason=reason)
|
166 |
+
|
167 |
+
stt_session = self.stt_sessions.get(session_id)
|
168 |
+
if not stt_session:
|
169 |
+
log_warning(f"⚠️ No STT session found", session_id=session_id)
|
170 |
+
return
|
171 |
+
|
172 |
+
try:
|
173 |
+
if stt_session.is_streaming:
|
174 |
+
# Stop streaming
|
175 |
+
final_result = await stt_session.stt_instance.stop_streaming()
|
176 |
+
stt_session.is_streaming = False
|
177 |
+
|
178 |
+
# If we got a final result, publish it
|
179 |
+
if final_result and final_result.text:
|
180 |
+
await self.event_bus.publish(Event(
|
181 |
+
type=EventType.STT_RESULT,
|
182 |
+
session_id=session_id,
|
183 |
+
data={
|
184 |
+
"text": final_result.text,
|
185 |
+
"is_final": True,
|
186 |
+
"confidence": final_result.confidence
|
187 |
+
}
|
188 |
+
))
|
189 |
+
|
190 |
+
# Don't remove session immediately - might restart
|
191 |
+
stt_session.update_activity()
|
192 |
+
|
193 |
+
log_info(f"✅ STT stopped", session_id=session_id)
|
194 |
+
|
195 |
+
except Exception as e:
|
196 |
+
log_error(
|
197 |
+
f"❌ Error stopping STT",
|
198 |
+
session_id=session_id,
|
199 |
+
error=str(e)
|
200 |
+
)
|
201 |
+
|
202 |
+
async def _handle_audio_chunk(self, event: Event):
|
203 |
+
"""Process audio chunk through STT"""
|
204 |
+
session_id = event.session_id
|
205 |
+
|
206 |
+
stt_session = self.stt_sessions.get(session_id)
|
207 |
+
if not stt_session or not stt_session.is_streaming:
|
208 |
+
# STT not ready, ignore chunk
|
209 |
+
return
|
210 |
+
|
211 |
+
try:
|
212 |
+
# Decode audio data
|
213 |
+
audio_data = base64.b64decode(event.data.get("audio_data", ""))
|
214 |
+
|
215 |
+
# Update stats
|
216 |
+
stt_session.total_chunks += 1
|
217 |
+
stt_session.total_bytes += len(audio_data)
|
218 |
+
stt_session.update_activity()
|
219 |
+
|
220 |
+
# Stream to STT
|
221 |
+
async for result in stt_session.stt_instance.stream_audio(audio_data):
|
222 |
+
# Publish transcription results
|
223 |
+
await self.event_bus.publish(Event(
|
224 |
+
type=EventType.STT_RESULT,
|
225 |
+
session_id=session_id,
|
226 |
+
data={
|
227 |
+
"text": result.text,
|
228 |
+
"is_final": result.is_final,
|
229 |
+
"confidence": result.confidence,
|
230 |
+
"timestamp": result.timestamp
|
231 |
+
}
|
232 |
+
))
|
233 |
+
|
234 |
+
# Log final results
|
235 |
+
if result.is_final:
|
236 |
+
log_info(
|
237 |
+
f"📝 STT final result",
|
238 |
+
session_id=session_id,
|
239 |
+
text=result.text[:50] + "..." if len(result.text) > 50 else result.text,
|
240 |
+
confidence=result.confidence
|
241 |
+
)
|
242 |
+
|
243 |
+
# Log progress periodically
|
244 |
+
if stt_session.total_chunks % 100 == 0:
|
245 |
+
log_debug(
|
246 |
+
f"📊 STT progress",
|
247 |
+
session_id=session_id,
|
248 |
+
chunks=stt_session.total_chunks,
|
249 |
+
bytes=stt_session.total_bytes
|
250 |
+
)
|
251 |
+
|
252 |
+
except Exception as e:
|
253 |
+
log_error(
|
254 |
+
f"❌ Error processing audio chunk",
|
255 |
+
session_id=session_id,
|
256 |
+
error=str(e)
|
257 |
+
)
|
258 |
+
|
259 |
+
# Check if it's a recoverable error
|
260 |
+
if "stream duration" in str(e) or "timeout" in str(e).lower():
|
261 |
+
# STT timeout, restart needed
|
262 |
+
await publish_error(
|
263 |
+
session_id=session_id,
|
264 |
+
error_type="stt_timeout",
|
265 |
+
error_message="STT stream timeout, restart needed"
|
266 |
+
)
|
267 |
+
else:
|
268 |
+
# Other STT error
|
269 |
+
await publish_error(
|
270 |
+
session_id=session_id,
|
271 |
+
error_type="stt_error",
|
272 |
+
error_message=str(e)
|
273 |
+
)
|
274 |
+
|
275 |
+
async def _handle_session_ended(self, event: Event):
|
276 |
+
"""Clean up STT resources when session ends"""
|
277 |
+
session_id = event.session_id
|
278 |
+
await self._cleanup_session(session_id)
|
279 |
+
|
280 |
+
async def _cleanup_session(self, session_id: str):
|
281 |
+
"""Clean up STT session"""
|
282 |
+
stt_session = self.stt_sessions.pop(session_id, None)
|
283 |
+
if not stt_session:
|
284 |
+
return
|
285 |
+
|
286 |
+
try:
|
287 |
+
# Stop streaming if active
|
288 |
+
if stt_session.is_streaming:
|
289 |
+
await stt_session.stt_instance.stop_streaming()
|
290 |
+
|
291 |
+
# Release resource
|
292 |
+
resource_id = f"stt_{session_id}"
|
293 |
+
await self.resource_manager.release(resource_id, delay_seconds=60)
|
294 |
+
|
295 |
+
log_info(
|
296 |
+
f"🧹 STT session cleaned up",
|
297 |
+
session_id=session_id,
|
298 |
+
total_chunks=stt_session.total_chunks,
|
299 |
+
total_bytes=stt_session.total_bytes
|
300 |
+
)
|
301 |
+
|
302 |
+
except Exception as e:
|
303 |
+
log_error(
|
304 |
+
f"❌ Error cleaning up STT session",
|
305 |
+
session_id=session_id,
|
306 |
+
error=str(e)
|
307 |
+
)
|
308 |
+
|
309 |
+
async def _cleanup_stt_instance(self, stt_instance: STTInterface):
|
310 |
+
"""Cleanup callback for STT instance"""
|
311 |
+
try:
|
312 |
+
# Ensure streaming is stopped
|
313 |
+
if hasattr(stt_instance, 'is_streaming') and stt_instance.is_streaming:
|
314 |
+
await stt_instance.stop_streaming()
|
315 |
+
|
316 |
+
log_debug("🧹 STT instance cleaned up")
|
317 |
+
|
318 |
+
except Exception as e:
|
319 |
+
log_error(f"❌ Error cleaning up STT instance", error=str(e))
|
320 |
+
|
321 |
+
def _get_language_code(self, locale: str) -> str:
|
322 |
+
"""Convert locale to STT language code"""
|
323 |
+
# Map common locales to STT language codes
|
324 |
+
locale_map = {
|
325 |
+
"tr": "tr-TR",
|
326 |
+
"en": "en-US",
|
327 |
+
"de": "de-DE",
|
328 |
+
"fr": "fr-FR",
|
329 |
+
"es": "es-ES",
|
330 |
+
"it": "it-IT",
|
331 |
+
"pt": "pt-BR",
|
332 |
+
"ru": "ru-RU",
|
333 |
+
"ja": "ja-JP",
|
334 |
+
"ko": "ko-KR",
|
335 |
+
"zh": "zh-CN",
|
336 |
+
"ar": "ar-SA"
|
337 |
+
}
|
338 |
+
|
339 |
+
# Check direct match
|
340 |
+
if locale in locale_map:
|
341 |
+
return locale_map[locale]
|
342 |
+
|
343 |
+
# Check if it's already a full code
|
344 |
+
if "-" in locale and len(locale) == 5:
|
345 |
+
return locale
|
346 |
+
|
347 |
+
# Default to locale-LOCALE format
|
348 |
+
return f"{locale}-{locale.upper()}"
|
349 |
+
|
350 |
+
def get_stats(self) -> Dict[str, Any]:
|
351 |
+
"""Get STT manager statistics"""
|
352 |
+
session_stats = {}
|
353 |
+
for session_id, stt_session in self.stt_sessions.items():
|
354 |
+
session_stats[session_id] = {
|
355 |
+
"is_streaming": stt_session.is_streaming,
|
356 |
+
"total_chunks": stt_session.total_chunks,
|
357 |
+
"total_bytes": stt_session.total_bytes,
|
358 |
+
"uptime_seconds": (datetime.utcnow() - stt_session.created_at).total_seconds(),
|
359 |
+
"last_activity": stt_session.last_activity.isoformat()
|
360 |
+
}
|
361 |
+
|
362 |
+
return {
|
363 |
+
"active_sessions": len(self.stt_sessions),
|
364 |
+
"streaming_sessions": sum(1 for s in self.stt_sessions.values() if s.is_streaming),
|
365 |
+
"sessions": session_stats
|
366 |
+
}
|
tts_lifecycle_manager.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
TTS Lifecycle Manager for Flare
|
3 |
+
===============================
|
4 |
+
Manages TTS instances lifecycle per session
|
5 |
+
"""
|
6 |
+
import asyncio
|
7 |
+
from typing import Dict, Optional, Any, List
|
8 |
+
from datetime import datetime
|
9 |
+
import traceback
|
10 |
+
import base64
|
11 |
+
|
12 |
+
from event_bus import EventBus, Event, EventType, publish_error
|
13 |
+
from resource_manager import ResourceManager, ResourceType
|
14 |
+
from tts_factory import TTSFactory
|
15 |
+
from tts_interface import TTSInterface
|
16 |
+
from tts_preprocessor import TTSPreprocessor
|
17 |
+
from logger import log_info, log_error, log_debug, log_warning
|
18 |
+
|
19 |
+
|
20 |
+
class TTSJob:
|
21 |
+
"""TTS synthesis job"""
|
22 |
+
|
23 |
+
def __init__(self, job_id: str, session_id: str, text: str, is_welcome: bool = False):
|
24 |
+
self.job_id = job_id
|
25 |
+
self.session_id = session_id
|
26 |
+
self.text = text
|
27 |
+
self.is_welcome = is_welcome
|
28 |
+
self.created_at = datetime.utcnow()
|
29 |
+
self.completed_at: Optional[datetime] = None
|
30 |
+
self.audio_data: Optional[bytes] = None
|
31 |
+
self.error: Optional[str] = None
|
32 |
+
self.chunks_sent = 0
|
33 |
+
|
34 |
+
def complete(self, audio_data: bytes):
|
35 |
+
"""Mark job as completed"""
|
36 |
+
self.audio_data = audio_data
|
37 |
+
self.completed_at = datetime.utcnow()
|
38 |
+
|
39 |
+
def fail(self, error: str):
|
40 |
+
"""Mark job as failed"""
|
41 |
+
self.error = error
|
42 |
+
self.completed_at = datetime.utcnow()
|
43 |
+
|
44 |
+
|
45 |
+
class TTSSession:
|
46 |
+
"""TTS session wrapper"""
|
47 |
+
|
48 |
+
def __init__(self, session_id: str, tts_instance: TTSInterface):
|
49 |
+
self.session_id = session_id
|
50 |
+
self.tts_instance = tts_instance
|
51 |
+
self.preprocessor: Optional[TTSPreprocessor] = None
|
52 |
+
self.active_jobs: Dict[str, TTSJob] = {}
|
53 |
+
self.completed_jobs: List[TTSJob] = []
|
54 |
+
self.created_at = datetime.utcnow()
|
55 |
+
self.last_activity = datetime.utcnow()
|
56 |
+
self.total_jobs = 0
|
57 |
+
self.total_chars = 0
|
58 |
+
|
59 |
+
def update_activity(self):
|
60 |
+
"""Update last activity timestamp"""
|
61 |
+
self.last_activity = datetime.utcnow()
|
62 |
+
|
63 |
+
|
64 |
+
class TTSLifecycleManager:
|
65 |
+
"""Manages TTS instances lifecycle"""
|
66 |
+
|
67 |
+
def __init__(self, event_bus: EventBus, resource_manager: ResourceManager):
|
68 |
+
self.event_bus = event_bus
|
69 |
+
self.resource_manager = resource_manager
|
70 |
+
self.tts_sessions: Dict[str, TTSSession] = {}
|
71 |
+
self.chunk_size = 16384 # 16KB chunks for base64
|
72 |
+
self._setup_event_handlers()
|
73 |
+
self._setup_resource_pool()
|
74 |
+
|
75 |
+
def _setup_event_handlers(self):
|
76 |
+
"""Subscribe to TTS-related events"""
|
77 |
+
self.event_bus.subscribe(EventType.TTS_STARTED, self._handle_tts_start)
|
78 |
+
self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended)
|
79 |
+
|
80 |
+
def _setup_resource_pool(self):
|
81 |
+
"""Setup TTS instance pool"""
|
82 |
+
self.resource_manager.register_pool(
|
83 |
+
resource_type=ResourceType.TTS_INSTANCE,
|
84 |
+
factory=self._create_tts_instance,
|
85 |
+
max_idle=3,
|
86 |
+
max_age_seconds=600 # 10 minutes
|
87 |
+
)
|
88 |
+
|
89 |
+
async def _create_tts_instance(self) -> Optional[TTSInterface]:
|
90 |
+
"""Factory for creating TTS instances"""
|
91 |
+
try:
|
92 |
+
tts_instance = TTSFactory.create_provider()
|
93 |
+
if not tts_instance:
|
94 |
+
log_warning("⚠️ No TTS provider configured")
|
95 |
+
return None
|
96 |
+
|
97 |
+
log_debug("🔊 Created new TTS instance")
|
98 |
+
return tts_instance
|
99 |
+
|
100 |
+
except Exception as e:
|
101 |
+
log_error(f"❌ Failed to create TTS instance", error=str(e))
|
102 |
+
return None
|
103 |
+
|
104 |
+
async def _handle_tts_start(self, event: Event):
|
105 |
+
"""Handle TTS synthesis request"""
|
106 |
+
session_id = event.session_id
|
107 |
+
text = event.data.get("text", "")
|
108 |
+
is_welcome = event.data.get("is_welcome", False)
|
109 |
+
|
110 |
+
if not text:
|
111 |
+
log_warning(f"⚠️ Empty text for TTS", session_id=session_id)
|
112 |
+
return
|
113 |
+
|
114 |
+
try:
|
115 |
+
log_info(
|
116 |
+
f"🔊 Starting TTS",
|
117 |
+
session_id=session_id,
|
118 |
+
text_length=len(text),
|
119 |
+
is_welcome=is_welcome
|
120 |
+
)
|
121 |
+
|
122 |
+
# Get or create session
|
123 |
+
if session_id not in self.tts_sessions:
|
124 |
+
# Acquire TTS instance from pool
|
125 |
+
resource_id = f"tts_{session_id}"
|
126 |
+
tts_instance = await self.resource_manager.acquire(
|
127 |
+
resource_id=resource_id,
|
128 |
+
session_id=session_id,
|
129 |
+
resource_type=ResourceType.TTS_INSTANCE,
|
130 |
+
cleanup_callback=self._cleanup_tts_instance
|
131 |
+
)
|
132 |
+
|
133 |
+
if not tts_instance:
|
134 |
+
# No TTS available
|
135 |
+
await self._handle_no_tts(session_id, text, is_welcome)
|
136 |
+
return
|
137 |
+
|
138 |
+
# Create session
|
139 |
+
tts_session = TTSSession(session_id, tts_instance)
|
140 |
+
|
141 |
+
# Get locale from event data or default
|
142 |
+
locale = event.data.get("locale", "tr")
|
143 |
+
tts_session.preprocessor = TTSPreprocessor(language=locale)
|
144 |
+
|
145 |
+
self.tts_sessions[session_id] = tts_session
|
146 |
+
else:
|
147 |
+
tts_session = self.tts_sessions[session_id]
|
148 |
+
|
149 |
+
# Create job
|
150 |
+
job_id = f"{session_id}_{tts_session.total_jobs}"
|
151 |
+
job = TTSJob(job_id, session_id, text, is_welcome)
|
152 |
+
tts_session.active_jobs[job_id] = job
|
153 |
+
tts_session.total_jobs += 1
|
154 |
+
tts_session.total_chars += len(text)
|
155 |
+
tts_session.update_activity()
|
156 |
+
|
157 |
+
# Process TTS
|
158 |
+
await self._process_tts_job(tts_session, job)
|
159 |
+
|
160 |
+
except Exception as e:
|
161 |
+
log_error(
|
162 |
+
f"❌ Failed to start TTS",
|
163 |
+
session_id=session_id,
|
164 |
+
error=str(e),
|
165 |
+
traceback=traceback.format_exc()
|
166 |
+
)
|
167 |
+
|
168 |
+
# Publish error event
|
169 |
+
await publish_error(
|
170 |
+
session_id=session_id,
|
171 |
+
error_type="tts_error",
|
172 |
+
error_message=f"Failed to synthesize speech: {str(e)}"
|
173 |
+
)
|
174 |
+
|
175 |
+
async def _process_tts_job(self, tts_session: TTSSession, job: TTSJob):
|
176 |
+
"""Process a TTS job"""
|
177 |
+
try:
|
178 |
+
# Preprocess text
|
179 |
+
processed_text = tts_session.preprocessor.preprocess(
|
180 |
+
job.text,
|
181 |
+
tts_session.tts_instance.get_preprocessing_flags()
|
182 |
+
)
|
183 |
+
|
184 |
+
log_debug(
|
185 |
+
f"📝 TTS preprocessed",
|
186 |
+
session_id=job.session_id,
|
187 |
+
original_length=len(job.text),
|
188 |
+
processed_length=len(processed_text)
|
189 |
+
)
|
190 |
+
|
191 |
+
# Synthesize audio
|
192 |
+
audio_data = await tts_session.tts_instance.synthesize(processed_text)
|
193 |
+
|
194 |
+
if not audio_data:
|
195 |
+
raise ValueError("TTS returned empty audio data")
|
196 |
+
|
197 |
+
job.complete(audio_data)
|
198 |
+
|
199 |
+
log_info(
|
200 |
+
f"✅ TTS synthesis complete",
|
201 |
+
session_id=job.session_id,
|
202 |
+
audio_size=len(audio_data),
|
203 |
+
duration_ms=(datetime.utcnow() - job.created_at).total_seconds() * 1000
|
204 |
+
)
|
205 |
+
|
206 |
+
# Stream audio chunks
|
207 |
+
await self._stream_audio_chunks(tts_session, job)
|
208 |
+
|
209 |
+
# Move to completed
|
210 |
+
tts_session.active_jobs.pop(job.job_id, None)
|
211 |
+
tts_session.completed_jobs.append(job)
|
212 |
+
|
213 |
+
# Keep only last 10 completed jobs
|
214 |
+
if len(tts_session.completed_jobs) > 10:
|
215 |
+
tts_session.completed_jobs.pop(0)
|
216 |
+
|
217 |
+
except Exception as e:
|
218 |
+
job.fail(str(e))
|
219 |
+
|
220 |
+
# Handle specific TTS errors
|
221 |
+
error_message = str(e)
|
222 |
+
if "quota" in error_message.lower() or "limit" in error_message.lower():
|
223 |
+
log_error(f"❌ TTS quota exceeded", session_id=job.session_id)
|
224 |
+
await publish_error(
|
225 |
+
session_id=job.session_id,
|
226 |
+
error_type="tts_quota_exceeded",
|
227 |
+
error_message="TTS service quota exceeded"
|
228 |
+
)
|
229 |
+
else:
|
230 |
+
log_error(
|
231 |
+
f"❌ TTS synthesis failed",
|
232 |
+
session_id=job.session_id,
|
233 |
+
error=error_message
|
234 |
+
)
|
235 |
+
await publish_error(
|
236 |
+
session_id=job.session_id,
|
237 |
+
error_type="tts_error",
|
238 |
+
error_message=error_message
|
239 |
+
)
|
240 |
+
|
241 |
+
async def _stream_audio_chunks(self, tts_session: TTSSession, job: TTSJob):
|
242 |
+
"""Stream audio data as chunks"""
|
243 |
+
if not job.audio_data:
|
244 |
+
return
|
245 |
+
|
246 |
+
# Convert to base64
|
247 |
+
audio_base64 = base64.b64encode(job.audio_data).decode('utf-8')
|
248 |
+
total_length = len(audio_base64)
|
249 |
+
total_chunks = (total_length + self.chunk_size - 1) // self.chunk_size
|
250 |
+
|
251 |
+
log_debug(
|
252 |
+
f"📤 Streaming TTS audio",
|
253 |
+
session_id=job.session_id,
|
254 |
+
total_size=len(job.audio_data),
|
255 |
+
base64_size=total_length,
|
256 |
+
chunks=total_chunks
|
257 |
+
)
|
258 |
+
|
259 |
+
# Stream chunks
|
260 |
+
for i in range(0, total_length, self.chunk_size):
|
261 |
+
chunk = audio_base64[i:i + self.chunk_size]
|
262 |
+
chunk_index = i // self.chunk_size
|
263 |
+
is_last = chunk_index == total_chunks - 1
|
264 |
+
|
265 |
+
await self.event_bus.publish(Event(
|
266 |
+
type=EventType.TTS_CHUNK_READY,
|
267 |
+
session_id=job.session_id,
|
268 |
+
data={
|
269 |
+
"audio_data": chunk,
|
270 |
+
"chunk_index": chunk_index,
|
271 |
+
"total_chunks": total_chunks,
|
272 |
+
"is_last": is_last,
|
273 |
+
"mime_type": "audio/mpeg",
|
274 |
+
"is_welcome": job.is_welcome
|
275 |
+
},
|
276 |
+
priority=8 # Higher priority for audio chunks
|
277 |
+
))
|
278 |
+
|
279 |
+
job.chunks_sent += 1
|
280 |
+
|
281 |
+
# Small delay between chunks to prevent overwhelming
|
282 |
+
await asyncio.sleep(0.01)
|
283 |
+
|
284 |
+
# Notify completion
|
285 |
+
await self.event_bus.publish(Event(
|
286 |
+
type=EventType.TTS_COMPLETED,
|
287 |
+
session_id=job.session_id,
|
288 |
+
data={
|
289 |
+
"job_id": job.job_id,
|
290 |
+
"total_chunks": total_chunks,
|
291 |
+
"is_welcome": job.is_welcome
|
292 |
+
}
|
293 |
+
))
|
294 |
+
|
295 |
+
log_info(
|
296 |
+
f"✅ TTS streaming complete",
|
297 |
+
session_id=job.session_id,
|
298 |
+
chunks_sent=job.chunks_sent
|
299 |
+
)
|
300 |
+
|
301 |
+
async def _handle_no_tts(self, session_id: str, text: str, is_welcome: bool):
|
302 |
+
"""Handle case when TTS is not available"""
|
303 |
+
log_warning(f"⚠️ No TTS available, skipping audio generation", session_id=session_id)
|
304 |
+
|
305 |
+
# Just notify completion without audio
|
306 |
+
await self.event_bus.publish(Event(
|
307 |
+
type=EventType.TTS_COMPLETED,
|
308 |
+
session_id=session_id,
|
309 |
+
data={
|
310 |
+
"no_audio": True,
|
311 |
+
"text": text,
|
312 |
+
"is_welcome": is_welcome
|
313 |
+
}
|
314 |
+
))
|
315 |
+
|
316 |
+
async def _handle_session_ended(self, event: Event):
|
317 |
+
"""Clean up TTS resources when session ends"""
|
318 |
+
session_id = event.session_id
|
319 |
+
await self._cleanup_session(session_id)
|
320 |
+
|
321 |
+
async def _cleanup_session(self, session_id: str):
|
322 |
+
"""Clean up TTS session"""
|
323 |
+
tts_session = self.tts_sessions.pop(session_id, None)
|
324 |
+
if not tts_session:
|
325 |
+
return
|
326 |
+
|
327 |
+
try:
|
328 |
+
# Cancel any active jobs
|
329 |
+
for job in tts_session.active_jobs.values():
|
330 |
+
if not job.completed_at:
|
331 |
+
job.fail("Session ended")
|
332 |
+
|
333 |
+
# Release resource
|
334 |
+
resource_id = f"tts_{session_id}"
|
335 |
+
await self.resource_manager.release(resource_id, delay_seconds=120)
|
336 |
+
|
337 |
+
log_info(
|
338 |
+
f"🧹 TTS session cleaned up",
|
339 |
+
session_id=session_id,
|
340 |
+
total_jobs=tts_session.total_jobs,
|
341 |
+
total_chars=tts_session.total_chars
|
342 |
+
)
|
343 |
+
|
344 |
+
except Exception as e:
|
345 |
+
log_error(
|
346 |
+
f"❌ Error cleaning up TTS session",
|
347 |
+
session_id=session_id,
|
348 |
+
error=str(e)
|
349 |
+
)
|
350 |
+
|
351 |
+
async def _cleanup_tts_instance(self, tts_instance: TTSInterface):
|
352 |
+
"""Cleanup callback for TTS instance"""
|
353 |
+
try:
|
354 |
+
# TTS instances typically don't need special cleanup
|
355 |
+
log_debug("🧹 TTS instance cleaned up")
|
356 |
+
|
357 |
+
except Exception as e:
|
358 |
+
log_error(f"❌ Error cleaning up TTS instance", error=str(e))
|
359 |
+
|
360 |
+
def get_stats(self) -> Dict[str, Any]:
|
361 |
+
"""Get TTS manager statistics"""
|
362 |
+
session_stats = {}
|
363 |
+
for session_id, tts_session in self.tts_sessions.items():
|
364 |
+
session_stats[session_id] = {
|
365 |
+
"active_jobs": len(tts_session.active_jobs),
|
366 |
+
"completed_jobs": len(tts_session.completed_jobs),
|
367 |
+
"total_jobs": tts_session.total_jobs,
|
368 |
+
"total_chars": tts_session.total_chars,
|
369 |
+
"uptime_seconds": (datetime.utcnow() - tts_session.created_at).total_seconds(),
|
370 |
+
"last_activity": tts_session.last_activity.isoformat()
|
371 |
+
}
|
372 |
+
|
373 |
+
return {
|
374 |
+
"active_sessions": len(self.tts_sessions),
|
375 |
+
"total_active_jobs": sum(len(s.active_jobs) for s in self.tts_sessions.values()),
|
376 |
+
"sessions": session_stats
|
377 |
+
}
|
websocket_manager.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
WebSocket Manager for Flare
|
3 |
+
===========================
|
4 |
+
Manages WebSocket connections and message routing
|
5 |
+
"""
|
6 |
+
import asyncio
|
7 |
+
from typing import Dict, Optional, Set
|
8 |
+
from fastapi import WebSocket, WebSocketDisconnect
|
9 |
+
import json
|
10 |
+
from datetime import datetime
|
11 |
+
import traceback
|
12 |
+
|
13 |
+
from event_bus import EventBus, Event, EventType
|
14 |
+
from logger import log_info, log_error, log_debug, log_warning
|
15 |
+
|
16 |
+
|
17 |
+
class WebSocketConnection:
|
18 |
+
"""Wrapper for WebSocket connection with metadata"""
|
19 |
+
|
20 |
+
def __init__(self, websocket: WebSocket, session_id: str):
|
21 |
+
self.websocket = websocket
|
22 |
+
self.session_id = session_id
|
23 |
+
self.connected_at = datetime.utcnow()
|
24 |
+
self.last_activity = datetime.utcnow()
|
25 |
+
self.is_active = True
|
26 |
+
|
27 |
+
async def send_json(self, data: dict):
|
28 |
+
"""Send JSON data to client"""
|
29 |
+
try:
|
30 |
+
if self.is_active:
|
31 |
+
await self.websocket.send_json(data)
|
32 |
+
self.last_activity = datetime.utcnow()
|
33 |
+
except Exception as e:
|
34 |
+
log_error(
|
35 |
+
f"❌ Failed to send message",
|
36 |
+
session_id=self.session_id,
|
37 |
+
error=str(e)
|
38 |
+
)
|
39 |
+
self.is_active = False
|
40 |
+
raise
|
41 |
+
|
42 |
+
async def receive_json(self) -> dict:
|
43 |
+
"""Receive JSON data from client"""
|
44 |
+
try:
|
45 |
+
data = await self.websocket.receive_json()
|
46 |
+
self.last_activity = datetime.utcnow()
|
47 |
+
return data
|
48 |
+
except WebSocketDisconnect:
|
49 |
+
self.is_active = False
|
50 |
+
raise
|
51 |
+
except Exception as e:
|
52 |
+
log_error(
|
53 |
+
f"❌ Failed to receive message",
|
54 |
+
session_id=self.session_id,
|
55 |
+
error=str(e)
|
56 |
+
)
|
57 |
+
self.is_active = False
|
58 |
+
raise
|
59 |
+
|
60 |
+
async def close(self):
|
61 |
+
"""Close the connection"""
|
62 |
+
try:
|
63 |
+
self.is_active = False
|
64 |
+
await self.websocket.close()
|
65 |
+
except:
|
66 |
+
pass
|
67 |
+
|
68 |
+
|
69 |
+
class WebSocketManager:
|
70 |
+
"""Manages WebSocket connections and routing"""
|
71 |
+
|
72 |
+
def __init__(self, event_bus: EventBus):
|
73 |
+
self.event_bus = event_bus
|
74 |
+
self.connections: Dict[str, WebSocketConnection] = {}
|
75 |
+
self.message_queues: Dict[str, asyncio.Queue] = {}
|
76 |
+
self._setup_event_handlers()
|
77 |
+
|
78 |
+
def _setup_event_handlers(self):
|
79 |
+
"""Subscribe to events that need to be sent to clients"""
|
80 |
+
# State events
|
81 |
+
self.event_bus.subscribe(EventType.STATE_TRANSITION, self._handle_state_transition)
|
82 |
+
|
83 |
+
# STT events
|
84 |
+
self.event_bus.subscribe(EventType.STT_READY, self._handle_stt_ready)
|
85 |
+
self.event_bus.subscribe(EventType.STT_RESULT, self._handle_stt_result)
|
86 |
+
|
87 |
+
# TTS events
|
88 |
+
self.event_bus.subscribe(EventType.TTS_CHUNK_READY, self._handle_tts_chunk)
|
89 |
+
self.event_bus.subscribe(EventType.TTS_COMPLETED, self._handle_tts_completed)
|
90 |
+
|
91 |
+
# LLM events
|
92 |
+
self.event_bus.subscribe(EventType.LLM_RESPONSE_READY, self._handle_llm_response)
|
93 |
+
|
94 |
+
# Error events
|
95 |
+
self.event_bus.subscribe(EventType.RECOVERABLE_ERROR, self._handle_error)
|
96 |
+
self.event_bus.subscribe(EventType.CRITICAL_ERROR, self._handle_error)
|
97 |
+
|
98 |
+
async def connect(self, websocket: WebSocket, session_id: str):
|
99 |
+
"""Accept new WebSocket connection"""
|
100 |
+
await websocket.accept()
|
101 |
+
|
102 |
+
# Check for existing connection
|
103 |
+
if session_id in self.connections:
|
104 |
+
log_warning(
|
105 |
+
f"⚠️ Existing connection for session, closing old one",
|
106 |
+
session_id=session_id
|
107 |
+
)
|
108 |
+
await self.disconnect(session_id)
|
109 |
+
|
110 |
+
# Create connection wrapper
|
111 |
+
connection = WebSocketConnection(websocket, session_id)
|
112 |
+
self.connections[session_id] = connection
|
113 |
+
|
114 |
+
# Create message queue
|
115 |
+
self.message_queues[session_id] = asyncio.Queue()
|
116 |
+
|
117 |
+
log_info(
|
118 |
+
f"✅ WebSocket connected",
|
119 |
+
session_id=session_id,
|
120 |
+
total_connections=len(self.connections)
|
121 |
+
)
|
122 |
+
|
123 |
+
# Publish connection event
|
124 |
+
await self.event_bus.publish(Event(
|
125 |
+
type=EventType.WEBSOCKET_CONNECTED,
|
126 |
+
session_id=session_id,
|
127 |
+
data={}
|
128 |
+
))
|
129 |
+
|
130 |
+
async def disconnect(self, session_id: str):
|
131 |
+
"""Disconnect WebSocket connection"""
|
132 |
+
connection = self.connections.get(session_id)
|
133 |
+
if connection:
|
134 |
+
await connection.close()
|
135 |
+
del self.connections[session_id]
|
136 |
+
|
137 |
+
# Remove message queue
|
138 |
+
if session_id in self.message_queues:
|
139 |
+
del self.message_queues[session_id]
|
140 |
+
|
141 |
+
log_info(
|
142 |
+
f"🔌 WebSocket disconnected",
|
143 |
+
session_id=session_id,
|
144 |
+
total_connections=len(self.connections)
|
145 |
+
)
|
146 |
+
|
147 |
+
# Publish disconnection event
|
148 |
+
await self.event_bus.publish(Event(
|
149 |
+
type=EventType.WEBSOCKET_DISCONNECTED,
|
150 |
+
session_id=session_id,
|
151 |
+
data={}
|
152 |
+
))
|
153 |
+
|
154 |
+
async def handle_connection(self, websocket: WebSocket, session_id: str):
|
155 |
+
"""Handle WebSocket connection lifecycle"""
|
156 |
+
try:
|
157 |
+
# Connect
|
158 |
+
await self.connect(websocket, session_id)
|
159 |
+
|
160 |
+
# Create tasks for bidirectional communication
|
161 |
+
receive_task = asyncio.create_task(self._receive_messages(session_id))
|
162 |
+
send_task = asyncio.create_task(self._send_messages(session_id))
|
163 |
+
|
164 |
+
# Wait for either task to complete
|
165 |
+
done, pending = await asyncio.wait(
|
166 |
+
[receive_task, send_task],
|
167 |
+
return_when=asyncio.FIRST_COMPLETED
|
168 |
+
)
|
169 |
+
|
170 |
+
# Cancel pending tasks
|
171 |
+
for task in pending:
|
172 |
+
task.cancel()
|
173 |
+
try:
|
174 |
+
await task
|
175 |
+
except asyncio.CancelledError:
|
176 |
+
pass
|
177 |
+
|
178 |
+
except WebSocketDisconnect:
|
179 |
+
log_info(f"WebSocket disconnected normally", session_id=session_id)
|
180 |
+
except Exception as e:
|
181 |
+
log_error(
|
182 |
+
f"❌ WebSocket error",
|
183 |
+
session_id=session_id,
|
184 |
+
error=str(e),
|
185 |
+
traceback=traceback.format_exc()
|
186 |
+
)
|
187 |
+
|
188 |
+
# Publish error event
|
189 |
+
await self.event_bus.publish(Event(
|
190 |
+
type=EventType.WEBSOCKET_ERROR,
|
191 |
+
session_id=session_id,
|
192 |
+
data={
|
193 |
+
"error_type": "websocket_error",
|
194 |
+
"message": str(e)
|
195 |
+
}
|
196 |
+
))
|
197 |
+
finally:
|
198 |
+
# Ensure disconnection
|
199 |
+
await self.disconnect(session_id)
|
200 |
+
|
201 |
+
async def _receive_messages(self, session_id: str):
|
202 |
+
"""Receive messages from client"""
|
203 |
+
connection = self.connections.get(session_id)
|
204 |
+
if not connection:
|
205 |
+
return
|
206 |
+
|
207 |
+
try:
|
208 |
+
while connection.is_active:
|
209 |
+
# Receive message
|
210 |
+
message = await connection.receive_json()
|
211 |
+
|
212 |
+
log_debug(
|
213 |
+
f"📨 Received message",
|
214 |
+
session_id=session_id,
|
215 |
+
message_type=message.get("type")
|
216 |
+
)
|
217 |
+
|
218 |
+
# Route message based on type
|
219 |
+
await self._route_client_message(session_id, message)
|
220 |
+
|
221 |
+
except WebSocketDisconnect:
|
222 |
+
log_info(f"Client disconnected", session_id=session_id)
|
223 |
+
except Exception as e:
|
224 |
+
log_error(
|
225 |
+
f"❌ Error receiving messages",
|
226 |
+
session_id=session_id,
|
227 |
+
error=str(e)
|
228 |
+
)
|
229 |
+
raise
|
230 |
+
|
231 |
+
async def _send_messages(self, session_id: str):
|
232 |
+
"""Send queued messages to client"""
|
233 |
+
connection = self.connections.get(session_id)
|
234 |
+
queue = self.message_queues.get(session_id)
|
235 |
+
|
236 |
+
if not connection or not queue:
|
237 |
+
return
|
238 |
+
|
239 |
+
try:
|
240 |
+
while connection.is_active:
|
241 |
+
# Wait for message with timeout
|
242 |
+
try:
|
243 |
+
message = await asyncio.wait_for(queue.get(), timeout=30.0)
|
244 |
+
|
245 |
+
# Send to client
|
246 |
+
await connection.send_json(message)
|
247 |
+
|
248 |
+
log_debug(
|
249 |
+
f"📤 Sent message",
|
250 |
+
session_id=session_id,
|
251 |
+
message_type=message.get("type")
|
252 |
+
)
|
253 |
+
|
254 |
+
except asyncio.TimeoutError:
|
255 |
+
# Send ping to keep connection alive
|
256 |
+
await connection.send_json({"type": "ping"})
|
257 |
+
|
258 |
+
except Exception as e:
|
259 |
+
log_error(
|
260 |
+
f"❌ Error sending messages",
|
261 |
+
session_id=session_id,
|
262 |
+
error=str(e)
|
263 |
+
)
|
264 |
+
raise
|
265 |
+
|
266 |
+
async def _route_client_message(self, session_id: str, message: dict):
|
267 |
+
"""Route message from client to appropriate handler"""
|
268 |
+
message_type = message.get("type")
|
269 |
+
|
270 |
+
if message_type == "audio_chunk":
|
271 |
+
# Audio data from client
|
272 |
+
await self.event_bus.publish(Event(
|
273 |
+
type=EventType.AUDIO_CHUNK_RECEIVED,
|
274 |
+
session_id=session_id,
|
275 |
+
data={
|
276 |
+
"audio_data": message.get("data"),
|
277 |
+
"timestamp": message.get("timestamp")
|
278 |
+
}
|
279 |
+
))
|
280 |
+
|
281 |
+
elif message_type == "control":
|
282 |
+
# Control messages
|
283 |
+
action = message.get("action")
|
284 |
+
|
285 |
+
if action == "start_session":
|
286 |
+
await self.event_bus.publish(Event(
|
287 |
+
type=EventType.SESSION_STARTED,
|
288 |
+
session_id=session_id,
|
289 |
+
data=message.get("config", {})
|
290 |
+
))
|
291 |
+
|
292 |
+
elif action == "end_session":
|
293 |
+
await self.event_bus.publish(Event(
|
294 |
+
type=EventType.SESSION_ENDED,
|
295 |
+
session_id=session_id,
|
296 |
+
data={"reason": "user_request"}
|
297 |
+
))
|
298 |
+
|
299 |
+
elif action == "audio_ended":
|
300 |
+
await self.event_bus.publish(Event(
|
301 |
+
type=EventType.AUDIO_PLAYBACK_COMPLETED,
|
302 |
+
session_id=session_id,
|
303 |
+
data={}
|
304 |
+
))
|
305 |
+
|
306 |
+
elif message_type == "ping":
|
307 |
+
# Respond to ping
|
308 |
+
await self.send_message(session_id, {"type": "pong"})
|
309 |
+
|
310 |
+
else:
|
311 |
+
log_warning(
|
312 |
+
f"⚠️ Unknown message type",
|
313 |
+
session_id=session_id,
|
314 |
+
message_type=message_type
|
315 |
+
)
|
316 |
+
|
317 |
+
async def send_message(self, session_id: str, message: dict):
|
318 |
+
"""Queue message for sending to client"""
|
319 |
+
queue = self.message_queues.get(session_id)
|
320 |
+
if queue:
|
321 |
+
await queue.put(message)
|
322 |
+
else:
|
323 |
+
log_warning(
|
324 |
+
f"⚠️ No queue for session",
|
325 |
+
session_id=session_id
|
326 |
+
)
|
327 |
+
|
328 |
+
async def broadcast_to_session(self, session_id: str, message: dict):
|
329 |
+
"""Send message immediately (bypass queue)"""
|
330 |
+
connection = self.connections.get(session_id)
|
331 |
+
if connection and connection.is_active:
|
332 |
+
await connection.send_json(message)
|
333 |
+
|
334 |
+
# Event handlers for sending messages to clients
|
335 |
+
|
336 |
+
async def _handle_state_transition(self, event: Event):
|
337 |
+
"""Send state transition to client"""
|
338 |
+
await self.send_message(event.session_id, {
|
339 |
+
"type": "state_change",
|
340 |
+
"from": event.data.get("from_state"),
|
341 |
+
"to": event.data.get("to_state")
|
342 |
+
})
|
343 |
+
|
344 |
+
async def _handle_stt_ready(self, event: Event):
|
345 |
+
"""Send STT ready signal to client"""
|
346 |
+
await self.send_message(event.session_id, {
|
347 |
+
"type": "stt_ready",
|
348 |
+
"message": "STT is ready to receive audio"
|
349 |
+
})
|
350 |
+
|
351 |
+
async def _handle_stt_result(self, event: Event):
|
352 |
+
"""Send STT result to client"""
|
353 |
+
await self.send_message(event.session_id, {
|
354 |
+
"type": "transcription",
|
355 |
+
"text": event.data.get("text", ""),
|
356 |
+
"is_final": event.data.get("is_final", False),
|
357 |
+
"confidence": event.data.get("confidence", 0.0)
|
358 |
+
})
|
359 |
+
|
360 |
+
async def _handle_tts_chunk(self, event: Event):
|
361 |
+
"""Send TTS audio chunk to client"""
|
362 |
+
await self.send_message(event.session_id, {
|
363 |
+
"type": "tts_audio",
|
364 |
+
"data": event.data.get("audio_data"),
|
365 |
+
"chunk_index": event.data.get("chunk_index"),
|
366 |
+
"total_chunks": event.data.get("total_chunks"),
|
367 |
+
"is_last": event.data.get("is_last", False),
|
368 |
+
"mime_type": event.data.get("mime_type", "audio/mpeg")
|
369 |
+
})
|
370 |
+
|
371 |
+
async def _handle_tts_completed(self, event: Event):
|
372 |
+
"""Notify client that TTS is complete"""
|
373 |
+
# Client knows from is_last flag in chunks
|
374 |
+
pass
|
375 |
+
|
376 |
+
async def _handle_llm_response(self, event: Event):
|
377 |
+
"""Send LLM response to client"""
|
378 |
+
await self.send_message(event.session_id, {
|
379 |
+
"type": "assistant_response",
|
380 |
+
"text": event.data.get("text", ""),
|
381 |
+
"is_welcome": event.data.get("is_welcome", False)
|
382 |
+
})
|
383 |
+
|
384 |
+
async def _handle_error(self, event: Event):
|
385 |
+
"""Send error to client"""
|
386 |
+
error_type = event.data.get("error_type", "unknown")
|
387 |
+
message = event.data.get("message", "An error occurred")
|
388 |
+
|
389 |
+
await self.send_message(event.session_id, {
|
390 |
+
"type": "error",
|
391 |
+
"error_type": error_type,
|
392 |
+
"message": message,
|
393 |
+
"details": event.data.get("details", {})
|
394 |
+
})
|
395 |
+
|
396 |
+
def get_connection_count(self) -> int:
|
397 |
+
"""Get number of active connections"""
|
398 |
+
return len(self.connections)
|
399 |
+
|
400 |
+
def get_session_connections(self) -> Set[str]:
|
401 |
+
"""Get all active session IDs"""
|
402 |
+
return set(self.connections.keys())
|
403 |
+
|
404 |
+
async def close_all_connections(self):
|
405 |
+
"""Close all active connections"""
|
406 |
+
session_ids = list(self.connections.keys())
|
407 |
+
for session_id in session_ids:
|
408 |
+
await self.disconnect(session_id)
|