MUFASA25 commited on
Commit
0bf9286
Β·
verified Β·
1 Parent(s): cd7de3f

added 2 more models to learn from, and Enhanced the risk assessment

Browse files
Files changed (1) hide show
  1. app.py +277 -93
app.py CHANGED
@@ -1,145 +1,317 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
 
 
4
  import os
5
 
6
- # Model configuration
7
- MODEL_NAME = "cybersectony/phishing-email-detection-distilbert_v2.4.1"
 
 
 
 
8
 
9
- # Global variables for model and tokenizer
10
- model = None
11
- tokenizer = None
12
 
13
- def load_model():
14
- """Load model and tokenizer once at startup"""
15
- global model, tokenizer
16
- try:
17
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
18
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
19
- model.eval() # Set to evaluation mode
20
- return True
21
- except Exception as e:
22
- print(f"Error loading model: {e}")
23
- return False
24
-
25
- def predict_phishing(text):
26
- """
27
- Predict if email/URL is phishing or legitimate
28
- """
29
- global model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  if not text.strip():
32
  return "Please enter some text to analyze", {}, ""
33
 
34
  try:
35
- # Tokenize input
36
- inputs = tokenizer(
37
- text,
38
- return_tensors="pt",
39
- truncation=True,
40
- max_length=512,
41
- padding=True
42
- )
43
-
44
- # Get prediction
45
- with torch.no_grad():
46
- outputs = model(**inputs)
47
- predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
48
 
49
- # Get probabilities
50
- probs = predictions[0].tolist()
51
-
52
- # Label mapping
53
  labels = {
54
  "Legitimate Email": probs[0],
55
- "Phishing URL": probs[1],
56
- "Legitimate URL": probs[2],
57
- "Phishing Email": probs[3] if len(probs) > 3 else 0
58
  }
59
 
60
- # Find highest probability
61
  max_label = max(labels.items(), key=lambda x: x[1])
62
  prediction = max_label[0]
63
  confidence = max_label[1]
64
 
65
- # Create confidence bar data
66
- confidence_data = {label: prob for label, prob in labels.items()}
67
-
68
- # Risk assessment
69
- if "Phishing" in prediction:
70
- risk_level = "🚨 HIGH RISK - Potential Phishing Detected"
71
  risk_color = "red"
 
 
 
 
 
 
72
  else:
73
  risk_level = "βœ… LOW RISK - Appears Legitimate"
74
  risk_color = "green"
75
 
76
- # Format result
 
 
 
 
 
 
 
 
 
 
 
77
  result = f"""
78
  ### {risk_level}
79
  **Primary Classification:** {prediction}
80
- **Confidence:** {confidence:.1%}
 
 
 
 
 
 
 
 
 
81
  """
82
 
 
 
 
83
  return result, confidence_data, risk_color
84
 
85
  except Exception as e:
86
- return f"Error during prediction: {str(e)}", {}, "orange"
87
 
88
- # Load model at startup
89
- print("Loading model...")
90
- model_loaded = load_model()
91
- if not model_loaded:
92
- print("Failed to load model!")
93
-
94
- # Create Gradio interface
95
  with gr.Blocks(
96
  theme=gr.themes.Soft(),
97
- title="Phishing Email & URL Detective",
98
  css="""
99
  .risk-high { color: #dc2626 !important; font-weight: bold; }
100
  .risk-low { color: #16a34a !important; font-weight: bold; }
101
- .main-container { max-width: 800px; margin: 0 auto; }
 
102
  """
103
  ) as demo:
104
 
105
  gr.Markdown("""
106
- # πŸ›‘οΈ Phishing Detection System
107
- **Instantly detect phishing emails and malicious URLs using AI**
108
 
109
- Powered by DistilBERT β€’ 99.58% Accuracy β€’ Real-time Analysis
110
  """)
111
 
112
  with gr.Row():
113
  with gr.Column(scale=2):
114
  input_text = gr.Textbox(
115
- label="πŸ“§ Email Content or URL",
116
- placeholder="Paste suspicious email content or URL here...",
117
- lines=8,
118
- max_lines=15
119
  )
120
 
121
- analyze_btn = gr.Button(
122
- "πŸ” Analyze for Phishing",
123
- variant="primary",
124
- size="lg"
125
- )
 
 
126
 
127
  with gr.Column(scale=1):
128
- result_output = gr.Markdown(label="Analysis Result")
129
 
