black44's picture
Update app.py
fb42ae8 verified
raw
history blame
4.41 kB
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
import torch
from transformers import (
AutoTokenizer,
AutoProcessor,
BarkModel,
pipeline,
AutoModelForSequenceClassification,
Wav2Vec2Processor,
Wav2Vec2ForCTC
)
import scipy.io.wavfile as wavfile
import uuid
import os
from io import BytesIO
import soundfile as sf
from typing import Optional
# FastAPI instance
app = FastAPI(title="Kinyarwanda Engine", version="1.0")
# Config
MODEL_PATH = "/app/models/suno-bark"
SENTIMENT_MODEL_PATH = "/app/models/sentiment"
SAMPLE_RATE = 24000
ASR_MODEL_PATH = "jonatasgrosman/wav2vec2-large-xlsr-53-Kinyarwanda"
# Ensure working directory for audio
AUDIO_DIR = "/tmp/audio"
os.makedirs(AUDIO_DIR, exist_ok=True)
# Load models
try:
# TTS
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
processor = AutoProcessor.from_pretrained(MODEL_PATH)
model = BarkModel.from_pretrained(MODEL_PATH)
# Sentiment
sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_PATH)
sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_PATH)
sentiment_pipeline = pipeline(
"sentiment-analysis",
model=sentiment_model,
tokenizer=sentiment_tokenizer,
truncation=True,
max_length=512
)
# STT
asr_processor = Wav2Vec2Processor.from_pretrained(ASR_MODEL_PATH)
asr_model = Wav2Vec2ForCTC.from_pretrained(ASR_MODEL_PATH)
asr_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# Device config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
except Exception as e:
raise RuntimeError(f"Model initialization failed: {e}")
# Request schemas
class TTSRequest(BaseModel):
text: str
class SentimentRequest(BaseModel):
text: str
class LegalDocRequest(BaseModel):
text: str
domain: Optional[str] = "general"
# Root route
@app.get("/")
def root():
return {"message": "Welcome to Kinyarwanda Engine"}
# Text-to-Speech endpoint
@app.post("/tts/")
def text_to_speech(request: TTSRequest):
try:
inputs = processor(request.text, return_tensors="pt").to(device)
with torch.no_grad():
audio_array = model.generate(**inputs)
audio_data = audio_array.cpu().numpy().squeeze()
buffer = BytesIO()
wavfile.write(buffer, rate=SAMPLE_RATE, data=audio_data)
buffer.seek(0)
return StreamingResponse(
buffer,
media_type="audio/wav",
headers={"Content-Disposition": f"attachment; filename=tts_{uuid.uuid4().hex}.wav"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")
# Speech-to-Text endpoint
@app.post("/stt/")
def speech_to_text(audio_file: UploadFile = File(...)):
try:
audio_bytes = audio_file.file.read()
audio, sample_rate = sf.read(BytesIO(audio_bytes))
inputs = asr_processor(audio, sampling_rate=sample_rate, return_tensors="pt", padding=True).input_values.to(device)
with torch.no_grad():
logits = asr_model(inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = asr_processor.batch_decode(predicted_ids)[0]
return {"transcription": transcription}
except Exception as e:
raise HTTPException(status_code=500, detail=f"STT failed: {str(e)}")
# Sentiment Analysis endpoint
@app.post("/sentiment/")
def analyze_sentiment(request: SentimentRequest):
try:
result = sentiment_pipeline(request.text)
return {"result": result}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Sentiment analysis failed: {str(e)}")
# Legal Parsing endpoint
@app.post("/legal-parse/")
def parse_legal_document(request: LegalDocRequest):
try:
keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"]
found = [kw for kw in keywords if kw in request.text.lower()]
return {
"identified_keywords": found,
"domain": request.domain,
"status": "success"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Legal parsing failed: {str(e)}")