EMS-Royal commited on
Commit
c50573b
·
verified ·
1 Parent(s): 12b8fe9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +6 -3
main.py CHANGED
@@ -2,6 +2,10 @@ from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import torch
 
 
 
 
5
 
6
  app = FastAPI()
7
 
@@ -9,7 +13,6 @@ app = FastAPI()
9
  tokenizer = AutoTokenizer.from_pretrained("cybersectony/phishing-email-detection-distilbert_v2.4.1")
10
  model = AutoModelForSequenceClassification.from_pretrained("cybersectony/phishing-email-detection-distilbert_v2.4.1")
11
 
12
- # Define input schema
13
  class EmailInput(BaseModel):
14
  text: str
15
 
@@ -19,7 +22,7 @@ def predict(input: EmailInput):
19
  with torch.no_grad():
20
  outputs = model(**inputs)
21
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
22
-
23
  probs = predictions[0].tolist()
24
  labels = {
25
  "legitimate_email": probs[0],
@@ -28,7 +31,7 @@ def predict(input: EmailInput):
28
  "phishing_url": probs[3]
29
  }
30
  max_label = max(labels.items(), key=lambda x: x[1])
31
-
32
  return {
33
  "prediction": max_label[0],
34
  "confidence": round(max_label[1], 4),
 
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import torch
5
+ import os
6
+
7
+ # Set cache directory to avoid permission errors
8
+ os.environ["HF_HOME"] = "/tmp/hf-cache"
9
 
10
  app = FastAPI()
11
 
 
13
  tokenizer = AutoTokenizer.from_pretrained("cybersectony/phishing-email-detection-distilbert_v2.4.1")
14
  model = AutoModelForSequenceClassification.from_pretrained("cybersectony/phishing-email-detection-distilbert_v2.4.1")
15
 
 
16
  class EmailInput(BaseModel):
17
  text: str
18
 
 
22
  with torch.no_grad():
23
  outputs = model(**inputs)
24
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
25
+
26
  probs = predictions[0].tolist()
27
  labels = {
28
  "legitimate_email": probs[0],
 
31
  "phishing_url": probs[3]
32
  }
33
  max_label = max(labels.items(), key=lambda x: x[1])
34
+
35
  return {
36
  "prediction": max_label[0],
37
  "confidence": round(max_label[1], 4),