black44's picture
Update app.py
cf0cbad verified
raw
history blame
2.9 kB
from fastapi import FastAPI, Form
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoProcessor, BarkModel, pipeline
import scipy.io.wavfile as wavfile
import uuid
import os
from typing import Optional
# Ensure proper model loading from pre-downloaded path
MODEL_PATH = "/app/models/suno-bark"
# Load models and processors once during startup
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
processor = AutoProcessor.from_pretrained(MODEL_PATH)
model = BarkModel.from_pretrained(MODEL_PATH)
# Load sentiment analysis pipeline
sentiment_model = pipeline(
"sentiment-analysis",
model="nlptown/bert-base-multilingual-uncased-sentiment"
)
except Exception as e:
raise RuntimeError(f"Model loading failed: {str(e)}")
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Initialize FastAPI app
app = FastAPI()
# Request models
class TTSRequest(BaseModel):
text: str
class SentimentRequest(BaseModel):
text: str
class LegalDocRequest(BaseModel):
text: str
domain: Optional[str] = "general"
@app.get("/")
def root():
return {"message": "Welcome to Kinyarwanda NLP API"}
@app.post("/tts/")
def text_to_speech(request: TTSRequest):
try:
inputs = processor(request.text, return_tensors="pt").to(device)
with torch.no_grad():
speech = model.generate(**inputs)
output_file = f"output_{uuid.uuid4().hex}.wav"
wavfile.write(output_file, rate=24000, data=speech.cpu().numpy().squeeze())
return FileResponse(output_file, media_type="audio/wav", filename=output_file)
except Exception as e:
return JSONResponse(status_code=500, content={"error": f"TTS failed: {str(e)}"})
finally:
if os.path.exists(output_file):
os.remove(output_file)
@app.post("/sentiment/")
def analyze_sentiment(request: SentimentRequest):
try:
result = sentiment_model(request.text)
return {"result": result}
except Exception as e:
return JSONResponse(status_code=500, content={"error": f"Sentiment analysis failed: {str(e)}"})
@app.post("/legal-parse/")
def parse_legal_document(request: LegalDocRequest):
try:
# Basic keyword extraction (replace with trained model in production)
keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"]
found_keywords = [kw for kw in keywords if kw in request.text.lower()]
return {
"identified_keywords": found_keywords,
"domain": request.domain,
"status": "success"
}
except Exception as e:
return JSONResponse(status_code=500, content={"error": f"Legal parsing failed: {str(e)}"})