File size: 4,708 Bytes
3f43e82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
# agents/voice_agent/main.py
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import Response # Import Response for returning audio bytes
from pydantic import BaseModel
from gtts import gTTS
import tempfile
import os
import logging
from faster_whisper import WhisperModel # For STT
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Voice Agent")
# Get Whisper model size from environment
WHISPER_MODEL_SIZE = os.getenv("WHISPER_MODEL_SIZE", "small") # Default to 'small'
# Initialize Whisper model once on startup
try:
# Using cpu is generally safer for deployment unless you have a specific GPU setup
whisper_model = WhisperModel(WHISPER_MODEL_SIZE, device="cpu")
logger.info(f"Whisper model '{WHISPER_MODEL_SIZE}' loaded successfully on CPU.")
except Exception as e:
logger.error(f"Error loading Whisper model '{WHISPER_MODEL_SIZE}': {e}")
# Depending on criticality, you might raise here or handle gracefully
whisper_model = None # Set to None if loading failed
class TTSRequest(BaseModel):
text: str
lang: str = "en"
@app.post("/stt")
async def stt(audio: UploadFile = File(...)):
"""
Performs Speech-to-Text on an uploaded audio file.
"""
if whisper_model is None:
raise HTTPException(status_code=503, detail="STT model not loaded.")
logger.info(f"Received audio file for STT: {audio.filename}")
# Save uploaded audio file to a temporary location
# Use .with_suffix('.wav') explicitly if needed, although whisper handles formats
suffix = os.path.splitext(audio.filename)[1] if audio.filename else ".wav"
tmp_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
audio_content = await audio.read()
tmp.write(audio_content)
tmp_path = tmp.name
logger.info(f"Audio saved to temporary file: {tmp_path}")
# Transcribe using faster-whisper
# max_int16 ensures compatibility, adjust as needed
segments, info = whisper_model.transcribe(
tmp_path, language=info.language if "info" in locals() else None
)
transcript = " ".join([seg.text for seg in segments]).strip()
logger.info(f"Transcription complete. Transcript: '{transcript}'")
return {"transcript": transcript}
except Exception as e:
logger.error(f"Error during STT processing: {e}")
raise HTTPException(status_code=500, detail=f"STT processing failed: {e}")
finally:
# Clean up temporary file
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
logger.info(f"Temporary file removed: {tmp_path}")
@app.post("/tts")
def tts(request: TTSRequest):
"""
Performs Text-to-Speech using gTTS.
Returns the audio data as a hex string (to match original orchestrator expectation).
NOTE: Returning raw bytes with media_type='audio/mpeg' is more standard for APIs.
This implementation keeps the hex encoding to avoid changing the orchestrator.
"""
logger.info(
f"Generating TTS for text (lang={request.lang}): '{request.text[:50]}...'"
)
tmp_path = None
try:
# Create gTTS object
tts_obj = gTTS(text=request.text, lang=request.lang, slow=False)
# Save to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
tts_obj.save(tmp.name)
tmp_path = tmp.name
logger.info(f"TTS audio saved to temporary file: {tmp_path}")
# Read the audio file bytes
with open(tmp_path, "rb") as f:
audio_bytes = f.read()
logger.info(f"Read {len(audio_bytes)} bytes from temporary file.")
# Return as hex string as per original orchestrator expectation
audio_hex = audio_bytes.hex()
logger.info("Audio bytes converted to hex.")
return {"audio": audio_hex}
# --- Alternative (More standard API practice - requires orchestrator change) ---
# return Response(content=audio_bytes, media_type="audio/mpeg")
# ---------------------------------------------------------------------------
except Exception as e:
logger.error(f"Error during TTS processing: {e}")
raise HTTPException(status_code=500, detail=f"TTS processing failed: {e}")
finally:
# Clean up temporary file
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
logger.info(f"Temporary file removed: {tmp_path}")
|