MUFASA25 commited on
Commit
09b3e31
·
verified ·
1 Parent(s): cc8ce94

simplest form

Browse files
Files changed (1) hide show
  1. app.py +91 -222
app.py CHANGED
@@ -2,91 +2,67 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import logging
5
- import numpy as np
6
 
7
- # Configure logging
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
  # Model configuration
12
  MODEL_NAME = "cybersectony/phishing-email-detection-distilbert_v2.4.1"
13
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Global variables for model and tokenizer
15
  tokenizer = None
16
  model = None
17
 
18
  def load_model():
19
- """Load the model and tokenizer with error handling"""
20
  global tokenizer, model
21
  try:
22
- logger.info("Loading model and tokenizer...")
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
25
-
26
- # Debug: Print model configuration
27
- logger.info(f"Model config: {model.config}")
28
- logger.info(f"Number of labels: {model.config.num_labels}")
29
- if hasattr(model.config, 'id2label'):
30
- logger.info(f"Label mapping: {model.config.id2label}")
31
-
32
- # Test model with simple input to check if it's working
33
- test_input = "Hello world"
34
- inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512)
35
- with torch.no_grad():
36
- outputs = model(**inputs)
37
- test_probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
38
- logger.info(f"Test probabilities: {test_probs[0].tolist()}")
39
-
40
  logger.info("Model loaded successfully!")
41
  return True
42
  except Exception as e:
43
  logger.error(f"Error loading model: {e}")
44
  return False
45
 
46
- def get_colored_bar(percentage, label):
47
- """Create colored progress bar based on percentage and label type"""
48
- # Determine color based on percentage and label
49
- if "phishing" in label.lower() or "suspicious" in label.lower():
50
- # Red scale for dangerous content
51
- if percentage >= 70:
52
- color = "🟥" # High danger - red
53
- elif percentage >= 40:
54
- color = "🟠" # Medium danger - orange
55
- else:
56
- color = "🟡" # Low danger - yellow
57
- else:
58
- # Green scale for legitimate content
59
- if percentage >= 70:
60
- color = "🟢" # High confidence - green
61
- elif percentage >= 40:
62
- color = "🟡" # Medium confidence - yellow
63
- else:
64
- color = "⚪" # Low confidence - white
65
-
66
- # Create bar (scale to 20 characters)
67
- bar_length = max(1, int(percentage / 5)) # Ensure at least 1 if percentage > 0
68
- bar = color * bar_length + "⚪" * (20 - bar_length)
69
-
70
- return bar
71
 
72
  def predict_email(email_text):
73
- """
74
- Enhanced prediction function with proper model output handling
75
- """
76
  # Input validation
77
- if not email_text or not email_text.strip():
78
- return "⚠️ **Error**: Please enter some email text to analyze."
79
-
80
- if len(email_text.strip()) < 5:
81
- return "⚠️ **Warning**: Email text too short for reliable analysis."
82
 
83
  # Check if model is loaded
84
  if tokenizer is None or model is None:
85
  if not load_model():
86
- return "❌ **Error**: Failed to load the model."
87
 
88
  try:
89
- # Preprocess and tokenize
90
  inputs = tokenizer(
91
  email_text,
92
  return_tensors="pt",
@@ -95,210 +71,103 @@ def predict_email(email_text):
95
  padding=True
96
  )
97
 
98
- # Get prediction with proper handling
99
  with torch.no_grad():
100
  outputs = model(**inputs)
101
- # Apply temperature scaling to prevent overconfidence
102
- temperature = 1.5
103
- scaled_logits = outputs.logits / temperature
104
- predictions = torch.nn.functional.softmax(scaled_logits, dim=-1)
105
 
106
- # Get probabilities
107
- probs = predictions[0].tolist()
 
 
108
 
109
- # Log raw outputs for debugging
110
- logger.info(f"Raw logits: {outputs.logits[0].tolist()}")
111
- logger.info(f"Scaled probabilities: {probs}")
112
 
