nuera / app.py
akshatOP's picture
Update app.py with fixed imports
4ae9dfe
raw
history blame
2.62 kB
from fastapi import FastAPI, File, UploadFile, Response
from transformers import ParlerTTSForConditionalGeneration, AutoTokenizer
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from llama_cpp import Llama
import torch
import soundfile as sf
import io
import os
from pydantic import BaseModel
app = FastAPI()
# Load models
# TTS: Use local fine-tuned model if available, else load from Hub
if os.path.exists("./models/tts_model"):
tts_model = ParlerTTSForConditionalGeneration.from_pretrained("./models/tts_model")
tts_tokenizer = AutoTokenizer.from_pretrained("./models/tts_model")
else:
tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1")
tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
# SST: Use local fine-tuned model if available, else load from Hub
if os.path.exists("./models/sst_model"):
sst_model = Wav2Vec2ForCTC.from_pretrained("./models/sst_model")
sst_processor = Wav2Vec2Processor.from_pretrained("./models/sst_model")
else:
sst_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
sst_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
# LLM: Use local GGUF file if available, else raise error
if os.path.exists("./models/llama.gguf"):
llm = Llama("./models/llama.gguf")
else:
raise FileNotFoundError("Please upload llama.gguf to models/ directory")
# Request models
class TTSRequest(BaseModel):
text: str
class LLMRequest(BaseModel):
prompt: str
# API Endpoints
@app.post("/tts")
async def tts_endpoint(request: TTSRequest):
text = request.text
inputs = tts_tokenizer(text, return_tensors="pt")
with torch.no_grad():
audio = tts_model.generate(**inputs)
audio = audio.squeeze().cpu().numpy()
buffer = io.BytesIO()
sf.write(buffer, audio, 22050, format="WAV")
buffer.seek(0)
return Response(content=buffer.getvalue(), media_type="audio/wav")
@app.post("/sst")
async def sst_endpoint(file: UploadFile = File(...)):
audio_bytes = await file.read()
audio, sr = sf.read(io.BytesIO(audio_bytes))
inputs = sst_processor(audio, sampling_rate=sr, return_tensors="pt")
with torch.no_grad():
logits = sst_model(inputs.input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = sst_processor.batch_decode(predicted_ids)[0]
return {"text": transcription}
@app.post("/llm")
async def llm_endpoint(request: LLMRequest):
prompt = request.prompt
output = llm(prompt, max_tokens=50)
return {"text": output["choices"][0]["text"]}