Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import numpy as np | |
import re | |
from urllib.parse import urlparse | |
import hashlib | |
import os | |
# Multi-Model Configuration | |
MODELS = { | |
"primary": "cybersectony/phishing-email-detection-distilbert_v2.4.1", | |
"secondary": "microsoft/DialoGPT-medium", # Fallback for context | |
"url_specialist": "cybersectony/phishing-email-detection-distilbert_v2.4.1" # URL-focused | |
} | |
# Global model storage | |
models = {} | |
tokenizers = {} | |
class AdvancedPhishingDetector: | |
def __init__(self): | |
self.load_models() | |
def load_models(self): | |
"""Load multiple models for ensemble prediction""" | |
global models, tokenizers | |
try: | |
for name, model_path in MODELS.items(): | |
if name == "secondary": | |
continue # Skip for now, use primary model | |
tokenizers[name] = AutoTokenizer.from_pretrained(model_path) | |
models[name] = AutoModelForSequenceClassification.from_pretrained(model_path) | |
models[name].eval() | |
return True | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
return False | |
def extract_features(self, text): | |
"""Extract hand-crafted features for bias reduction""" | |
features = {} | |
# URL features | |
urls = re.findall(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', text) | |
features['url_count'] = len(urls) | |
features['has_suspicious_domains'] = any( | |
domain in url.lower() for url in urls | |
for domain in ['bit.ly', 'tinyurl', 'shorturl', 'suspicious', 'phish', 'scam'] | |
) | |
# Text pattern features | |
features['urgency_words'] = len(re.findall(r'urgent|immediate|expire|suspend|verify|confirm|click|act now', text.lower())) | |
features['money_mentions'] = len(re.findall(r'\$|money|payment|refund|prize|winner|lottery', text.lower())) | |
features['personal_info_requests'] = len(re.findall(r'password|ssn|social security|credit card|pin|account', text.lower())) | |
features['spelling_errors'] = self.count_potential_errors(text) | |
features['excessive_caps'] = len(re.findall(r'[A-Z]{3,}', text)) | |
# Sender authenticity indicators | |
features['generic_greetings'] = 1 if re.search(r'^(dear (customer|user|sir|madam))', text.lower()) else 0 | |
features['email_length'] = len(text) | |
features['has_attachments'] = 1 if 'attachment' in text.lower() else 0 | |
return features | |
def count_potential_errors(self, text): | |
"""Simple heuristic for spelling errors""" | |
# Look for common phishing misspellings | |
errors = re.findall(r'recieve|occured|seperate|definately|goverment|secruity|varify', text.lower()) | |
return len(errors) | |
def get_model_predictions(self, text): | |
"""Get predictions from multiple models""" | |
predictions = {} | |
for model_name in ['primary', 'url_specialist']: | |
if model_name not in models: | |
continue | |
try: | |
inputs = tokenizers[model_name]( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
padding=True | |
) | |
with torch.no_grad(): | |
outputs = models[model_name](**inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
predictions[model_name] = probs[0].tolist() | |
except Exception as e: | |
print(f"Error with model {model_name}: {e}") | |
predictions[model_name] = [0.5, 0.5, 0.0, 0.0] # Default neutral | |
return predictions | |
def ensemble_predict(self, text): | |
"""Advanced ensemble prediction with feature weighting""" | |
# Get model predictions | |
model_preds = self.get_model_predictions(text) | |
# Extract hand-crafted features | |
features = self.extract_features(text) | |
# Calculate feature-based risk score | |
risk_score = self.calculate_risk_score(features) | |
# Ensemble combination | |
if len(model_preds) == 0: | |
return self.fallback_prediction(features) | |
# Weight model predictions | |
weights = {'primary': 0.7, 'url_specialist': 0.3} | |
ensemble_probs = [0.0, 0.0, 0.0, 0.0] | |
total_weight = 0 | |
for model_name, probs in model_preds.items(): | |
weight = weights.get(model_name, 0.5) | |
total_weight += weight | |
for i in range(len(probs)): | |
ensemble_probs[i] += probs[i] * weight | |
# Normalize | |
if total_weight > 0: | |
ensemble_probs = [p / total_weight for p in ensemble_probs] | |
# Adjust with feature-based risk | |
ensemble_probs = self.adjust_with_features(ensemble_probs, risk_score) | |
return ensemble_probs, features, risk_score | |
def calculate_risk_score(self, features): | |
"""Calculate risk score from hand-crafted features""" | |
score = 0 | |
# URL-based risk | |
score += features['url_count'] * 0.1 | |
score += features['has_suspicious_domains'] * 0.3 | |
# Content-based risk | |
score += min(features['urgency_words'] * 0.15, 0.4) | |
score += min(features['money_mentions'] * 0.1, 0.3) | |
score += min(features['personal_info_requests'] * 0.2, 0.5) | |
score += min(features['spelling_errors'] * 0.1, 0.2) | |
score += min(features['excessive_caps'] * 0.05, 0.15) | |
# Generic patterns | |
score += features['generic_greetings'] * 0.1 | |
return min(score, 1.0) # Cap at 1.0 | |
def adjust_with_features(self, probs, risk_score): | |
"""Adjust model predictions with feature-based risk""" | |
adjusted = probs.copy() | |
# If high risk score, increase phishing probabilities | |
if risk_score > 0.5: | |
phishing_boost = risk_score * 0.3 | |
adjusted[1] += phishing_boost # Phishing URL | |
adjusted[3] += phishing_boost # Phishing Email | |
# Reduce legitimate probabilities | |
adjusted[0] = max(0, adjusted[0] - phishing_boost/2) | |
adjusted[2] = max(0, adjusted[2] - phishing_boost/2) | |
# Normalize to ensure sum = 1 | |
total = sum(adjusted) | |
if total > 0: | |
adjusted = [p / total for p in adjusted] | |
return adjusted | |
def fallback_prediction(self, features): | |
"""Fallback prediction when models fail""" | |
risk_score = self.calculate_risk_score(features) | |
if risk_score > 0.7: | |
return [0.1, 0.4, 0.1, 0.4], features, risk_score # High phishing | |
elif risk_score > 0.4: | |
return [0.3, 0.2, 0.3, 0.2], features, risk_score # Medium risk | |
else: | |
return [0.45, 0.05, 0.45, 0.05], features, risk_score # Low risk | |
# Initialize detector | |
detector = AdvancedPhishingDetector() | |
def advanced_predict_phishing(text): | |
"""Advanced phishing prediction with ensemble and feature analysis""" | |
if not text.strip(): | |
return "Please enter some text to analyze", {}, "" | |
try: | |
# Get ensemble prediction | |
probs, features, risk_score = detector.ensemble_predict(text) | |
# Create label mapping | |
labels = { | |
"Legitimate Email": probs[0], | |
"Phishing URL": probs[1], | |
"Legitimate URL": probs[2], | |
"Phishing Email": probs[3] | |
} | |
# Find primary classification | |
max_label = max(labels.items(), key=lambda x: x[1]) | |
prediction = max_label[0] | |
confidence = max_label[1] | |
# Enhanced risk assessment | |
if "Phishing" in prediction and confidence > 0.8: | |
risk_level = "π¨ HIGH RISK - Strong Phishing Indicators" | |
risk_color = "red" | |
elif "Phishing" in prediction or risk_score > 0.5: | |
risk_level = "β οΈ MEDIUM RISK - Suspicious Patterns Detected" | |
risk_color = "orange" | |
elif risk_score > 0.3: | |
risk_level = "β‘ LOW-MEDIUM RISK - Some Concerns" | |
risk_color = "yellow" | |
else: | |
risk_level = "β LOW RISK - Appears Legitimate" | |
risk_color = "green" | |
# Feature analysis summary | |
feature_alerts = [] | |
if features['has_suspicious_domains']: | |
feature_alerts.append("Suspicious domain detected") | |
if features['urgency_words'] > 2: | |
feature_alerts.append("High urgency language") | |
if features['personal_info_requests'] > 1: | |
feature_alerts.append("Requests personal information") | |
if features['spelling_errors'] > 0: | |
feature_alerts.append("Potential spelling errors") | |
# Format detailed result | |
result = f""" | |
### {risk_level} | |
**Primary Classification:** {prediction} | |
**Confidence:** {confidence:.1%} | |
**Feature Risk Score:** {risk_score:.2f}/1.00 | |
**Analysis Alerts:** | |
{chr(10).join(f"β’ {alert}" for alert in feature_alerts) if feature_alerts else "β’ No significant risk patterns detected"} | |
**Technical Details:** | |
β’ URLs found: {features['url_count']} | |
β’ Urgency indicators: {features['urgency_words']} | |
β’ Personal info requests: {features['personal_info_requests']} | |
""" | |
# Confidence breakdown for display (raw floats for gr.Label) | |
confidence_data = {label: prob for label, prob in labels.items()} | |
return result, confidence_data, risk_color | |
except Exception as e: | |
return f"Error during analysis: {str(e)}", {}, "orange" | |
# Enhanced Gradio Interface | |
with gr.Blocks( | |
theme=gr.themes.Soft(), | |
title="EmailGuard - Advanced Phishing Detection", | |
css=""" | |
.risk-high { color: #dc2626 !important; font-weight: bold; } | |
.risk-low { color: #16a34a !important; font-weight: bold; } | |
.main-container { max-width: 900px; margin: 0 auto; } | |
.feature-box { background: #f8f9fa; padding: 15px; border-radius: 8px; margin: 10px 0; } | |
""" | |
) as demo: | |
gr.Markdown(""" | |
# π‘οΈ EmailGuard2 - Advanced AI Phishing Detection | |
**Multi-Model Ensemble System with Feature Analysis** | |
β¨ **Enhanced Accuracy** β’ π **Deep Pattern Analysis** β’ π **Real-time Results** | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
input_text = gr.Textbox( | |
label="π§ Email Content, URL, or Suspicious Message", | |
placeholder="Paste your email content, suspicious URL, or any text message here for comprehensive analysis...", | |
lines=10, | |
max_lines=20 | |
) | |
with gr.Row(): | |
analyze_btn = gr.Button( | |
"π Advanced Analysis", | |
variant="primary", | |
size="lg" | |
) | |
clear_btn = gr.Button("ποΈ Clear", variant="secondary") | |
with gr.Column(scale=1): | |
result_output = gr.Markdown(label="π Analysis Results") | |
confidence_output = gr.Label( | |
label="π― Confidence Breakdown", | |
num_top_classes=4 | |
) | |
# Enhanced examples | |
gr.Markdown("### π Test These Examples:") | |
examples = [ | |
["URGENT: Your PayPal account has been limited! Verify immediately at http://paypal-security-check.suspicious.com/verify or lose access forever!"], | |
["Hi Mufasa, Thanks for sending the quarterly report. I've reviewed the numbers and they look good. Let's discuss in tomorrow's meeting. Best, Simba"], | |
["π CONGRATULATIONS, Chinno! You've won $50,000! Click here to claim: bit.ly/winner123. Act fast, expires in 24hrs! Reply with SSN to confirm."], | |
["Your Microsoft Office subscription expires tomorrow. Renew now to avoid service interruption. Visit: https://office.microsoft.com/renew"], | |
["Dear Valued Customer, We detected unusual activity on your account. Please verify your identity by clicking the link below and entering your password."], | |
["Meeting reminder: Team standup at 10 AM in conference room A, Y4C Hub. Please bring your project updates. Thanks!"] | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=input_text, | |
outputs=[result_output, confidence_output] | |
) | |
# Event handlers | |
analyze_btn.click( | |
fn=advanced_predict_phishing, | |
inputs=input_text, | |
outputs=[result_output, confidence_output, gr.State()] | |
) | |
clear_btn.click( | |
fn=lambda: ("", "", {}), | |
outputs=[input_text, result_output, confidence_output] | |
) | |
input_text.submit( | |
fn=advanced_predict_phishing, | |
inputs=input_text, | |
outputs=[result_output, confidence_output, gr.State()] | |
) | |
gr.Markdown(""" | |
--- | |
### π¬ Advanced Detection Features | |
**π€ Multi-Model Ensemble:** Combines predictions from specialized models | |
**π― Feature Engineering:** Hand-crafted rules for pattern detection | |
**βοΈ Bias Reduction:** Multiple validation layers prevent false positives | |
**π Risk Scoring:** Comprehensive analysis beyond simple classification | |
**π URL Analysis:** Specialized detection for malicious links | |
**π Content Analysis:** Deep text pattern recognition | |
### β‘ What Makes This More Accurate: | |
- **Ensemble Learning:** Multiple models vote on final decision | |
- **Feature Fusion:** AI + Rule-based detection combined | |
- **Adaptive Thresholds:** Dynamic risk assessment | |
- **Comprehensive Coverage:** Email, URL, and text message analysis | |
**β οΈ Academic Research Tool:** For educational purposes - always verify through official channels. | |
""") | |
if __name__ == "__main__": | |
demo.launch( | |
share=False, | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) |