black44's picture
Update app.py
80bc30e verified
raw
history blame
3.36 kB
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
import torch
from transformers import (
AutoTokenizer,
AutoProcessor,
BarkModel,
pipeline,
AutoModelForSequenceClassification
)
import scipy.io.wavfile as wavfile
import uuid
import os
from typing import Optional
# FastAPI instance
app = FastAPI(title="Kinyarwanda NLP API", version="1.0")
# Config
MODEL_PATH = "/app/models/suno-bark"
SENTIMENT_MODEL_PATH = "/app/models/sentiment"
SAMPLE_RATE = 24000
# Ensure working directory for audio
AUDIO_DIR = "/tmp/audio"
os.makedirs(AUDIO_DIR, exist_ok=True)
# Load models
try:
# TTS
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
processor = AutoProcessor.from_pretrained(MODEL_PATH)
model = BarkModel.from_pretrained(MODEL_PATH)
# Sentiment
sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_PATH)
sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_PATH)
sentiment_pipeline = pipeline(
"sentiment-analysis",
model=sentiment_model,
tokenizer=sentiment_tokenizer,
truncation=True,
max_length=512
)
# Device config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
except Exception as e:
raise RuntimeError(f"Model initialization failed: {e}")
# Request schemas
class TTSRequest(BaseModel):
text: str
class SentimentRequest(BaseModel):
text: str
class LegalDocRequest(BaseModel):
text: str
domain: Optional[str] = "general"
# Root route
@app.get("/")
def root():
return {"message": "Welcome to Kinyarwanda NLP API"}
# Text-to-Speech endpoint
@app.post("/tts/")
def text_to_speech(request: TTSRequest):
output_file = os.path.join(AUDIO_DIR, f"tts_{uuid.uuid4().hex}.wav")
try:
inputs = processor(request.text, return_tensors="pt").to(device)
with torch.no_grad():
audio_array = model.generate(**inputs)
wavfile.write(output_file, rate=SAMPLE_RATE, data=audio_array.cpu().numpy().squeeze())
return FileResponse(
output_file,
media_type="audio/wav",
filename=os.path.basename(output_file)
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")
finally:
if os.path.exists(output_file):
os.remove(output_file)
# Sentiment Analysis endpoint
@app.post("/sentiment/")
def analyze_sentiment(request: SentimentRequest):
try:
result = sentiment_pipeline(request.text)
return {"result": result}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Sentiment analysis failed: {str(e)}")
# Legal Parsing endpoint
@app.post("/legal-parse/")
def parse_legal_document(request: LegalDocRequest):
try:
keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"]
found = [kw for kw in keywords if kw in request.text.lower()]
return {
"identified_keywords": found,
"domain": request.domain,
"status": "success"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Legal parsing failed: {str(e)}")