MUFASA25 commited on
Commit
62a31d8
·
verified ·
1 Parent(s): 511e73c

model respnsive

Browse files
Files changed (1) hide show
  1. app.py +170 -105
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import logging
 
5
 
6
  # Configure logging
7
  logging.basicConfig(level=logging.INFO)
@@ -21,33 +22,68 @@ def load_model():
21
  logger.info("Loading model and tokenizer...")
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  logger.info("Model loaded successfully!")
25
  return True
26
  except Exception as e:
27
  logger.error(f"Error loading model: {e}")
28
  return False
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def predict_email(email_text):
31
  """
32
- Predict whether an email is phishing or legitimate
33
-
34
- Args:
35
- email_text (str): The email content to analyze
36
-
37
- Returns:
38
- str: Formatted prediction results
39
  """
40
  # Input validation
41
  if not email_text or not email_text.strip():
42
  return "⚠️ **Error**: Please enter some email text to analyze."
43
 
44
- if len(email_text.strip()) < 10:
45
- return "⚠️ **Warning**: Email text seems too short for reliable analysis. Please provide more content."
46
 
47
  # Check if model is loaded
48
  if tokenizer is None or model is None:
49
  if not load_model():
50
- return "❌ **Error**: Failed to load the model. Please try again later."
51
 
52
  try:
53
  # Preprocess and tokenize
@@ -59,32 +95,61 @@ def predict_email(email_text):
59
  padding=True
60
  )
61
 
62
- # Get prediction
63
  with torch.no_grad():
64
  outputs = model(**inputs)
65
- predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
 
 
 
66
 
67
- # Get probabilities for each class
68
  probs = predictions[0].tolist()
69
 
