MasumBhuiyan commited on
Commit
01e8001
·
verified ·
1 Parent(s): 6f2a3aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -18
app.py CHANGED
@@ -81,21 +81,35 @@ app = FastAPI()
81
  def greet_json():
82
  return {"Hello": "World!"}
83
 
84
- #
85
- # # Predict endpoint for JSON input
86
- # @app.post("/predict")
87
- # async def predict_image(file: UploadFile = File(...)):
88
- # try:
89
- # # Read and preprocess the uploaded image
90
- # image = Image.open(file.file)
91
- # image = preprocess_image(image)
92
- #
93
- # # Make prediction
94
- # model.eval()
95
- # with torch.no_grad():
96
- # output = model(image)
97
- # prediction = output.argmax(dim=1).item()
98
- #
99
- # return JSONResponse(content={"prediction": f"The digit is {prediction}"})
100
- # except Exception as e:
101
- # return JSONResponse(content={"error": str(e)}, status_code=500)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def greet_json():
82
  return {"Hello": "World!"}
83
 
84
+ # Input model for request validation
85
+ class TextAspectInput(BaseModel):
86
+ text: str
87
+ aspect: str
88
+
89
+
90
+ # Sentiment labels
91
+ SENTIMENT_LABELS = {0: "Negative", 1: "Neutral", 2: "Positive"}
92
+
93
+
94
+ # Predict endpoint
95
+ @app.post("/predict")
96
+ async def predict_sentiment(input_data: TextAspectInput):
97
+ print(input_data)
98
+ try:
99
+ # Extract text and aspect
100
+ text = input_data.text
101
+ aspect = input_data.aspect
102
+
103
+ # Process input
104
+ input_ids = process_text(text, aspect)
105
+
106
+ # Make prediction
107
+ with torch.no_grad():
108
+ logits = model(input_ids)
109
+ probabilities = torch.softmax(logits, dim=-1)
110
+ prediction = probabilities.argmax(dim=-1).item()
111
+ sentiment = SENTIMENT_LABELS[prediction]
112
+
113
+ return {"sentiment": sentiment, "probabilities": probabilities.squeeze().tolist()}
114
+ except Exception as e:
115
+ raise HTTPException(status_code=500, detail=str(e))