from fastapi import FastAPI, Form from fastapi.responses import FileResponse, JSONResponse from pydantic import BaseModel import torch from transformers import AutoProcessor, BarkModel, pipeline import scipy.io.wavfile as wavfile import uuid import os from typing import Optional # Load TTS model and processor processor = AutoProcessor.from_pretrained("suno/bark") model = BarkModel.from_pretrained("suno/bark") # Load sentiment analysis pipeline (using multilingual model) sentiment_model = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment") # Ensure model is on CPU or CUDA if available device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) # FastAPI app app = FastAPI() # Endpoint input models class TTSRequest(BaseModel): text: str class SentimentRequest(BaseModel): text: str class LegalDocRequest(BaseModel): text: str domain: Optional[str] = "general" @app.get("/") def root(): return {"message": "Welcome to Kinyarwanda NLP API"} @app.post("/tts/") def text_to_speech(request: TTSRequest): try: # Generate speech inputs = processor(request.text, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} speech = model.generate(**inputs) # Save audio output_file = f"output_{uuid.uuid4().hex}.wav" speech_np = speech.cpu().numpy().squeeze() wavfile.write(output_file, rate=22050, data=speech_np) return FileResponse(output_file, media_type="audio/wav") except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) @app.post("/sentiment/") def analyze_sentiment(request: SentimentRequest): try: result = sentiment_model(request.text) return {"result": result} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) @app.post("/legal-parse/") def parse_legal_document(request: LegalDocRequest): try: # Placeholder logic (replace with training-based custom logic) keywords = ["contract", "agreement", "party", "terms"] found_keywords = [kw for kw in keywords if kw in request.text.lower()] return {"identified_keywords": found_keywords, "domain": request.domain} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)})