Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
-
from fastapi import FastAPI, Request, Header, HTTPException
|
2 |
-
from fastapi.responses import HTMLResponse, JSONResponse
|
3 |
from fastapi.openapi.utils import get_openapi
|
4 |
from fastapi.openapi.docs import get_swagger_ui_html
|
|
|
5 |
from pydantic import BaseModel
|
6 |
from transformers import pipeline
|
7 |
from io import StringIO
|
@@ -19,6 +20,14 @@ app = FastAPI(
|
|
19 |
redoc_url="/redoc"
|
20 |
)
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
@app.get("/docs", include_in_schema=False)
|
23 |
def custom_swagger_ui():
|
24 |
return get_swagger_ui_html(
|
@@ -97,7 +106,6 @@ def root():
|
|
97 |
</html>
|
98 |
"""
|
99 |
|
100 |
-
# --- Models ---
|
101 |
class ReviewInput(BaseModel):
|
102 |
text: str
|
103 |
model: str = "distilbert-base-uncased-finetuned-sst-2-english"
|
@@ -123,11 +131,9 @@ class TranslationInput(BaseModel):
|
|
123 |
text: str
|
124 |
target_lang: str = "fr"
|
125 |
|
126 |
-
# --- Auth & Logging ---
|
127 |
VALID_API_KEY = "my-secret-key"
|
128 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
129 |
|
130 |
-
# --- Load Models Once ---
|
131 |
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
|
132 |
emotion_model = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=1)
|
133 |
sentiment_pipelines = {
|
@@ -135,7 +141,6 @@ sentiment_pipelines = {
|
|
135 |
"nlptown/bert-base-multilingual-uncased-sentiment": pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
|
136 |
}
|
137 |
|
138 |
-
# --- Analyze (Bulk) ---
|
139 |
@app.post("/bulk/")
|
140 |
async def bulk(data: BulkReviewInput, x_api_key: str = Header(None)):
|
141 |
if x_api_key != VALID_API_KEY:
|
@@ -167,6 +172,14 @@ async def bulk(data: BulkReviewInput, x_api_key: str = Header(None)):
|
|
167 |
|
168 |
return {"results": results}
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
@app.post("/analyze/")
|
171 |
async def analyze(request: Request, data: ReviewInput, x_api_key: str = Header(None), download: str = None):
|
172 |
if x_api_key != VALID_API_KEY:
|
@@ -205,13 +218,12 @@ async def analyze(request: Request, data: ReviewInput, x_api_key: str = Header(N
|
|
205 |
"device": data.device,
|
206 |
"industry": data.industry
|
207 |
}
|
208 |
-
|
209 |
@app.post("/translate/")
|
210 |
async def translate(data: TranslationInput):
|
211 |
translator = pipeline("translation", model=f"Helsinki-NLP/opus-mt-en-{data.target_lang}")
|
212 |
return {"translated_text": translator(data.text)[0]["translation_text"]}
|
213 |
|
214 |
-
# --- LLM Agent Chat ---
|
215 |
@app.post("/chat/")
|
216 |
async def chat(input: ChatInput, x_api_key: str = Header(None)):
|
217 |
if x_api_key != VALID_API_KEY:
|
@@ -229,7 +241,6 @@ def chat_llm(question, context):
|
|
229 |
)
|
230 |
return res.choices[0].message.content.strip()
|
231 |
|
232 |
-
# --- Custom OpenAPI ---
|
233 |
def custom_openapi():
|
234 |
if app.openapi_schema:
|
235 |
return app.openapi_schema
|
@@ -246,4 +257,4 @@ Summarize reviews, detect sentiment/emotion, extract aspects, tag metadata, and
|
|
246 |
app.openapi_schema = openapi_schema
|
247 |
return app.openapi_schema
|
248 |
|
249 |
-
app.openapi = custom_openapi
|
|
|
1 |
+
from fastapi import FastAPI, Request, Header, HTTPException, UploadFile, File
|
2 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
3 |
from fastapi.openapi.utils import get_openapi
|
4 |
from fastapi.openapi.docs import get_swagger_ui_html
|
5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
6 |
from pydantic import BaseModel
|
7 |
from transformers import pipeline
|
8 |
from io import StringIO
|
|
|
20 |
redoc_url="/redoc"
|
21 |
)
|
22 |
|
23 |
+
app.add_middleware(
|
24 |
+
CORSMiddleware,
|
25 |
+
allow_origins=["*"],
|
26 |
+
allow_credentials=True,
|
27 |
+
allow_methods=["*"],
|
28 |
+
allow_headers=["*"],
|
29 |
+
)
|
30 |
+
|
31 |
@app.get("/docs", include_in_schema=False)
|
32 |
def custom_swagger_ui():
|
33 |
return get_swagger_ui_html(
|
|
|
106 |
</html>
|
107 |
"""
|
108 |
|
|
|
109 |
class ReviewInput(BaseModel):
|
110 |
text: str
|
111 |
model: str = "distilbert-base-uncased-finetuned-sst-2-english"
|
|
|
131 |
text: str
|
132 |
target_lang: str = "fr"
|
133 |
|
|
|
134 |
VALID_API_KEY = "my-secret-key"
|
135 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
136 |
|
|
|
137 |
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
|
138 |
emotion_model = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=1)
|
139 |
sentiment_pipelines = {
|
|
|
141 |
"nlptown/bert-base-multilingual-uncased-sentiment": pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
|
142 |
}
|
143 |
|
|
|
144 |
@app.post("/bulk/")
|
145 |
async def bulk(data: BulkReviewInput, x_api_key: str = Header(None)):
|
146 |
if x_api_key != VALID_API_KEY:
|
|
|
172 |
|
173 |
return {"results": results}
|
174 |
|
175 |
+
@app.post("/upload/")
|
176 |
+
async def upload(file: UploadFile = File(...)):
|
177 |
+
content = await file.read()
|
178 |
+
decoded = content.decode("utf-8")
|
179 |
+
reader = csv.DictReader(StringIO(decoded))
|
180 |
+
reviews = [row["review"] for row in reader if "review" in row]
|
181 |
+
return {"count": len(reviews), "sample": reviews[:3]}
|
182 |
+
|
183 |
@app.post("/analyze/")
|
184 |
async def analyze(request: Request, data: ReviewInput, x_api_key: str = Header(None), download: str = None):
|
185 |
if x_api_key != VALID_API_KEY:
|
|
|
218 |
"device": data.device,
|
219 |
"industry": data.industry
|
220 |
}
|
221 |
+
|
222 |
@app.post("/translate/")
|
223 |
async def translate(data: TranslationInput):
|
224 |
translator = pipeline("translation", model=f"Helsinki-NLP/opus-mt-en-{data.target_lang}")
|
225 |
return {"translated_text": translator(data.text)[0]["translation_text"]}
|
226 |
|
|
|
227 |
@app.post("/chat/")
|
228 |
async def chat(input: ChatInput, x_api_key: str = Header(None)):
|
229 |
if x_api_key != VALID_API_KEY:
|
|
|
241 |
)
|
242 |
return res.choices[0].message.content.strip()
|
243 |
|
|
|
244 |
def custom_openapi():
|
245 |
if app.openapi_schema:
|
246 |
return app.openapi_schema
|
|
|
257 |
app.openapi_schema = openapi_schema
|
258 |
return app.openapi_schema
|
259 |
|
260 |
+
app.openapi = custom_openapi
|