MUFASA25 commited on
Commit
42dc091
·
verified ·
1 Parent(s): 0b9d5c7
Files changed (1) hide show
  1. app.py +226 -40
app.py CHANGED
@@ -1,53 +1,239 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # Load model and tokenizer
6
- tokenizer = AutoTokenizer.from_pretrained("cybersectony/phishing-email-detection-distilbert_v2.4.1")
7
- model = AutoModelForSequenceClassification.from_pretrained("cybersectony/phishing-email-detection-distilbert_v2.4.1")
 
 
 
 
 
 
 
 
 
8
 
9
  def predict_email(email_text):
10
- # Preprocess and tokenize
11
- inputs = tokenizer(
12
- email_text,
13
- return_tensors="pt",
14
- truncation=True,
15
- max_length=512
16
- )
17
-
18
- # Get prediction
19
- with torch.no_grad():
20
- outputs = model(**inputs)
21
- predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
22
 
23
- # Get probabilities for each class
24
- probs = predictions[0].tolist()
 
 
 
 
 
 
 
25
 
26
- # Create labels dictionary
27
- labels = {
28
- "Legitimate Email": probs[0],
29
- "Phishing URL": probs[1],
30
- "Legitimate URL": probs[2],
31
- "Phishing URL (Alt)": probs[3]
32
- }
33
 
34
- # Determine the most likely classification
35
- max_label = max(labels.items(), key=lambda x: x[1])
 
 
36
 
