Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -109,11 +109,11 @@ def root():
|
|
109 |
class ReviewInput(BaseModel):
|
110 |
text: str
|
111 |
model: str = "distilbert-base-uncased-finetuned-sst-2-english"
|
112 |
-
industry: str =
|
113 |
aspects: bool = False
|
114 |
-
follow_up: str = None
|
115 |
-
product_category: str = None
|
116 |
-
device: str = None
|
117 |
|
118 |
class BulkReviewInput(BaseModel):
|
119 |
reviews: list[str]
|
@@ -141,6 +141,11 @@ sentiment_pipelines = {
|
|
141 |
"nlptown/bert-base-multilingual-uncased-sentiment": pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
|
142 |
}
|
143 |
|
|
|
|
|
|
|
|
|
|
|
144 |
@app.post("/bulk/")
|
145 |
async def bulk(data: BulkReviewInput, x_api_key: str = Header(None)):
|
146 |
if x_api_key != VALID_API_KEY:
|
@@ -164,22 +169,14 @@ async def bulk(data: BulkReviewInput, x_api_key: str = Header(None)):
|
|
164 |
"sentiment": label,
|
165 |
"emotion": emotions[i][0]["label"],
|
166 |
"aspects": [],
|
167 |
-
"product_category": data.product_category[i] if data.product_category else None,
|
168 |
-
"device": data.device[i] if data.device else None,
|
169 |
-
"industry": data.industry[i] if data.industry else None,
|
170 |
}
|
171 |
results.append(result)
|
172 |
|
173 |
return {"results": results}
|
174 |
|
175 |
-
@app.post("/upload/")
|
176 |
-
async def upload(file: UploadFile = File(...)):
|
177 |
-
content = await file.read()
|
178 |
-
decoded = content.decode("utf-8")
|
179 |
-
reader = csv.DictReader(StringIO(decoded))
|
180 |
-
reviews = [row["review"] for row in reader if "review" in row]
|
181 |
-
return {"count": len(reviews), "sample": reviews[:3]}
|
182 |
-
|
183 |
@app.post("/analyze/")
|
184 |
async def analyze(request: Request, data: ReviewInput, x_api_key: str = Header(None), download: str = None):
|
185 |
if x_api_key != VALID_API_KEY:
|
@@ -214,9 +211,9 @@ async def analyze(request: Request, data: ReviewInput, x_api_key: str = Header(N
|
|
214 |
"emotion": emotion,
|
215 |
"aspects": aspects_list,
|
216 |
"follow_up": follow_up_response,
|
217 |
-
"product_category": data.product_category,
|
218 |
-
"device": data.device,
|
219 |
-
"industry": data.industry
|
220 |
}
|
221 |
|
222 |
@app.post("/translate/")
|
|
|
109 |
class ReviewInput(BaseModel):
|
110 |
text: str
|
111 |
model: str = "distilbert-base-uncased-finetuned-sst-2-english"
|
112 |
+
industry: Optional[str] = None
|
113 |
aspects: bool = False
|
114 |
+
follow_up: Optional[str] = None
|
115 |
+
product_category: Optional[str] = None
|
116 |
+
device: Optional[str] = None
|
117 |
|
118 |
class BulkReviewInput(BaseModel):
|
119 |
reviews: list[str]
|
|
|
141 |
"nlptown/bert-base-multilingual-uncased-sentiment": pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
|
142 |
}
|
143 |
|
144 |
+
def auto_fill(value: Optional[str], default: str = "Generic") -> str:
|
145 |
+
if not value or value.strip().lower() == "auto-detect":
|
146 |
+
return default
|
147 |
+
return value
|
148 |
+
|
149 |
@app.post("/bulk/")
|
150 |
async def bulk(data: BulkReviewInput, x_api_key: str = Header(None)):
|
151 |
if x_api_key != VALID_API_KEY:
|
|
|
169 |
"sentiment": label,
|
170 |
"emotion": emotions[i][0]["label"],
|
171 |
"aspects": [],
|
172 |
+
"product_category": auto_fill(data.product_category[i]) if data.product_category else None,
|
173 |
+
"device": auto_fill(data.device[i], "Web") if data.device else None,
|
174 |
+
"industry": auto_fill(data.industry[i]) if data.industry else None,
|
175 |
}
|
176 |
results.append(result)
|
177 |
|
178 |
return {"results": results}
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
@app.post("/analyze/")
|
181 |
async def analyze(request: Request, data: ReviewInput, x_api_key: str = Header(None), download: str = None):
|
182 |
if x_api_key != VALID_API_KEY:
|
|
|
211 |
"emotion": emotion,
|
212 |
"aspects": aspects_list,
|
213 |
"follow_up": follow_up_response,
|
214 |
+
"product_category": auto_fill(data.product_category),
|
215 |
+
"device": auto_fill(data.device, "Web"),
|
216 |
+
"industry": auto_fill(data.industry)
|
217 |
}
|
218 |
|
219 |
@app.post("/translate/")
|