Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import FileResponse, JSONResponse | |
from pydantic import BaseModel | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoProcessor, | |
BarkModel, | |
pipeline, | |
AutoModelForSequenceClassification | |
) | |
import scipy.io.wavfile as wavfile | |
import uuid | |
import os | |
from typing import Optional | |
# FastAPI instance | |
app = FastAPI(title="Kinyarwanda NLP API", version="1.0") | |
# Config | |
MODEL_PATH = "/app/models/suno-bark" | |
SENTIMENT_MODEL_PATH = "/app/models/sentiment" | |
SAMPLE_RATE = 24000 | |
# Ensure working directory for audio | |
AUDIO_DIR = "/tmp/audio" | |
os.makedirs(AUDIO_DIR, exist_ok=True) | |
# Load models | |
try: | |
# TTS | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
processor = AutoProcessor.from_pretrained(MODEL_PATH) | |
model = BarkModel.from_pretrained(MODEL_PATH) | |
# Sentiment | |
sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_PATH) | |
sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_PATH) | |
sentiment_pipeline = pipeline( | |
"sentiment-analysis", | |
model=sentiment_model, | |
tokenizer=sentiment_tokenizer, | |
truncation=True, | |
max_length=512 | |
) | |
# Device config | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
except Exception as e: | |
raise RuntimeError(f"Model initialization failed: {e}") | |
# Request schemas | |
class TTSRequest(BaseModel): | |
text: str | |
class SentimentRequest(BaseModel): | |
text: str | |
class LegalDocRequest(BaseModel): | |
text: str | |
domain: Optional[str] = "general" | |
# Root route | |
def root(): | |
return {"message": "Welcome to Kinyarwanda NLP API"} | |
# Text-to-Speech endpoint | |
def text_to_speech(request: TTSRequest): | |
output_file = os.path.join(AUDIO_DIR, f"tts_{uuid.uuid4().hex}.wav") | |
try: | |
inputs = processor(request.text, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
audio_array = model.generate(**inputs) | |
wavfile.write(output_file, rate=SAMPLE_RATE, data=audio_array.cpu().numpy().squeeze()) | |
return FileResponse( | |
output_file, | |
media_type="audio/wav", | |
filename=os.path.basename(output_file) | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}") | |
finally: | |
if os.path.exists(output_file): | |
os.remove(output_file) | |
# Sentiment Analysis endpoint | |
def analyze_sentiment(request: SentimentRequest): | |
try: | |
result = sentiment_pipeline(request.text) | |
return {"result": result} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Sentiment analysis failed: {str(e)}") | |
# Legal Parsing endpoint | |
def parse_legal_document(request: LegalDocRequest): | |
try: | |
keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"] | |
found = [kw for kw in keywords if kw in request.text.lower()] | |
return { | |
"identified_keywords": found, | |
"domain": request.domain, | |
"status": "success" | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Legal parsing failed: {str(e)}") | |