Spaces:
Paused
Paused
| """ | |
| Audio API endpoints for Flare (Refactored with Event-Driven Architecture) | |
| ======================================================================== | |
| Provides text-to-speech (TTS) and speech-to-text (STT) endpoints. | |
| """ | |
| from fastapi import APIRouter, HTTPException, Response, Body, Request, WebSocket | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| from datetime import datetime | |
| import sys | |
| import base64 | |
| from utils.logger import log_info, log_error, log_warning, log_debug | |
| from tts.tts_factory import TTSFactory | |
| from tts.tts_preprocessor import TTSPreprocessor | |
| from config.config_provider import ConfigProvider | |
| router = APIRouter(tags=["audio"]) | |
| # ===================== Models ===================== | |
| class TTSRequest(BaseModel): | |
| text: str | |
| voice_id: Optional[str] = None | |
| language: Optional[str] = "tr-TR" | |
| session_id: Optional[str] = None # For event-driven mode | |
| class STTRequest(BaseModel): | |
| audio_data: str # Base64 encoded audio | |
| language: Optional[str] = "tr-TR" | |
| format: Optional[str] = "webm" # webm, wav, mp3 | |
| session_id: Optional[str] = None # For event-driven mode | |
| # ===================== TTS Endpoints ===================== | |
| async def generate_tts(request: TTSRequest, req: Request): | |
| """ | |
| Generate TTS audio from text | |
| - If session_id is provided and event bus is available, uses event-driven mode | |
| - Otherwise, uses direct TTS generation | |
| """ | |
| try: | |
| # Check if we should use event-driven mode | |
| if request.session_id and hasattr(req.app.state, 'event_bus'): | |
| # Event-driven mode for realtime sessions | |
| from event_bus import Event, EventType | |
| log_info(f"π€ TTS request via event bus for session: {request.session_id}") | |
| # Publish TTS event | |
| await req.app.state.event_bus.publish(Event( | |
| type=EventType.TTS_STARTED, | |
| session_id=request.session_id, | |
| data={ | |
| "text": request.text, | |
| "voice_id": request.voice_id, | |
| "language": request.language, | |
| "is_api_call": True # Flag to indicate this is from REST API | |
| } | |
| )) | |
| # Return a response indicating audio will be streamed via WebSocket | |
| return { | |
| "status": "processing", | |
| "message": "TTS audio will be streamed via WebSocket connection", | |
| "session_id": request.session_id | |
| } | |
| else: | |
| # Direct TTS generation (legacy mode) | |
| tts_provider = TTSFactory.create_provider() | |
| if not tts_provider: | |
| log_info("π΅ TTS disabled - returning empty response") | |
| return Response( | |
| content=b"", | |
| media_type="audio/mpeg", | |
| headers={"X-TTS-Status": "disabled"} | |
| ) | |
| log_info(f"π€ Direct TTS request: '{request.text[:50]}...' with provider: {tts_provider.get_provider_name()}") | |
| # Preprocess text if needed | |
| preprocessor = TTSPreprocessor(language=request.language) | |
| processed_text = preprocessor.preprocess( | |
| request.text, | |
| tts_provider.get_preprocessing_flags() | |
| ) | |
| log_debug(f"π Preprocessed text: {processed_text[:100]}...") | |
| # Generate audio | |
| audio_data = await tts_provider.synthesize( | |
| text=processed_text, | |
| voice_id=request.voice_id | |
| ) | |
| log_info(f"β TTS generated {len(audio_data)} bytes of audio") | |
| # Return audio as binary response | |
| return Response( | |
| content=audio_data, | |
| media_type="audio/mpeg", | |
| headers={ | |
| "Content-Disposition": 'inline; filename="tts_output.mp3"', | |
| "X-TTS-Provider": tts_provider.get_provider_name(), | |
| "X-TTS-Language": request.language, | |
| "Cache-Control": "no-cache" | |
| } | |
| ) | |
| except Exception as e: | |
| log_error("β TTS generation error", e) | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"TTS generation failed: {str(e)}" | |
| ) | |
| async def get_tts_voices(): | |
| """Get available TTS voices""" | |
| try: | |
| tts_provider = TTSFactory.create_provider() | |
| if not tts_provider: | |
| return { | |
| "voices": [], | |
| "provider": "none", | |
| "enabled": False | |
| } | |
| voices = tts_provider.get_supported_voices() | |
| # Convert dict to list format | |
| voice_list = [ | |
| {"id": voice_id, "name": voice_name} | |
| for voice_id, voice_name in voices.items() | |
| ] | |
| return { | |
| "voices": voice_list, | |
| "provider": tts_provider.get_provider_name(), | |
| "enabled": True | |
| } | |
| except Exception as e: | |
| log_error("β Error getting TTS voices", e) | |
| return { | |
| "voices": [], | |
| "provider": "error", | |
| "enabled": False, | |
| "error": str(e) | |
| } | |
| async def get_tts_status(): | |
| """Get TTS service status""" | |
| cfg = ConfigProvider.get() | |
| return { | |
| "enabled": cfg.global_config.tts_provider.name != "no_tts", | |
| "provider": cfg.global_config.tts_provider.name, | |
| "provider_config": { | |
| "name": cfg.global_config.tts_provider.name, | |
| "has_api_key": bool(cfg.global_config.tts_provider.api_key), | |
| "endpoint": cfg.global_config.tts_provider.endpoint | |
| } | |
| } | |
| # ===================== STT Endpoints ===================== | |
| async def transcribe_audio(request: STTRequest, req: Request): | |
| """ | |
| Transcribe audio to text | |
| - If session_id is provided and event bus is available, uses event-driven mode | |
| - Otherwise, uses direct STT transcription | |
| """ | |
| try: | |
| # Check if we should use event-driven mode | |
| if request.session_id and hasattr(req.app.state, 'event_bus'): | |
| # Event-driven mode for realtime sessions | |
| from event_bus import Event, EventType | |
| log_info(f"π§ STT request via event bus for session: {request.session_id}") | |
| # Publish audio chunk event | |
| await req.app.state.event_bus.publish(Event( | |
| type=EventType.AUDIO_CHUNK_RECEIVED, | |
| session_id=request.session_id, | |
| data={ | |
| "audio_data": request.audio_data, # Already base64 | |
| "format": request.format, | |
| "language": request.language, | |
| "is_api_call": True | |
| } | |
| )) | |
| # Return a response indicating transcription will be available via WebSocket | |
| return { | |
| "status": "processing", | |
| "message": "Transcription will be available via WebSocket connection", | |
| "session_id": request.session_id | |
| } | |
| else: | |
| # Direct STT transcription (legacy mode) | |
| from stt.stt_factory import STTFactory | |
| from stt.stt_interface import STTConfig | |
| # Create STT provider | |
| stt_provider = STTFactory.create_provider() | |
| if not stt_provider or not stt_provider.supports_realtime(): | |
| log_warning("π΅ STT disabled or doesn't support transcription") | |
| raise HTTPException( | |
| status_code=503, | |
| detail="STT service not available" | |
| ) | |
| # Get config | |
| cfg = ConfigProvider.get() | |
| stt_config = cfg.global_config.stt_provider.settings | |
| # Decode audio data | |
| audio_bytes = base64.b64decode(request.audio_data) | |
| # Create STT config | |
| config = STTConfig( | |
| language=request.language or stt_config.get("language", "tr-TR"), | |
| sample_rate=16000, | |
| encoding=request.format.upper() if request.format else "WEBM_OPUS", | |
| enable_punctuation=stt_config.get("enable_punctuation", True), | |
| enable_word_timestamps=False, | |
| model=stt_config.get("model", "latest_long"), | |
| use_enhanced=stt_config.get("use_enhanced", True), | |
| single_utterance=True, | |
| interim_results=False | |
| ) | |
| # Start streaming session | |
| await stt_provider.start_streaming(config) | |
| # Process audio | |
| transcription = "" | |
| confidence = 0.0 | |
| try: | |
| async for result in stt_provider.stream_audio(audio_bytes): | |
| if result.is_final: | |
| transcription = result.text | |
| confidence = result.confidence | |
| break | |
| finally: | |
| # Stop streaming | |
| await stt_provider.stop_streaming() | |
| log_info(f"β STT transcription completed: '{transcription[:50]}...'") | |
| return { | |
| "text": transcription, | |
| "confidence": confidence, | |
| "language": request.language, | |
| "provider": stt_provider.get_provider_name() | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| log_error("β STT transcription error", e) | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Transcription failed: {str(e)}" | |
| ) | |
| async def get_stt_languages(): | |
| """Get supported STT languages""" | |
| try: | |
| from stt.stt_factory import STTFactory | |
| stt_provider = STTFactory.create_provider() | |
| if not stt_provider: | |
| return { | |
| "languages": [], | |
| "provider": "none", | |
| "enabled": False | |
| } | |
| languages = stt_provider.get_supported_languages() | |
| return { | |
| "languages": languages, | |
| "provider": stt_provider.get_provider_name(), | |
| "enabled": True | |
| } | |
| except Exception as e: | |
| log_error("β Error getting STT languages", e) | |
| return { | |
| "languages": [], | |
| "provider": "error", | |
| "enabled": False, | |
| "error": str(e) | |
| } | |
| async def get_stt_status(): | |
| """Get STT service status""" | |
| cfg = ConfigProvider.get() | |
| return { | |
| "enabled": cfg.global_config.stt_provider.name != "no_stt", | |
| "provider": cfg.global_config.stt_provider.name, | |
| "provider_config": { | |
| "name": cfg.global_config.stt_provider.name, | |
| "has_api_key": bool(cfg.global_config.stt_provider.api_key), | |
| "endpoint": cfg.global_config.stt_provider.endpoint | |
| } | |
| } | |
| # ===================== WebSocket Audio Stream Endpoint ===================== | |
| async def audio_websocket(websocket: WebSocket, session_id: str, request: Request): | |
| """ | |
| WebSocket endpoint for streaming audio | |
| This is a dedicated audio stream separate from the main conversation WebSocket | |
| """ | |
| from fastapi import WebSocketDisconnect | |
| try: | |
| await websocket.accept() | |
| log_info(f"π΅ Audio WebSocket connected for session: {session_id}") | |
| if not hasattr(request.app.state, 'event_bus'): | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "Event bus not initialized" | |
| }) | |
| await websocket.close() | |
| return | |
| while True: | |
| try: | |
| # Receive audio data | |
| data = await websocket.receive_json() | |
| if data.get("type") == "audio_chunk": | |
| # Forward to event bus | |
| from event_bus import Event, EventType | |
| await request.app.state.event_bus.publish(Event( | |
| type=EventType.AUDIO_CHUNK_RECEIVED, | |
| session_id=session_id, | |
| data={ | |
| "audio_data": data.get("data"), | |
| "timestamp": data.get("timestamp"), | |
| "chunk_index": data.get("chunk_index", 0) | |
| } | |
| )) | |
| elif data.get("type") == "control": | |
| action = data.get("action") | |
| if action == "start_recording": | |
| from event_bus import Event, EventType | |
| await request.app.state.event_bus.publish(Event( | |
| type=EventType.STT_STARTED, | |
| session_id=session_id, | |
| data={ | |
| "language": data.get("language", "tr-TR"), | |
| "format": data.get("format", "webm") | |
| } | |
| )) | |
| elif action == "stop_recording": | |
| from event_bus import Event, EventType | |
| await request.app.state.event_bus.publish(Event( | |
| type=EventType.STT_STOPPED, | |
| session_id=session_id, | |
| data={"reason": "user_request"} | |
| )) | |
| except WebSocketDisconnect: | |
| break | |
| except Exception as e: | |
| log_error(f"Error in audio WebSocket", error=str(e)) | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": str(e) | |
| }) | |
| except Exception as e: | |
| log_error(f"Audio WebSocket error", error=str(e)) | |
| finally: | |
| log_info(f"π΅ Audio WebSocket disconnected for session: {session_id}") | |