Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import logging | |
import re | |
from flask import Flask, request, jsonify, render_template_string, redirect, url_for | |
from flask_login import LoginManager, UserMixin, login_required, login_user, logout_user, current_user | |
import sqlite3 | |
from functools import wraps | |
from datetime import datetime | |
import bleach | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Flask app setup | |
app = Flask(__name__) | |
app.secret_key = 'secure_gov_key_2025' # Replace with a secure key in production | |
# Initialize Flask-Login | |
login_manager = LoginManager() | |
login_manager.init_app(app) | |
login_manager.login_view = "login" | |
# Model configuration | |
MODEL_NAME = "cybersectony/phishing-email-detection-distilbert_v2.4.1" | |
# Explanation of labels and their values | |
""" | |
Labels and Their Meanings: | |
- Legitimate: The email is safe and likely from a trusted source. | |
- Phishing: The email is a scam attempting to steal personal information. | |
- Suspicious: The email has questionable content and may be unsafe. | |
- Spam: The email is unwanted promotional or junk content. | |
Each label has a percentage (0-100%) showing the model's confidence. | |
Higher percentages indicate greater certainty. | |
""" | |
# Global variables for model and tokenizer | |
tokenizer = None | |
model = None | |
# User class for Flask-Login | |
class User(UserMixin): | |
def __init__(self, user_id, role): | |
self.id = user_id | |
self.role = role | |
# Database setup | |
def init_db(): | |
with sqlite3.connect('phishguardian.db') as conn: | |
c = conn.cursor() | |
c.execute('''CREATE TABLE IF NOT EXISTS users | |
(id TEXT PRIMARY KEY, username TEXT, password TEXT, role TEXT)''') | |
c.execute('''CREATE TABLE IF NOT EXISTS analysis_logs | |
(id INTEGER PRIMARY KEY AUTOINCREMENT, user_id TEXT, email_text TEXT, | |
result TEXT, timestamp TEXT)''') | |
# Add default admin user (password: 'admin123' for demo, use hashed passwords in production) | |
c.execute("INSERT OR IGNORE INTO users (id, username, password, role) VALUES (?, ?, ?, ?)", | |
('admin1', 'admin', 'admin123', 'Admin')) | |
conn.commit() | |
# Load user for Flask-Login | |
def load_user(user_id): | |
with sqlite3.connect('phishguardian.db') as conn: | |
c = conn.cursor() | |
c.execute("SELECT id, role FROM users WHERE id = ?", (user_id,)) | |
user = c.fetchone() | |
if user: | |
return User(user[0], user[1]) | |
return None | |
# RBAC decorator | |
def role_required(*roles): | |
def decorator(f): | |
def decorated_function(*args, **kwargs): | |
if not current_user.is_authenticated: | |
return redirect(url_for('login')) | |
if current_user.role not in roles: | |
return render_template_string("<h1>403 Forbidden</h1><p>Unauthorized role.</p>") | |
return f(*args, **kwargs) | |
return decorated_function | |
return decorator | |
def load_model(): | |
"""Load the model and tokenizer with basic error handling""" | |
global tokenizer, model | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
logger.info("Model loaded successfully!") | |
return True | |
except Exception as e: | |
logger.error(f"Error loading model: {e}") | |
return False | |
def is_valid_email_text(text): | |
"""Basic validation for email-like text""" | |
if not text or not text.strip(): | |
return False, "Please enter some email text." | |
if len(text.strip()) < 10: | |
return False, "Text too short for analysis." | |
if len(text.split()) < 3 or not re.search(r"[a-zA-Z]{3,}", text): | |
return False, "Text appears incoherent or not email-like." | |
return True, "" | |
def get_colored_bar(percentage): | |
"""Create a colored bar based on percentage""" | |
if percentage >= 85: | |
color = "🟢" | |
elif percentage >= 50: | |
color = "🟡" | |
else: | |
color = "⚪" | |
bar_length = max(1, int(percentage / 5)) # Scale to 20 characters | |
return color * bar_length + "⚪" * (20 - bar_length) | |
def predict_email(email_text, user_id): | |
"""Prediction with actual labels and colored bars""" | |
# Sanitize input | |
email_text = bleach.clean(email_text, tags=[], strip=True) | |
# Input validation | |
valid, message = is_valid_email_text(email_text) | |
if not valid: | |
return f"⚠️ Error: {message}" | |
# Check if model is loaded | |
if tokenizer is None or model is None: | |
if not load_model(): | |
return "❌ Error: Failed to load the model." | |
try: | |
# Tokenize input | |
inputs = tokenizer( | |
email_text, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
padding=True | |
) | |
# Get prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0].tolist() | |
# Define actual labels (handle 2 or 4 classes) | |
labels = (model.config.id2label if hasattr(model.config, 'id2label') and model.config.id2label | |
else {0: "Legitimate", 1: "Phishing", 2: "Suspicious", 3: "Spam"} if len(probs) == 4 | |
else {0: "Legitimate", 1: "Phishing"}) | |
# Map probabilities to labels | |
results = {labels[i]: probs[i] * 100 for i in range(len(probs))} | |
# Get top prediction | |
max_label, max_prob = max(results.items(), key=lambda x: x[1]) | |
# Risk levels with 85% threshold | |
if "phishing" in max_label.lower() or "suspicious" in max_label.lower(): | |
risk_level = "⚠️ Risky" if max_prob >= 85 else "⚡ Low Risk" | |
elif "spam" in max_label.lower(): | |
risk_level = "🗑️ Spam" if max_prob >= 85 else "⚡ Low Risk" | |
else: | |
risk_level = "✅ Safe" if max_prob >= 85 else "❓ Uncertain" | |
# Format output | |
output = f"Result: {risk_level}\n" | |
output += f"Top Prediction: {max_label} ({max_prob:.1f}%)\n" | |
output += "Details:\n" | |
for label, prob in sorted(results.items(), key=lambda x: x[1], reverse=True): | |
output += f"{label}: {prob:.1f}% {get_colored_bar(prob)}\n" | |
# Simple recommendation | |
if "phishing" in max_label.lower() or "suspicious" in max_label.lower(): | |
output += "Advice: Avoid clicking links or sharing info." | |
elif "spam" in max_label.lower(): | |
output += "Advice: Mark as spam or delete." | |
else: | |
output += "Advice: Appears safe, but stay cautious." | |
# Log analysis | |
with sqlite3.connect('phishguardian.db') as conn: | |
c = conn.cursor() | |
c.execute("INSERT INTO analysis_logs (user_id, email_text, result, timestamp) VALUES (?, ?, ?, ?)", | |
(user_id, email_text[:1000], output, datetime.utcnow().isoformat())) | |
conn.commit() | |
return output | |
except Exception as e: | |
logger.error(f"Error during prediction: {e}") | |
return f"❌ Error: Analysis failed - {str(e)}" | |
# Flask routes | |
def index(): | |
if not current_user.is_authenticated: | |
return redirect(url_for('login')) | |
return render_template_string(""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>PhishGuardian - MDA Email System</title> | |
<style> | |
body { font-family: Arial, sans-serif; margin: 20px; background-color: #f4f4f9; } | |
h1 { color: #333; } | |
.container { max-width: 800px; margin: auto; } | |
textarea { width: 100%; height: 200px; margin-bottom: 10px; } | |
button { padding: 10px 20px; margin-right: 10px; } | |
pre { background-color: #fff; padding: 15px; border: 1px solid #ddd; } | |
.error { color: red; } | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>🛡️ PhishGuardian - MDA Email System</h1> | |
<p>Analyze emails for safety. Paste email text below.</p> | |
<p>Labels: Legitimate (safe), Phishing (scam), Suspicious (questionable), Spam (junk). Percentages show confidence (0-100%).</p> | |
{% if current_user.is_authenticated %} | |
<p>Logged in as: {{ current_user.id }} ({{ current_user.role }}) | <a href="{{ url_for('logout') }}">Logout</a></p> | |
{% if current_user.role in ['Admin', 'Analyst'] %} | |
<form method="POST" action="{{ url_for('analyze') }}"> | |
<textarea name="email_text" placeholder="Paste email here..."></textarea> | |
<button type="submit">🔍 Check</button> | |
<button type="button" onclick="document.querySelector('textarea').value=''">🗑️ Clear</button> | |
</form> | |
{% if result %} | |
<h3>Results</h3> | |
<pre>{{ result }}</pre> | |
{% endif %} | |
{% endif %} | |
{% if current_user.role in ['Admin', 'Auditor'] %} | |
<p><a href="{{ url_for('view_logs') }}">View Analysis Logs</a></p> | |
{% endif %} | |
{% if current_user.role == 'Admin' %} | |
<p><a href="{{ url_for('manage_users') }}">Manage Users</a></p> | |
{% endif %} | |
{% endif %} | |
</div> | |
</body> | |
</html> | |
""", result=request.args.get('result', '')) | |
def login(): | |
if request.method == 'POST': | |
username = bleach.clean(request.form['username']) | |
password = bleach.clean(request.form['password']) | |
with sqlite3.connect('phishguardian.db') as conn: | |
c = conn.cursor() | |
c.execute("SELECT id, role FROM users WHERE username = ? AND password = ?", (username, password)) | |
user = c.fetchone() | |
if user: | |
login_user(User(user[0], user[1])) | |
return redirect(url_for('index')) | |
return render_template_string("<h1>Login Failed</h1><p>Invalid credentials.</p>") | |
return render_template_string(""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Login - PhishGuardian</title> | |
<style> | |
body { font-family: Arial, sans-serif; margin: 20px; background-color: #f4f4f9; } | |
.container { max-width: 400px; margin: auto; } | |
input { width: 100%; padding: 10px; margin-bottom: 10px; } | |
button { padding: 10px 20px; } | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>Login</h1> | |
<form method="POST"> | |
<input type="text" name="username" placeholder="Username" required> | |
<input type="password" name="password" placeholder="Password" required> | |
<button type="submit">Login</button> | |
</form> | |
</div> | |
</body> | |
</html> | |
""") | |
def logout(): | |
logout_user() | |
return redirect(url_for('login')) | |
def analyze(): | |
email_text = request.form['email_text'] | |
result = predict_email(email_text, current_user.id) | |
return redirect(url_for('index', result=result)) | |
def view_logs(): | |
with sqlite3.connect('phishguardian.db') as conn: | |
c = conn.cursor() | |
c.execute("SELECT user_id, email_text, result, timestamp FROM analysis_logs ORDER BY timestamp DESC") | |
logs = c.fetchall() | |
logs_html = "<h3>Analysis Logs</h3><ul>" + "".join( | |
f"<li><b>{log[3]}</b> | User: {log[0]} | Email: {log[1][:50]}... | Result: {log[2][:100]}...</li>" | |
for log in logs | |
) + "</ul>" | |
return render_template_string(""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Logs - PhishGuardian</title> | |
<style> | |
body { font-family: Arial, sans-serif; margin: 20px; background-color: #f4f4f9; } | |
.container { max-width: 800px; margin: auto; } | |
ul { list-style-type: none; padding: 0; } | |
li { margin-bottom: 10px; } | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>Analysis Logs</h1> | |
<p><a href="{{ url_for('index') }}">Back to Home</a></p> | |
{{ logs_html | safe }} | |
</div> | |
</body> | |
</html> | |
""", logs_html=logs_html) | |
def manage_users(): | |
if request.method == 'POST': | |
username = bleach.clean(request.form['username']) | |
password = bleach.clean(request.form['password']) | |
role = bleach.clean(request.form['role']) | |
user_id = f"user_{datetime.utcnow().timestamp()}" | |
with sqlite3.connect('phishguardian.db') as conn: | |
c = conn.cursor() | |
c.execute("INSERT INTO users (id, username, password, role) VALUES (?, ?, ?, ?)", | |
(user_id, username, password, role)) | |
conn.commit() | |
return redirect(url_for('manage_users')) | |
with sqlite3.connect('phishguardian.db') as conn: | |
c = conn.cursor() | |
c.execute("SELECT id, username, role FROM users") | |
users = c.fetchall() | |
users_html = "<h3>Users</h3><ul>" + "".join( | |
f"<li>ID: {user[0]} | Username: {user[1]} | Role: {user[2]}</li>" for user in users | |
) + "</ul>" | |
return render_template_string(""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Manage Users - PhishGuardian</title> | |
<style> | |
body { font-family: Arial, sans-serif; margin: 20px; background-color: #f4f4f9; } | |
.container { max-width: 800px; margin: auto; } | |
input, select { width: 100%; padding: 10px; margin-bottom: 10px; } | |
button { padding: 10px 20px; } | |
ul { list-style-type: none; padding: 0; } | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>Manage Users</h1> | |
<p><a href="{{ url_for('index') }}">Back to Home</a></p> | |
<form method="POST"> | |
<input type="text" name="username" placeholder="Username" required> | |
<input type="password" name="password" placeholder="Password" required> | |
<select name="role"> | |
<option value="Admin">Admin</option> | |
<option value="Analyst">Analyst</option> | |
<option value="Auditor">Auditor</option> | |
</select> | |
<button type="submit">Add User</button> | |
</form> | |
{{ users_html | safe }} | |
</div> | |
</body> | |
</html> | |
""", users_html=users_html) | |
# Example emails | |
example_legitimate = """Dear Customer, | |
Thank you for your purchase from TechStore. Your order #ORD-2024-001234 is processed. | |
Order Details: | |
- Product: Wireless Headphones | |
- Amount: $79.99 | |
- Delivery: 3-5 days | |
Best regards, | |
TechStore""" | |
example_phishing = """URGENT!!! | |
Your account is COMPROMISED! Click here to secure: http://fake-site.com/verify | |
Act NOW or your account will be suspended! | |
Security Team""" | |
example_neutral = """Hi team, | |
Reminder: meeting tomorrow at 2 PM. Bring project updates. | |
Thanks, | |
Sarah""" | |
# Initialize database and load model | |
init_db() | |
load_model() | |
if __name__ == "__main__": | |
app.run(ssl_context='adhoc', host='0.0.0.0', port=5000) |