MUFASA25's picture
model respnsive
62a31d8 verified
raw
history blame
11.4 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Model configuration
MODEL_NAME = "cybersectony/phishing-email-detection-distilbert_v2.4.1"
# Global variables for model and tokenizer
tokenizer = None
model = None
def load_model():
"""Load the model and tokenizer with error handling"""
global tokenizer, model
try:
logger.info("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
# Debug: Print model configuration
logger.info(f"Model config: {model.config}")
logger.info(f"Number of labels: {model.config.num_labels}")
if hasattr(model.config, 'id2label'):
logger.info(f"Label mapping: {model.config.id2label}")
# Test model with simple input to check if it's working
test_input = "Hello world"
inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
test_probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
logger.info(f"Test probabilities: {test_probs[0].tolist()}")
logger.info("Model loaded successfully!")
return True
except Exception as e:
logger.error(f"Error loading model: {e}")
return False
def get_colored_bar(percentage, label):
"""Create colored progress bar based on percentage and label type"""
# Determine color based on percentage and label
if "phishing" in label.lower() or "suspicious" in label.lower():
# Red scale for dangerous content
if percentage >= 70:
color = "🟥" # High danger - red
elif percentage >= 40:
color = "🟠" # Medium danger - orange
else:
color = "🟡" # Low danger - yellow
else:
# Green scale for legitimate content
if percentage >= 70:
color = "🟢" # High confidence - green
elif percentage >= 40:
color = "🟡" # Medium confidence - yellow
else:
color = "⚪" # Low confidence - white
# Create bar (scale to 20 characters)
bar_length = max(1, int(percentage / 5)) # Ensure at least 1 if percentage > 0
bar = color * bar_length + "⚪" * (20 - bar_length)
return bar
def predict_email(email_text):
"""
Enhanced prediction function with proper model output handling
"""
# Input validation
if not email_text or not email_text.strip():
return "⚠️ **Error**: Please enter some email text to analyze."
if len(email_text.strip()) < 5:
return "⚠️ **Warning**: Email text too short for reliable analysis."
# Check if model is loaded
if tokenizer is None or model is None:
if not load_model():
return "❌ **Error**: Failed to load the model."
try:
# Preprocess and tokenize
inputs = tokenizer(
email_text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# Get prediction with proper handling
with torch.no_grad():
outputs = model(**inputs)
# Apply temperature scaling to prevent overconfidence
temperature = 1.5
scaled_logits = outputs.logits / temperature
predictions = torch.nn.functional.softmax(scaled_logits, dim=-1)
# Get probabilities
probs = predictions[0].tolist()
# Log raw outputs for debugging
logger.info(f"Raw logits: {outputs.logits[0].tolist()}")
logger.info(f"Scaled probabilities: {probs}")
# Get proper labels from model config or use fallback
if hasattr(model.config, 'id2label') and model.config.id2label:
labels = {model.config.id2label[i]: probs[i] for i in range(len(probs))}
else:
# Fallback - check the actual model output dimension
if len(probs) == 2:
labels = {
"Legitimate Email": probs[0],
"Phishing Email": probs[1]
}
elif len(probs) == 4:
labels = {
"Legitimate Email": probs[0],
"Phishing Email": probs[1],
"Suspicious Content": probs[2],
"Spam Email": probs[3]
}
else:
# Generic labels
labels = {f"Class {i}": probs[i] for i in range(len(probs))}
# Check if model is giving reasonable outputs
prob_variance = np.var(probs)
max_prob = max(probs)
# If variance is too low, the model might not be working properly
if prob_variance < 0.01 and max_prob > 0.99:
logger.warning("Model showing signs of overconfidence or poor calibration")
# Apply smoothing
smoothed_probs = [(p * 0.8 + 0.2/len(probs)) for p in probs]
labels = {list(labels.keys())[i]: smoothed_probs[i] for i in range(len(smoothed_probs))}
# Find prediction
max_label = max(labels.items(), key=lambda x: x[1])
# Determine risk level and emoji
confidence = max_label[1]
prediction_name = max_label[0]
if any(word in prediction_name.lower() for word in ['phishing', 'suspicious', 'spam']):
if confidence > 0.8:
risk_emoji = "🚨"
risk_level = "HIGH RISK"
elif confidence > 0.6:
risk_emoji = "⚠️"
risk_level = "MEDIUM RISK"
else:
risk_emoji = "⚡"
risk_level = "LOW RISK"
else:
if confidence > 0.8:
risk_emoji = "✅"
risk_level = "SAFE"
elif confidence > 0.6:
risk_emoji = "✅"
risk_level = "LIKELY SAFE"
else:
risk_emoji = "❓"
risk_level = "UNCERTAIN"
# Format output with colored bars
result = f"{risk_emoji} **{risk_level}**\n\n"
result += f"**Primary Classification**: {prediction_name}\n"
result += f"**Confidence**: {confidence:.1%}\n\n"
result += f"**Detailed Analysis**:\n"
# Sort by probability and add colored bars
for label, prob in sorted(labels.items(), key=lambda x: x[1], reverse=True):
percentage = prob * 100
colored_bar = get_colored_bar(percentage, label)
result += f"{label}: {percentage:.1f}% {colored_bar}\n"
# Add debug info
result += f"\n**Debug Info**:\n"
result += f"Model Variance: {prob_variance:.4f}\n"
result += f"Raw Probabilities: {[f'{p:.3f}' for p in probs]}\n"
# Add recommendations based on actual classification
if any(word in prediction_name.lower() for word in ['phishing', 'suspicious']) and confidence > 0.6:
result += f"\n⚠️ **Recommendation**: This email shows signs of being malicious. Avoid clicking links or providing personal information."
elif 'spam' in prediction_name.lower():
result += f"\n🗑️ **Recommendation**: This appears to be spam. Consider deleting or marking as junk."
elif confidence > 0.7:
result += f"\n✅ **Recommendation**: This email appears legitimate, but always remain vigilant."
else:
result += f"\n❓ **Recommendation**: Classification uncertain. Exercise caution and verify sender if needed."
return result
except Exception as e:
logger.error(f"Error during prediction: {e}", exc_info=True)
return f"❌ **Error**: Analysis failed - {str(e)}"
# Example emails for testing
example_legitimate = """Dear Customer,
Thank you for your recent purchase from TechStore. Your order #ORD-2024-001234 has been successfully processed.
Order Details:
- Product: Wireless Headphones
- Amount: $79.99
- Estimated delivery: 3-5 business days
You will receive a tracking number once your item ships.
Best regards,
TechStore Customer Service"""
example_phishing = """URGENT SECURITY ALERT!!!
Your account has been COMPROMISED! Immediate action required!
Click here NOW to secure your account: http://fake-security-site.malicious.com/urgent-verify
WARNING: You have only 24 hours before your account is permanently suspended!
This is your FINAL notice - act immediately!
Security Department"""
example_neutral = """Hi team,
Hope everyone is doing well. Just wanted to remind you about the meeting scheduled for tomorrow at 2 PM in the conference room.
Please bring your project updates and any questions you might have.
Thanks,
Sarah"""
# Load model on startup
load_model()
# Create enhanced Gradio interface
with gr.Blocks(title="PhishGuardian AI", theme=gr.themes.Soft()) as iface:
gr.Markdown("""
# 🛡️ PhishGuardian AI - Enhanced Detection
Advanced phishing email detection with colored risk indicators and improved model handling.
""")
with gr.Row():
with gr.Column(scale=2):
email_input = gr.Textbox(
lines=10,
placeholder="Paste your email content here for analysis...",
label="📧 Email Content",
info="Enter the complete email text for comprehensive analysis"
)
with gr.Row():
analyze_btn = gr.Button("🔍 Analyze Email", variant="primary", size="lg")
clear_btn = gr.Button("🗑️ Clear", variant="secondary")
with gr.Column(scale=2):
output = gr.Textbox(
label="🛡️ Security Analysis Results",
lines=20,
interactive=False,
show_copy_button=True
)
# Example section with better examples
gr.Markdown("### 📝 Test Examples")
with gr.Row():
legit_btn = gr.Button("✅ Legitimate Email", size="sm")
phish_btn = gr.Button("🚨 Phishing Email", size="sm")
neutral_btn = gr.Button("📄 Neutral Text", size="sm")
# Event handlers
analyze_btn.click(predict_email, inputs=email_input, outputs=output)
clear_btn.click(lambda: ("", ""), outputs=[email_input, output])
legit_btn.click(lambda: example_legitimate, outputs=email_input)
phish_btn.click(lambda: example_phishing, outputs=email_input)
neutral_btn.click(lambda: example_neutral, outputs=email_input)
# Footer with model info
gr.Markdown("""
---
**🔧 Model**: cybersectony/phishing-email-detection-distilbert_v2.4.1
**🎯 Features**: Temperature scaling, colored risk bars, enhanced debugging
**🏛️ Institution**: University of Dar es Salaam (UDSM)
""")
if __name__ == "__main__":
iface.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True,
debug=True
)