70
- # Create labels dictionary
71
- # Note: Verify these labels match your model's actual training configuration
72
- labels = {
73
- "Legitimate Email": probs[0],
74
- "Phishing Email": probs[1] if len(probs) > 1 else 0.0,
75
- "Suspicious Content": probs[2] if len(probs) > 2 else 0.0,
76
- "Other": probs[3] if len(probs) > 3 else 0.0
77
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- # Remove zero probability labels
80
- labels = {k: v for k, v in labels.items() if v > 0.001}
 
 
 
 
81
 
82
- # Determine the most likely classification
83
  max_label = max(labels.items(), key=lambda x: x[1])
84
 
85
  # Determine risk level and emoji
86
  confidence = max_label[1]
87
- if "Phishing" in max_label[0] or "Suspicious" in max_label[0]:
 
 
88
  if confidence > 0.8:
89
  risk_emoji = "🚨"
90
  risk_level = "HIGH RISK"
@@ -95,145 +160,145 @@ def predict_email(email_text):
95
  risk_emoji = "⚡"
96
  risk_level = "LOW RISK"
97
  else:
98
- risk_emoji = "✅"
99
- risk_level = "SAFE"
 
 
 
 
 
 
 
100
 
101
- # Format output
102
  result = f"{risk_emoji} **{risk_level}**\n\n"
103
- result += f"**Primary Classification**: {max_label[0]}\n"
104
  result += f"**Confidence**: {confidence:.1%}\n\n"
105
  result += f"**Detailed Analysis**:\n"
106
 
 
107
  for label, prob in sorted(labels.items(), key=lambda x: x[1], reverse=True):
108
  percentage = prob * 100
109
- bar_length = int(percentage / 5) # Scale bar to 20 chars max
110
- bar = "█" * bar_length + "░" * (20 - bar_length)
111
- result += f"{label}: {percentage:.1f}% {bar}\n"
112
-
113
- # Add recommendations
114
- if "Phishing" in max_label[0] and confidence > 0.7:
115
- result += f"\n⚠️ **Recommendation**: This email shows signs of phishing. Do not click any links or provide personal information."
116
- elif "Suspicious" in max_label[0] and confidence > 0.6:
117
- result += f"\n🔍 **Recommendation**: Exercise caution with this email. Verify sender identity before taking any action."
 
 
 
 
 
 
118
  else:
119
- result += f"\n **Recommendation**: This email appears to be legitimate, but always remain vigilant."
120
 
121
  return result
122
 
123
  except Exception as e:
124
- logger.error(f"Error during prediction: {e}")
125
- return f"❌ **Error**: An error occurred during analysis: {str(e)}"
126
 
127
- # Example emails for demonstration
128
  example_legitimate = """Dear Customer,
129
 
130
- Thank you for your recent purchase from our store. Your order #ORD-2024-001234 has been successfully processed and will be shipped within 2-3 business days.
131
 
132
  Order Details:
133
- - Product: Wireless Headphones
134
  - Amount: $79.99
135
- - Shipping Address: [Your provided address]
136
 
137
- You can track your shipment using the tracking number that will be sent to your email once the item is dispatched.
138
-
139
- If you have any questions, please contact our customer service team.
140
 
141
  Best regards,
142
- Customer Service Team
143
- TechStore Inc."""
 
144
 
145
- example_phishing = """URGENT - Account Security Alert!!!
146
 
147
- Your account has been COMPROMISED and will be SUSPENDED in 24 hours!
148
 
149
- Immediate action required: Click here to verify your account NOW: http://security-verify-account-urgent.suspicious-domain.com/verify-now
150
 
151
- If you don't verify within 24 hours, your account will be permanently deleted and all your data will be lost forever!
152
 
153
- This is your FINAL WARNING - Act immediately!
154
 
155
- Security Team
156
- [Suspicious Bank Name]"""
 
 
 
 
 
 
157
 
158
  # Load model on startup
159
  load_model()
160
 
161
- # Create Gradio interface
162
- with gr.Blocks(title="Phishing Email Detection", theme=gr.themes.Soft()) as iface:
163
  gr.Markdown("""
164
- # 🛡️ Phishing Email Detection System
165
-
166
- This tool uses a DistilBERT model to analyze email content and detect potential phishing attempts.
167
- Simply paste the email text below and get an instant security assessment.
168
 
169
- **⚠️ Disclaimer**: This is an AI-based tool for educational purposes. Always use your judgment and follow your organization's security policies.
170
  """)
171
 
172
  with gr.Row():
173
  with gr.Column(scale=2):
174
  email_input = gr.Textbox(
175
  lines=10,
176
- placeholder="Paste the email content here...",
177
- label="Email Text",
178
- info="Enter the complete email text including headers, body, and any suspicious elements."
179
  )
180
 
181
  with gr.Row():
182
  analyze_btn = gr.Button("🔍 Analyze Email", variant="primary", size="lg")
183
  clear_btn = gr.Button("🗑️ Clear", variant="secondary")
184
 
185
- with gr.Column(scale=1):
186
  output = gr.Textbox(
187
- label="Analysis Results",
188
- lines=15,
189
- interactive=False
 
190
  )
191
 
192
- # Example section
193
- gr.Markdown("### 📝 Try These Examples:")
194
  with gr.Row():
195
- with gr.Column():
196
- gr.Markdown("**Legitimate Email Example:**")
197
- legitimate_btn = gr.Button("Load Legitimate Email", size="sm")
198
- with gr.Column():
199
- gr.Markdown("**Phishing Email Example:**")
200
- phishing_btn = gr.Button("Load Phishing Email", size="sm")
201
 
202
  # Event handlers
203
- analyze_btn.click(
204
- fn=predict_email,
205
- inputs=email_input,
206
- outputs=output
207
- )
208
-
209
- clear_btn.click(
210
- fn=lambda: ("", ""),
211
- outputs=[email_input, output]
212
- )
213
-
214
- legitimate_btn.click(
215
- fn=lambda: example_legitimate,
216
- outputs=email_input
217
- )
218
 
219
- phishing_btn.click(
220
- fn=lambda: example_phishing,
221
- outputs=email_input
222
- )
223
 
224
- # Footer
225
  gr.Markdown("""
226
  ---
227
- **Model**: cybersectony/phishing-email-detection-distilbert_v2.4.1
228
- **Framework**: Transformers + DistilBERT
229
- **Interface**: Gradio
230
  """)
231
 
232
- # Launch the interface
233
  if __name__ == "__main__":
234
  iface.launch(
235
  share=True,
236
- server_name="0.0.0.0",
237
  server_port=7860,
238
- show_error=True
 
239
  )
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import logging
5
+ import numpy as np
6
 
7
  # Configure logging
8
  logging.basicConfig(level=logging.INFO)
 
22
  logger.info("Loading model and tokenizer...")
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
25
+
26
+ # Debug: Print model configuration
27
+ logger.info(f"Model config: {model.config}")
28
+ logger.info(f"Number of labels: {model.config.num_labels}")
29
+ if hasattr(model.config, 'id2label'):
30
+ logger.info(f"Label mapping: {model.config.id2label}")
31
+
32
+ # Test model with simple input to check if it's working
33
+ test_input = "Hello world"
34
+ inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512)
35
+ with torch.no_grad():
36
+ outputs = model(**inputs)
37
+ test_probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
38
+ logger.info(f"Test probabilities: {test_probs[0].tolist()}")
39
+
40
  logger.info("Model loaded successfully!")
