Spaces:
Paused
Paused
| """ | |
| TTS Lifecycle Manager for Flare | |
| =============================== | |
| Manages TTS instances lifecycle per session | |
| """ | |
| import asyncio | |
| from typing import Dict, Optional, Any, List | |
| from datetime import datetime | |
| import traceback | |
| import base64 | |
| from event_bus import EventBus, Event, EventType, publish_error | |
| from resource_manager import ResourceManager, ResourceType | |
| from tts.tts_factory import TTSFactory | |
| from tts.tts_interface import TTSInterface | |
| from tts.tts_preprocessor import TTSPreprocessor | |
| from utils.logger import log_info, log_error, log_debug, log_warning | |
| class TTSJob: | |
| """TTS synthesis job""" | |
| def __init__(self, job_id: str, session_id: str, text: str, is_welcome: bool = False): | |
| self.job_id = job_id | |
| self.session_id = session_id | |
| self.text = text | |
| self.is_welcome = is_welcome | |
| self.created_at = datetime.utcnow() | |
| self.completed_at: Optional[datetime] = None | |
| self.audio_data: Optional[bytes] = None | |
| self.error: Optional[str] = None | |
| self.chunks_sent = 0 | |
| def complete(self, audio_data: bytes): | |
| """Mark job as completed""" | |
| self.audio_data = audio_data | |
| self.completed_at = datetime.utcnow() | |
| def fail(self, error: str): | |
| """Mark job as failed""" | |
| self.error = error | |
| self.completed_at = datetime.utcnow() | |
| class TTSSession: | |
| """TTS session wrapper""" | |
| def __init__(self, session_id: str, tts_instance: TTSInterface): | |
| self.session_id = session_id | |
| self.tts_instance = tts_instance | |
| self.preprocessor: Optional[TTSPreprocessor] = None | |
| self.active_jobs: Dict[str, TTSJob] = {} | |
| self.completed_jobs: List[TTSJob] = [] | |
| self.created_at = datetime.utcnow() | |
| self.last_activity = datetime.utcnow() | |
| self.total_jobs = 0 | |
| self.total_chars = 0 | |
| def update_activity(self): | |
| """Update last activity timestamp""" | |
| self.last_activity = datetime.utcnow() | |
| class TTSLifecycleManager: | |
| """Manages TTS instances lifecycle""" | |
| def __init__(self, event_bus: EventBus, resource_manager: ResourceManager): | |
| self.event_bus = event_bus | |
| self.resource_manager = resource_manager | |
| self.tts_sessions: Dict[str, TTSSession] = {} | |
| self.chunk_size = 16384 # 16KB chunks for base64 | |
| self._setup_event_handlers() | |
| self._setup_resource_pool() | |
| def _setup_event_handlers(self): | |
| """Subscribe to TTS-related events""" | |
| self.event_bus.subscribe(EventType.TTS_STARTED, self._handle_tts_start) | |
| self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended) | |
| def _setup_resource_pool(self): | |
| """Setup TTS instance pool""" | |
| self.resource_manager.register_pool( | |
| resource_type=ResourceType.TTS_INSTANCE, | |
| factory=self._create_tts_instance, | |
| max_idle=3, | |
| max_age_seconds=600 # 10 minutes | |
| ) | |
| async def _create_tts_instance(self) -> Optional[TTSInterface]: | |
| """Factory for creating TTS instances""" | |
| try: | |
| tts_instance = TTSFactory.create_provider() | |
| if not tts_instance: | |
| log_warning("β οΈ No TTS provider configured") | |
| return None | |
| log_debug("π Created new TTS instance") | |
| return tts_instance | |
| except Exception as e: | |
| log_error(f"β Failed to create TTS instance", error=str(e)) | |
| return None | |
| async def _handle_tts_start(self, event: Event): | |
| """Handle TTS synthesis request""" | |
| session_id = event.session_id | |
| text = event.data.get("text", "") | |
| is_welcome = event.data.get("is_welcome", False) | |
| if not text: | |
| log_warning(f"β οΈ Empty text for TTS", session_id=session_id) | |
| return | |
| try: | |
| log_info( | |
| f"π Starting TTS", | |
| session_id=session_id, | |
| text_length=len(text), | |
| is_welcome=is_welcome | |
| ) | |
| # Get or create session | |
| if session_id not in self.tts_sessions: | |
| # Acquire TTS instance from pool | |
| resource_id = f"tts_{session_id}" | |
| tts_instance = await self.resource_manager.acquire( | |
| resource_id=resource_id, | |
| session_id=session_id, | |
| resource_type=ResourceType.TTS_INSTANCE, | |
| cleanup_callback=self._cleanup_tts_instance | |
| ) | |
| if not tts_instance: | |
| # No TTS available | |
| await self._handle_no_tts(session_id, text, is_welcome) | |
| return | |
| # Create session | |
| tts_session = TTSSession(session_id, tts_instance) | |
| # Get locale from event data or default | |
| locale = event.data.get("locale", "tr") | |
| tts_session.preprocessor = TTSPreprocessor(language=locale) | |
| self.tts_sessions[session_id] = tts_session | |
| else: | |
| tts_session = self.tts_sessions[session_id] | |
| # Create job | |
| job_id = f"{session_id}_{tts_session.total_jobs}" | |
| job = TTSJob(job_id, session_id, text, is_welcome) | |
| tts_session.active_jobs[job_id] = job | |
| tts_session.total_jobs += 1 | |
| tts_session.total_chars += len(text) | |
| tts_session.update_activity() | |
| # Process TTS | |
| await self._process_tts_job(tts_session, job) | |
| except Exception as e: | |
| log_error( | |
| f"β Failed to start TTS", | |
| session_id=session_id, | |
| error=str(e), | |
| traceback=traceback.format_exc() | |
| ) | |
| # Publish error event | |
| await publish_error( | |
| session_id=session_id, | |
| error_type="tts_error", | |
| error_message=f"Failed to synthesize speech: {str(e)}" | |
| ) | |
| async def _process_tts_job(self, tts_session: TTSSession, job: TTSJob): | |
| """Process a TTS job""" | |
| try: | |
| # Preprocess text | |
| processed_text = tts_session.preprocessor.preprocess( | |
| job.text, | |
| tts_session.tts_instance.get_preprocessing_flags() | |
| ) | |
| log_debug( | |
| f"π TTS preprocessed", | |
| session_id=job.session_id, | |
| original_length=len(job.text), | |
| processed_length=len(processed_text) | |
| ) | |
| # Synthesize audio | |
| audio_data = await tts_session.tts_instance.synthesize(processed_text) | |
| if not audio_data: | |
| raise ValueError("TTS returned empty audio data") | |
| job.complete(audio_data) | |
| log_info( | |
| f"β TTS synthesis complete", | |
| session_id=job.session_id, | |
| audio_size=len(audio_data), | |
| duration_ms=(datetime.utcnow() - job.created_at).total_seconds() * 1000 | |
| ) | |
| # Stream audio chunks | |
| await self._stream_audio_chunks(tts_session, job) | |
| # Move to completed | |
| tts_session.active_jobs.pop(job.job_id, None) | |
| tts_session.completed_jobs.append(job) | |
| # Keep only last 10 completed jobs | |
| if len(tts_session.completed_jobs) > 10: | |
| tts_session.completed_jobs.pop(0) | |
| except Exception as e: | |
| job.fail(str(e)) | |
| # Handle specific TTS errors | |
| error_message = str(e) | |
| if "quota" in error_message.lower() or "limit" in error_message.lower(): | |
| log_error(f"β TTS quota exceeded", session_id=job.session_id) | |
| await publish_error( | |
| session_id=job.session_id, | |
| error_type="tts_quota_exceeded", | |
| error_message="TTS service quota exceeded" | |
| ) | |
| else: | |
| log_error( | |
| f"β TTS synthesis failed", | |
| session_id=job.session_id, | |
| error=error_message | |
| ) | |
| await publish_error( | |
| session_id=job.session_id, | |
| error_type="tts_error", | |
| error_message=error_message | |
| ) | |
| async def _stream_audio_chunks(self, tts_session: TTSSession, job: TTSJob): | |
| """Stream audio data as chunks""" | |
| if not job.audio_data: | |
| return | |
| # Convert to base64 | |
| audio_base64 = base64.b64encode(job.audio_data).decode('utf-8') | |
| total_length = len(audio_base64) | |
| total_chunks = (total_length + self.chunk_size - 1) // self.chunk_size | |
| log_debug( | |
| f"π€ Streaming TTS audio", | |
| session_id=job.session_id, | |
| total_size=len(job.audio_data), | |
| base64_size=total_length, | |
| chunks=total_chunks | |
| ) | |
| # Stream chunks | |
| for i in range(0, total_length, self.chunk_size): | |
| chunk = audio_base64[i:i + self.chunk_size] | |
| chunk_index = i // self.chunk_size | |
| is_last = chunk_index == total_chunks - 1 | |
| await self.event_bus.publish(Event( | |
| type=EventType.TTS_CHUNK_READY, | |
| session_id=job.session_id, | |
| data={ | |
| "audio_data": chunk, | |
| "chunk_index": chunk_index, | |
| "total_chunks": total_chunks, | |
| "is_last": is_last, | |
| "mime_type": "audio/mpeg", | |
| "is_welcome": job.is_welcome | |
| }, | |
| priority=8 # Higher priority for audio chunks | |
| )) | |
| job.chunks_sent += 1 | |
| # Small delay between chunks to prevent overwhelming | |
| await asyncio.sleep(0.01) | |
| # Notify completion | |
| await self.event_bus.publish(Event( | |
| type=EventType.TTS_COMPLETED, | |
| session_id=job.session_id, | |
| data={ | |
| "job_id": job.job_id, | |
| "total_chunks": total_chunks, | |
| "is_welcome": job.is_welcome | |
| } | |
| )) | |
| log_info( | |
| f"β TTS streaming complete", | |
| session_id=job.session_id, | |
| chunks_sent=job.chunks_sent | |
| ) | |
| async def _handle_no_tts(self, session_id: str, text: str, is_welcome: bool): | |
| """Handle case when TTS is not available""" | |
| log_warning(f"β οΈ No TTS available, skipping audio generation", session_id=session_id) | |
| # Just notify completion without audio | |
| await self.event_bus.publish(Event( | |
| type=EventType.TTS_COMPLETED, | |
| session_id=session_id, | |
| data={ | |
| "no_audio": True, | |
| "text": text, | |
| "is_welcome": is_welcome | |
| } | |
| )) | |
| async def _handle_session_ended(self, event: Event): | |
| """Clean up TTS resources when session ends""" | |
| session_id = event.session_id | |
| await self._cleanup_session(session_id) | |
| async def _cleanup_session(self, session_id: str): | |
| """Clean up TTS session""" | |
| tts_session = self.tts_sessions.pop(session_id, None) | |
| if not tts_session: | |
| return | |
| try: | |
| # Cancel any active jobs | |
| for job in tts_session.active_jobs.values(): | |
| if not job.completed_at: | |
| job.fail("Session ended") | |
| # Release resource | |
| resource_id = f"tts_{session_id}" | |
| await self.resource_manager.release(resource_id, delay_seconds=120) | |
| log_info( | |
| f"π§Ή TTS session cleaned up", | |
| session_id=session_id, | |
| total_jobs=tts_session.total_jobs, | |
| total_chars=tts_session.total_chars | |
| ) | |
| except Exception as e: | |
| log_error( | |
| f"β Error cleaning up TTS session", | |
| session_id=session_id, | |
| error=str(e) | |
| ) | |
| async def _cleanup_tts_instance(self, tts_instance: TTSInterface): | |
| """Cleanup callback for TTS instance""" | |
| try: | |
| # TTS instances typically don't need special cleanup | |
| log_debug("π§Ή TTS instance cleaned up") | |
| except Exception as e: | |
| log_error(f"β Error cleaning up TTS instance", error=str(e)) | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get TTS manager statistics""" | |
| session_stats = {} | |
| for session_id, tts_session in self.tts_sessions.items(): | |
| session_stats[session_id] = { | |
| "active_jobs": len(tts_session.active_jobs), | |
| "completed_jobs": len(tts_session.completed_jobs), | |
| "total_jobs": tts_session.total_jobs, | |
| "total_chars": tts_session.total_chars, | |
| "uptime_seconds": (datetime.utcnow() - tts_session.created_at).total_seconds(), | |
| "last_activity": tts_session.last_activity.isoformat() | |
| } | |
| return { | |
| "active_sessions": len(self.tts_sessions), | |
| "total_active_jobs": sum(len(s.active_jobs) for s in self.tts_sessions.values()), | |
| "sessions": session_stats | |
| } |