113
- # Get proper labels from model config or use fallback
114
- if hasattr(model.config, 'id2label') and model.config.id2label:
115
- labels = {model.config.id2label[i]: probs[i] for i in range(len(probs))}
116
- else:
117
- # Fallback - check the actual model output dimension
118
- if len(probs) == 2:
119
- labels = {
120
- "Legitimate Email": probs[0],
121
- "Phishing Email": probs[1]
122
- }
123
- elif len(probs) == 4:
124
- labels = {
125
- "Legitimate Email": probs[0],
126
- "Phishing Email": probs[1],
127
- "Suspicious Content": probs[2],
128
- "Spam Email": probs[3]
129
- }
130
- else:
131
- # Generic labels
132
- labels = {f"Class {i}": probs[i] for i in range(len(probs))}
133
-
134
- # Check if model is giving reasonable outputs
135
- prob_variance = np.var(probs)
136
- max_prob = max(probs)
137
-
138
- # If variance is too low, the model might not be working properly
139
- if prob_variance < 0.01 and max_prob > 0.99:
140
- logger.warning("Model showing signs of overconfidence or poor calibration")
141
- # Apply smoothing
142
- smoothed_probs = [(p * 0.8 + 0.2/len(probs)) for p in probs]
143
- labels = {list(labels.keys())[i]: smoothed_probs[i] for i in range(len(smoothed_probs))}
144
-
145
- # Find prediction
146
- max_label = max(labels.items(), key=lambda x: x[1])
147
-
148
- # Determine risk level and emoji
149
- confidence = max_label[1]
150
- prediction_name = max_label[0]
151
 
152
- if any(word in prediction_name.lower() for word in ['phishing', 'suspicious', 'spam']):
153
- if confidence > 0.8:
154
- risk_emoji = "🚨"
155
- risk_level = "HIGH RISK"
156
- elif confidence > 0.6:
157
- risk_emoji = "⚠️"
158
- risk_level = "MEDIUM RISK"
159
- else:
160
- risk_emoji = "⚡"
161
- risk_level = "LOW RISK"
162
  else:
163
- if confidence > 0.8:
164
- risk_emoji = "✅"
165
- risk_level = "SAFE"
166
- elif confidence > 0.6:
167
- risk_emoji = ""
168
- risk_level = "LIKELY SAFE"
169
- else:
170
- risk_emoji = ""
171
- risk_level = "UNCERTAIN"
172
-
173
- # Format output with colored bars
174
- result = f"{risk_emoji} **{risk_level}**\n\n"
175
- result += f"**Primary Classification**: {prediction_name}\n"
176
- result += f"**Confidence**: {confidence:.1%}\n\n"
177
- result += f"**Detailed Analysis**:\n"
178
-
179
- # Sort by probability and add colored bars
180
- for label, prob in sorted(labels.items(), key=lambda x: x[1], reverse=True):
181
- percentage = prob * 100
182
- colored_bar = get_colored_bar(percentage, label)
183
- result += f"{label}: {percentage:.1f}% {colored_bar}\n"
184
-
185
- # Add debug info
186
- result += f"\n**Debug Info**:\n"
187
- result += f"Model Variance: {prob_variance:.4f}\n"
188
- result += f"Raw Probabilities: {[f'{p:.3f}' for p in probs]}\n"
189
-
190
- # Add recommendations based on actual classification
191
- if any(word in prediction_name.lower() for word in ['phishing', 'suspicious']) and confidence > 0.6:
192
- result += f"\n⚠️ **Recommendation**: This email shows signs of being malicious. Avoid clicking links or providing personal information."
193
- elif 'spam' in prediction_name.lower():
194
- result += f"\n🗑️ **Recommendation**: This appears to be spam. Consider deleting or marking as junk."
195
- elif confidence > 0.7:
196
- result += f"\n✅ **Recommendation**: This email appears legitimate, but always remain vigilant."
197
  else:
198
- result += f"\n❓ **Recommendation**: Classification uncertain. Exercise caution and verify sender if needed."
199
-
200
- return result
201
 
 
 
202
  except Exception as e:
203
- logger.error(f"Error during prediction: {e}", exc_info=True)
204
- return f"❌ **Error**: Analysis failed - {str(e)}"
205
 
206
- # Example emails for testing
207
  example_legitimate = """Dear Customer,
208
-
209
- Thank you for your recent purchase from TechStore. Your order #ORD-2024-001234 has been successfully processed.
210
-
211
  Order Details:
212
- - Product: Wireless Headphones
213
  - Amount: $79.99
214
- - Estimated delivery: 3-5 business days
215
-
216
- You will receive a tracking number once your item ships.
217
-
218
  Best regards,
219
- TechStore Customer Service"""
220
-
221
- example_phishing = """URGENT SECURITY ALERT!!!
222
-
223
- Your account has been COMPROMISED! Immediate action required!
224
-
225
- Click here NOW to secure your account: http://fake-security-site.malicious.com/urgent-verify
226
-
227
- WARNING: You have only 24 hours before your account is permanently suspended!
228
-
229
- This is your FINAL notice - act immediately!
230
-
231
- Security Department"""
232
-
233
  example_neutral = """Hi team,
234
-
235
- Hope everyone is doing well. Just wanted to remind you about the meeting scheduled for tomorrow at 2 PM in the conference room.
236
-
237
- Please bring your project updates and any questions you might have.
238
-
239
  Thanks,
240
- Sarah"""
241
 
242
  # Load model on startup
243
  load_model()
244
 
