flare / stt /stt_lifecycle_manager.py
ciyidogan's picture
Update stt/stt_lifecycle_manager.py
913182c verified
raw
history blame
8.77 kB
"""
STT Lifecycle Manager for Flare - Batch Mode
===============================
Manages STT instances and audio collection
"""
import asyncio
from typing import Dict, Optional, Any
from datetime import datetime
import traceback
import base64
from chat_session.event_bus import EventBus, Event, EventType, publish_error
from chat_session.resource_manager import ResourceManager, ResourceType
from stt.stt_factory import STTFactory
from stt.stt_interface import STTInterface, STTConfig, TranscriptionResult
from stt.voice_activity_detector import VoiceActivityDetector
from utils.logger import log_info, log_error, log_debug, log_warning
class STTSession:
"""STT session with audio collection"""
def __init__(self, session_id: str, stt_instance: STTInterface):
self.session_id = session_id
self.stt_instance = stt_instance
self.is_active = False
self.config: Optional[STTConfig] = None
self.created_at = datetime.utcnow()
# Audio collection
self.audio_buffer = []
self.vad = VoiceActivityDetector()
# Stats
self.total_chunks = 0
self.total_bytes = 0
def reset(self):
"""Reset session for new utterance"""
self.audio_buffer = []
self.vad.reset()
self.total_chunks = 0
self.total_bytes = 0
class STTLifecycleManager:
"""Manages STT instances lifecycle"""
def __init__(self, event_bus: EventBus, resource_manager: ResourceManager):
self.event_bus = event_bus
self.resource_manager = resource_manager
self.stt_sessions: Dict[str, STTSession] = {}
self._setup_event_handlers()
self._setup_resource_pool()
def _setup_event_handlers(self):
"""Subscribe to STT-related events"""
self.event_bus.subscribe(EventType.STT_STARTED, self._handle_stt_start)
self.event_bus.subscribe(EventType.STT_STOPPED, self._handle_stt_stop)
self.event_bus.subscribe(EventType.AUDIO_CHUNK_RECEIVED, self._handle_audio_chunk)
self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended)
async def _handle_stt_start(self, event: Event):
"""Handle STT start request"""
session_id = event.session_id
config_data = event.data
try:
log_info(f"🎀 Starting STT", session_id=session_id)
# Get or create session
if session_id not in self.stt_sessions:
# Acquire STT instance from pool
resource_id = f"stt_{session_id}"
stt_instance = await self.resource_manager.acquire(
resource_id=resource_id,
session_id=session_id,
resource_type=ResourceType.STT_INSTANCE,
cleanup_callback=self._cleanup_stt_instance
)
# Create session
stt_session = STTSession(session_id, stt_instance)
self.stt_sessions[session_id] = stt_session
else:
stt_session = self.stt_sessions[session_id]
stt_session.reset()
# Build STT config
locale = config_data.get("locale", "tr")
stt_config = STTConfig(
language=self._get_language_code(locale),
sample_rate=config_data.get("sample_rate", 16000),
encoding=config_data.get("encoding", "LINEAR16"),
enable_punctuation=config_data.get("enable_punctuation", True),
model=config_data.get("model", "latest_long"),
use_enhanced=config_data.get("use_enhanced", True),
)
stt_session.config = stt_config
stt_session.is_active = True
log_info(f"βœ… STT started in batch mode", session_id=session_id, language=stt_config.language)
# Notify STT is ready
await self.event_bus.publish(Event(
type=EventType.STT_READY,
session_id=session_id,
data={"language": stt_config.language}
))
except Exception as e:
log_error(
f"❌ Failed to start STT",
session_id=session_id,
error=str(e),
traceback=traceback.format_exc()
)
# Clean up on error
if session_id in self.stt_sessions:
await self._cleanup_session(session_id)
# Publish error event
await publish_error(
session_id=session_id,
error_type="stt_error",
error_message=f"Failed to start STT: {str(e)}"
)
async def _handle_audio_chunk(self, event: Event):
"""Process audio chunk through VAD and collect"""
session_id = event.session_id
stt_session = self.stt_sessions.get(session_id)
if not stt_session or not stt_session.is_active:
return
try:
# Decode audio data
audio_data = base64.b64decode(event.data.get("audio_data", ""))
# Add to buffer
stt_session.audio_buffer.append(audio_data)
stt_session.total_chunks += 1
stt_session.total_bytes += len(audio_data)
# Process through VAD
is_speech, silence_duration_ms = stt_session.vad.process_chunk(audio_data)
# Check if utterance ended (silence threshold reached)
if not is_speech and silence_duration_ms >= 2000: # 2 seconds of silence
log_info(f"πŸ’¬ Utterance ended after {silence_duration_ms}ms silence", session_id=session_id)
# Stop STT to trigger transcription
await self.event_bus.publish(Event(
type=EventType.STT_STOPPED,
session_id=session_id,
data={"reason": "silence_detected"}
))
# Log progress periodically
if stt_session.total_chunks % 100 == 0:
log_debug(
f"πŸ“Š STT progress",
session_id=session_id,
chunks=stt_session.total_chunks,
bytes=stt_session.total_bytes,
vad_stats=stt_session.vad.get_stats()
)
except Exception as e:
log_error(
f"❌ Error processing audio chunk",
session_id=session_id,
error=str(e)
)
async def _handle_stt_stop(self, event: Event):
"""Handle STT stop request and perform transcription"""
session_id = event.session_id
reason = event.data.get("reason", "unknown")
log_info(f"πŸ›‘ Stopping STT", session_id=session_id, reason=reason)
stt_session = self.stt_sessions.get(session_id)
if not stt_session:
log_warning(f"⚠️ No STT session found", session_id=session_id)
return
try:
if stt_session.is_active and stt_session.audio_buffer:
# Combine audio chunks
combined_audio = b''.join(stt_session.audio_buffer)
# Transcribe using batch mode
log_info(f"πŸ“ Transcribing {len(combined_audio)} bytes of audio", session_id=session_id)
result = await stt_session.stt_instance.transcribe(
audio_data=combined_audio,
config=stt_session.config
)
# Publish result if we got transcription
if result and result.text:
await self.event_bus.publish(Event(
type=EventType.STT_RESULT,
session_id=session_id,
data={
"text": result.text,
"is_final": True,
"confidence": result.confidence
}
))
else:
log_warning(f"⚠️ No transcription result", session_id=session_id)
# Mark as inactive and reset
stt_session.is_active = False
stt_session.reset()
# Send STT_STOPPED event
await self.event_bus.publish(Event(
type=EventType.STT_STOPPED,
session_id=session_id,
data={"reason": reason}
))
log_info(f"βœ… STT stopped", session_id=session_id)
except Exception as e:
log_error(
f"❌ Error stopping STT",
session_id=session_id,
error=str(e)
)