black44's picture
creating the main engine
364b411 verified
raw
history blame
2.39 kB
from fastapi import FastAPI, Form
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
import torch
from transformers import AutoProcessor, BarkModel, pipeline
import scipy.io.wavfile as wavfile
import uuid
import os
from typing import Optional
# Load TTS model and processor
processor = AutoProcessor.from_pretrained("suno/bark")
model = BarkModel.from_pretrained("suno/bark")
# Load sentiment analysis pipeline (using multilingual model)
sentiment_model = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
# Ensure model is on CPU or CUDA if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# FastAPI app
app = FastAPI()
# Endpoint input models
class TTSRequest(BaseModel):
text: str
class SentimentRequest(BaseModel):
text: str
class LegalDocRequest(BaseModel):
text: str
domain: Optional[str] = "general"
@app.get("/")
def root():
return {"message": "Welcome to Kinyarwanda NLP API"}
@app.post("/tts/")
def text_to_speech(request: TTSRequest):
try:
# Generate speech
inputs = processor(request.text, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
speech = model.generate(**inputs)
# Save audio
output_file = f"output_{uuid.uuid4().hex}.wav"
speech_np = speech.cpu().numpy().squeeze()
wavfile.write(output_file, rate=22050, data=speech_np)
return FileResponse(output_file, media_type="audio/wav")
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/sentiment/")
def analyze_sentiment(request: SentimentRequest):
try:
result = sentiment_model(request.text)
return {"result": result}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/legal-parse/")
def parse_legal_document(request: LegalDocRequest):
try:
# Placeholder logic (replace with training-based custom logic)
keywords = ["contract", "agreement", "party", "terms"]
found_keywords = [kw for kw in keywords if kw in request.text.lower()]
return {"identified_keywords": found_keywords, "domain": request.domain}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})