black44 commited on
Commit
af59fff
·
verified ·
1 Parent(s): d1416af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -23
app.py CHANGED
@@ -1,5 +1,4 @@
1
- from fastapi import FastAPI, Form
2
- from fastapi.responses import FileResponse, JSONResponse
3
  from pydantic import BaseModel
4
  import torch
5
  from transformers import AutoTokenizer, AutoProcessor, BarkModel, pipeline
@@ -8,30 +7,34 @@ import uuid
8
  import os
9
  from typing import Optional
10
 
11
- # Ensure proper model loading from pre-downloaded path
 
 
 
12
  MODEL_PATH = "/app/models/suno-bark"
 
13
 
14
- # Load models and processors once during startup
15
  try:
 
16
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
17
  processor = AutoProcessor.from_pretrained(MODEL_PATH)
18
  model = BarkModel.from_pretrained(MODEL_PATH)
19
 
20
- # Load sentiment analysis pipeline
21
- sentiment_model = pipeline(
22
  "sentiment-analysis",
23
- model="nlptown/bert-base-multilingual-uncased-sentiment"
 
 
24
  )
25
 
 
 
 
 
26
  except Exception as e:
27
- raise RuntimeError(f"Model loading failed: {str(e)}")
28
-
29
- # Device configuration
30
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- model.to(device)
32
-
33
- # Initialize FastAPI app
34
- app = FastAPI()
35
 
36
  # Request models
37
  class TTSRequest(BaseModel):
@@ -46,10 +49,11 @@ class LegalDocRequest(BaseModel):
46
 
47
  @app.get("/")
48
  def root():
49
- return {"message": "Welcome to Kinyarwanda NLP API"}
50
 
51
  @app.post("/tts/")
52
  def text_to_speech(request: TTSRequest):
 
53
  try:
54
  inputs = processor(request.text, return_tensors="pt").to(device)
55
  with torch.no_grad():
@@ -58,26 +62,37 @@ def text_to_speech(request: TTSRequest):
58
  output_file = f"output_{uuid.uuid4().hex}.wav"
59
  wavfile.write(output_file, rate=24000, data=speech.cpu().numpy().squeeze())
60
 
61
- return FileResponse(output_file, media_type="audio/wav", filename=output_file)
 
 
 
 
62
 
63
  except Exception as e:
64
- return JSONResponse(status_code=500, content={"error": f"TTS failed: {str(e)}"})
 
 
 
 
65
  finally:
66
- if os.path.exists(output_file):
67
  os.remove(output_file)
68
 
69
  @app.post("/sentiment/")
70
  def analyze_sentiment(request: SentimentRequest):
71
  try:
72
- result = sentiment_model(request.text)
73
  return {"result": result}
 
74
  except Exception as e:
75
- return JSONResponse(status_code=500, content={"error": f"Sentiment analysis failed: {str(e)}"})
 
 
 
76
 
77
  @app.post("/legal-parse/")
78
  def parse_legal_document(request: LegalDocRequest):
79
  try:
80
- # Basic keyword extraction (replace with trained model in production)
81
  keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"]
82
  found_keywords = [kw for kw in keywords if kw in request.text.lower()]
83
 
@@ -86,5 +101,9 @@ def parse_legal_document(request: LegalDocRequest):
86
  "domain": request.domain,
87
  "status": "success"
88
  }
 
89
  except Exception as e:
90
- return JSONResponse(status_code=500, content={"error": f"Legal parsing failed: {str(e)}"})
 
 
 
 
1
+ from fastapi import FastAPI
 
2
  from pydantic import BaseModel
3
  import torch
4
  from transformers import AutoTokenizer, AutoProcessor, BarkModel, pipeline
 
7
  import os
8
  from typing import Optional
9
 
10
+ # Initialize FastAPI app
11
+ app = FastAPI()
12
+
13
+ # Configuration
14
  MODEL_PATH = "/app/models/suno-bark"
15
+ SENTIMENT_MODEL = "cardiffnlp/twitter-xlm-roberta-base-sentiment" # PyTorch-compatible model
16
 
17
+ # Load all models in a single try-except block
18
  try:
19
+ # TTS Model
20
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
21
  processor = AutoProcessor.from_pretrained(MODEL_PATH)
22
  model = BarkModel.from_pretrained(MODEL_PATH)
23
 
24
+ # Sentiment Analysis Model
25
+ sentiment_pipeline = pipeline(
26
  "sentiment-analysis",
27
+ model=SENTIMENT_MODEL,
28
+ truncation=True,
29
+ max_length=512
30
  )
31
 
32
+ # Device configuration
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ model.to(device)
35
+
36
  except Exception as e:
37
+ raise RuntimeError(f"Initialization failed: {str(e)}")
 
 
 
 
 
 
 
38
 
39
  # Request models
40
  class TTSRequest(BaseModel):
 
49
 
50
  @app.get("/")
51
  def root():
52
+ return {"message": "Welcome to Kinyarwanda-Engine"}
53
 
54
  @app.post("/tts/")
55
  def text_to_speech(request: TTSRequest):
56
+ output_file = None
57
  try:
58
  inputs = processor(request.text, return_tensors="pt").to(device)
59
  with torch.no_grad():
 
62
  output_file = f"output_{uuid.uuid4().hex}.wav"
63
  wavfile.write(output_file, rate=24000, data=speech.cpu().numpy().squeeze())
64
 
65
+ return FileResponse(
66
+ output_file,
67
+ media_type="audio/wav",
68
+ filename=output_file
69
+ )
70
 
71
  except Exception as e:
72
+ return JSONResponse(
73
+ status_code=500,
74
+ content={"error": f"TTS failed: {str(e)}"}
75
+ )
76
+
77
  finally:
78
+ if output_file and os.path.exists(output_file):
79
  os.remove(output_file)
80
 
81
  @app.post("/sentiment/")
82
  def analyze_sentiment(request: SentimentRequest):
83
  try:
84
+ result = sentiment_pipeline(request.text)
85
  return {"result": result}
86
+
87
  except Exception as e:
88
+ return JSONResponse(
89
+ status_code=500,
90
+ content={"error": f"Sentiment analysis failed: {str(e)}"}
91
+ )
92
 
93
  @app.post("/legal-parse/")
94
  def parse_legal_document(request: LegalDocRequest):
95
  try:
 
96
  keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"]
97
  found_keywords = [kw for kw in keywords if kw in request.text.lower()]
98
 
 
101
  "domain": request.domain,
102
  "status": "success"
103
  }
104
+
105
  except Exception as e:
106
+ return JSONResponse(
107
+ status_code=500,
108
+ content={"error": f"Legal parsing failed: {str(e)}"}
109
+ )