41
  return True
42
  except Exception as e:
43
  logger.error(f"Error loading model: {e}")
44
  return False
45
 
46
+ def get_colored_bar(percentage, label):
47
+ """Create colored progress bar based on percentage and label type"""
48
+ # Determine color based on percentage and label
49
+ if "phishing" in label.lower() or "suspicious" in label.lower():
50
+ # Red scale for dangerous content
51
+ if percentage >= 70:
52
+ color = "🟥" # High danger - red
53
+ elif percentage >= 40:
54
+ color = "🟠" # Medium danger - orange
55
+ else:
56
+ color = "🟡" # Low danger - yellow
57
+ else:
58
+ # Green scale for legitimate content
59
+ if percentage >= 70:
60
+ color = "🟢" # High confidence - green
61
+ elif percentage >= 40:
62
+ color = "🟡" # Medium confidence - yellow
63
+ else:
64
+ color = "⚪" # Low confidence - white
65
+
66
+ # Create bar (scale to 20 characters)
67
+ bar_length = max(1, int(percentage / 5)) # Ensure at least 1 if percentage > 0
68
+ bar = color * bar_length + "⚪" * (20 - bar_length)
69
+
70
+ return bar
71
+
72
  def predict_email(email_text):
73
  """
74
+ Enhanced prediction function with proper model output handling
 
 
 
 
 
 
75
  """
76
  # Input validation
77
  if not email_text or not email_text.strip():
78
  return "⚠️ **Error**: Please enter some email text to analyze."
79
 
80
+ if len(email_text.strip()) < 5:
81
+ return "⚠️ **Warning**: Email text too short for reliable analysis."
82
 
83
  # Check if model is loaded
84
  if tokenizer is None or model is None:
85
  if not load_model():
86
+ return "❌ **Error**: Failed to load the model."
87
 
88
  try:
89
  # Preprocess and tokenize
 
95
  padding=True
96
  )
97
 
98
+ # Get prediction with proper handling
99
  with torch.no_grad():
100
  outputs = model(**inputs)
101
+ # Apply temperature scaling to prevent overconfidence
102
+ temperature = 1.5
103
+ scaled_logits = outputs.logits / temperature
104
+ predictions = torch.nn.functional.softmax(scaled_logits, dim=-1)
105
 
106
+ # Get probabilities
107
  probs = predictions[0].tolist()
108
 
