Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, Form | |
from fastapi.responses import FileResponse, JSONResponse | |
from pydantic import BaseModel | |
import torch | |
from transformers import AutoProcessor, BarkModel, pipeline | |
import scipy.io.wavfile as wavfile | |
import uuid | |
import os | |
from typing import Optional | |
# Load TTS model and processor | |
processor = AutoProcessor.from_pretrained("suno/bark") | |
model = BarkModel.from_pretrained("suno/bark") | |
# Load sentiment analysis pipeline (using multilingual model) | |
sentiment_model = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment") | |
# Ensure model is on CPU or CUDA if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
# FastAPI app | |
app = FastAPI() | |
# Endpoint input models | |
class TTSRequest(BaseModel): | |
text: str | |
class SentimentRequest(BaseModel): | |
text: str | |
class LegalDocRequest(BaseModel): | |
text: str | |
domain: Optional[str] = "general" | |
def root(): | |
return {"message": "Welcome to Kinyarwanda NLP API"} | |
def text_to_speech(request: TTSRequest): | |
try: | |
# Generate speech | |
inputs = processor(request.text, return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
speech = model.generate(**inputs) | |
# Save audio | |
output_file = f"output_{uuid.uuid4().hex}.wav" | |
speech_np = speech.cpu().numpy().squeeze() | |
wavfile.write(output_file, rate=22050, data=speech_np) | |
return FileResponse(output_file, media_type="audio/wav") | |
except Exception as e: | |
return JSONResponse(status_code=500, content={"error": str(e)}) | |
def analyze_sentiment(request: SentimentRequest): | |
try: | |
result = sentiment_model(request.text) | |
return {"result": result} | |
except Exception as e: | |
return JSONResponse(status_code=500, content={"error": str(e)}) | |
def parse_legal_document(request: LegalDocRequest): | |
try: | |
# Placeholder logic (replace with training-based custom logic) | |
keywords = ["contract", "agreement", "party", "terms"] | |
found_keywords = [kw for kw in keywords if kw in request.text.lower()] | |
return {"identified_keywords": found_keywords, "domain": request.domain} | |
except Exception as e: | |
return JSONResponse(status_code=500, content={"error": str(e)}) | |