Spaces:
Runtime error
Runtime error
from fastapi import FastAPI | |
from pydantic import BaseModel | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoProcessor, | |
BarkModel, | |
pipeline, | |
AutoModelForSequenceClassification | |
) | |
import scipy.io.wavfile as wavfile | |
import uuid | |
import os | |
from typing import Optional | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Configuration | |
MODEL_PATH = "/app/models/suno-bark" | |
SENTIMENT_MODEL_PATH = "/app/models/sentiment" | |
# Load all models in a single try-except block | |
try: | |
# TTS Model | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
processor = AutoProcessor.from_pretrained(MODEL_PATH) | |
model = BarkModel.from_pretrained(MODEL_PATH) | |
# Sentiment Analysis Model (pre-downloaded) | |
sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_PATH) | |
sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_PATH) | |
sentiment_pipeline = pipeline( | |
"sentiment-analysis", | |
model=sentiment_model, | |
tokenizer=sentiment_tokenizer, | |
truncation=True, | |
max_length=512 | |
) | |
# Device configuration | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
except Exception as e: | |
raise RuntimeError(f"Initialization failed: {str(e)}") | |
# Request models | |
class TTSRequest(BaseModel): | |
text: str | |
class SentimentRequest(BaseModel): | |
text: str | |
class LegalDocRequest(BaseModel): | |
text: str | |
domain: Optional[str] = "general" | |
def root(): | |
return {"message": "Welcome to Kinyarwanda NLP API"} | |
def text_to_speech(request: TTSRequest): | |
output_file = None | |
try: | |
inputs = processor(request.text, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
speech = model.generate(**inputs) | |
output_file = f"output_{uuid.uuid4().hex}.wav" | |
wavfile.write(output_file, rate=24000, data=speech.cpu().numpy().squeeze()) | |
return FileResponse( | |
output_file, | |
media_type="audio/wav", | |
filename=output_file | |
) | |
except Exception as e: | |
return JSONResponse( | |
status_code=500, | |
content={"error": f"TTS failed: {str(e)}"} | |
) | |
finally: | |
if output_file and os.path.exists(output_file): | |
os.remove(output_file) | |
def analyze_sentiment(request: SentimentRequest): | |
try: | |
result = sentiment_pipeline(request.text) | |
return {"result": result} | |
except Exception as e: | |
return JSONResponse( | |
status_code=500, | |
content={"error": f"Sentiment analysis failed: {str(e)}"} | |
) | |
def parse_legal_document(request: LegalDocRequest): | |
try: | |
keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"] | |
found_keywords = [kw for kw in keywords if kw in request.text.lower()] | |
return { | |
"identified_keywords": found_keywords, | |
"domain": request.domain, | |
"status": "success" | |
} | |
except Exception as e: | |
return JSONResponse( | |
status_code=500, | |
content={"error": f"Legal parsing failed: {str(e)}"} | |
) |