Spaces:
Paused
Paused
| """ | |
| Audio Buffer Manager for Flare | |
| ============================== | |
| Manages audio buffering, silence detection, and chunk processing | |
| """ | |
| import asyncio | |
| from typing import Dict, Optional, List, Tuple, Any | |
| from collections import deque | |
| from datetime import datetime | |
| import base64 | |
| import numpy as np | |
| from dataclasses import dataclass | |
| import traceback | |
| from event_bus import EventBus, Event, EventType | |
| from utils.logger import log_info, log_error, log_debug, log_warning | |
| class AudioChunk: | |
| """Audio chunk with metadata""" | |
| data: bytes | |
| timestamp: datetime | |
| chunk_index: int | |
| is_speech: bool = True | |
| energy_level: float = 0.0 | |
| class SilenceDetector: | |
| """Detect silence in audio stream""" | |
| def __init__(self, | |
| threshold_ms: int = 2000, | |
| energy_threshold: float = 0.01, | |
| sample_rate: int = 16000): | |
| self.threshold_ms = threshold_ms | |
| self.energy_threshold = energy_threshold | |
| self.sample_rate = sample_rate | |
| self.silence_start: Optional[datetime] = None | |
| def detect_silence(self, audio_chunk: bytes) -> Tuple[bool, int]: | |
| """ | |
| Detect if chunk is silence and return duration | |
| Returns: (is_silence, silence_duration_ms) | |
| """ | |
| try: | |
| # Handle empty or invalid chunk | |
| if not audio_chunk or len(audio_chunk) < 2: | |
| return True, 0 | |
| # Ensure even number of bytes for 16-bit audio | |
| if len(audio_chunk) % 2 != 0: | |
| audio_chunk = audio_chunk[:-1] | |
| # Convert to numpy array | |
| audio_data = np.frombuffer(audio_chunk, dtype=np.int16) | |
| if len(audio_data) == 0: | |
| return True, 0 | |
| # Calculate RMS energy | |
| rms = np.sqrt(np.mean(audio_data.astype(float) ** 2)) | |
| normalized_rms = rms / 32768.0 # Normalize for 16-bit audio | |
| is_silence = normalized_rms < self.energy_threshold | |
| # Track silence duration | |
| now = datetime.utcnow() | |
| if is_silence: | |
| if self.silence_start is None: | |
| self.silence_start = now | |
| duration_ms = int((now - self.silence_start).total_seconds() * 1000) | |
| else: | |
| self.silence_start = None | |
| duration_ms = 0 | |
| return is_silence, duration_ms | |
| except Exception as e: | |
| log_warning(f"Silence detection error: {e}") | |
| return False, 0 | |
| def reset(self): | |
| """Reset silence detection state""" | |
| self.silence_start = None | |
| class AudioBuffer: | |
| """Manage audio chunks for a session""" | |
| def __init__(self, | |
| session_id: str, | |
| max_chunks: int = 1000, | |
| chunk_size_bytes: int = 4096): | |
| self.session_id = session_id | |
| self.max_chunks = max_chunks | |
| self.chunk_size_bytes = chunk_size_bytes | |
| self.chunks: deque[AudioChunk] = deque(maxlen=max_chunks) | |
| self.chunk_counter = 0 | |
| self.total_bytes = 0 | |
| self.lock = asyncio.Lock() | |
| async def add_chunk(self, audio_data: bytes, timestamp: Optional[datetime] = None) -> AudioChunk: | |
| """Add audio chunk to buffer""" | |
| async with self.lock: | |
| if timestamp is None: | |
| timestamp = datetime.utcnow() | |
| chunk = AudioChunk( | |
| data=audio_data, | |
| timestamp=timestamp, | |
| chunk_index=self.chunk_counter | |
| ) | |
| self.chunks.append(chunk) | |
| self.chunk_counter += 1 | |
| self.total_bytes += len(audio_data) | |
| return chunk | |
| async def get_recent_audio(self, duration_ms: int = 5000) -> bytes: | |
| """Get recent audio data""" | |
| async with self.lock: | |
| cutoff_time = datetime.utcnow() | |
| audio_parts = [] | |
| # Iterate backwards through chunks | |
| for chunk in reversed(self.chunks): | |
| time_diff = (cutoff_time - chunk.timestamp).total_seconds() * 1000 | |
| if time_diff > duration_ms: | |
| break | |
| audio_parts.append(chunk.data) | |
| # Reverse to maintain chronological order | |
| audio_parts.reverse() | |
| return b''.join(audio_parts) | |
| async def clear(self): | |
| """Clear buffer""" | |
| async with self.lock: | |
| self.chunks.clear() | |
| self.chunk_counter = 0 | |
| self.total_bytes = 0 | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get buffer statistics""" | |
| return { | |
| "chunks": len(self.chunks), | |
| "total_bytes": self.total_bytes, | |
| "chunk_counter": self.chunk_counter, | |
| "oldest_chunk": self.chunks[0].timestamp if self.chunks else None, | |
| "newest_chunk": self.chunks[-1].timestamp if self.chunks else None | |
| } | |
| class AudioBufferManager: | |
| """Manage audio buffers for all sessions""" | |
| def __init__(self, event_bus: EventBus): | |
| self.event_bus = event_bus | |
| self.session_buffers: Dict[str, AudioBuffer] = {} | |
| self.silence_detectors: Dict[str, SilenceDetector] = {} | |
| self._setup_event_handlers() | |
| def _setup_event_handlers(self): | |
| """Subscribe to audio events""" | |
| self.event_bus.subscribe(EventType.SESSION_STARTED, self._handle_session_started) | |
| self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended) | |
| self.event_bus.subscribe(EventType.AUDIO_CHUNK_RECEIVED, self._handle_audio_chunk) | |
| async def _handle_session_started(self, event: Event): | |
| """Initialize buffer for new session""" | |
| session_id = event.session_id | |
| config = event.data | |
| # Create audio buffer | |
| self.session_buffers[session_id] = AudioBuffer( | |
| session_id=session_id, | |
| max_chunks=config.get("max_chunks", 1000), | |
| chunk_size_bytes=config.get("chunk_size", 4096) | |
| ) | |
| log_info(f"π¦ Audio buffer initialized", session_id=session_id) | |
| async def _handle_session_ended(self, event: Event): | |
| """Cleanup session buffers""" | |
| session_id = event.session_id | |
| # Clear and remove buffer | |
| if session_id in self.session_buffers: | |
| await self.session_buffers[session_id].clear() | |
| del self.session_buffers[session_id] | |
| # Remove silence detector | |
| if session_id in self.silence_detectors: | |
| del self.silence_detectors[session_id] | |
| log_info(f"π¦ Audio buffer cleaned up", session_id=session_id) | |
| async def _handle_audio_chunk(self, event: Event): | |
| """Process incoming audio chunk""" | |
| session_id = event.session_id | |
| buffer = self.session_buffers.get(session_id) | |
| if not buffer: | |
| log_warning(f"β οΈ No buffer for session", session_id=session_id) | |
| return | |
| try: | |
| # Decode audio data | |
| audio_data = base64.b64decode(event.data.get("audio_data", "")) | |
| # Add to buffer | |
| chunk = await buffer.add_chunk(audio_data) | |
| # Log periodically | |
| if chunk.chunk_index % 100 == 0: | |
| stats = buffer.get_stats() | |
| log_debug( | |
| f"π Buffer stats", | |
| session_id=session_id, | |
| **stats | |
| ) | |
| except Exception as e: | |
| log_error( | |
| f"β Error processing audio chunk", | |
| session_id=session_id, | |
| error=str(e), | |
| traceback=traceback.format_exc() | |
| ) | |
| async def get_buffer(self, session_id: str) -> Optional[AudioBuffer]: | |
| """Get buffer for session""" | |
| return self.session_buffers.get(session_id) | |
| async def reset_buffer(self, session_id: str): | |
| """Reset buffer for new utterance""" | |
| buffer = self.session_buffers.get(session_id) | |
| detector = self.silence_detectors.get(session_id) | |
| if buffer: | |
| await buffer.clear() | |
| if detector: | |
| detector.reset() | |
| log_debug(f"π Audio buffer reset", session_id=session_id) | |
| def get_all_stats(self) -> Dict[str, Dict[str, Any]]: | |
| """Get statistics for all buffers""" | |
| stats = {} | |
| for session_id, buffer in self.session_buffers.items(): | |
| stats[session_id] = buffer.get_stats() | |
| return stats |