Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import logging | |
import numpy as np | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Model configuration | |
MODEL_NAME = "cybersectony/phishing-email-detection-distilbert_v2.4.1" | |
# Global variables for model and tokenizer | |
tokenizer = None | |
model = None | |
def load_model(): | |
"""Load the model and tokenizer with error handling""" | |
global tokenizer, model | |
try: | |
logger.info("Loading model and tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
# Debug: Print model configuration | |
logger.info(f"Model config: {model.config}") | |
logger.info(f"Number of labels: {model.config.num_labels}") | |
if hasattr(model.config, 'id2label'): | |
logger.info(f"Label mapping: {model.config.id2label}") | |
# Test model with simple input to check if it's working | |
test_input = "Hello world" | |
inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
test_probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
logger.info(f"Test probabilities: {test_probs[0].tolist()}") | |
logger.info("Model loaded successfully!") | |
return True | |
except Exception as e: | |
logger.error(f"Error loading model: {e}") | |
return False | |
def get_colored_bar(percentage, label): | |
"""Create colored progress bar based on percentage and label type""" | |
# Determine color based on percentage and label | |
if "phishing" in label.lower() or "suspicious" in label.lower(): | |
# Red scale for dangerous content | |
if percentage >= 70: | |
color = "🟥" # High danger - red | |
elif percentage >= 40: | |
color = "🟠" # Medium danger - orange | |
else: | |
color = "🟡" # Low danger - yellow | |
else: | |
# Green scale for legitimate content | |
if percentage >= 70: | |
color = "🟢" # High confidence - green | |
elif percentage >= 40: | |
color = "🟡" # Medium confidence - yellow | |
else: | |
color = "⚪" # Low confidence - white | |
# Create bar (scale to 20 characters) | |
bar_length = max(1, int(percentage / 5)) # Ensure at least 1 if percentage > 0 | |
bar = color * bar_length + "⚪" * (20 - bar_length) | |
return bar | |
def predict_email(email_text): | |
""" | |
Enhanced prediction function with proper model output handling | |
""" | |
# Input validation | |
if not email_text or not email_text.strip(): | |
return "⚠️ **Error**: Please enter some email text to analyze." | |
if len(email_text.strip()) < 5: | |
return "⚠️ **Warning**: Email text too short for reliable analysis." | |
# Check if model is loaded | |
if tokenizer is None or model is None: | |
if not load_model(): | |
return "❌ **Error**: Failed to load the model." | |
try: | |
# Preprocess and tokenize | |
inputs = tokenizer( | |
email_text, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
padding=True | |
) | |
# Get prediction with proper handling | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Apply temperature scaling to prevent overconfidence | |
temperature = 1.5 | |
scaled_logits = outputs.logits / temperature | |
predictions = torch.nn.functional.softmax(scaled_logits, dim=-1) | |
# Get probabilities | |
probs = predictions[0].tolist() | |
# Log raw outputs for debugging | |
logger.info(f"Raw logits: {outputs.logits[0].tolist()}") | |
logger.info(f"Scaled probabilities: {probs}") | |
# Get proper labels from model config or use fallback | |
if hasattr(model.config, 'id2label') and model.config.id2label: | |
labels = {model.config.id2label[i]: probs[i] for i in range(len(probs))} | |
else: | |
# Fallback - check the actual model output dimension | |
if len(probs) == 2: | |
labels = { | |
"Legitimate Email": probs[0], | |
"Phishing Email": probs[1] | |
} | |
elif len(probs) == 4: | |
labels = { | |
"Legitimate Email": probs[0], | |
"Phishing Email": probs[1], | |
"Suspicious Content": probs[2], | |
"Spam Email": probs[3] | |
} | |
else: | |
# Generic labels | |
labels = {f"Class {i}": probs[i] for i in range(len(probs))} | |
# Check if model is giving reasonable outputs | |
prob_variance = np.var(probs) | |
max_prob = max(probs) | |
# If variance is too low, the model might not be working properly | |
if prob_variance < 0.01 and max_prob > 0.99: | |
logger.warning("Model showing signs of overconfidence or poor calibration") | |
# Apply smoothing | |
smoothed_probs = [(p * 0.8 + 0.2/len(probs)) for p in probs] | |
labels = {list(labels.keys())[i]: smoothed_probs[i] for i in range(len(smoothed_probs))} | |
# Find prediction | |
max_label = max(labels.items(), key=lambda x: x[1]) | |
# Determine risk level and emoji | |
confidence = max_label[1] | |
prediction_name = max_label[0] | |
if any(word in prediction_name.lower() for word in ['phishing', 'suspicious', 'spam']): | |
if confidence > 0.8: | |
risk_emoji = "🚨" | |
risk_level = "HIGH RISK" | |
elif confidence > 0.6: | |
risk_emoji = "⚠️" | |
risk_level = "MEDIUM RISK" | |
else: | |
risk_emoji = "⚡" | |
risk_level = "LOW RISK" | |
else: | |
if confidence > 0.8: | |
risk_emoji = "✅" | |
risk_level = "SAFE" | |
elif confidence > 0.6: | |
risk_emoji = "✅" | |
risk_level = "LIKELY SAFE" | |
else: | |
risk_emoji = "❓" | |
risk_level = "UNCERTAIN" | |
# Format output with colored bars | |
result = f"{risk_emoji} **{risk_level}**\n\n" | |
result += f"**Primary Classification**: {prediction_name}\n" | |
result += f"**Confidence**: {confidence:.1%}\n\n" | |
result += f"**Detailed Analysis**:\n" | |
# Sort by probability and add colored bars | |
for label, prob in sorted(labels.items(), key=lambda x: x[1], reverse=True): | |
percentage = prob * 100 | |
colored_bar = get_colored_bar(percentage, label) | |
result += f"{label}: {percentage:.1f}% {colored_bar}\n" | |
# Add debug info | |
result += f"\n**Debug Info**:\n" | |
result += f"Model Variance: {prob_variance:.4f}\n" | |
result += f"Raw Probabilities: {[f'{p:.3f}' for p in probs]}\n" | |
# Add recommendations based on actual classification | |
if any(word in prediction_name.lower() for word in ['phishing', 'suspicious']) and confidence > 0.6: | |
result += f"\n⚠️ **Recommendation**: This email shows signs of being malicious. Avoid clicking links or providing personal information." | |
elif 'spam' in prediction_name.lower(): | |
result += f"\n🗑️ **Recommendation**: This appears to be spam. Consider deleting or marking as junk." | |
elif confidence > 0.7: | |
result += f"\n✅ **Recommendation**: This email appears legitimate, but always remain vigilant." | |
else: | |
result += f"\n❓ **Recommendation**: Classification uncertain. Exercise caution and verify sender if needed." | |
return result | |
except Exception as e: | |
logger.error(f"Error during prediction: {e}", exc_info=True) | |
return f"❌ **Error**: Analysis failed - {str(e)}" | |
# Example emails for testing | |
example_legitimate = """Dear Customer, | |
Thank you for your recent purchase from TechStore. Your order #ORD-2024-001234 has been successfully processed. | |
Order Details: | |
- Product: Wireless Headphones | |
- Amount: $79.99 | |
- Estimated delivery: 3-5 business days | |
You will receive a tracking number once your item ships. | |
Best regards, | |
TechStore Customer Service""" | |
example_phishing = """URGENT SECURITY ALERT!!! | |
Your account has been COMPROMISED! Immediate action required! | |
Click here NOW to secure your account: http://fake-security-site.malicious.com/urgent-verify | |
WARNING: You have only 24 hours before your account is permanently suspended! | |
This is your FINAL notice - act immediately! | |
Security Department""" | |
example_neutral = """Hi team, | |
Hope everyone is doing well. Just wanted to remind you about the meeting scheduled for tomorrow at 2 PM in the conference room. | |
Please bring your project updates and any questions you might have. | |
Thanks, | |
Sarah""" | |
# Load model on startup | |
load_model() | |
# Create enhanced Gradio interface | |
with gr.Blocks(title="PhishGuardian AI", theme=gr.themes.Soft()) as iface: | |
gr.Markdown(""" | |
# 🛡️ PhishGuardian AI - Enhanced Detection | |
Advanced phishing email detection with colored risk indicators and improved model handling. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
email_input = gr.Textbox( | |
lines=10, | |
placeholder="Paste your email content here for analysis...", | |
label="📧 Email Content", | |
info="Enter the complete email text for comprehensive analysis" | |
) | |
with gr.Row(): | |
analyze_btn = gr.Button("🔍 Analyze Email", variant="primary", size="lg") | |
clear_btn = gr.Button("🗑️ Clear", variant="secondary") | |
with gr.Column(scale=2): | |
output = gr.Textbox( | |
label="🛡️ Security Analysis Results", | |
lines=20, | |
interactive=False, | |
show_copy_button=True | |
) | |
# Example section with better examples | |
gr.Markdown("### 📝 Test Examples") | |
with gr.Row(): | |
legit_btn = gr.Button("✅ Legitimate Email", size="sm") | |
phish_btn = gr.Button("🚨 Phishing Email", size="sm") | |
neutral_btn = gr.Button("📄 Neutral Text", size="sm") | |
# Event handlers | |
analyze_btn.click(predict_email, inputs=email_input, outputs=output) | |
clear_btn.click(lambda: ("", ""), outputs=[email_input, output]) | |
legit_btn.click(lambda: example_legitimate, outputs=email_input) | |
phish_btn.click(lambda: example_phishing, outputs=email_input) | |
neutral_btn.click(lambda: example_neutral, outputs=email_input) | |
# Footer with model info | |
gr.Markdown(""" | |
--- | |
**🔧 Model**: cybersectony/phishing-email-detection-distilbert_v2.4.1 | |
**🎯 Features**: Temperature scaling, colored risk bars, enhanced debugging | |
**🏛️ Institution**: University of Dar es Salaam (UDSM) | |
""") | |
if __name__ == "__main__": | |
iface.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
debug=True | |
) |