import os from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch app = FastAPI() # Explicitly set cache directory in function (avoid HF_HOME confusion) tokenizer = AutoTokenizer.from_pretrained( "cybersectony/phishing-email-detection-distilbert_v2.4.1", cache_dir="/tmp" ) model = AutoModelForSequenceClassification.from_pretrained( "cybersectony/phishing-email-detection-distilbert_v2.4.1", cache_dir="/tmp" ) class EmailInput(BaseModel): text: str @app.post("/predict") def predict(input: EmailInput): inputs = tokenizer(input.text, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) probs = predictions[0].tolist() labels = { "legitimate_email": probs[0], "phishing_email": probs[1], "legitimate_url": probs[2], "phishing_url": probs[3] } max_label = max(labels.items(), key=lambda x: x[1]) return { "prediction": max_label[0], "confidence": round(max_label[1], 4), "all_probabilities": labels }