File size: 3,308 Bytes
af59fff
364b411
 
2dacdf9
 
 
 
 
 
 
364b411
 
 
 
 
af59fff
 
 
 
cf0cbad
2dacdf9
cf0cbad
af59fff
cf0cbad
af59fff
cf0cbad
 
 
 
2dacdf9
 
 
af59fff
cf0cbad
2dacdf9
 
af59fff
 
cf0cbad
 
af59fff
 
 
 
cf0cbad
af59fff
364b411
cf0cbad
364b411
 
 
 
 
 
 
 
 
 
 
 
2dacdf9
364b411
 
 
af59fff
364b411
cf0cbad
 
 
 
364b411
cf0cbad
 
af59fff
 
 
 
 
cf0cbad
364b411
af59fff
 
 
 
 
cf0cbad
af59fff
cf0cbad
364b411
 
 
 
af59fff
364b411
af59fff
364b411
af59fff
 
 
 
364b411
 
 
 
cf0cbad
364b411
cf0cbad
 
 
 
 
 
af59fff
364b411
af59fff
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import (
    AutoTokenizer,
    AutoProcessor,
    BarkModel,
    pipeline,
    AutoModelForSequenceClassification
)
import scipy.io.wavfile as wavfile
import uuid
import os
from typing import Optional

# Initialize FastAPI app
app = FastAPI()

# Configuration
MODEL_PATH = "/app/models/suno-bark"
SENTIMENT_MODEL_PATH = "/app/models/sentiment"

# Load all models in a single try-except block
try:
    # TTS Model
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    processor = AutoProcessor.from_pretrained(MODEL_PATH)
    model = BarkModel.from_pretrained(MODEL_PATH)
    
    # Sentiment Analysis Model (pre-downloaded)
    sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_PATH)
    sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_PATH)
    sentiment_pipeline = pipeline(
        "sentiment-analysis",
        model=sentiment_model,
        tokenizer=sentiment_tokenizer,
        truncation=True,
        max_length=512
    )
    
    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
except Exception as e:
    raise RuntimeError(f"Initialization failed: {str(e)}")

# Request models
class TTSRequest(BaseModel):
    text: str

class SentimentRequest(BaseModel):
    text: str

class LegalDocRequest(BaseModel):
    text: str
    domain: Optional[str] = "general"

@app.get("/")
def root():
    return {"message": "Welcome to Kinyarwanda NLP API"}

@app.post("/tts/")
def text_to_speech(request: TTSRequest):
    output_file = None
    try:
        inputs = processor(request.text, return_tensors="pt").to(device)
        with torch.no_grad():
            speech = model.generate(**inputs)
        
        output_file = f"output_{uuid.uuid4().hex}.wav"
        wavfile.write(output_file, rate=24000, data=speech.cpu().numpy().squeeze())
        
        return FileResponse(
            output_file,
            media_type="audio/wav",
            filename=output_file
        )
    
    except Exception as e:
        return JSONResponse(
            status_code=500,
            content={"error": f"TTS failed: {str(e)}"}
        )
    
    finally:
        if output_file and os.path.exists(output_file):
            os.remove(output_file)

@app.post("/sentiment/")
def analyze_sentiment(request: SentimentRequest):
    try:
        result = sentiment_pipeline(request.text)
        return {"result": result}
        
    except Exception as e:
        return JSONResponse(
            status_code=500,
            content={"error": f"Sentiment analysis failed: {str(e)}"}
        )

@app.post("/legal-parse/")
def parse_legal_document(request: LegalDocRequest):
    try:
        keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"]
        found_keywords = [kw for kw in keywords if kw in request.text.lower()]
        
        return {
            "identified_keywords": found_keywords,
            "domain": request.domain,
            "status": "success"
        }
        
    except Exception as e:
        return JSONResponse(
            status_code=500,
            content={"error": f"Legal parsing failed: {str(e)}"}
        )