109
+ # Log raw outputs for debugging
110
+ logger.info(f"Raw logits: {outputs.logits[0].tolist()}")
111
+ logger.info(f"Scaled probabilities: {probs}")
112
+
113
+ # Get proper labels from model config or use fallback
114
+ if hasattr(model.config, 'id2label') and model.config.id2label:
115
+ labels = {model.config.id2label[i]: probs[i] for i in range(len(probs))}
116
+ else:
117
+ # Fallback - check the actual model output dimension
118
+ if len(probs) == 2:
119
+ labels = {
120
+ "Legitimate Email": probs[0],
121
+ "Phishing Email": probs[1]
122
+ }
123
+ elif len(probs) == 4:
124
+ labels = {
125
+ "Legitimate Email": probs[0],
126
+ "Phishing Email": probs[1],
127
+ "Suspicious Content": probs[2],
128
+ "Spam Email": probs[3]
129
+ }
130
+ else:
131
+ # Generic labels
132
+ labels = {f"Class {i}": probs[i] for i in range(len(probs))}
133
+
134
+ # Check if model is giving reasonable outputs
135
+ prob_variance = np.var(probs)
136
+ max_prob = max(probs)
137
 
138
+ # If variance is too low, the model might not be working properly
139
+ if prob_variance < 0.01 and max_prob > 0.99:
140
+ logger.warning("Model showing signs of overconfidence or poor calibration")
141
+ # Apply smoothing
142
+ smoothed_probs = [(p * 0.8 + 0.2/len(probs)) for p in probs]
143
+ labels = {list(labels.keys())[i]: smoothed_probs[i] for i in range(len(smoothed_probs))}
144
 
145
+ # Find prediction
146
  max_label = max(labels.items(), key=lambda x: x[1])
147
 
148
  # Determine risk level and emoji
149
  confidence = max_label[1]
150
+ prediction_name = max_label[0]
151
+
152
+ if any(word in prediction_name.lower() for word in ['phishing', 'suspicious', 'spam']):
153
  if confidence > 0.8:
154
  risk_emoji = "🚨"
155
  risk_level = "HIGH RISK"
 
160
  risk_emoji = "⚡"
161
  risk_level = "LOW RISK"
162
  else:
163
+ if confidence > 0.8:
164
+ risk_emoji = ""
165
+ risk_level = "SAFE"
166
+ elif confidence > 0.6:
167
+ risk_emoji = "✅"
168
+ risk_level = "LIKELY SAFE"
169
+ else:
170
+ risk_emoji = "❓"
171
+ risk_level = "UNCERTAIN"
172
 
173
+ # Format output with colored bars
174
  result = f"{risk_emoji} **{risk_level}**\n\n"
175
+ result += f"**Primary Classification**: {prediction_name}\n"
176
  result += f"**Confidence**: {confidence:.1%}\n\n"
177
  result += f"**Detailed Analysis**:\n"
178
 
179
+ # Sort by probability and add colored bars
180
  for label, prob in sorted(labels.items(), key=lambda x: x[1], reverse=True):
181
  percentage = prob * 100
182
+ colored_bar = get_colored_bar(percentage, label)
183
+ result += f"{label}: {percentage:.1f}% {colored_bar}\n"
184
+
185
+ # Add debug info
186
+ result += f"\n**Debug Info**:\n"
187
+ result += f"Model Variance: {prob_variance:.4f}\n"
188
+ result += f"Raw Probabilities: {[f'{p:.3f}' for p in probs]}\n"
189
+
190
+ # Add recommendations based on actual classification
191
+ if any(word in prediction_name.lower() for word in ['phishing', 'suspicious']) and confidence > 0.6:
192
+ result += f"\n⚠️ **Recommendation**: This email shows signs of being malicious. Avoid clicking links or providing personal information."
193
+ elif 'spam' in prediction_name.lower():
194
+ result += f"\n🗑️ **Recommendation**: This appears to be spam. Consider deleting or marking as junk."
195
+ elif confidence > 0.7:
196
+ result += f"\n✅ **Recommendation**: This email appears legitimate, but always remain vigilant."
197
  else:
198
+ result += f"\n **Recommendation**: Classification uncertain. Exercise caution and verify sender if needed."
199
 
200
  return result
201
 
202
  except Exception as e:
203
+ logger.error(f"Error during prediction: {e}", exc_info=True)
204
+ return f"❌ **Error**: Analysis failed - {str(e)}"
205
 
