from fastapi import FastAPI, HTTPException from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig import torch app = FastAPI() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load config first config = AutoConfig.from_pretrained("SrivarshiniGanesan/finetuned-stress-model") model = AutoModelForSequenceClassification.from_pretrained( "SrivarshiniGanesan/finetuned-stress-model", config=config ).to(device) tokenizer = AutoTokenizer.from_pretrained("SrivarshiniGanesan/finetuned-stress-model") @app.post("/predict/") def predict(text: str): try: inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=-1) class_labels = config.id2label if config.id2label else {0: "No Stress", 1: "Stress"} stress_idx = list(class_labels.values()).index("Stress") return {"stress_probability": probs[0, stress_idx].item()} except Exception as e: raise HTTPException( status_code=500, detail=f"Prediction failed: {str(e)}" )