SrivarshiniGanesan commited on
Commit
2136020
·
verified ·
1 Parent(s): e147874

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -0
app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
+ import torch
5
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
6
+
7
+ # Initialize FastAPI app
8
+ app = FastAPI(title="Stress Detection API", version="1.0")
9
+
10
+ model = AutoModelForSequenceClassification.from_pretrained("SrivarshiniGanesan/finetuned-stress-model")
11
+ tokenizer = AutoTokenizer.from_pretrained("SrivarshiniGanesan/finetuned-stress-model")
12
+
13
+ # Request format
14
+ class TextInput(BaseModel):
15
+ text: str
16
+
17
+ @app.post("/predict")
18
+ def predict_stress(input_text: TextInput):
19
+ inputs = tokenizer(input_text.text, return_tensors="pt", padding=True, truncation=True)
20
+ with torch.no_grad():
21
+ logits = model(**inputs).logits
22
+ prediction = torch.argmax(logits, dim=-1).item()
23
+
24
+ return {"text": input_text.text, "stress_prediction": "Stress" if prediction == 1 else "No Stress"}
25
+
26
+ @app.get("/")
27
+ def home():
28
+ return {"message": "Welcome to the Stress Detection API!"}