Spaces:
Runtime error
Runtime error
File size: 11,423 Bytes
a50dd05 da7d74b 42dc091 62a31d8 42dc091 da7d74b 42dc091 62a31d8 42dc091 da7d74b 62a31d8 da7d74b 42dc091 62a31d8 42dc091 da7d74b 62a31d8 da7d74b 42dc091 62a31d8 da7d74b 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 da7d74b 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 62a31d8 42dc091 a50dd05 42dc091 62a31d8 42dc091 62a31d8 42dc091 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Model configuration
MODEL_NAME = "cybersectony/phishing-email-detection-distilbert_v2.4.1"
# Global variables for model and tokenizer
tokenizer = None
model = None
def load_model():
"""Load the model and tokenizer with error handling"""
global tokenizer, model
try:
logger.info("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
# Debug: Print model configuration
logger.info(f"Model config: {model.config}")
logger.info(f"Number of labels: {model.config.num_labels}")
if hasattr(model.config, 'id2label'):
logger.info(f"Label mapping: {model.config.id2label}")
# Test model with simple input to check if it's working
test_input = "Hello world"
inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
test_probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
logger.info(f"Test probabilities: {test_probs[0].tolist()}")
logger.info("Model loaded successfully!")
return True
except Exception as e:
logger.error(f"Error loading model: {e}")
return False
def get_colored_bar(percentage, label):
"""Create colored progress bar based on percentage and label type"""
# Determine color based on percentage and label
if "phishing" in label.lower() or "suspicious" in label.lower():
# Red scale for dangerous content
if percentage >= 70:
color = "🟥" # High danger - red
elif percentage >= 40:
color = "🟠" # Medium danger - orange
else:
color = "🟡" # Low danger - yellow
else:
# Green scale for legitimate content
if percentage >= 70:
color = "🟢" # High confidence - green
elif percentage >= 40:
color = "🟡" # Medium confidence - yellow
else:
color = "⚪" # Low confidence - white
# Create bar (scale to 20 characters)
bar_length = max(1, int(percentage / 5)) # Ensure at least 1 if percentage > 0
bar = color * bar_length + "⚪" * (20 - bar_length)
return bar
def predict_email(email_text):
"""
Enhanced prediction function with proper model output handling
"""
# Input validation
if not email_text or not email_text.strip():
return "⚠️ **Error**: Please enter some email text to analyze."
if len(email_text.strip()) < 5:
return "⚠️ **Warning**: Email text too short for reliable analysis."
# Check if model is loaded
if tokenizer is None or model is None:
if not load_model():
return "❌ **Error**: Failed to load the model."
try:
# Preprocess and tokenize
inputs = tokenizer(
email_text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# Get prediction with proper handling
with torch.no_grad():
outputs = model(**inputs)
# Apply temperature scaling to prevent overconfidence
temperature = 1.5
scaled_logits = outputs.logits / temperature
predictions = torch.nn.functional.softmax(scaled_logits, dim=-1)
# Get probabilities
probs = predictions[0].tolist()
# Log raw outputs for debugging
logger.info(f"Raw logits: {outputs.logits[0].tolist()}")
logger.info(f"Scaled probabilities: {probs}")
# Get proper labels from model config or use fallback
if hasattr(model.config, 'id2label') and model.config.id2label:
labels = {model.config.id2label[i]: probs[i] for i in range(len(probs))}
else:
# Fallback - check the actual model output dimension
if len(probs) == 2:
labels = {
"Legitimate Email": probs[0],
"Phishing Email": probs[1]
}
elif len(probs) == 4:
labels = {
"Legitimate Email": probs[0],
"Phishing Email": probs[1],
"Suspicious Content": probs[2],
"Spam Email": probs[3]
}
else:
# Generic labels
labels = {f"Class {i}": probs[i] for i in range(len(probs))}
# Check if model is giving reasonable outputs
prob_variance = np.var(probs)
max_prob = max(probs)
# If variance is too low, the model might not be working properly
if prob_variance < 0.01 and max_prob > 0.99:
logger.warning("Model showing signs of overconfidence or poor calibration")
# Apply smoothing
smoothed_probs = [(p * 0.8 + 0.2/len(probs)) for p in probs]
labels = {list(labels.keys())[i]: smoothed_probs[i] for i in range(len(smoothed_probs))}
# Find prediction
max_label = max(labels.items(), key=lambda x: x[1])
# Determine risk level and emoji
confidence = max_label[1]
prediction_name = max_label[0]
if any(word in prediction_name.lower() for word in ['phishing', 'suspicious', 'spam']):
if confidence > 0.8:
risk_emoji = "🚨"
risk_level = "HIGH RISK"
elif confidence > 0.6:
risk_emoji = "⚠️"
risk_level = "MEDIUM RISK"
else:
risk_emoji = "⚡"
risk_level = "LOW RISK"
else:
if confidence > 0.8:
risk_emoji = "✅"
risk_level = "SAFE"
elif confidence > 0.6:
risk_emoji = "✅"
risk_level = "LIKELY SAFE"
else:
risk_emoji = "❓"
risk_level = "UNCERTAIN"
# Format output with colored bars
result = f"{risk_emoji} **{risk_level}**\n\n"
result += f"**Primary Classification**: {prediction_name}\n"
result += f"**Confidence**: {confidence:.1%}\n\n"
result += f"**Detailed Analysis**:\n"
# Sort by probability and add colored bars
for label, prob in sorted(labels.items(), key=lambda x: x[1], reverse=True):
percentage = prob * 100
colored_bar = get_colored_bar(percentage, label)
result += f"{label}: {percentage:.1f}% {colored_bar}\n"
# Add debug info
result += f"\n**Debug Info**:\n"
result += f"Model Variance: {prob_variance:.4f}\n"
result += f"Raw Probabilities: {[f'{p:.3f}' for p in probs]}\n"
# Add recommendations based on actual classification
if any(word in prediction_name.lower() for word in ['phishing', 'suspicious']) and confidence > 0.6:
result += f"\n⚠️ **Recommendation**: This email shows signs of being malicious. Avoid clicking links or providing personal information."
elif 'spam' in prediction_name.lower():
result += f"\n🗑️ **Recommendation**: This appears to be spam. Consider deleting or marking as junk."
elif confidence > 0.7:
result += f"\n✅ **Recommendation**: This email appears legitimate, but always remain vigilant."
else:
result += f"\n❓ **Recommendation**: Classification uncertain. Exercise caution and verify sender if needed."
return result
except Exception as e:
logger.error(f"Error during prediction: {e}", exc_info=True)
return f"❌ **Error**: Analysis failed - {str(e)}"
# Example emails for testing
example_legitimate = """Dear Customer,
Thank you for your recent purchase from TechStore. Your order #ORD-2024-001234 has been successfully processed.
Order Details:
- Product: Wireless Headphones
- Amount: $79.99
- Estimated delivery: 3-5 business days
You will receive a tracking number once your item ships.
Best regards,
TechStore Customer Service"""
example_phishing = """URGENT SECURITY ALERT!!!
Your account has been COMPROMISED! Immediate action required!
Click here NOW to secure your account: http://fake-security-site.malicious.com/urgent-verify
WARNING: You have only 24 hours before your account is permanently suspended!
This is your FINAL notice - act immediately!
Security Department"""
example_neutral = """Hi team,
Hope everyone is doing well. Just wanted to remind you about the meeting scheduled for tomorrow at 2 PM in the conference room.
Please bring your project updates and any questions you might have.
Thanks,
Sarah"""
# Load model on startup
load_model()
# Create enhanced Gradio interface
with gr.Blocks(title="PhishGuardian AI", theme=gr.themes.Soft()) as iface:
gr.Markdown("""
# 🛡️ PhishGuardian AI - Enhanced Detection
Advanced phishing email detection with colored risk indicators and improved model handling.
""")
with gr.Row():
with gr.Column(scale=2):
email_input = gr.Textbox(
lines=10,
placeholder="Paste your email content here for analysis...",
label="📧 Email Content",
info="Enter the complete email text for comprehensive analysis"
)
with gr.Row():
analyze_btn = gr.Button("🔍 Analyze Email", variant="primary", size="lg")
clear_btn = gr.Button("🗑️ Clear", variant="secondary")
with gr.Column(scale=2):
output = gr.Textbox(
label="🛡️ Security Analysis Results",
lines=20,
interactive=False,
show_copy_button=True
)
# Example section with better examples
gr.Markdown("### 📝 Test Examples")
with gr.Row():
legit_btn = gr.Button("✅ Legitimate Email", size="sm")
phish_btn = gr.Button("🚨 Phishing Email", size="sm")
neutral_btn = gr.Button("📄 Neutral Text", size="sm")
# Event handlers
analyze_btn.click(predict_email, inputs=email_input, outputs=output)
clear_btn.click(lambda: ("", ""), outputs=[email_input, output])
legit_btn.click(lambda: example_legitimate, outputs=email_input)
phish_btn.click(lambda: example_phishing, outputs=email_input)
neutral_btn.click(lambda: example_neutral, outputs=email_input)
# Footer with model info
gr.Markdown("""
---
**🔧 Model**: cybersectony/phishing-email-detection-distilbert_v2.4.1
**🎯 Features**: Temperature scaling, colored risk bars, enhanced debugging
**🏛️ Institution**: University of Dar es Salaam (UDSM)
""")
if __name__ == "__main__":
iface.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True,
debug=True
) |