DeepStress / app.py
SrivarshiniGanesan's picture
Update app.py
ff0671e verified
raw
history blame
982 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model_path = "./finetuned_model"
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
app = FastAPI(title="Stress Detection API", version="1.0")
# Request format
class TextInput(BaseModel):
text: str
@app.post("/predict")
def predict_stress(input_text: TextInput):
inputs = tokenizer(input_text.text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
logits = model(**inputs).logits
prediction = torch.argmax(logits, dim=-1).item()
return {"text": input_text.text, "stress_prediction": "Stress" if prediction == 1 else "No Stress"}
@app.get("/")
def home():
return {"message": "Welcome to the Stress Detection API!"}