File size: 11,423 Bytes
a50dd05
da7d74b
42dc091
 
62a31d8
42dc091
 
 
 
 
 
 
 
 
 
 
da7d74b
42dc091
 
 
 
 
 
 
62a31d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42dc091
 
 
 
 
da7d74b
62a31d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da7d74b
42dc091
62a31d8
42dc091
 
 
 
da7d74b
62a31d8
 
da7d74b
42dc091
 
 
62a31d8
da7d74b
42dc091
 
 
 
 
 
 
 
 
 
62a31d8
42dc091
 
62a31d8
 
 
 
42dc091
62a31d8
42dc091
 
62a31d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42dc091
62a31d8
 
 
 
 
 
42dc091
62a31d8
42dc091
 
 
 
62a31d8
 
 
42dc091
 
 
 
 
 
 
 
 
 
62a31d8
 
 
 
 
 
 
 
 
42dc091
62a31d8
42dc091
62a31d8
42dc091
 
 
62a31d8
42dc091
 
62a31d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42dc091
62a31d8
42dc091
 
 
 
62a31d8
 
42dc091
62a31d8
42dc091
 
62a31d8
42dc091
 
62a31d8
42dc091
62a31d8
42dc091
62a31d8
42dc091
 
62a31d8
 
 
42dc091
62a31d8
42dc091
62a31d8
42dc091
62a31d8
42dc091
62a31d8
42dc091
62a31d8
42dc091
62a31d8
 
 
 
 
 
 
 
42dc091
 
 
da7d74b
62a31d8
 
42dc091
62a31d8
42dc091
62a31d8
42dc091
 
 
 
 
 
62a31d8
 
 
42dc091
 
 
 
 
 
62a31d8
42dc091
62a31d8
 
 
 
42dc091
 
62a31d8
 
42dc091
62a31d8
 
 
42dc091
 
62a31d8
 
42dc091
62a31d8
 
 
42dc091
62a31d8
42dc091
 
62a31d8
 
 
42dc091
a50dd05
42dc091
 
 
62a31d8
42dc091
62a31d8
 
42dc091
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging
import numpy as np

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Model configuration
MODEL_NAME = "cybersectony/phishing-email-detection-distilbert_v2.4.1"

# Global variables for model and tokenizer
tokenizer = None
model = None

def load_model():
    """Load the model and tokenizer with error handling"""
    global tokenizer, model
    try:
        logger.info("Loading model and tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
        
        # Debug: Print model configuration
        logger.info(f"Model config: {model.config}")
        logger.info(f"Number of labels: {model.config.num_labels}")
        if hasattr(model.config, 'id2label'):
            logger.info(f"Label mapping: {model.config.id2label}")
        
        # Test model with simple input to check if it's working
        test_input = "Hello world"
        inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512)
        with torch.no_grad():
            outputs = model(**inputs)
            test_probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
            logger.info(f"Test probabilities: {test_probs[0].tolist()}")
        
        logger.info("Model loaded successfully!")
        return True
    except Exception as e:
        logger.error(f"Error loading model: {e}")
        return False

def get_colored_bar(percentage, label):
    """Create colored progress bar based on percentage and label type"""
    # Determine color based on percentage and label
    if "phishing" in label.lower() or "suspicious" in label.lower():
        # Red scale for dangerous content
        if percentage >= 70:
            color = "🟥"  # High danger - red
        elif percentage >= 40:
            color = "🟠"  # Medium danger - orange  
        else:
            color = "🟡"  # Low danger - yellow
    else:
        # Green scale for legitimate content
        if percentage >= 70:
            color = "🟢"  # High confidence - green
        elif percentage >= 40:
            color = "🟡"  # Medium confidence - yellow
        else:
            color = "⚪"  # Low confidence - white
    
    # Create bar (scale to 20 characters)
    bar_length = max(1, int(percentage / 5))  # Ensure at least 1 if percentage > 0
    bar = color * bar_length + "⚪" * (20 - bar_length)
    
    return bar

