SrivarshiniGanesan commited on
Commit
36f5dbc
·
verified ·
1 Parent(s): 69857da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -15
app.py CHANGED
@@ -1,25 +1,33 @@
1
- from fastapi import FastAPI
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
3
  import torch
4
 
5
- # Verify model configuration
6
- config = AutoConfig.from_pretrained("SrivarshiniGanesan/finetuned-stress-model")
7
- print(config)
8
- # Initialize FastAPI app
9
  app = FastAPI()
 
10
 
11
- model = AutoModelForSequenceClassification.from_pretrained("SrivarshiniGanesan/finetuned-stress-model",config=config)
 
 
 
 
 
12
  tokenizer = AutoTokenizer.from_pretrained("SrivarshiniGanesan/finetuned-stress-model")
13
 
14
- @app.get("/")
15
- def home():
16
- return {"message": "Stress Prediction API is running"}
17
-
18
  @app.post("/predict/")
19
  def predict(text: str):
20
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
21
- with torch.no_grad():
22
- outputs = model(**inputs)
 
 
23
  probs = torch.softmax(outputs.logits, dim=-1)
24
- stress_prob = probs[:, 1].item() # Probability of "Stress" class
25
- return {"stress_probability": stress_prob}
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
3
  import torch
4
 
 
 
 
 
5
  app = FastAPI()
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
+ # Load config first
9
+ config = AutoConfig.from_pretrained("SrivarshiniGanesan/finetuned-stress-model")
10
+ model = AutoModelForSequenceClassification.from_pretrained(
11
+ "SrivarshiniGanesan/finetuned-stress-model",
12
+ config=config
13
+ ).to(device)
14
  tokenizer = AutoTokenizer.from_pretrained("SrivarshiniGanesan/finetuned-stress-model")
15
 
 
 
 
 
16
  @app.post("/predict/")
17
  def predict(text: str):
18
+ try:
19
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
20
+ with torch.no_grad():
21
+ outputs = model(**inputs)
22
+
23
  probs = torch.softmax(outputs.logits, dim=-1)
24
+ class_labels = config.id2label if config.id2label else {0: "No Stress", 1: "Stress"}
25
+ stress_idx = list(class_labels.values()).index("Stress")
26
+
27
+ return {"stress_probability": probs[0, stress_idx].item()}
28
+
29
+ except Exception as e:
30
+ raise HTTPException(
31
+ status_code=500,
32
+ detail=f"Prediction failed: {str(e)}"
33
+ )