File size: 3,355 Bytes
80bc30e
 
364b411
 
2dacdf9
 
 
 
 
 
 
364b411
 
 
 
 
80bc30e
 
af59fff
80bc30e
cf0cbad
2dacdf9
80bc30e
cf0cbad
80bc30e
 
 
 
 
cf0cbad
80bc30e
cf0cbad
 
 
80bc30e
 
2dacdf9
 
af59fff
cf0cbad
2dacdf9
 
af59fff
 
cf0cbad
80bc30e
 
af59fff
 
80bc30e
cf0cbad
80bc30e
364b411
80bc30e
364b411
 
 
 
 
 
 
 
 
 
80bc30e
364b411
 
2dacdf9
364b411
80bc30e
364b411
 
80bc30e
 
364b411
cf0cbad
 
80bc30e
 
 
 
af59fff
 
 
80bc30e
af59fff
80bc30e
364b411
80bc30e
 
cf0cbad
80bc30e
cf0cbad
364b411
80bc30e
364b411
 
 
af59fff
364b411
 
80bc30e
364b411
80bc30e
364b411
 
 
cf0cbad
80bc30e
cf0cbad
80bc30e
cf0cbad
 
 
364b411
80bc30e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
import torch
from transformers import (
    AutoTokenizer,
    AutoProcessor,
    BarkModel,
    pipeline,
    AutoModelForSequenceClassification
)
import scipy.io.wavfile as wavfile
import uuid
import os
from typing import Optional

# FastAPI instance
app = FastAPI(title="Kinyarwanda NLP API", version="1.0")

# Config
MODEL_PATH = "/app/models/suno-bark"
SENTIMENT_MODEL_PATH = "/app/models/sentiment"
SAMPLE_RATE = 24000

# Ensure working directory for audio
AUDIO_DIR = "/tmp/audio"
os.makedirs(AUDIO_DIR, exist_ok=True)

# Load models
try:
    # TTS
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    processor = AutoProcessor.from_pretrained(MODEL_PATH)
    model = BarkModel.from_pretrained(MODEL_PATH)

    # Sentiment
    sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_PATH)
    sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_PATH)
    sentiment_pipeline = pipeline(
        "sentiment-analysis",
        model=sentiment_model,
        tokenizer=sentiment_tokenizer,
        truncation=True,
        max_length=512
    )

    # Device config
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

except Exception as e:
    raise RuntimeError(f"Model initialization failed: {e}")

# Request schemas
class TTSRequest(BaseModel):
    text: str

class SentimentRequest(BaseModel):
    text: str

class LegalDocRequest(BaseModel):
    text: str
    domain: Optional[str] = "general"

# Root route
@app.get("/")
def root():
    return {"message": "Welcome to Kinyarwanda NLP API"}

# Text-to-Speech endpoint
@app.post("/tts/")
def text_to_speech(request: TTSRequest):
    output_file = os.path.join(AUDIO_DIR, f"tts_{uuid.uuid4().hex}.wav")

    try:
        inputs = processor(request.text, return_tensors="pt").to(device)
        with torch.no_grad():
            audio_array = model.generate(**inputs)

        wavfile.write(output_file, rate=SAMPLE_RATE, data=audio_array.cpu().numpy().squeeze())

        return FileResponse(
            output_file,
            media_type="audio/wav",
            filename=os.path.basename(output_file)
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")

    finally:
        if os.path.exists(output_file):
            os.remove(output_file)

# Sentiment Analysis endpoint
@app.post("/sentiment/")
def analyze_sentiment(request: SentimentRequest):
    try:
        result = sentiment_pipeline(request.text)
        return {"result": result}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Sentiment analysis failed: {str(e)}")

# Legal Parsing endpoint
@app.post("/legal-parse/")
def parse_legal_document(request: LegalDocRequest):
    try:
        keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"]
        found = [kw for kw in keywords if kw in request.text.lower()]
        return {
            "identified_keywords": found,
            "domain": request.domain,
            "status": "success"
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Legal parsing failed: {str(e)}")