MUFASA25's picture
Update app.py
d453c46 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging
import re
from flask import Flask, request, jsonify, render_template_string, redirect, url_for
from flask_login import LoginManager, UserMixin, login_required, login_user, logout_user, current_user
import sqlite3
from functools import wraps
from datetime import datetime
import bleach
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Flask app setup
app = Flask(__name__)
app.secret_key = 'secure_gov_key_2025' # Replace with a secure key in production
# Initialize Flask-Login
login_manager = LoginManager()
login_manager.init_app(app)
login_manager.login_view = "login"
# Model configuration
MODEL_NAME = "cybersectony/phishing-email-detection-distilbert_v2.4.1"
# Explanation of labels and their values
"""
Labels and Their Meanings:
- Legitimate: The email is safe and likely from a trusted source.
- Phishing: The email is a scam attempting to steal personal information.
- Suspicious: The email has questionable content and may be unsafe.
- Spam: The email is unwanted promotional or junk content.
Each label has a percentage (0-100%) showing the model's confidence.
Higher percentages indicate greater certainty.
"""
# Global variables for model and tokenizer
tokenizer = None
model = None
# User class for Flask-Login
class User(UserMixin):
def __init__(self, user_id, role):
self.id = user_id
self.role = role
# Database setup
def init_db():
with sqlite3.connect('phishguardian.db') as conn:
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS users
(id TEXT PRIMARY KEY, username TEXT, password TEXT, role TEXT)''')
c.execute('''CREATE TABLE IF NOT EXISTS analysis_logs
(id INTEGER PRIMARY KEY AUTOINCREMENT, user_id TEXT, email_text TEXT,
result TEXT, timestamp TEXT)''')
# Add default admin user (password: 'admin123' for demo, use hashed passwords in production)
c.execute("INSERT OR IGNORE INTO users (id, username, password, role) VALUES (?, ?, ?, ?)",
('admin1', 'admin', 'admin123', 'Admin'))
conn.commit()
# Load user for Flask-Login
@login_manager.user_loader
def load_user(user_id):
with sqlite3.connect('phishguardian.db') as conn:
c = conn.cursor()
c.execute("SELECT id, role FROM users WHERE id = ?", (user_id,))
user = c.fetchone()
if user:
return User(user[0], user[1])
return None
# RBAC decorator
def role_required(*roles):
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
if not current_user.is_authenticated:
return redirect(url_for('login'))
if current_user.role not in roles:
return render_template_string("<h1>403 Forbidden</h1><p>Unauthorized role.</p>")
return f(*args, **kwargs)
return decorated_function
return decorator
def load_model():
"""Load the model and tokenizer with basic error handling"""
global tokenizer, model
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
logger.info("Model loaded successfully!")
return True
except Exception as e:
logger.error(f"Error loading model: {e}")
return False
def is_valid_email_text(text):
"""Basic validation for email-like text"""
if not text or not text.strip():
return False, "Please enter some email text."
if len(text.strip()) < 10:
return False, "Text too short for analysis."
if len(text.split()) < 3 or not re.search(r"[a-zA-Z]{3,}", text):
return False, "Text appears incoherent or not email-like."
return True, ""
def get_colored_bar(percentage):
"""Create a colored bar based on percentage"""
if percentage >= 85:
color = "🟢"
elif percentage >= 50:
color = "🟡"
else:
color = "⚪"
bar_length = max(1, int(percentage / 5)) # Scale to 20 characters
return color * bar_length + "⚪" * (20 - bar_length)
def predict_email(email_text, user_id):
"""Prediction with actual labels and colored bars"""
# Sanitize input
email_text = bleach.clean(email_text, tags=[], strip=True)
# Input validation
valid, message = is_valid_email_text(email_text)
if not valid:
return f"⚠️ Error: {message}"
# Check if model is loaded
if tokenizer is None or model is None:
if not load_model():
return "❌ Error: Failed to load the model."
try:
# Tokenize input
inputs = tokenizer(
email_text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# Get prediction
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0].tolist()
# Define actual labels (handle 2 or 4 classes)
labels = (model.config.id2label if hasattr(model.config, 'id2label') and model.config.id2label
else {0: "Legitimate", 1: "Phishing", 2: "Suspicious", 3: "Spam"} if len(probs) == 4
else {0: "Legitimate", 1: "Phishing"})
# Map probabilities to labels
results = {labels[i]: probs[i] * 100 for i in range(len(probs))}
# Get top prediction
max_label, max_prob = max(results.items(), key=lambda x: x[1])
# Risk levels with 85% threshold
if "phishing" in max_label.lower() or "suspicious" in max_label.lower():
risk_level = "⚠️ Risky" if max_prob >= 85 else "⚡ Low Risk"
elif "spam" in max_label.lower():
risk_level = "🗑️ Spam" if max_prob >= 85 else "⚡ Low Risk"
else:
risk_level = "✅ Safe" if max_prob >= 85 else "❓ Uncertain"
# Format output
output = f"Result: {risk_level}\n"
output += f"Top Prediction: {max_label} ({max_prob:.1f}%)\n"
output += "Details:\n"
for label, prob in sorted(results.items(), key=lambda x: x[1], reverse=True):
output += f"{label}: {prob:.1f}% {get_colored_bar(prob)}\n"
# Simple recommendation
if "phishing" in max_label.lower() or "suspicious" in max_label.lower():
output += "Advice: Avoid clicking links or sharing info."
elif "spam" in max_label.lower():
output += "Advice: Mark as spam or delete."
else:
output += "Advice: Appears safe, but stay cautious."
# Log analysis
with sqlite3.connect('phishguardian.db') as conn:
c = conn.cursor()
c.execute("INSERT INTO analysis_logs (user_id, email_text, result, timestamp) VALUES (?, ?, ?, ?)",
(user_id, email_text[:1000], output, datetime.utcnow().isoformat()))
conn.commit()
return output
except Exception as e:
logger.error(f"Error during prediction: {e}")
return f"❌ Error: Analysis failed - {str(e)}"
# Flask routes
@app.route('/')
def index():
if not current_user.is_authenticated:
return redirect(url_for('login'))
return render_template_string("""
<!DOCTYPE html>
<html>
<head>
<title>PhishGuardian - MDA Email System</title>
<style>
body { font-family: Arial, sans-serif; margin: 20px; background-color: #f4f4f9; }
h1 { color: #333; }
.container { max-width: 800px; margin: auto; }
textarea { width: 100%; height: 200px; margin-bottom: 10px; }
button { padding: 10px 20px; margin-right: 10px; }
pre { background-color: #fff; padding: 15px; border: 1px solid #ddd; }
.error { color: red; }
</style>
</head>
<body>
<div class="container">
<h1>🛡️ PhishGuardian - MDA Email System</h1>
<p>Analyze emails for safety. Paste email text below.</p>
<p>Labels: Legitimate (safe), Phishing (scam), Suspicious (questionable), Spam (junk). Percentages show confidence (0-100%).</p>
{% if current_user.is_authenticated %}
<p>Logged in as: {{ current_user.id }} ({{ current_user.role }}) | <a href="{{ url_for('logout') }}">Logout</a></p>
{% if current_user.role in ['Admin', 'Analyst'] %}
<form method="POST" action="{{ url_for('analyze') }}">
<textarea name="email_text" placeholder="Paste email here..."></textarea>
<button type="submit">🔍 Check</button>
<button type="button" onclick="document.querySelector('textarea').value=''">🗑️ Clear</button>
</form>
{% if result %}
<h3>Results</h3>
<pre>{{ result }}</pre>
{% endif %}
{% endif %}
{% if current_user.role in ['Admin', 'Auditor'] %}
<p><a href="{{ url_for('view_logs') }}">View Analysis Logs</a></p>
{% endif %}
{% if current_user.role == 'Admin' %}
<p><a href="{{ url_for('manage_users') }}">Manage Users</a></p>
{% endif %}
{% endif %}
</div>
</body>
</html>
""", result=request.args.get('result', ''))
@app.route('/login', methods=['GET', 'POST'])
def login():
if request.method == 'POST':
username = bleach.clean(request.form['username'])
password = bleach.clean(request.form['password'])
with sqlite3.connect('phishguardian.db') as conn:
c = conn.cursor()
c.execute("SELECT id, role FROM users WHERE username = ? AND password = ?", (username, password))
user = c.fetchone()
if user:
login_user(User(user[0], user[1]))
return redirect(url_for('index'))
return render_template_string("<h1>Login Failed</h1><p>Invalid credentials.</p>")
return render_template_string("""
<!DOCTYPE html>
<html>
<head>
<title>Login - PhishGuardian</title>
<style>
body { font-family: Arial, sans-serif; margin: 20px; background-color: #f4f4f9; }
.container { max-width: 400px; margin: auto; }
input { width: 100%; padding: 10px; margin-bottom: 10px; }
button { padding: 10px 20px; }
</style>
</head>
<body>
<div class="container">
<h1>Login</h1>
<form method="POST">
<input type="text" name="username" placeholder="Username" required>
<input type="password" name="password" placeholder="Password" required>
<button type="submit">Login</button>
</form>
</div>
</body>
</html>
""")
@app.route('/logout')
@login_required
def logout():
logout_user()
return redirect(url_for('login'))
@app.route('/analyze', methods=['POST'])
@role_required('Admin', 'Analyst')
def analyze():
email_text = request.form['email_text']
result = predict_email(email_text, current_user.id)
return redirect(url_for('index', result=result))
@app.route('/logs')
@role_required('Admin', 'Auditor')
def view_logs():
with sqlite3.connect('phishguardian.db') as conn:
c = conn.cursor()
c.execute("SELECT user_id, email_text, result, timestamp FROM analysis_logs ORDER BY timestamp DESC")
logs = c.fetchall()
logs_html = "<h3>Analysis Logs</h3><ul>" + "".join(
f"<li><b>{log[3]}</b> | User: {log[0]} | Email: {log[1][:50]}... | Result: {log[2][:100]}...</li>"
for log in logs
) + "</ul>"
return render_template_string("""
<!DOCTYPE html>
<html>
<head>
<title>Logs - PhishGuardian</title>
<style>
body { font-family: Arial, sans-serif; margin: 20px; background-color: #f4f4f9; }
.container { max-width: 800px; margin: auto; }
ul { list-style-type: none; padding: 0; }
li { margin-bottom: 10px; }
</style>
</head>
<body>
<div class="container">
<h1>Analysis Logs</h1>
<p><a href="{{ url_for('index') }}">Back to Home</a></p>
{{ logs_html | safe }}
</div>
</body>
</html>
""", logs_html=logs_html)
@app.route('/users', methods=['GET', 'POST'])
@role_required('Admin')
def manage_users():
if request.method == 'POST':
username = bleach.clean(request.form['username'])
password = bleach.clean(request.form['password'])
role = bleach.clean(request.form['role'])
user_id = f"user_{datetime.utcnow().timestamp()}"
with sqlite3.connect('phishguardian.db') as conn:
c = conn.cursor()
c.execute("INSERT INTO users (id, username, password, role) VALUES (?, ?, ?, ?)",
(user_id, username, password, role))
conn.commit()
return redirect(url_for('manage_users'))
with sqlite3.connect('phishguardian.db') as conn:
c = conn.cursor()
c.execute("SELECT id, username, role FROM users")
users = c.fetchall()
users_html = "<h3>Users</h3><ul>" + "".join(
f"<li>ID: {user[0]} | Username: {user[1]} | Role: {user[2]}</li>" for user in users
) + "</ul>"
return render_template_string("""
<!DOCTYPE html>
<html>
<head>
<title>Manage Users - PhishGuardian</title>
<style>
body { font-family: Arial, sans-serif; margin: 20px; background-color: #f4f4f9; }
.container { max-width: 800px; margin: auto; }
input, select { width: 100%; padding: 10px; margin-bottom: 10px; }
button { padding: 10px 20px; }
ul { list-style-type: none; padding: 0; }
</style>
</head>
<body>
<div class="container">
<h1>Manage Users</h1>
<p><a href="{{ url_for('index') }}">Back to Home</a></p>
<form method="POST">
<input type="text" name="username" placeholder="Username" required>
<input type="password" name="password" placeholder="Password" required>
<select name="role">
<option value="Admin">Admin</option>
<option value="Analyst">Analyst</option>
<option value="Auditor">Auditor</option>
</select>
<button type="submit">Add User</button>
</form>
{{ users_html | safe }}
</div>
</body>
</html>
""", users_html=users_html)
# Example emails
example_legitimate = """Dear Customer,
Thank you for your purchase from TechStore. Your order #ORD-2024-001234 is processed.
Order Details:
- Product: Wireless Headphones
- Amount: $79.99
- Delivery: 3-5 days
Best regards,
TechStore"""
example_phishing = """URGENT!!!
Your account is COMPROMISED! Click here to secure: http://fake-site.com/verify
Act NOW or your account will be suspended!
Security Team"""
example_neutral = """Hi team,
Reminder: meeting tomorrow at 2 PM. Bring project updates.
Thanks,
Sarah"""
# Initialize database and load model
init_db()
load_model()
if __name__ == "__main__":
app.run(ssl_context='adhoc', host='0.0.0.0', port=5000)