Hasitha16 commited on
Commit
e280c8a
·
verified ·
1 Parent(s): 7f93129

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -10
main.py CHANGED
@@ -1,7 +1,8 @@
1
- from fastapi import FastAPI, Request, Header, HTTPException
2
- from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
3
  from fastapi.openapi.utils import get_openapi
4
  from fastapi.openapi.docs import get_swagger_ui_html
 
5
  from pydantic import BaseModel
6
  from transformers import pipeline
7
  from io import StringIO
@@ -19,6 +20,14 @@ app = FastAPI(
19
  redoc_url="/redoc"
20
  )
21
 
 
 
 
 
 
 
 
 
22
  @app.get("/docs", include_in_schema=False)
23
  def custom_swagger_ui():
24
  return get_swagger_ui_html(
@@ -97,7 +106,6 @@ def root():
97
  </html>
98
  """
99
 
100
- # --- Models ---
101
  class ReviewInput(BaseModel):
102
  text: str
103
  model: str = "distilbert-base-uncased-finetuned-sst-2-english"
@@ -123,11 +131,9 @@ class TranslationInput(BaseModel):
123
  text: str
124
  target_lang: str = "fr"
125
 
126
- # --- Auth & Logging ---
127
  VALID_API_KEY = "my-secret-key"
128
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
129
 
130
- # --- Load Models Once ---
131
  summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
132
  emotion_model = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=1)
133
  sentiment_pipelines = {
@@ -135,7 +141,6 @@ sentiment_pipelines = {
135
  "nlptown/bert-base-multilingual-uncased-sentiment": pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
136
  }
137
 
138
- # --- Analyze (Bulk) ---
139
  @app.post("/bulk/")
140
  async def bulk(data: BulkReviewInput, x_api_key: str = Header(None)):
141
  if x_api_key != VALID_API_KEY:
@@ -167,6 +172,14 @@ async def bulk(data: BulkReviewInput, x_api_key: str = Header(None)):
167
 
168
  return {"results": results}
169
 
 
 
 
 
 
 
 
 
170
  @app.post("/analyze/")
171
  async def analyze(request: Request, data: ReviewInput, x_api_key: str = Header(None), download: str = None):
172
  if x_api_key != VALID_API_KEY:
@@ -205,13 +218,12 @@ async def analyze(request: Request, data: ReviewInput, x_api_key: str = Header(N
205
  "device": data.device,
206
  "industry": data.industry
207
  }
208
- # --- Translate ---
209
  @app.post("/translate/")
210
  async def translate(data: TranslationInput):
211
  translator = pipeline("translation", model=f"Helsinki-NLP/opus-mt-en-{data.target_lang}")
212
  return {"translated_text": translator(data.text)[0]["translation_text"]}
213
 
214
- # --- LLM Agent Chat ---
215
  @app.post("/chat/")
216
  async def chat(input: ChatInput, x_api_key: str = Header(None)):
217
  if x_api_key != VALID_API_KEY:
@@ -229,7 +241,6 @@ def chat_llm(question, context):
229
  )
230
  return res.choices[0].message.content.strip()
231
 
232
- # --- Custom OpenAPI ---
233
  def custom_openapi():
234
  if app.openapi_schema:
235
  return app.openapi_schema
@@ -246,4 +257,4 @@ Summarize reviews, detect sentiment/emotion, extract aspects, tag metadata, and
246
  app.openapi_schema = openapi_schema
247
  return app.openapi_schema
248
 
249
- app.openapi = custom_openapi
 
1
+ from fastapi import FastAPI, Request, Header, HTTPException, UploadFile, File
2
+ from fastapi.responses import HTMLResponse, JSONResponse
3
  from fastapi.openapi.utils import get_openapi
4
  from fastapi.openapi.docs import get_swagger_ui_html
5
+ from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
  from transformers import pipeline
8
  from io import StringIO
 
20
  redoc_url="/redoc"
21
  )
22
 
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
  @app.get("/docs", include_in_schema=False)
32
  def custom_swagger_ui():
33
  return get_swagger_ui_html(
 
106
  </html>
107
  """
108
 
 
109
  class ReviewInput(BaseModel):
110
  text: str
111
  model: str = "distilbert-base-uncased-finetuned-sst-2-english"
 
131
  text: str
132
  target_lang: str = "fr"
133
 
 
134
  VALID_API_KEY = "my-secret-key"
135
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
136
 
 
137
  summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
138
  emotion_model = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=1)
139
  sentiment_pipelines = {
 
141
  "nlptown/bert-base-multilingual-uncased-sentiment": pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
142
  }
143
 
 
144
  @app.post("/bulk/")
145
  async def bulk(data: BulkReviewInput, x_api_key: str = Header(None)):
146
  if x_api_key != VALID_API_KEY:
 
172
 
173
  return {"results": results}
174
 
175
+ @app.post("/upload/")
176
+ async def upload(file: UploadFile = File(...)):
177
+ content = await file.read()
178
+ decoded = content.decode("utf-8")
179
+ reader = csv.DictReader(StringIO(decoded))
180
+ reviews = [row["review"] for row in reader if "review" in row]
181
+ return {"count": len(reviews), "sample": reviews[:3]}
182
+
183
  @app.post("/analyze/")
184
  async def analyze(request: Request, data: ReviewInput, x_api_key: str = Header(None), download: str = None):
185
  if x_api_key != VALID_API_KEY:
 
218
  "device": data.device,
219
  "industry": data.industry
220
  }
221
+
222
  @app.post("/translate/")
223
  async def translate(data: TranslationInput):
224
  translator = pipeline("translation", model=f"Helsinki-NLP/opus-mt-en-{data.target_lang}")
225
  return {"translated_text": translator(data.text)[0]["translation_text"]}
226
 
 
227
  @app.post("/chat/")
228
  async def chat(input: ChatInput, x_api_key: str = Header(None)):
229
  if x_api_key != VALID_API_KEY:
 
241
  )
242
  return res.choices[0].message.content.strip()
243
 
 
244
  def custom_openapi():
245
  if app.openapi_schema:
246
  return app.openapi_schema
 
257
  app.openapi_schema = openapi_schema
258
  return app.openapi_schema
259
 
260
+ app.openapi = custom_openapi