Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
-
from fastapi import FastAPI, Request, Header, HTTPException
|
| 2 |
-
from fastapi.responses import HTMLResponse, JSONResponse
|
| 3 |
from fastapi.openapi.utils import get_openapi
|
| 4 |
from fastapi.openapi.docs import get_swagger_ui_html
|
|
|
|
| 5 |
from pydantic import BaseModel
|
| 6 |
from transformers import pipeline
|
| 7 |
from io import StringIO
|
|
@@ -19,6 +20,14 @@ app = FastAPI(
|
|
| 19 |
redoc_url="/redoc"
|
| 20 |
)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
@app.get("/docs", include_in_schema=False)
|
| 23 |
def custom_swagger_ui():
|
| 24 |
return get_swagger_ui_html(
|
|
@@ -97,7 +106,6 @@ def root():
|
|
| 97 |
</html>
|
| 98 |
"""
|
| 99 |
|
| 100 |
-
# --- Models ---
|
| 101 |
class ReviewInput(BaseModel):
|
| 102 |
text: str
|
| 103 |
model: str = "distilbert-base-uncased-finetuned-sst-2-english"
|
|
@@ -123,11 +131,9 @@ class TranslationInput(BaseModel):
|
|
| 123 |
text: str
|
| 124 |
target_lang: str = "fr"
|
| 125 |
|
| 126 |
-
# --- Auth & Logging ---
|
| 127 |
VALID_API_KEY = "my-secret-key"
|
| 128 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 129 |
|
| 130 |
-
# --- Load Models Once ---
|
| 131 |
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
|
| 132 |
emotion_model = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=1)
|
| 133 |
sentiment_pipelines = {
|
|
@@ -135,7 +141,6 @@ sentiment_pipelines = {
|
|
| 135 |
"nlptown/bert-base-multilingual-uncased-sentiment": pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
|
| 136 |
}
|
| 137 |
|
| 138 |
-
# --- Analyze (Bulk) ---
|
| 139 |
@app.post("/bulk/")
|
| 140 |
async def bulk(data: BulkReviewInput, x_api_key: str = Header(None)):
|
| 141 |
if x_api_key != VALID_API_KEY:
|
|
@@ -167,6 +172,14 @@ async def bulk(data: BulkReviewInput, x_api_key: str = Header(None)):
|
|
| 167 |
|
| 168 |
return {"results": results}
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
@app.post("/analyze/")
|
| 171 |
async def analyze(request: Request, data: ReviewInput, x_api_key: str = Header(None), download: str = None):
|
| 172 |
if x_api_key != VALID_API_KEY:
|
|
@@ -205,13 +218,12 @@ async def analyze(request: Request, data: ReviewInput, x_api_key: str = Header(N
|
|
| 205 |
"device": data.device,
|
| 206 |
"industry": data.industry
|
| 207 |
}
|
| 208 |
-
|
| 209 |
@app.post("/translate/")
|
| 210 |
async def translate(data: TranslationInput):
|
| 211 |
translator = pipeline("translation", model=f"Helsinki-NLP/opus-mt-en-{data.target_lang}")
|
| 212 |
return {"translated_text": translator(data.text)[0]["translation_text"]}
|
| 213 |
|
| 214 |
-
# --- LLM Agent Chat ---
|
| 215 |
@app.post("/chat/")
|
| 216 |
async def chat(input: ChatInput, x_api_key: str = Header(None)):
|
| 217 |
if x_api_key != VALID_API_KEY:
|
|
@@ -229,7 +241,6 @@ def chat_llm(question, context):
|
|
| 229 |
)
|
| 230 |
return res.choices[0].message.content.strip()
|
| 231 |
|
| 232 |
-
# --- Custom OpenAPI ---
|
| 233 |
def custom_openapi():
|
| 234 |
if app.openapi_schema:
|
| 235 |
return app.openapi_schema
|
|
@@ -246,4 +257,4 @@ Summarize reviews, detect sentiment/emotion, extract aspects, tag metadata, and
|
|
| 246 |
app.openapi_schema = openapi_schema
|
| 247 |
return app.openapi_schema
|
| 248 |
|
| 249 |
-
app.openapi = custom_openapi
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Request, Header, HTTPException, UploadFile, File
|
| 2 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
| 3 |
from fastapi.openapi.utils import get_openapi
|
| 4 |
from fastapi.openapi.docs import get_swagger_ui_html
|
| 5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
from pydantic import BaseModel
|
| 7 |
from transformers import pipeline
|
| 8 |
from io import StringIO
|
|
|
|
| 20 |
redoc_url="/redoc"
|
| 21 |
)
|
| 22 |
|
| 23 |
+
app.add_middleware(
|
| 24 |
+
CORSMiddleware,
|
| 25 |
+
allow_origins=["*"],
|
| 26 |
+
allow_credentials=True,
|
| 27 |
+
allow_methods=["*"],
|
| 28 |
+
allow_headers=["*"],
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
@app.get("/docs", include_in_schema=False)
|
| 32 |
def custom_swagger_ui():
|
| 33 |
return get_swagger_ui_html(
|
|
|
|
| 106 |
</html>
|
| 107 |
"""
|
| 108 |
|
|
|
|
| 109 |
class ReviewInput(BaseModel):
|
| 110 |
text: str
|
| 111 |
model: str = "distilbert-base-uncased-finetuned-sst-2-english"
|
|
|
|
| 131 |
text: str
|
| 132 |
target_lang: str = "fr"
|
| 133 |
|
|
|
|
| 134 |
VALID_API_KEY = "my-secret-key"
|
| 135 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 136 |
|
|
|
|
| 137 |
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
|
| 138 |
emotion_model = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=1)
|
| 139 |
sentiment_pipelines = {
|
|
|
|
| 141 |
"nlptown/bert-base-multilingual-uncased-sentiment": pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
|
| 142 |
}
|
| 143 |
|
|
|
|
| 144 |
@app.post("/bulk/")
|
| 145 |
async def bulk(data: BulkReviewInput, x_api_key: str = Header(None)):
|
| 146 |
if x_api_key != VALID_API_KEY:
|
|
|
|
| 172 |
|
| 173 |
return {"results": results}
|
| 174 |
|
| 175 |
+
@app.post("/upload/")
|
| 176 |
+
async def upload(file: UploadFile = File(...)):
|
| 177 |
+
content = await file.read()
|
| 178 |
+
decoded = content.decode("utf-8")
|
| 179 |
+
reader = csv.DictReader(StringIO(decoded))
|
| 180 |
+
reviews = [row["review"] for row in reader if "review" in row]
|
| 181 |
+
return {"count": len(reviews), "sample": reviews[:3]}
|
| 182 |
+
|
| 183 |
@app.post("/analyze/")
|
| 184 |
async def analyze(request: Request, data: ReviewInput, x_api_key: str = Header(None), download: str = None):
|
| 185 |
if x_api_key != VALID_API_KEY:
|
|
|
|
| 218 |
"device": data.device,
|
| 219 |
"industry": data.industry
|
| 220 |
}
|
| 221 |
+
|
| 222 |
@app.post("/translate/")
|
| 223 |
async def translate(data: TranslationInput):
|
| 224 |
translator = pipeline("translation", model=f"Helsinki-NLP/opus-mt-en-{data.target_lang}")
|
| 225 |
return {"translated_text": translator(data.text)[0]["translation_text"]}
|
| 226 |
|
|
|
|
| 227 |
@app.post("/chat/")
|
| 228 |
async def chat(input: ChatInput, x_api_key: str = Header(None)):
|
| 229 |
if x_api_key != VALID_API_KEY:
|
|
|
|
| 241 |
)
|
| 242 |
return res.choices[0].message.content.strip()
|
| 243 |
|
|
|
|
| 244 |
def custom_openapi():
|
| 245 |
if app.openapi_schema:
|
| 246 |
return app.openapi_schema
|
|
|
|
| 257 |
app.openapi_schema = openapi_schema
|
| 258 |
return app.openapi_schema
|
| 259 |
|
| 260 |
+
app.openapi = custom_openapi
|