245
- # Create enhanced Gradio interface
246
- with gr.Blocks(title="PhishGuardian AI", theme=gr.themes.Soft()) as iface:
247
- gr.Markdown("""
248
- # 🛡️ PhishGuardian AI - Enhanced Detection
249
-
250
- Advanced phishing email detection with colored risk indicators and improved model handling.
251
- """)
252
 
253
  with gr.Row():
254
  with gr.Column(scale=2):
255
  email_input = gr.Textbox(
256
- lines=10,
257
- placeholder="Paste your email content here for analysis...",
258
- label="📧 Email Content",
259
- info="Enter the complete email text for comprehensive analysis"
260
  )
261
-
262
  with gr.Row():
263
- analyze_btn = gr.Button("🔍 Analyze Email", variant="primary", size="lg")
264
- clear_btn = gr.Button("🗑️ Clear", variant="secondary")
265
 
266
  with gr.Column(scale=2):
267
  output = gr.Textbox(
268
- label="🛡️ Security Analysis Results",
269
- lines=20,
270
  interactive=False,
271
  show_copy_button=True
272
  )
273
 
274
- # Example section with better examples
275
- gr.Markdown("### 📝 Test Examples")
276
  with gr.Row():
277
- legit_btn = gr.Button("✅ Legitimate Email", size="sm")
278
- phish_btn = gr.Button("🚨 Phishing Email", size="sm")
279
- neutral_btn = gr.Button("📄 Neutral Text", size="sm")
280
-
281
- # Event handlers
282
  analyze_btn.click(predict_email, inputs=email_input, outputs=output)
283
  clear_btn.click(lambda: ("", ""), outputs=[email_input, output])
284
-
285
- legit_btn.click(lambda: example_legitimate, outputs=email_input)
286
- phish_btn.click(lambda: example_phishing, outputs=email_input)
287
- neutral_btn.click(lambda: example_neutral, outputs=email_input)
288
-
289
- # Footer with model info
290
- gr.Markdown("""
291
- ---
292
- **🔧 Model**: cybersectony/phishing-email-detection-distilbert_v2.4.1
293
- **🎯 Features**: Temperature scaling, colored risk bars, enhanced debugging
294
- **🏛️ Institution**: University of Dar es Salaam (UDSM)
295
- """)
296
 
297
  if __name__ == "__main__":
298
- iface.launch(
299
- share=True,
300
- server_name="0.0.0.0",
301
- server_port=7860,
302
- show_error=True,
303
- debug=True
304
- )
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import logging
5
+ import re
6
 
7
+ # Configure logging (minimal)
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
  # Model configuration
12
  MODEL_NAME = "cybersectony/phishing-email-detection-distilbert_v2.4.1"
13
 
14
+ # Explanation of labels and their values
15
+ """
16
+ Labels and Their Meanings:
17
+ - Legitimate: The email appears safe and is likely from a trusted source.
18
+ - Phishing: The email may be a scam attempting to steal personal information.
19
+ - Suspicious: The email has questionable content and may not be safe.
20
+ - Spam: The email is likely unwanted promotional or junk content.
21
+ Each label comes with a percentage (0-100%) indicating the model's confidence.
22
+ Higher percentages mean the model is more certain of the classification.
23
+ """
24
+
25
  # Global variables for model and tokenizer
26
  tokenizer = None
27
  model = None
28
 
29
  def load_model():
30
+ """Load the model and tokenizer with basic error handling"""
31
  global tokenizer, model
32
  try:
 
33
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
34
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  logger.info("Model loaded successfully!")
36
  return True
37
  except Exception as e:
38
  logger.error(f"Error loading model: {e}")
39
  return False
40
 
41
+ def is_valid_email_text(text):
42
+ """Basic validation for email-like text"""
43
+ if not text or not text.strip():
44
+ return False, "Please enter some email text."
45
+ if len(text.strip()) < 10:
46
+ return False, "Text too short for analysis."
47
+ # Check for basic email-like structure or meaningful words
48
+ if len(text.split()) < 3 or not re.search(r"[a-zA-Z]{3,}", text):
49
+ return False, "Text appears incoherent or not email-like."
50
+ return True, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def predict_email(email_text):
53
+ """Simplified prediction function with clear output"""
 
 
54
  # Input validation
55
+ valid, message = is_valid_email_text(email_text)
56
+ if not valid:
57
+ return f"⚠️ Error: {message}"
 
 
58
 
59
  # Check if model is loaded
60
  if tokenizer is None or model is None:
61
  if not load_model():
62
+ return "❌ Error: Failed to load the model."
63
 
64
  try:
65
+ # Tokenize input
66
  inputs = tokenizer(
67
  email_text,
68
  return_tensors="pt",
 
71
  padding=True
72
  )
