Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig | |
import torch | |
app = FastAPI() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load config first | |
config = AutoConfig.from_pretrained("SrivarshiniGanesan/finetuned-stress-model") | |
model = AutoModelForSequenceClassification.from_pretrained( | |
"SrivarshiniGanesan/finetuned-stress-model", | |
config=config | |
).to(device) | |
tokenizer = AutoTokenizer.from_pretrained("SrivarshiniGanesan/finetuned-stress-model") | |
def predict(text: str): | |
try: | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = torch.softmax(outputs.logits, dim=-1) | |
class_labels = config.id2label if config.id2label else {0: "No Stress", 1: "Stress"} | |
stress_idx = list(class_labels.values()).index("Stress") | |
return {"stress_probability": probs[0, stress_idx].item()} | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Prediction failed: {str(e)}" | |
) | |