SrivarshiniGanesan commited on
Commit
54b76e3
·
verified ·
1 Parent(s): 2136020

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -28
app.py CHANGED
@@ -1,28 +1,26 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
- import torch
5
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
6
-
7
- # Initialize FastAPI app
8
- app = FastAPI(title="Stress Detection API", version="1.0")
9
-
10
- model = AutoModelForSequenceClassification.from_pretrained("SrivarshiniGanesan/finetuned-stress-model")
11
- tokenizer = AutoTokenizer.from_pretrained("SrivarshiniGanesan/finetuned-stress-model")
12
-
13
- # Request format
14
- class TextInput(BaseModel):
15
- text: str
16
-
17
- @app.post("/predict")
18
- def predict_stress(input_text: TextInput):
19
- inputs = tokenizer(input_text.text, return_tensors="pt", padding=True, truncation=True)
20
- with torch.no_grad():
21
- logits = model(**inputs).logits
22
- prediction = torch.argmax(logits, dim=-1).item()
23
-
24
- return {"text": input_text.text, "stress_prediction": "Stress" if prediction == 1 else "No Stress"}
25
-
26
- @app.get("/")
27
- def home():
28
- return {"message": "Welcome to the Stress Detection API!"}
 
1
+ from fastapi import FastAPI
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ import torch
4
+
5
+ # Initialize FastAPI app
6
+ app = FastAPI(title="Stress Detection API", version="1.0")
7
+
8
+ model = AutoModelForSequenceClassification.from_pretrained("SrivarshiniGanesan/finetuned-stress-model")
9
+ tokenizer = AutoTokenizer.from_pretrained("SrivarshiniGanesan/finetuned-stress-model")
10
+
11
+ # Request format
12
+ class TextInput(BaseModel):
13
+ text: str
14
+
15
+ @app.post("/predict")
16
+ def predict_stress(input_text: TextInput):
17
+ inputs = tokenizer(input_text.text, return_tensors="pt", padding=True, truncation=True)
18
+ with torch.no_grad():
19
+ logits = model(**inputs).logits
20
+ prediction = torch.argmax(logits, dim=-1).item()
21
+
22
+ return {"text": input_text.text, "stress_prediction": "Stress" if prediction == 1 else "No Stress"}
23
+
24
+ @app.get("/")
25
+ def home():
26
+ return {"message": "Welcome to the Stress Detection API!"}