130
  confidence_output = gr.Label(
131
- label="Confidence Breakdown",
132
  num_top_classes=4
133
  )
134
 
135
- # Example inputs
136
- gr.Markdown("### πŸ“‹ Try These Examples:")
137
 
138
  examples = [
139
- ["Dear User, Your account will be suspended! Click here immediately: http://fake-bank-login.com/urgent"],
140
- ["Hi Mufasa, Thanks for your email. The quarterly report is attached. Best regards, Simba"],
141
- ["URGENT: Verify your PayPal account now or lose access: https://paypal-security-verify.suspicious.com"],
142
- ["Meeting reminder: Project sync at 3 PM in conference room B. See you there!"]
 
 
143
  ]
144
 
145
  gr.Examples(
@@ -150,34 +322,46 @@ with gr.Blocks(
150
 
151
  # Event handlers
152
  analyze_btn.click(
153
- fn=predict_phishing,
154
  inputs=input_text,
155
  outputs=[result_output, confidence_output, gr.State()]
156
  )
157
 
 
 
 
 
 
158
  input_text.submit(
159
- fn=predict_phishing,
160
  inputs=input_text,
161
  outputs=[result_output, confidence_output, gr.State()]
162
  )
163
 
164
  gr.Markdown("""
165
  ---
166
- ### ℹ️ About This Tool and the team.
167
- - **Model:** DistilBERT fine-tuned for phishing detection
168
- - **Accuracy:** 99.58% on test dataset
169
- - **Speed:** Real-time analysis
170
- - **Privacy:** All processing happens locally, no data stored
 
 
 
 
 
 
 
 
 
171
 
172
- **⚠️ Disclaimer:** This tool is for educational purposes (Assignemnt) only, we currently hold no rights and responsibility to this tool. So please Always verify suspicious content through official channels.
173
  """)
174
 
175
- # Launch configuration
176
  if __name__ == "__main__":
177
  demo.launch(
178
  share=False,
179
- server_name="0.0.0.0",
180
  server_port=7860,
181
- show_error=True,
182
- quiet=False
183
  )
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import numpy as np
5
+ import re
6
+ from urllib.parse import urlparse
7
+ import hashlib
8
  import os
9
 
10
+ # Multi-Model Configuration
11
+ MODELS = {
12
+ "primary": "cybersectony/phishing-email-detection-distilbert_v2.4.1",
13
+ "secondary": "microsoft/DialoGPT-medium", # Fallback for context
14
+ "url_specialist": "cybersectony/phishing-email-detection-distilbert_v2.4.1" # URL-focused
15
+ }
16
 
17
+ # Global model storage
18
+ models = {}
19
+ tokenizers = {}
20
 
21
+ class AdvancedPhishingDetector:
22
+ def __init__(self):
23
+ self.load_models()
24
+
25
+ def load_models(self):
26
+ """Load multiple models for ensemble prediction"""
27
+ global models, tokenizers
28
+ try:
29
+ for name, model_path in MODELS.items():
30
+ if name == "secondary":
31
+ continue # Skip for now, use primary model
32
+ tokenizers[name] = AutoTokenizer.from_pretrained(model_path)
33
+ models[name] = AutoModelForSequenceClassification.from_pretrained(model_path)
34
+ models[name].eval()
35
+ return True
36
+ except Exception as e:
37
+ print(f"Error loading models: {e}")
38
+ return False
39
+
40
+ def extract_features(self, text):
41
+ """Extract hand-crafted features for bias reduction"""
42
+ features = {}
43
+
44
+ # URL features
45
+ urls = re.findall(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', text)
46
+ features['url_count'] = len(urls)
47
+ features['has_suspicious_domains'] = any(
48
+ domain in url.lower() for url in urls
49
+ for domain in ['bit.ly', 'tinyurl', 'shorturl', 'suspicious', 'phish', 'scam']
50
+ )
51
+
52
+ # Text pattern features
53
+ features['urgency_words'] = len(re.findall(r'urgent|immediate|expire|suspend|verify|confirm|click|act now', text.lower()))
54
+ features['money_mentions'] = len(re.findall(r'\$|money|payment|refund|prize|winner|lottery', text.lower()))
55
+ features['personal_info_requests'] = len(re.findall(r'password|ssn|social security|credit card|pin|account', text.lower()))
56
+ features['spelling_errors'] = self.count_potential_errors(text)
57
+ features['excessive_caps'] = len(re.findall(r'[A-Z]{3,}', text))
58
+
59
+ # Sender authenticity indicators
60
+ features['generic_greetings'] = 1 if re.search(r'^(dear (customer|user|sir|madam))', text.lower()) else 0
61
+ features['email_length'] = len(text)
62
+ features['has_attachments'] = 1 if 'attachment' in text.lower() else 0
63
+
64
+ return features
65
+
66
+ def count_potential_errors(self, text):
67
+ """Simple heuristic for spelling errors"""
68
+ # Look for common phishing misspellings
69
+ errors = re.findall(r'recieve|occured|seperate|definately|goverment|secruity|varify', text.lower())
70
+ return len(errors)
71
+
72
+ def get_model_predictions(self, text):
73
+ """Get predictions from multiple models"""
74
+ predictions = {}
75
+
76
+ for model_name in ['primary', 'url_specialist']:
77
+ if model_name not in models:
78
+ continue
79
+
80
+ try:
81
+ inputs = tokenizers[model_name](
82
+ text,
83
+ return_tensors="pt",
84
+ truncation=True,
85
+ max_length=512,
86
+ padding=True
87
+ )
88
+
89
+ with torch.no_grad():
90
+ outputs = models[model_name](**inputs)
91
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
92
+ predictions[model_name] = probs[0].tolist()
93
+
94
+ except Exception as e:
95
+ print(f"Error with model {model_name}: {e}")
96
+ predictions[model_name] = [0.5, 0.5, 0.0, 0.0] # Default neutral
97
+
98
+ return predictions
99
+
100
+ def ensemble_predict(self, text):
101
+ """Advanced ensemble prediction with feature weighting"""
102
+ # Get model predictions
103
+ model_preds = self.get_model_predictions(text)
104
+
105
+ # Extract hand-crafted features
106
+ features = self.extract_features(text)
107
+
108
+ # Calculate feature-based risk score
109
+ risk_score = self.calculate_risk_score(features)
110
+
111
+ # Ensemble combination
112
+ if len(model_preds) == 0:
113
+ return self.fallback_prediction(features)
114
+
115
+ # Weight model predictions
116
+ weights = {'primary': 0.7, 'url_specialist': 0.3}
117
+ ensemble_probs = [0.0, 0.0, 0.0, 0.0]
118
+
119
+ total_weight = 0
120
+ for model_name, probs in model_preds.items():
121
+ weight = weights.get(model_name, 0.5)
122
+ total_weight += weight
123
+ for i in range(len(probs)):
124
+ ensemble_probs[i] += probs[i] * weight
125
+
126
+ # Normalize
127
+ if total_weight > 0:
128
+ ensemble_probs = [p / total_weight for p in ensemble_probs]
129
+
130
+ # Adjust with feature-based risk
131
+ ensemble_probs = self.adjust_with_features(ensemble_probs, risk_score)
132
+
133
+ return ensemble_probs, features, risk_score
134
+
135
+ def calculate_risk_score(self, features):
136
+ """Calculate risk score from hand-crafted features"""
137
+ score = 0
138
+
139
+ # URL-based risk
140
+ score += features['url_count'] * 0.1
141
+ score += features['has_suspicious_domains'] * 0.3
142
+
143
+ # Content-based risk
144
+ score += min(features['urgency_words'] * 0.15, 0.4)
145
+ score += min(features['money_mentions'] * 0.1, 0.3)
146
+ score += min(features['personal_info_requests'] * 0.2, 0.5)
147
+ score += min(features['spelling_errors'] * 0.1, 0.2)
148
+ score += min(features['excessive_caps'] * 0.05, 0.15)
149
+
150
+ # Generic patterns
151
+ score += features['generic_greetings'] * 0.1
152
+
153
+ return min(score, 1.0) # Cap at 1.0
154
+
155
+ def adjust_with_features(self, probs, risk_score):
156
+ """Adjust model predictions with feature-based risk"""
157
+ adjusted = probs.copy()
158
+
159
+ # If high risk score, increase phishing probabilities
160
+ if risk_score > 0.5:
161
+ phishing_boost = risk_score * 0.3
162
+ adjusted[1] += phishing_boost # Phishing URL
163
+ adjusted[3] += phishing_boost # Phishing Email
164
+
165
+ # Reduce legitimate probabilities
166
+ adjusted[0] = max(0, adjusted[0] - phishing_boost/2)
167
+ adjusted[2] = max(0, adjusted[2] - phishing_boost/2)
168
+
169
+ # Normalize to ensure sum = 1
170
+ total = sum(adjusted)
171
+ if total > 0:
172
+ adjusted = [p / total for p in adjusted]
173
+
174
+ return adjusted
175
 
176
+ def fallback_prediction(self, features):
177
+ """Fallback prediction when models fail"""
178
+ risk_score = self.calculate_risk_score(features)
179
+
180
+ if risk_score > 0.7:
181
+ return [0.1, 0.4, 0.1, 0.4], features, risk_score # High phishing
182
+ elif risk_score > 0.4:
183
+ return [0.3, 0.2, 0.3, 0.2], features, risk_score # Medium risk
184
+ else:
185
+ return [0.45, 0.05, 0.45, 0.05], features, risk_score # Low risk
186
+
187
+ # Initialize detector
188
+ detector = AdvancedPhishingDetector()
189
+
190
+ def advanced_predict_phishing(text):
191
+ """Advanced phishing prediction with ensemble and feature analysis"""
192
  if not text.strip():
193
  return "Please enter some text to analyze", {}, ""
194
 
195
  try:
196
+ # Get ensemble prediction
197
+ probs, features, risk_score = detector.ensemble_predict(text)
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ # Create label mapping
 
 
 
200
  labels = {
201
  "Legitimate Email": probs[0],
202
+ "Phishing URL": probs[1],
203
+ "Legitimate URL": probs[2],
204
+ "Phishing Email": probs[3]
205
  }
206
 
207
+ # Find primary classification
208
  max_label = max(labels.items(), key=lambda x: x[1])
209
  prediction = max_label[0]
210
  confidence = max_label[1]
211
 
212
+ # Enhanced risk assessment
213
+ if "Phishing" in prediction and confidence > 0.6:
214
+ risk_level = "🚨 HIGH RISK - Strong Phishing Indicators"
 
 
 
215
  risk_color = "red"
216
+ elif "Phishing" in prediction or risk_score > 0.5:
217
+ risk_level = "⚠️ MEDIUM RISK - Suspicious Patterns Detected"
218
+ risk_color = "orange"
219
+ elif risk_score > 0.3:
220
+ risk_level = "⚑ LOW-MEDIUM RISK - Some Concerns"
221
+ risk_color = "yellow"
222
  else:
223
  risk_level = "βœ… LOW RISK - Appears Legitimate"
224
  risk_color = "green"
225
 
226
+ # Feature analysis summary
227
+ feature_alerts = []
228
+ if features['has_suspicious_domains']:
229
+ feature_alerts.append("Suspicious domain detected")
230
+ if features['urgency_words'] > 2:
231
+ feature_alerts.append("High urgency language")
232
+ if features['personal_info_requests'] > 1:
233
+ feature_alerts.append("Requests personal information")
234
+ if features['spelling_errors'] > 0:
235
+ feature_alerts.append("Potential spelling errors")
236
+
237
+ # Format detailed result
238
  result = f"""
239
  ### {risk_level}
240
  **Primary Classification:** {prediction}
241
+ **Confidence:** {confidence:.1%}
242
+ **Feature Risk Score:** {risk_score:.2f}/1.00
243
+
244
+ **Analysis Alerts:**
245
+ {chr(10).join(f"β€’ {alert}" for alert in feature_alerts) if feature_alerts else "β€’ No significant risk patterns detected"}
246
+
247
+ **Technical Details:**
248
+ β€’ URLs found: {features['url_count']}
249
+ β€’ Urgency indicators: {features['urgency_words']}
250
+ β€’ Personal info requests: {features['personal_info_requests']}
251
  """
252
 
253
+ # Confidence breakdown for display (raw floats for gr.Label)
254
+ confidence_data = {label: prob for label, prob in labels.items()}
255
+
256
  return result, confidence_data, risk_color
257
 
258
  except Exception as e:
259
+ return f"Error during analysis: {str(e)}", {}, "orange"
260
 
261
+ # Enhanced Gradio Interface
 
 
 
 
 
 
262
  with gr.Blocks(
263
  theme=gr.themes.Soft(),
264
+ title="EmailGuard - Advanced Phishing Detection",
265
  css="""
266
  .risk-high { color: #dc2626 !important; font-weight: bold; }
267
  .risk-low { color: #16a34a !important; font-weight: bold; }
268
+ .main-container { max-width: 900px; margin: 0 auto; }
269
+ .feature-box { background: #f8f9fa; padding: 15px; border-radius: 8px; margin: 10px 0; }
270
  """
271
  ) as demo:
272
 
273
  gr.Markdown("""
274
+ # πŸ›‘οΈ EmailGuard - Advanced AI Phishing Detection
275
+ **Multi-Model Ensemble System with Feature Analysis**
276
 
277
+ ✨ **Enhanced Accuracy** β€’ πŸ” **Deep Pattern Analysis** β€’ πŸš€ **Real-time Results**
278
  """)
279
 
280
  with gr.Row():
281
  with gr.Column(scale=2):
282
  input_text = gr.Textbox(
283
+ label="πŸ“§ Email Content, URL, or Suspicious Message",
284
+ placeholder="Paste your email content, suspicious URL, or any text message here for comprehensive analysis...",
285
+ lines=10,
286
+ max_lines=20
287
  )
288
 
289
+ with gr.Row():
290
+ analyze_btn = gr.Button(
291
+ "πŸ” Advanced Analysis",
292
+ variant="primary",
293
+ size="lg"
294
+ )
295
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
296
 
297
  with gr.Column(scale=1):
298
+ result_output = gr.Markdown(label="πŸ“Š Analysis Results")
299
 
300
  confidence_output = gr.Label(
301
+ label="🎯 Confidence Breakdown",
302
  num_top_classes=4
303
  )
304
 
305
+ # Enhanced examples
306
+ gr.Markdown("### πŸ“‹ Test These Examples:")
307
 
308
  examples = [
309
+ ["URGENT: Your PayPal account has been limited! Verify immediately at http://paypal-security-check.suspicious.com/verify or lose access forever!"],
310
+ ["Hi Sarah, Thanks for sending the quarterly report. I've reviewed the numbers and they look good. Let's discuss in tomorrow's meeting. Best, Mike"],
311
+ ["πŸŽ‰ CONGRATULATIONS! You've won $50,000! Click here to claim: bit.ly/winner123. Act fast, expires in 24hrs! Reply with SSN to confirm."],
312
+ ["Your Microsoft Office subscription expires tomorrow. Renew now to avoid service interruption. Visit: https://office.microsoft.com/renew"],
313
+ ["Dear Valued Customer, We detected unusual activity on your account. Please verify your identity by clicking the link below and entering your password."],
314
+ ["Meeting reminder: Team standup at 10 AM in conference room A. Please bring your project updates. Thanks!"]
315
  ]
316
 
317
  gr.Examples(
 
322
 
323
  # Event handlers
324
  analyze_btn.click(
325
+ fn=advanced_predict_phishing,
326
  inputs=input_text,
327
  outputs=[result_output, confidence_output, gr.State()]
328
  )
329
 
330
+ clear_btn.click(
331
+ fn=lambda: ("", "", {}),
332
+ outputs=[input_text, result_output, confidence_output]
333
+ )
334
+
335
  input_text.submit(
336
+ fn=advanced_predict_phishing,
337
  inputs=input_text,
338
  outputs=[result_output, confidence_output, gr.State()]
339
  )
340
 
341
  gr.Markdown("""
342
  ---
343
+ ### πŸ”¬ Advanced Detection Features
344
+
345
+ **πŸ€– Multi-Model Ensemble:** Combines predictions from specialized models
346
+ **🎯 Feature Engineering:** Hand-crafted rules for pattern detection
347
+ **βš–οΈ Bias Reduction:** Multiple validation layers prevent false positives
348
+ **πŸ“Š Risk Scoring:** Comprehensive analysis beyond simple classification
349
+ **πŸ” URL Analysis:** Specialized detection for malicious links
350
+ **πŸ“ Content Analysis:** Deep text pattern recognition
351
+
352
+ ### ⚑ What Makes This More Accurate:
353
+ - **Ensemble Learning:** Multiple models vote on final decision
354
+ - **Feature Fusion:** AI + Rule-based detection combined
355
+ - **Adaptive Thresholds:** Dynamic risk assessment
356
+ - **Comprehensive Coverage:** Email, URL, and text message analysis
357
 
358
+ **⚠️ Academic Research Tool:** For educational purposes - always verify through official channels.
359
  """)
360
 
 
361
  if __name__ == "__main__":
362
  demo.launch(
363
  share=False,
364
+ server_name="0.0.0.0",
365
  server_port=7860,
366
+ show_error=True
 
367
  )