|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException |
|
from fastapi.responses import Response |
|
from pydantic import BaseModel |
|
from gtts import gTTS |
|
import tempfile |
|
import os |
|
import logging |
|
from faster_whisper import WhisperModel |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI(title="Voice Agent") |
|
|
|
|
|
WHISPER_MODEL_SIZE = os.getenv("WHISPER_MODEL_SIZE", "small") |
|
|
|
try: |
|
|
|
whisper_model = WhisperModel(WHISPER_MODEL_SIZE, device="cpu") |
|
logger.info(f"Whisper model '{WHISPER_MODEL_SIZE}' loaded successfully on CPU.") |
|
except Exception as e: |
|
logger.error(f"Error loading Whisper model '{WHISPER_MODEL_SIZE}': {e}") |
|
|
|
whisper_model = None |
|
|
|
|
|
class TTSRequest(BaseModel): |
|
text: str |
|
lang: str = "en" |
|
|
|
|
|
@app.post("/stt") |
|
async def stt(audio: UploadFile = File(...)): |
|
""" |
|
Performs Speech-to-Text on an uploaded audio file. |
|
""" |
|
if whisper_model is None: |
|
raise HTTPException(status_code=503, detail="STT model not loaded.") |
|
|
|
logger.info(f"Received audio file for STT: {audio.filename}") |
|
|
|
|
|
|
|
suffix = os.path.splitext(audio.filename)[1] if audio.filename else ".wav" |
|
tmp_path = None |
|
try: |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: |
|
audio_content = await audio.read() |
|
tmp.write(audio_content) |
|
tmp_path = tmp.name |
|
logger.info(f"Audio saved to temporary file: {tmp_path}") |
|
|
|
|
|
|
|
segments, info = whisper_model.transcribe( |
|
tmp_path, language=info.language if "info" in locals() else None |
|
) |
|
transcript = " ".join([seg.text for seg in segments]).strip() |
|
logger.info(f"Transcription complete. Transcript: '{transcript}'") |
|
|
|
return {"transcript": transcript} |
|
|
|
except Exception as e: |
|
logger.error(f"Error during STT processing: {e}") |
|
raise HTTPException(status_code=500, detail=f"STT processing failed: {e}") |
|
finally: |
|
|
|
if tmp_path and os.path.exists(tmp_path): |
|
os.remove(tmp_path) |
|
logger.info(f"Temporary file removed: {tmp_path}") |
|
|
|
|
|
@app.post("/tts") |
|
def tts(request: TTSRequest): |
|
""" |
|
Performs Text-to-Speech using gTTS. |
|
Returns the audio data as a hex string (to match original orchestrator expectation). |
|
NOTE: Returning raw bytes with media_type='audio/mpeg' is more standard for APIs. |
|
This implementation keeps the hex encoding to avoid changing the orchestrator. |
|
""" |
|
logger.info( |
|
f"Generating TTS for text (lang={request.lang}): '{request.text[:50]}...'" |
|
) |
|
tmp_path = None |
|
try: |
|
|
|
tts_obj = gTTS(text=request.text, lang=request.lang, slow=False) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp: |
|
tts_obj.save(tmp.name) |
|
tmp_path = tmp.name |
|
logger.info(f"TTS audio saved to temporary file: {tmp_path}") |
|
|
|
|
|
with open(tmp_path, "rb") as f: |
|
audio_bytes = f.read() |
|
logger.info(f"Read {len(audio_bytes)} bytes from temporary file.") |
|
|
|
|
|
audio_hex = audio_bytes.hex() |
|
logger.info("Audio bytes converted to hex.") |
|
|
|
return {"audio": audio_hex} |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
logger.error(f"Error during TTS processing: {e}") |
|
raise HTTPException(status_code=500, detail=f"TTS processing failed: {e}") |
|
finally: |
|
|
|
if tmp_path and os.path.exists(tmp_path): |
|
os.remove(tmp_path) |
|
logger.info(f"Temporary file removed: {tmp_path}") |
|
|