from fastapi import FastAPI, Form from fastapi.responses import FileResponse, JSONResponse from pydantic import BaseModel import torch from transformers import AutoTokenizer, AutoProcessor, BarkModel, pipeline import scipy.io.wavfile as wavfile import uuid import os from typing import Optional # Ensure proper model loading from pre-downloaded path MODEL_PATH = "/app/models/suno-bark" # Load models and processors once during startup try: tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) processor = AutoProcessor.from_pretrained(MODEL_PATH) model = BarkModel.from_pretrained(MODEL_PATH) # Load sentiment analysis pipeline sentiment_model = pipeline( "sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment" ) except Exception as e: raise RuntimeError(f"Model loading failed: {str(e)}") # Device configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Initialize FastAPI app app = FastAPI() # Request models class TTSRequest(BaseModel): text: str class SentimentRequest(BaseModel): text: str class LegalDocRequest(BaseModel): text: str domain: Optional[str] = "general" @app.get("/") def root(): return {"message": "Welcome to Kinyarwanda NLP API"} @app.post("/tts/") def text_to_speech(request: TTSRequest): try: inputs = processor(request.text, return_tensors="pt").to(device) with torch.no_grad(): speech = model.generate(**inputs) output_file = f"output_{uuid.uuid4().hex}.wav" wavfile.write(output_file, rate=24000, data=speech.cpu().numpy().squeeze()) return FileResponse(output_file, media_type="audio/wav", filename=output_file) except Exception as e: return JSONResponse(status_code=500, content={"error": f"TTS failed: {str(e)}"}) finally: if os.path.exists(output_file): os.remove(output_file) @app.post("/sentiment/") def analyze_sentiment(request: SentimentRequest): try: result = sentiment_model(request.text) return {"result": result} except Exception as e: return JSONResponse(status_code=500, content={"error": f"Sentiment analysis failed: {str(e)}"}) @app.post("/legal-parse/") def parse_legal_document(request: LegalDocRequest): try: # Basic keyword extraction (replace with trained model in production) keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"] found_keywords = [kw for kw in keywords if kw in request.text.lower()] return { "identified_keywords": found_keywords, "domain": request.domain, "status": "success" } except Exception as e: return JSONResponse(status_code=500, content={"error": f"Legal parsing failed: {str(e)}"})