def predict_email(email_text):
    """
    Enhanced prediction function with proper model output handling
    """
    # Input validation
    if not email_text or not email_text.strip():
        return "⚠️ **Error**: Please enter some email text to analyze."
    
    if len(email_text.strip()) < 5:
        return "⚠️ **Warning**: Email text too short for reliable analysis."
    
    # Check if model is loaded
    if tokenizer is None or model is None:
        if not load_model():
            return "❌ **Error**: Failed to load the model."
    
    try:
        # Preprocess and tokenize
        inputs = tokenizer(
            email_text,
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding=True
        )
        
        # Get prediction with proper handling
        with torch.no_grad():
            outputs = model(**inputs)
            # Apply temperature scaling to prevent overconfidence
            temperature = 1.5
            scaled_logits = outputs.logits / temperature
            predictions = torch.nn.functional.softmax(scaled_logits, dim=-1)
        
        # Get probabilities
        probs = predictions[0].tolist()
        
        # Log raw outputs for debugging
        logger.info(f"Raw logits: {outputs.logits[0].tolist()}")
        logger.info(f"Scaled probabilities: {probs}")
        
        # Get proper labels from model config or use fallback
        if hasattr(model.config, 'id2label') and model.config.id2label:
            labels = {model.config.id2label[i]: probs[i] for i in range(len(probs))}
        else:
            # Fallback - check the actual model output dimension
            if len(probs) == 2:
                labels = {
                    "Legitimate Email": probs[0],
                    "Phishing Email": probs[1]
                }
            elif len(probs) == 4:
                labels = {
                    "Legitimate Email": probs[0],
                    "Phishing Email": probs[1],
                    "Suspicious Content": probs[2],
                    "Spam Email": probs[3]
                }
            else:
                # Generic labels
                labels = {f"Class {i}": probs[i] for i in range(len(probs))}
        
        # Check if model is giving reasonable outputs
        prob_variance = np.var(probs)
        max_prob = max(probs)
        
        # If variance is too low, the model might not be working properly
        if prob_variance < 0.01 and max_prob > 0.99:
            logger.warning("Model showing signs of overconfidence or poor calibration")
            # Apply smoothing
            smoothed_probs = [(p * 0.8 + 0.2/len(probs)) for p in probs]
            labels = {list(labels.keys())[i]: smoothed_probs[i] for i in range(len(smoothed_probs))}
        
        # Find prediction
        max_label = max(labels.items(), key=lambda x: x[1])
        
        # Determine risk level and emoji
        confidence = max_label[1]
        prediction_name = max_label[0]
        
        if any(word in prediction_name.lower() for word in ['phishing', 'suspicious', 'spam']):
            if confidence > 0.8:
                risk_emoji = "🚨"
                risk_level = "HIGH RISK"
            elif confidence > 0.6:
                risk_emoji = "⚠️"
                risk_level = "MEDIUM RISK"
            else:
                risk_emoji = "⚡"
                risk_level = "LOW RISK"
        else:
            if confidence > 0.8:
                risk_emoji = "✅"
                risk_level = "SAFE"
            elif confidence > 0.6:
                risk_emoji = "✅"
                risk_level = "LIKELY SAFE"
            else:
                risk_emoji = "❓"
                risk_level = "UNCERTAIN"
        
        # Format output with colored bars
        result = f"{risk_emoji} **{risk_level}**\n\n"
        result += f"**Primary Classification**: {prediction_name}\n"
        result += f"**Confidence**: {confidence:.1%}\n\n"
        result += f"**Detailed Analysis**:\n"
        
        # Sort by probability and add colored bars
        for label, prob in sorted(labels.items(), key=lambda x: x[1], reverse=True):
            percentage = prob * 100
            colored_bar = get_colored_bar(percentage, label)
            result += f"{label}: {percentage:.1f}% {colored_bar}\n"
        
        # Add debug info
        result += f"\n**Debug Info**:\n"
        result += f"Model Variance: {prob_variance:.4f}\n"
        result += f"Raw Probabilities: {[f'{p:.3f}' for p in probs]}\n"
        
        # Add recommendations based on actual classification
        if any(word in prediction_name.lower() for word in ['phishing', 'suspicious']) and confidence > 0.6:
            result += f"\n⚠️ **Recommendation**: This email shows signs of being malicious. Avoid clicking links or providing personal information."
        elif 'spam' in prediction_name.lower():
            result += f"\n🗑️ **Recommendation**: This appears to be spam. Consider deleting or marking as junk."
        elif confidence > 0.7:
            result += f"\n✅ **Recommendation**: This email appears legitimate, but always remain vigilant."
        else:
            result += f"\n❓ **Recommendation**: Classification uncertain. Exercise caution and verify sender if needed."
            
        return result
        
    except Exception as e:
        logger.error(f"Error during prediction: {e}", exc_info=True)
        return f"❌ **Error**: Analysis failed - {str(e)}"