73
 
74
+ # Get prediction
75
  with torch.no_grad():
76
  outputs = model(**inputs)
77
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0].tolist()
 
 
 
78
 
79
+ # Get labels from model config or fallback
80
+ labels = (model.config.id2label if hasattr(model.config, 'id2label') and model.config.id2label
81
+ else {0: "Legitimate", 1: "Phishing", 2: "Suspicious", 3: "Spam"} if len(probs) == 4
82
+ else {0: "Legitimate", 1: "Phishing"})
83
 
84
+ # Map probabilities to labels
85
+ results = {labels[i]: probs[i] * 100 for i in range(len(probs))}
 
86
 
87
+ # Get top prediction
88
+ max_label, max_prob = max(results.items(), key=lambda x: x[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ # Simplified risk levels
91
+ if "phishing" in max_label.lower() or "suspicious" in max_label.lower():
92
+ risk_level = "⚠️ Risky" if max_prob > 60 else "⚡ Low Risk"
93
+ elif "spam" in max_label.lower():
94
+ risk_level = "🗑️ Spam"
 
 
 
 
 
95
  else:
96
+ risk_level = "✅ Safe" if max_prob > 60 else "❓ Uncertain"
97
+
98
+ # Format output
99
+ output = f"Result: {risk_level}\n"
100
+ output += f"Top Prediction: {max_label} ({max_prob:.1f}%)\n"
101
+ output += "Details:\n"
102
+ for label, prob in sorted(results.items(), key=lambda x: x[1], reverse=True):
103
+ output += f"{label}: {prob:.1f}%\n"
104
+
105
+ # Simple recommendation
106
+ if "phishing" in max_label.lower() or "suspicious" in max_label.lower():
107
+ output += "Advice: Avoid clicking links or sharing info."
108
+ elif "spam" in max_label.lower():
109
+ output += "Advice: Mark as spam or delete."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  else:
111
+ output += "Advice: Appears safe, but stay cautious."
 
 
112
 
113
+ return output
114
+
115
  except Exception as e:
116
+ logger.error(f"Error during prediction: {e}")
117
+ return f"❌ Error: Analysis failed - {str(e)}"
118
 
119
+ # Example emails
120
  example_legitimate = """Dear Customer,
121
+ Thank you for your purchase from TechStore. Your order #ORD-2024-001234 is processed.
 
 
122
  Order Details:
123
+ - Product: Wireless Headphones
124
  - Amount: $79.99
125
+ - Delivery: 3-5 days
 
 
 
126
  Best regards,
127
+ TechStore"""
128
+ example_phishing = """URGENT!!!
129
+ Your account is COMPROMISED! Click here to secure: http://fake-site.com/verify
130
+ Act NOW or your account will be suspended!
131
+ Security Team"""
 
 
 
 
 
 
 
 
 
132
  example_neutral = """Hi team,
133
+ Reminder: meeting today at 10 PM. Bring project updates.
 
 
 
 
134
  Thanks,
135
+ Byabato"""
136
 
137
  # Load model on startup
138
  load_model()
139
 
140
+ # Minimalist Gradio interface
141
+ with gr.Blocks(title="PhishGuardian", theme=gr.themes.Soft()) as iface:
142
+ gr.Markdown("# 🛡️ PhishGuardian\nSimple email safety checker.\n\nCheck if an email is safe or risky. Paste the email text and click 'Check'.")
 
 
 
 
143
 
144
  with gr.Row():
145
  with gr.Column(scale=2):
146
  email_input = gr.Textbox(
147
+ lines=8,
148
+ placeholder="Paste email here...",
149
+ label="📧 Email"
 
150
  )
 
151
  with gr.Row():
152
+ analyze_btn = gr.Button("🔍 Check", variant="primary")
153
+ clear_btn = gr.Button("🗑️ Clear")
154
 
155
  with gr.Column(scale=2):
156
  output = gr.Textbox(
157
+ label=" Results",
158
+ lines=10,
159
  interactive=False,
160
  show_copy_button=True
161
  )
162
 
163
+ gr.Markdown("### 📝 Examples")
 
164
  with gr.Row():
165
+ gr.Button("✅ Legitimate", size="sm").click(lambda: example_legitimate, outputs=email_input)
166
+ gr.Button("🚨 Phishing", size="sm").click(lambda: example_phishing, outputs=email_input)
167
+ gr.Button("📄 Neutral", size="sm").click(lambda: example_neutral, outputs=email_input)
168
+
 
169
  analyze_btn.click(predict_email, inputs=email_input, outputs=output)
170
  clear_btn.click(lambda: ("", ""), outputs=[email_input, output])
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  if __name__ == "__main__":
173
+ iface.launch(server_port=7860, show_error=True)