from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse, JSONResponse from pydantic import BaseModel import torch from transformers import ( AutoTokenizer, AutoProcessor, BarkModel, pipeline, AutoModelForSequenceClassification ) import scipy.io.wavfile as wavfile import uuid import os from typing import Optional # FastAPI instance app = FastAPI(title="Kinyarwanda NLP API", version="1.0") # Config MODEL_PATH = "/app/models/suno-bark" SENTIMENT_MODEL_PATH = "/app/models/sentiment" SAMPLE_RATE = 24000 # Ensure working directory for audio AUDIO_DIR = "/tmp/audio" os.makedirs(AUDIO_DIR, exist_ok=True) # Load models try: # TTS tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) processor = AutoProcessor.from_pretrained(MODEL_PATH) model = BarkModel.from_pretrained(MODEL_PATH) # Sentiment sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_PATH) sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_PATH) sentiment_pipeline = pipeline( "sentiment-analysis", model=sentiment_model, tokenizer=sentiment_tokenizer, truncation=True, max_length=512 ) # Device config device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) except Exception as e: raise RuntimeError(f"Model initialization failed: {e}") # Request schemas class TTSRequest(BaseModel): text: str class SentimentRequest(BaseModel): text: str class LegalDocRequest(BaseModel): text: str domain: Optional[str] = "general" # Root route @app.get("/") def root(): return {"message": "Welcome to Kinyarwanda NLP API"} # Text-to-Speech endpoint @app.post("/tts/") def text_to_speech(request: TTSRequest): output_file = os.path.join(AUDIO_DIR, f"tts_{uuid.uuid4().hex}.wav") try: inputs = processor(request.text, return_tensors="pt").to(device) with torch.no_grad(): audio_array = model.generate(**inputs) wavfile.write(output_file, rate=SAMPLE_RATE, data=audio_array.cpu().numpy().squeeze()) return FileResponse( output_file, media_type="audio/wav", filename=os.path.basename(output_file) ) except Exception as e: raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}") finally: if os.path.exists(output_file): os.remove(output_file) # Sentiment Analysis endpoint @app.post("/sentiment/") def analyze_sentiment(request: SentimentRequest): try: result = sentiment_pipeline(request.text) return {"result": result} except Exception as e: raise HTTPException(status_code=500, detail=f"Sentiment analysis failed: {str(e)}") # Legal Parsing endpoint @app.post("/legal-parse/") def parse_legal_document(request: LegalDocRequest): try: keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"] found = [kw for kw in keywords if kw in request.text.lower()] return { "identified_keywords": found, "domain": request.domain, "status": "success" } except Exception as e: raise HTTPException(status_code=500, detail=f"Legal parsing failed: {str(e)}")