206
+ # Example emails for testing
207
  example_legitimate = """Dear Customer,
208
 
209
+ Thank you for your recent purchase from TechStore. Your order #ORD-2024-001234 has been successfully processed.
210
 
211
  Order Details:
212
+ - Product: Wireless Headphones
213
  - Amount: $79.99
214
+ - Estimated delivery: 3-5 business days
215
 
216
+ You will receive a tracking number once your item ships.
 
 
217
 
218
  Best regards,
219
+ TechStore Customer Service"""
220
+
221
+ example_phishing = """URGENT SECURITY ALERT!!!
222
 
223
+ Your account has been COMPROMISED! Immediate action required!
224
 
225
+ Click here NOW to secure your account: http://fake-security-site.malicious.com/urgent-verify
226
 
227
+ WARNING: You have only 24 hours before your account is permanently suspended!
228
 
229
+ This is your FINAL notice - act immediately!
230
 
231
+ Security Department"""
232
 
233
+ example_neutral = """Hi team,
234
+
235
+ Hope everyone is doing well. Just wanted to remind you about the meeting scheduled for tomorrow at 2 PM in the conference room.
236
+
237
+ Please bring your project updates and any questions you might have.
238
+
239
+ Thanks,
240
+ Sarah"""
241
 
242
  # Load model on startup
243
  load_model()
244
 
245
+ # Create enhanced Gradio interface
246
+ with gr.Blocks(title="PhishGuardian AI", theme=gr.themes.Soft()) as iface:
247
  gr.Markdown("""
248
+ # 🛡️ PhishGuardian AI - Enhanced Detection
 
 
 
249
 
250
+ Advanced phishing email detection with colored risk indicators and improved model handling.
251
  """)
252
 
253
  with gr.Row():
254
  with gr.Column(scale=2):
255
  email_input = gr.Textbox(
256
  lines=10,
257
+ placeholder="Paste your email content here for analysis...",
258
+ label="📧 Email Content",
259
+ info="Enter the complete email text for comprehensive analysis"
260
  )
261
 
262
  with gr.Row():
263
  analyze_btn = gr.Button("🔍 Analyze Email", variant="primary", size="lg")
264
  clear_btn = gr.Button("🗑️ Clear", variant="secondary")
265
 
266
+ with gr.Column(scale=2):
267
  output = gr.Textbox(
268
+ label="🛡️ Security Analysis Results",
269
+ lines=20,
270
+ interactive=False,
271
+ show_copy_button=True
272
  )
273
 
274
+ # Example section with better examples
275
+ gr.Markdown("### 📝 Test Examples")
276
  with gr.Row():
277
+ legit_btn = gr.Button("✅ Legitimate Email", size="sm")
278
+ phish_btn = gr.Button("🚨 Phishing Email", size="sm")
279
+ neutral_btn = gr.Button("📄 Neutral Text", size="sm")
 
 
 
280
 
281
  # Event handlers
282
+ analyze_btn.click(predict_email, inputs=email_input, outputs=output)
283
+ clear_btn.click(lambda: ("", ""), outputs=[email_input, output])
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
+ legit_btn.click(lambda: example_legitimate, outputs=email_input)
286
+ phish_btn.click(lambda: example_phishing, outputs=email_input)
287
+ neutral_btn.click(lambda: example_neutral, outputs=email_input)
 
288
 
289
+ # Footer with model info
290
  gr.Markdown("""
291
  ---
292
+ **🔧 Model**: cybersectony/phishing-email-detection-distilbert_v2.4.1
293
+ **🎯 Features**: Temperature scaling, colored risk bars, enhanced debugging
294
+ **🏛️ Institution**: University of Dar es Salaam (UDSM)
295
  """)
296
 
 
297
  if __name__ == "__main__":
298
  iface.launch(
299
  share=True,
300
+ server_name="0.0.0.0",
301
  server_port=7860,
302
+ show_error=True,
303
+ debug=True
304
  )