black44 commited on
Commit
cf0cbad
·
verified ·
1 Parent(s): 0cb9548

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -36
app.py CHANGED
@@ -1,35 +1,39 @@
1
- from transformers import AutoTokenizer, AutoProcessor
2
  from fastapi import FastAPI, Form
3
  from fastapi.responses import FileResponse, JSONResponse
4
  from pydantic import BaseModel
5
  import torch
6
- from transformers import AutoProcessor, BarkModel, pipeline
7
  import scipy.io.wavfile as wavfile
8
  import uuid
9
  import os
10
  from typing import Optional
11
 
12
- # Load the pre-downloaded model and tokenizer
13
- tokenizer = AutoTokenizer.from_pretrained("/app/models/suno-bark")
14
- processor = AutoProcessor.from_pretrained("/app/models/suno-bark")
15
-
16
- # Rest of your application code...
17
-
18
- # Load TTS model and processor
19
- processor = AutoProcessor.from_pretrained("suno/bark")
20
- model = BarkModel.from_pretrained("suno/bark")
21
-
22
- # Load sentiment analysis pipeline (using multilingual model)
23
- sentiment_model = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
24
-
25
- # Ensure model is on CPU or CUDA if available
26
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
27
  model.to(device)
28
 
29
- # FastAPI app
30
  app = FastAPI()
31
 
32
- # Endpoint input models
33
  class TTSRequest(BaseModel):
34
  text: str
35
 
@@ -47,20 +51,20 @@ def root():
47
  @app.post("/tts/")
48
  def text_to_speech(request: TTSRequest):
49
  try:
50
- # Generate speech
51
- inputs = processor(request.text, return_tensors="pt")
52
- inputs = {k: v.to(device) for k, v in inputs.items()}
53
- speech = model.generate(**inputs)
54
-
55
- # Save audio
56
  output_file = f"output_{uuid.uuid4().hex}.wav"
57
- speech_np = speech.cpu().numpy().squeeze()
58
- wavfile.write(output_file, rate=22050, data=speech_np)
59
-
60
- return FileResponse(output_file, media_type="audio/wav")
61
-
62
  except Exception as e:
63
- return JSONResponse(status_code=500, content={"error": str(e)})
 
 
 
64
 
65
  @app.post("/sentiment/")
66
  def analyze_sentiment(request: SentimentRequest):
@@ -68,14 +72,19 @@ def analyze_sentiment(request: SentimentRequest):
68
  result = sentiment_model(request.text)
69
  return {"result": result}
70
  except Exception as e:
71
- return JSONResponse(status_code=500, content={"error": str(e)})
72
 
73
  @app.post("/legal-parse/")
74
  def parse_legal_document(request: LegalDocRequest):
75
  try:
76
- # Placeholder logic (replace with training-based custom logic)
77
- keywords = ["contract", "agreement", "party", "terms"]
78
  found_keywords = [kw for kw in keywords if kw in request.text.lower()]
79
- return {"identified_keywords": found_keywords, "domain": request.domain}
 
 
 
 
 
80
  except Exception as e:
81
- return JSONResponse(status_code=500, content={"error": str(e)})
 
 
1
  from fastapi import FastAPI, Form
2
  from fastapi.responses import FileResponse, JSONResponse
3
  from pydantic import BaseModel
4
  import torch
5
+ from transformers import AutoTokenizer, AutoProcessor, BarkModel, pipeline
6
  import scipy.io.wavfile as wavfile
7
  import uuid
8
  import os
9
  from typing import Optional
10
 
11
+ # Ensure proper model loading from pre-downloaded path
12
+ MODEL_PATH = "/app/models/suno-bark"
13
+
14
+ # Load models and processors once during startup
15
+ try:
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
17
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
18
+ model = BarkModel.from_pretrained(MODEL_PATH)
19
+
20
+ # Load sentiment analysis pipeline
21
+ sentiment_model = pipeline(
22
+ "sentiment-analysis",
23
+ model="nlptown/bert-base-multilingual-uncased-sentiment"
24
+ )
25
+
26
+ except Exception as e:
27
+ raise RuntimeError(f"Model loading failed: {str(e)}")
28
+
29
+ # Device configuration
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  model.to(device)
32
 
33
+ # Initialize FastAPI app
34
  app = FastAPI()
35
 
36
+ # Request models
37
  class TTSRequest(BaseModel):
38
  text: str
39
 
 
51
  @app.post("/tts/")
52
  def text_to_speech(request: TTSRequest):
53
  try:
54
+ inputs = processor(request.text, return_tensors="pt").to(device)
55
+ with torch.no_grad():
56
+ speech = model.generate(**inputs)
57
+
 
 
58
  output_file = f"output_{uuid.uuid4().hex}.wav"
59
+ wavfile.write(output_file, rate=24000, data=speech.cpu().numpy().squeeze())
60
+
61
+ return FileResponse(output_file, media_type="audio/wav", filename=output_file)
62
+
 
63
  except Exception as e:
64
+ return JSONResponse(status_code=500, content={"error": f"TTS failed: {str(e)}"})
65
+ finally:
66
+ if os.path.exists(output_file):
67
+ os.remove(output_file)
68
 
69
  @app.post("/sentiment/")
70
  def analyze_sentiment(request: SentimentRequest):
 
72
  result = sentiment_model(request.text)
73
  return {"result": result}
74
  except Exception as e:
75
+ return JSONResponse(status_code=500, content={"error": f"Sentiment analysis failed: {str(e)}"})
76
 
77
  @app.post("/legal-parse/")
78
  def parse_legal_document(request: LegalDocRequest):
79
  try:
80
+ # Basic keyword extraction (replace with trained model in production)
81
+ keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"]
82
  found_keywords = [kw for kw in keywords if kw in request.text.lower()]
83
+
84
+ return {
85
+ "identified_keywords": found_keywords,
86
+ "domain": request.domain,
87
+ "status": "success"
88
+ }
89
  except Exception as e:
90
+ return JSONResponse(status_code=500, content={"error": f"Legal parsing failed: {str(e)}"})