# Example emails for testing
example_legitimate = """Dear Customer,

Thank you for your recent purchase from TechStore. Your order #ORD-2024-001234 has been successfully processed.

Order Details:
- Product: Wireless Headphones  
- Amount: $79.99
- Estimated delivery: 3-5 business days

You will receive a tracking number once your item ships.

Best regards,
TechStore Customer Service"""

example_phishing = """URGENT SECURITY ALERT!!!

Your account has been COMPROMISED! Immediate action required!

Click here NOW to secure your account: http://fake-security-site.malicious.com/urgent-verify

WARNING: You have only 24 hours before your account is permanently suspended!

This is your FINAL notice - act immediately!

Security Department"""

example_neutral = """Hi team,

Hope everyone is doing well. Just wanted to remind you about the meeting scheduled for tomorrow at 2 PM in the conference room.

Please bring your project updates and any questions you might have.

Thanks,
Sarah"""

# Load model on startup
load_model()

# Create enhanced Gradio interface
with gr.Blocks(title="PhishGuardian AI", theme=gr.themes.Soft()) as iface:
    gr.Markdown("""
    # 🛡️ PhishGuardian AI - Enhanced Detection
    
    Advanced phishing email detection with colored risk indicators and improved model handling.
    """)
    
    with gr.Row():
        with gr.Column(scale=2):
            email_input = gr.Textbox(
                lines=10,
                placeholder="Paste your email content here for analysis...",
                label="📧 Email Content",
                info="Enter the complete email text for comprehensive analysis"
            )
            
            with gr.Row():
                analyze_btn = gr.Button("🔍 Analyze Email", variant="primary", size="lg")
                clear_btn = gr.Button("🗑️ Clear", variant="secondary")
        
        with gr.Column(scale=2):
            output = gr.Textbox(
                label="🛡️ Security Analysis Results",
                lines=20,
                interactive=False,
                show_copy_button=True
            )
    
    # Example section with better examples
    gr.Markdown("### 📝 Test Examples")
    with gr.Row():
        legit_btn = gr.Button("✅ Legitimate Email", size="sm")
        phish_btn = gr.Button("🚨 Phishing Email", size="sm") 
        neutral_btn = gr.Button("📄 Neutral Text", size="sm")
    
    # Event handlers
    analyze_btn.click(predict_email, inputs=email_input, outputs=output)
    clear_btn.click(lambda: ("", ""), outputs=[email_input, output])
    
    legit_btn.click(lambda: example_legitimate, outputs=email_input)
    phish_btn.click(lambda: example_phishing, outputs=email_input)
    neutral_btn.click(lambda: example_neutral, outputs=email_input)
    
    # Footer with model info
    gr.Markdown("""
    ---
    **🔧 Model**: cybersectony/phishing-email-detection-distilbert_v2.4.1  
    **🎯 Features**: Temperature scaling, colored risk bars, enhanced debugging  
    **🏛️ Institution**: University of Dar es Salaam (UDSM)
    """)

if __name__ == "__main__":
    iface.launch(
        share=True,
        server_name="0.0.0.0", 
        server_port=7860,
        show_error=True,
        debug=True
    )