black44 commited on
Commit
80bc30e
·
verified ·
1 Parent(s): cc0ca05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -40
app.py CHANGED
@@ -1,4 +1,5 @@
1
- from fastapi import FastAPI
 
2
  from pydantic import BaseModel
3
  import torch
4
  from transformers import (
@@ -13,21 +14,26 @@ import uuid
13
  import os
14
  from typing import Optional
15
 
16
- # Initialize FastAPI app
17
- app = FastAPI()
18
 
19
- # Configuration
20
  MODEL_PATH = "/app/models/suno-bark"
21
  SENTIMENT_MODEL_PATH = "/app/models/sentiment"
 
22
 
23
- # Load all models in a single try-except block
 
 
 
 
24
  try:
25
- # TTS Model
26
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
27
  processor = AutoProcessor.from_pretrained(MODEL_PATH)
28
  model = BarkModel.from_pretrained(MODEL_PATH)
29
-
30
- # Sentiment Analysis Model (pre-downloaded)
31
  sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_PATH)
32
  sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_PATH)
33
  sentiment_pipeline = pipeline(
@@ -37,15 +43,15 @@ try:
37
  truncation=True,
38
  max_length=512
39
  )
40
-
41
- # Device configuration
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
  model.to(device)
44
-
45
  except Exception as e:
46
- raise RuntimeError(f"Initialization failed: {str(e)}")
47
 
48
- # Request models
49
  class TTSRequest(BaseModel):
50
  text: str
51
 
@@ -56,63 +62,55 @@ class LegalDocRequest(BaseModel):
56
  text: str
57
  domain: Optional[str] = "general"
58
 
 
59
  @app.get("/")
60
  def root():
61
  return {"message": "Welcome to Kinyarwanda NLP API"}
62
 
 
63
  @app.post("/tts/")
64
  def text_to_speech(request: TTSRequest):
65
- output_file = None
 
66
  try:
67
  inputs = processor(request.text, return_tensors="pt").to(device)
68
  with torch.no_grad():
69
- speech = model.generate(**inputs)
70
-
71
- output_file = f"output_{uuid.uuid4().hex}.wav"
72
- wavfile.write(output_file, rate=24000, data=speech.cpu().numpy().squeeze())
73
-
74
  return FileResponse(
75
  output_file,
76
  media_type="audio/wav",
77
- filename=output_file
78
  )
79
-
80
  except Exception as e:
81
- return JSONResponse(
82
- status_code=500,
83
- content={"error": f"TTS failed: {str(e)}"}
84
- )
85
-
86
  finally:
87
- if output_file and os.path.exists(output_file):
88
  os.remove(output_file)
89
 
 
90
  @app.post("/sentiment/")
91
  def analyze_sentiment(request: SentimentRequest):
92
  try:
93
  result = sentiment_pipeline(request.text)
94
  return {"result": result}
95
-
96
  except Exception as e:
97
- return JSONResponse(
98
- status_code=500,
99
- content={"error": f"Sentiment analysis failed: {str(e)}"}
100
- )
101
 
 
102
  @app.post("/legal-parse/")
103
  def parse_legal_document(request: LegalDocRequest):
104
  try:
105
  keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"]
106
- found_keywords = [kw for kw in keywords if kw in request.text.lower()]
107
-
108
  return {
109
- "identified_keywords": found_keywords,
110
  "domain": request.domain,
111
  "status": "success"
112
  }
113
-
114
  except Exception as e:
115
- return JSONResponse(
116
- status_code=500,
117
- content={"error": f"Legal parsing failed: {str(e)}"}
118
- )
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import FileResponse, JSONResponse
3
  from pydantic import BaseModel
4
  import torch
5
  from transformers import (
 
14
  import os
15
  from typing import Optional
16
 
17
+ # FastAPI instance
18
+ app = FastAPI(title="Kinyarwanda NLP API", version="1.0")
19
 
20
+ # Config
21
  MODEL_PATH = "/app/models/suno-bark"
22
  SENTIMENT_MODEL_PATH = "/app/models/sentiment"
23
+ SAMPLE_RATE = 24000
24
 
25
+ # Ensure working directory for audio
26
+ AUDIO_DIR = "/tmp/audio"
27
+ os.makedirs(AUDIO_DIR, exist_ok=True)
28
+
29
+ # Load models
30
  try:
31
+ # TTS
32
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
33
  processor = AutoProcessor.from_pretrained(MODEL_PATH)
34
  model = BarkModel.from_pretrained(MODEL_PATH)
35
+
36
+ # Sentiment
37
  sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_PATH)
38
  sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_PATH)
39
  sentiment_pipeline = pipeline(
 
43
  truncation=True,
44
  max_length=512
45
  )
46
+
47
+ # Device config
48
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
  model.to(device)
50
+
51
  except Exception as e:
52
+ raise RuntimeError(f"Model initialization failed: {e}")
53
 
54
+ # Request schemas
55
  class TTSRequest(BaseModel):
56
  text: str
57
 
 
62
  text: str
63
  domain: Optional[str] = "general"
64
 
65
+ # Root route
66
  @app.get("/")
67
  def root():
68
  return {"message": "Welcome to Kinyarwanda NLP API"}
69
 
70
+ # Text-to-Speech endpoint
71
  @app.post("/tts/")
72
  def text_to_speech(request: TTSRequest):
73
+ output_file = os.path.join(AUDIO_DIR, f"tts_{uuid.uuid4().hex}.wav")
74
+
75
  try:
76
  inputs = processor(request.text, return_tensors="pt").to(device)
77
  with torch.no_grad():
78
+ audio_array = model.generate(**inputs)
79
+
80
+ wavfile.write(output_file, rate=SAMPLE_RATE, data=audio_array.cpu().numpy().squeeze())
81
+
 
82
  return FileResponse(
83
  output_file,
84
  media_type="audio/wav",
85
+ filename=os.path.basename(output_file)
86
  )
87
+
88
  except Exception as e:
89
+ raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")
90
+
 
 
 
91
  finally:
92
+ if os.path.exists(output_file):
93
  os.remove(output_file)
94
 
95
+ # Sentiment Analysis endpoint
96
  @app.post("/sentiment/")
97
  def analyze_sentiment(request: SentimentRequest):
98
  try:
99
  result = sentiment_pipeline(request.text)
100
  return {"result": result}
 
101
  except Exception as e:
102
+ raise HTTPException(status_code=500, detail=f"Sentiment analysis failed: {str(e)}")
 
 
 
103
 
104
+ # Legal Parsing endpoint
105
  @app.post("/legal-parse/")
106
  def parse_legal_document(request: LegalDocRequest):
107
  try:
108
  keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"]
109
+ found = [kw for kw in keywords if kw in request.text.lower()]
 
110
  return {
111
+ "identified_keywords": found,
112
  "domain": request.domain,
113
  "status": "success"
114
  }
 
115
  except Exception as e:
116
+ raise HTTPException(status_code=500, detail=f"Legal parsing failed: {str(e)}")