black44's picture
Update app.py
2dacdf9 verified
raw
history blame
3.31 kB
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import (
AutoTokenizer,
AutoProcessor,
BarkModel,
pipeline,
AutoModelForSequenceClassification
)
import scipy.io.wavfile as wavfile
import uuid
import os
from typing import Optional
# Initialize FastAPI app
app = FastAPI()
# Configuration
MODEL_PATH = "/app/models/suno-bark"
SENTIMENT_MODEL_PATH = "/app/models/sentiment"
# Load all models in a single try-except block
try:
# TTS Model
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
processor = AutoProcessor.from_pretrained(MODEL_PATH)
model = BarkModel.from_pretrained(MODEL_PATH)
# Sentiment Analysis Model (pre-downloaded)
sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_PATH)
sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_PATH)
sentiment_pipeline = pipeline(
"sentiment-analysis",
model=sentiment_model,
tokenizer=sentiment_tokenizer,
truncation=True,
max_length=512
)
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
except Exception as e:
raise RuntimeError(f"Initialization failed: {str(e)}")
# Request models
class TTSRequest(BaseModel):
text: str
class SentimentRequest(BaseModel):
text: str
class LegalDocRequest(BaseModel):
text: str
domain: Optional[str] = "general"
@app.get("/")
def root():
return {"message": "Welcome to Kinyarwanda NLP API"}
@app.post("/tts/")
def text_to_speech(request: TTSRequest):
output_file = None
try:
inputs = processor(request.text, return_tensors="pt").to(device)
with torch.no_grad():
speech = model.generate(**inputs)
output_file = f"output_{uuid.uuid4().hex}.wav"
wavfile.write(output_file, rate=24000, data=speech.cpu().numpy().squeeze())
return FileResponse(
output_file,
media_type="audio/wav",
filename=output_file
)
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": f"TTS failed: {str(e)}"}
)
finally:
if output_file and os.path.exists(output_file):
os.remove(output_file)
@app.post("/sentiment/")
def analyze_sentiment(request: SentimentRequest):
try:
result = sentiment_pipeline(request.text)
return {"result": result}
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": f"Sentiment analysis failed: {str(e)}"}
)
@app.post("/legal-parse/")
def parse_legal_document(request: LegalDocRequest):
try:
keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"]
found_keywords = [kw for kw in keywords if kw in request.text.lower()]
return {
"identified_keywords": found_keywords,
"domain": request.domain,
"status": "success"
}
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": f"Legal parsing failed: {str(e)}"}
)