37
- # Format output
38
- result = f"**Prediction**: {max_label[0]}\n**Confidence**: {max_label[1]:.4f}\n\n**All Probabilities**:\n"
39
- for label, prob in labels.items():
40
- result += f"{label}: {prob:.4f}\n"
41
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # Create Gradio interface
44
- iface = gr.Interface(
45
- fn=predict_email,
46
- inputs=gr.Textbox(lines=5, placeholder="Enter the email text here..."),
47
- outputs="text",
48
- title="Phishing Email Detection",
49
- description="Enter an email text to classify it as legitimate or phishing using a DistilBERT model."
50
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # Launch the interface
53
- iface.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import logging
5
+
6
+ # Configure logging
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # Model configuration
11
+ MODEL_NAME = "cybersectony/phishing-email-detection-distilbert_v2.4.1"
12
+
13
+ # Global variables for model and tokenizer
14
+ tokenizer = None
15
+ model = None
16
 
17
+ def load_model():
18
+ """Load the model and tokenizer with error handling"""
19
+ global tokenizer, model
20
+ try:
21
+ logger.info("Loading model and tokenizer...")
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
24
+ logger.info("Model loaded successfully!")
25
+ return True
26
+ except Exception as e:
27
+ logger.error(f"Error loading model: {e}")
28
+ return False
29
 
30
  def predict_email(email_text):
31
+ """
32
+ Predict whether an email is phishing or legitimate
 
 
 
 
 
 
 
 
 
 
33
 
34
+ Args:
35
+ email_text (str): The email content to analyze
36
+
37
+ Returns:
38
+ str: Formatted prediction results
39
+ """
40
+ # Input validation
41
+ if not email_text or not email_text.strip():
42
+ return "⚠️ **Error**: Please enter some email text to analyze."
43
 
44
+ if len(email_text.strip()) < 10:
45
+ return "⚠️ **Warning**: Email text seems too short for reliable analysis. Please provide more content."
 
 
 
 
 
46
 
47
+ # Check if model is loaded
48
+ if tokenizer is None or model is None:
49
+ if not load_model():
50
+ return "❌ **Error**: Failed to load the model. Please try again later."
51
 
52
+ try:
53
+ # Preprocess and tokenize
54
+ inputs = tokenizer(
55
+ email_text,
56
+ return_tensors="pt",
57
+ truncation=True,
58
+ max_length=512,
59
+ padding=True
60
+ )
61
+
62
+ # Get prediction
63
+ with torch.no_grad():
64
+ outputs = model(**inputs)
65
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
66
+
67
+ # Get probabilities for each class
68
+ probs = predictions[0].tolist()
69
+
70
+ # Create labels dictionary
71
+ # Note: Verify these labels match your model's actual training configuration
72
+ labels = {
73
+ "Legitimate Email": probs[0],
74
+ "Phishing Email": probs[1] if len(probs) > 1 else 0.0,
75
+ "Suspicious Content": probs[2] if len(probs) > 2 else 0.0,
76
+ "Other": probs[3] if len(probs) > 3 else 0.0
77
+ }
78
+
79
+ # Remove zero probability labels
80
+ labels = {k: v for k, v in labels.items() if v > 0.001}
81
+
82
+ # Determine the most likely classification
83
+ max_label = max(labels.items(), key=lambda x: x[1])
84
+
85
+ # Determine risk level and emoji
86
+ confidence = max_label[1]
87
+ if "Phishing" in max_label[0] or "Suspicious" in max_label[0]:
88
+ if confidence > 0.8:
89
+ risk_emoji = "🚨"
90
+ risk_level = "HIGH RISK"
91
+ elif confidence > 0.6:
92
+ risk_emoji = "⚠️"
93
+ risk_level = "MEDIUM RISK"
94
+ else:
95
+ risk_emoji = "⚡"
96
+ risk_level = "LOW RISK"
97
+ else:
98
+ risk_emoji = "✅"
99
+ risk_level = "SAFE"
100
+
101
+ # Format output
102
+ result = f"{risk_emoji} **{risk_level}**\n\n"
103
+ result += f"**Primary Classification**: {max_label[0]}\n"
104
+ result += f"**Confidence**: {confidence:.1%}\n\n"
105
+ result += f"**Detailed Analysis**:\n"
106
+
107
+ for label, prob in sorted(labels.items(), key=lambda x: x[1], reverse=True):
108
+ percentage = prob * 100
109
+ bar_length = int(percentage / 5) # Scale bar to 20 chars max
110
+ bar = "█" * bar_length + "░" * (20 - bar_length)
111
+ result += f"{label}: {percentage:.1f}% {bar}\n"
112
+
113
+ # Add recommendations
114
+ if "Phishing" in max_label[0] and confidence > 0.7:
115
+ result += f"\n⚠️ **Recommendation**: This email shows signs of phishing. Do not click any links or provide personal information."
116
+ elif "Suspicious" in max_label[0] and confidence > 0.6:
117
+ result += f"\n🔍 **Recommendation**: Exercise caution with this email. Verify sender identity before taking any action."
118
+ else:
119
+ result += f"\n✅ **Recommendation**: This email appears to be legitimate, but always remain vigilant."
120
+
121
+ return result
122
+
123
+ except Exception as e:
124
+ logger.error(f"Error during prediction: {e}")
125
+ return f"❌ **Error**: An error occurred during analysis: {str(e)}"
126
+
127
+ # Example emails for demonstration
128
+ example_legitimate = """Dear Customer,
129
+
130
+ Thank you for your recent purchase from our store. Your order #ORD-2024-001234 has been successfully processed and will be shipped within 2-3 business days.
131
+
132
+ Order Details:
133
+ - Product: Wireless Headphones
134
+ - Amount: $79.99
135
+ - Shipping Address: [Your provided address]
136
+
137
+ You can track your shipment using the tracking number that will be sent to your email once the item is dispatched.
138
+
139
+ If you have any questions, please contact our customer service team.
140
+
141
+ Best regards,
142
+ Customer Service Team
143
+ TechStore Inc."""
144
+
145
+ example_phishing = """URGENT - Account Security Alert!!!
146
+
147
+ Your account has been COMPROMISED and will be SUSPENDED in 24 hours!
148
+
149
+ Immediate action required: Click here to verify your account NOW: http://security-verify-account-urgent.suspicious-domain.com/verify-now
150
+
151
+ If you don't verify within 24 hours, your account will be permanently deleted and all your data will be lost forever!
152
+
153
+ This is your FINAL WARNING - Act immediately!
154
+
155
+ Security Team
156
+ [Suspicious Bank Name]"""
157
+
158
+ # Load model on startup
159
+ load_model()
160
 
161
  # Create Gradio interface
162
+ with gr.Blocks(title="Phishing Email Detection", theme=gr.themes.Soft()) as iface:
163
+ gr.Markdown("""
164
+ # 🛡️ Phishing Email Detection System
165
+
166
+ This tool uses a DistilBERT model to analyze email content and detect potential phishing attempts.
167
+ Simply paste the email text below and get an instant security assessment.
168
+
169
+ **⚠️ Disclaimer**: This is an AI-based tool for educational purposes. Always use your judgment and follow your organization's security policies.
170
+ """)
171
+
172
+ with gr.Row():
173
+ with gr.Column(scale=2):
174
+ email_input = gr.Textbox(
175
+ lines=10,
176
+ placeholder="Paste the email content here...",
177
+ label="Email Text",
178
+ info="Enter the complete email text including headers, body, and any suspicious elements."
179
+ )
180
+
181
+ with gr.Row():
182
+ analyze_btn = gr.Button("🔍 Analyze Email", variant="primary", size="lg")
183
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
184
+
185
+ with gr.Column(scale=1):
186
+ output = gr.Textbox(
187
+ label="Analysis Results",
188
+ lines=15,
189
+ interactive=False
190
+ )
191
+
192
+ # Example section
193
+ gr.Markdown("### 📝 Try These Examples:")
194
+ with gr.Row():
195
+ with gr.Column():
196
+ gr.Markdown("**Legitimate Email Example:**")
197
+ legitimate_btn = gr.Button("Load Legitimate Email", size="sm")
198
+ with gr.Column():
199
+ gr.Markdown("**Phishing Email Example:**")
200
+ phishing_btn = gr.Button("Load Phishing Email", size="sm")
201
+
202
+ # Event handlers
203
+ analyze_btn.click(
204
+ fn=predict_email,
205
+ inputs=email_input,
206
+ outputs=output
207
+ )
208
+
209
+ clear_btn.click(
210
+ fn=lambda: ("", ""),
211
+ outputs=[email_input, output]
212
+ )
213
+
214
+ legitimate_btn.click(
215
+ fn=lambda: example_legitimate,
216
+ outputs=email_input
217
+ )
218
+
219
+ phishing_btn.click(
220
+ fn=lambda: example_phishing,
221
+ outputs=email_input
222
+ )
223
+
224
+ # Footer
225
+ gr.Markdown("""
226
+ ---
227
+ **Model**: cybersectony/phishing-email-detection-distilbert_v2.4.1
228
+ **Framework**: Transformers + DistilBERT
229
+ **Interface**: Gradio
230
+ """)
231
 
232
  # Launch the interface
233
+ if __name__ == "__main__":
234
+ iface.launch(
235
+ share=True,
236
+ server_name="0.0.0.0",
237
+ server_port=7860,
238
+ show_error=True
239
+ )