Hasitha16 commited on
Commit
7b67b3d
·
verified ·
1 Parent(s): 0e76a85

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +15 -18
main.py CHANGED
@@ -109,11 +109,11 @@ def root():
109
  class ReviewInput(BaseModel):
110
  text: str
111
  model: str = "distilbert-base-uncased-finetuned-sst-2-english"
112
- industry: str = "Generic"
113
  aspects: bool = False
114
- follow_up: str = None
115
- product_category: str = None
116
- device: str = None
117
 
118
  class BulkReviewInput(BaseModel):
119
  reviews: list[str]
@@ -141,6 +141,11 @@ sentiment_pipelines = {
141
  "nlptown/bert-base-multilingual-uncased-sentiment": pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
142
  }
143
 
 
 
 
 
 
144
  @app.post("/bulk/")
145
  async def bulk(data: BulkReviewInput, x_api_key: str = Header(None)):
146
  if x_api_key != VALID_API_KEY:
@@ -164,22 +169,14 @@ async def bulk(data: BulkReviewInput, x_api_key: str = Header(None)):
164
  "sentiment": label,
165
  "emotion": emotions[i][0]["label"],
166
  "aspects": [],
167
- "product_category": data.product_category[i] if data.product_category else None,
168
- "device": data.device[i] if data.device else None,
169
- "industry": data.industry[i] if data.industry else None,
170
  }
171
  results.append(result)
172
 
173
  return {"results": results}
174
 
175
- @app.post("/upload/")
176
- async def upload(file: UploadFile = File(...)):
177
- content = await file.read()
178
- decoded = content.decode("utf-8")
179
- reader = csv.DictReader(StringIO(decoded))
180
- reviews = [row["review"] for row in reader if "review" in row]
181
- return {"count": len(reviews), "sample": reviews[:3]}
182
-
183
  @app.post("/analyze/")
184
  async def analyze(request: Request, data: ReviewInput, x_api_key: str = Header(None), download: str = None):
185
  if x_api_key != VALID_API_KEY:
@@ -214,9 +211,9 @@ async def analyze(request: Request, data: ReviewInput, x_api_key: str = Header(N
214
  "emotion": emotion,
215
  "aspects": aspects_list,
216
  "follow_up": follow_up_response,
217
- "product_category": data.product_category,
218
- "device": data.device,
219
- "industry": data.industry
220
  }
221
 
222
  @app.post("/translate/")
 
109
  class ReviewInput(BaseModel):
110
  text: str
111
  model: str = "distilbert-base-uncased-finetuned-sst-2-english"
112
+ industry: Optional[str] = None
113
  aspects: bool = False
114
+ follow_up: Optional[str] = None
115
+ product_category: Optional[str] = None
116
+ device: Optional[str] = None
117
 
118
  class BulkReviewInput(BaseModel):
119
  reviews: list[str]
 
141
  "nlptown/bert-base-multilingual-uncased-sentiment": pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
142
  }
143
 
144
+ def auto_fill(value: Optional[str], default: str = "Generic") -> str:
145
+ if not value or value.strip().lower() == "auto-detect":
146
+ return default
147
+ return value
148
+
149
  @app.post("/bulk/")
150
  async def bulk(data: BulkReviewInput, x_api_key: str = Header(None)):
151
  if x_api_key != VALID_API_KEY:
 
169
  "sentiment": label,
170
  "emotion": emotions[i][0]["label"],
171
  "aspects": [],
172
+ "product_category": auto_fill(data.product_category[i]) if data.product_category else None,
173
+ "device": auto_fill(data.device[i], "Web") if data.device else None,
174
+ "industry": auto_fill(data.industry[i]) if data.industry else None,
175
  }
176
  results.append(result)
177
 
178
  return {"results": results}
179
 
 
 
 
 
 
 
 
 
180
  @app.post("/analyze/")
181
  async def analyze(request: Request, data: ReviewInput, x_api_key: str = Header(None), download: str = None):
182
  if x_api_key != VALID_API_KEY:
 
211
  "emotion": emotion,
212
  "aspects": aspects_list,
213
  "follow_up": follow_up_response,
214
+ "product_category": auto_fill(data.product_category),
215
+ "device": auto_fill(data.device, "Web"),
216
+ "industry": auto_fill(data.industry)
217
  }
218
 
219
  @app.post("/translate/")