Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from huggingface_hub import hf_hub_download | |
import gradio as gr | |
import requests | |
import re | |
from urllib.parse import urlparse | |
from bs4 import BeautifulSoup | |
import time | |
import joblib | |
# --- import your architecture --- | |
# Make sure this file is in the repo (e.g., models/deberta_lstm_classifier.py) | |
# and update the import path accordingly. | |
from model import DeBERTaLSTMClassifier # <-- your class | |
# --------- Config ---------- | |
REPO_ID = "khoa-done/phishing-detector" # HF repo that holds the checkpoint | |
CKPT_NAME = "deberta_lstm_checkpoint.pt" # the .pt file name | |
MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone | |
LABELS = ["benign", "phishing"] # adjust to your classes | |
# If your checkpoint contains hyperparams, you can fetch them like: | |
# checkpoint.get("config") or checkpoint.get("model_args") | |
# and pass into DeBERTaLSTMClassifier(**model_args) | |
# --------- Load model/tokenizer once (global) ---------- | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME) | |
checkpoint = torch.load(ckpt_path, map_location=device) | |
# If you saved hyperparams in the checkpoint, use them: | |
model_args = checkpoint.get("model_args", {}) # e.g., {"lstm_hidden":256, "num_labels":2, ...} | |
model = DeBERTaLSTMClassifier(**model_args) | |
# Load state dict and handle missing attention layer for older models | |
try: | |
model.load_state_dict(checkpoint["model_state_dict"]) | |
except RuntimeError as e: | |
if "attention" in str(e): | |
# Old model without attention layer - initialize attention layer and load partial state | |
state_dict = checkpoint["model_state_dict"] | |
model_dict = model.state_dict() | |
# Filter out attention layer parameters | |
filtered_dict = {k: v for k, v in state_dict.items() if "attention" not in k} | |
model_dict.update(filtered_dict) | |
model.load_state_dict(model_dict) | |
print("Loaded model without attention layer, using newly initialized attention weights") | |
else: | |
raise e | |
model.to(device).eval() | |
# --------- Load BERT model/tokenizer from Hugging Face Hub ---------- | |
BERT_MODEL_PATH = "th1enq/bert_checkpoint" # Use Hugging Face Hub model | |
bert_tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_PATH) | |
bert_model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL_PATH) | |
bert_model.to(device).eval() | |
# --------- Helper functions ---------- | |
def is_url(text): | |
"""Check if text is a URL""" | |
url_pattern = re.compile( | |
r'^https?://' # http:// or https:// | |
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain... | |
r'localhost|' # localhost... | |
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip | |
r'(?::\d+)?' # optional port | |
r'(?:/?|[/?]\S+)$', re.IGNORECASE) | |
return url_pattern.match(text) is not None | |
def fetch_html_content(url, timeout=10): | |
"""Fetch HTML content from URL""" | |
try: | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
} | |
response = requests.get(url, headers=headers, timeout=timeout, verify=False) | |
response.raise_for_status() | |
return response.text, response.status_code | |
except requests.exceptions.RequestException as e: | |
return None, f"Request error: {str(e)}" | |
except Exception as e: | |
return None, f"General error: {str(e)}" | |
def predict_single_text(text, text_type="text"): | |
"""Predict for a single text input""" | |
# Tokenize | |
inputs = tokenizer( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
padding=True, | |
max_length=256 | |
) | |
# DeBERTa typically doesn't use token_type_ids | |
inputs.pop("token_type_ids", None) | |
# Move to device | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
try: | |
# Try to get predictions with attention weights | |
result = model(**inputs, return_attention=True) | |
if isinstance(result, tuple) and len(result) == 3: | |
logits, attention_weights, deberta_attentions = result | |
has_attention = True | |
else: | |
logits = result | |
has_attention = False | |
except TypeError: | |
# Fallback for older model without return_attention parameter | |
logits = model(**inputs) | |
has_attention = False | |
probs = F.softmax(logits, dim=-1).squeeze(0).tolist() | |
# Get tokens for visualization | |
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist()) | |
return probs, tokens, has_attention, attention_weights if has_attention else None | |
def combine_predictions(url_probs, html_probs, url_weight=0.3, html_weight=0.7): | |
"""Combine URL and HTML content predictions""" | |
combined_probs = [ | |
url_weight * url_probs[0] + html_weight * html_probs[0], # benign | |
url_weight * url_probs[1] + html_weight * html_probs[1] # phishing | |
] | |
return combined_probs | |
# --------- Inference function ---------- | |
def predict_fn(text: str): | |
if not text or not text.strip(): | |
return {"error": "Please enter a URL or text."}, "" | |
# Check if input is URL | |
if is_url(text.strip()): | |
# Process URL | |
url = text.strip() | |
# Get prediction for URL itself | |
url_probs, url_tokens, url_has_attention, url_attention = predict_single_text(url, "URL") | |
# Try to fetch HTML content | |
html_content, status = fetch_html_content(url) | |
if html_content: | |
# Get prediction for HTML content | |
html_probs, html_tokens, html_has_attention, html_attention = predict_single_text(html_content, "HTML") | |
# Combine predictions | |
combined_probs = combine_predictions(url_probs, html_probs) | |
# Use combined probabilities but show analysis for both | |
probs = combined_probs | |
tokens = url_tokens + ["[SEP]"] + html_tokens[:50] # Limit HTML tokens for display | |
has_attention = url_has_attention or html_has_attention | |
attention_weights = url_attention if url_has_attention else html_attention | |
analysis_type = "Combined URL + HTML Analysis" | |
fetch_status = f"✅ Successfully fetched HTML content (Status: {status})" | |
else: | |
# Fallback to URL-only analysis | |
probs = url_probs | |
tokens = url_tokens | |
has_attention = url_has_attention | |
attention_weights = url_attention | |
analysis_type = "URL-only Analysis" | |
fetch_status = f"⚠️ Could not fetch HTML content: {status}" | |
else: | |
# Process as regular text | |
probs, tokens, has_attention, attention_weights = predict_single_text(text, "text") | |
analysis_type = "Text Analysis" | |
fetch_status = "" | |
# Get tokens for visualization | |
# Create detailed analysis | |
predicted_class = "phishing" if probs[1] > probs[0] else "benign" | |
confidence = max(probs) | |
detailed_analysis = f""" | |
<div style="font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; background: #1e1e1e; padding: 20px; border-radius: 15px;"> | |
<div style="background: linear-gradient(135deg, {'#8b0000' if predicted_class == 'phishing' else '#006400'} 0%, {'#dc143c' if predicted_class == 'phishing' else '#228b22'} 100%); padding: 25px; border-radius: 20px; color: white; text-align: center; margin-bottom: 20px; box-shadow: 0 8px 32px rgba(0,0,0,0.5); border: 2px solid {'#ff4444' if predicted_class == 'phishing' else '#44ff44'};"> | |
<h2 style="margin: 0 0 10px 0; font-size: 28px; color: white;">🔍 {analysis_type}</h2> | |
<div style="font-size: 36px; font-weight: bold; margin: 10px 0; color: white;"> | |
{predicted_class.upper()} | |
</div> | |
<div style="font-size: 18px; color: #f0f0f0;"> | |
Confidence: {confidence:.1%} | |
</div> | |
<div style="margin-top: 15px; font-size: 14px; color: #e0e0e0;"> | |
{'This appears to be a phishing attempt!' if predicted_class == 'phishing' else '✅ This appears to be legitimate content.'} | |
</div> | |
</div> | |
""" | |
if fetch_status: | |
detailed_analysis += f""" | |
<div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;"> | |
<strong>Fetch Status:</strong> {fetch_status} | |
</div> | |
""" | |
if has_attention and attention_weights is not None: | |
attention_scores = attention_weights.squeeze(0).tolist() | |
token_analysis = [] | |
for i, (token, score) in enumerate(zip(tokens, attention_scores)): | |
# More lenient filtering - include more tokens for text analysis | |
if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>'] and len(token.strip()) > 0 and score > 0.005: | |
clean_token = token.replace('▁', '').replace('Ġ', '').strip() # Handle different tokenizer prefixes | |
if clean_token: # Only add if token has content after cleaning | |
token_analysis.append({ | |
'token': clean_token, | |
'importance': score, | |
'position': i | |
}) | |
# Sort by importance | |
token_analysis.sort(key=lambda x: x['importance'], reverse=True) | |
detailed_analysis += f""" | |
## Top important tokens: | |
<div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;"> | |
<strong>Analysis Info:</strong> Found {len(token_analysis)} important tokens out of {len(tokens)} total tokens | |
</div> | |
<div style="font-family: Arial, sans-serif;"> | |
""" | |
for i, token_info in enumerate(token_analysis[:10]): # Top 10 tokens | |
bar_width = int(token_info['importance'] * 100) | |
color = "#ff4444" if predicted_class == "phishing" else "#44ff44" | |
detailed_analysis += f""" | |
<div style="margin: 8px 0; display: flex; align-items: center; background: #2d2d2d; padding: 8px; border-radius: 8px; border-left: 4px solid {color};"> | |
<div style="width: 30px; text-align: right; margin-right: 10px; font-weight: bold; color: #ffffff;"> | |
{i+1}. | |
</div> | |
<div style="width: 120px; margin-right: 10px; font-weight: bold; color: #e0e0e0; text-align: right;"> | |
{token_info['token']} | |
</div> | |
<div style="width: 300px; background-color: #404040; border-radius: 10px; overflow: hidden; margin-right: 10px; border: 1px solid #555;"> | |
<div style="width: {bar_width}%; background-color: {color}; height: 20px; border-radius: 10px; transition: width 0.3s ease;"></div> | |
</div> | |
<div style="color: #cccccc; font-size: 12px; font-weight: bold;"> | |
{token_info['importance']:.1%} | |
</div> | |
</div> | |
""" | |
detailed_analysis += "</div>\n" | |
detailed_analysis += f""" | |
## Detailed analysis: | |
<div style="font-family: Arial, sans-serif; background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;"> | |
<h3 style="margin: 0 0 15px 0; color: white;">Statistical Overview</h3> | |
<div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;"> | |
<div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);"> | |
<div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div> | |
<div style="font-size: 14px; color: #e0e0e0;">Total tokens</div> | |
</div> | |
<div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);"> | |
<div style="font-size: 24px; font-weight: bold, color: white;">{len([t for t in token_analysis if t['importance'] > 0.05])}</div> | |
<div style="font-size: 14px, color: #e0e0e0;">High impact tokens (>5%)</div> | |
</div> | |
</div> | |
</div> | |
<div style="font-family: Arial, sans-serif; margin: 15px 0; background: #2d2d2d; padding: 20px; border-radius: 15px; border: 1px solid #555;"> | |
<h3 style="color: #ffffff; margin-bottom: 15px;"> Prediction Confidence</h3> | |
<div style="display: flex; justify-content: space-between; margin-bottom: 10px;"> | |
<span style="font-weight: bold; color: #ff4444;">Phishing</span> | |
<span style="font-weight: bold; color: #44ff44;">Benign</span> | |
</div> | |
<div style="width: 100%; background-color: #404040; border-radius: 25px; overflow: hidden; height: 30px; border: 1px solid #666;"> | |
<div style="width: {probs[1]*100:.1f}%; background: linear-gradient(90deg, #ff4444 0%, #ff6666 100%); height: 100%; display: flex; align-items: center; justify-content: center; color: white; font-weight: bold; font-size: 14px;"> | |
{probs[1]:.1%} | |
</div> | |
</div> | |
<div style="margin-top: 10px; text-align: center; color: #cccccc; font-size: 14px;"> | |
Benign: {probs[0]:.1%} | |
</div> | |
</div> | |
""" | |
else: | |
# Fallback analysis without attention weights | |
detailed_analysis += f""" | |
<div style="background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;"> | |
<h3 style="margin: 0 0 15px 0; color: white;">Basic Analysis</h3> | |
<div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 15px;"> | |
<div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);"> | |
<div style="font-size: 24px; font-weight: bold; color: white;">{probs[1]:.1%}</div> | |
<div style="font-size: 14px; color: #e0e0e0;">Phishing</div> | |
</div> | |
<div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);"> | |
<div style="font-size: 24px; font-weight: bold; color: white;">{probs[0]:.1%}</div> | |
<div style="font-size: 14px; color: #e0e0e0;">Benign</div> | |
</div> | |
<div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);"> | |
<div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div> | |
<div style="font-size: 14px; color: #e0e0e0;">Tokens</div> | |
</div> | |
</div> | |
</div> | |
<div style="background: #2d2d2d; padding: 20px; border-radius: 15px; margin: 15px 0; border: 1px solid #555;"> | |
<h3 style="color: #ffffff; margin: 0 0 15px 0;">🔤 Tokens in text:</h3> | |
<div style="display: flex; flex-wrap: wrap; gap: 8px;">""" + ''.join([f'<span style="background: #404040; color: #64b5f6; padding: 4px 8px; border-radius: 15px; font-size: 12px; border: 1px solid #666;">{token.replace("▁", "")}</span>' for token in tokens if token not in ['[CLS]', '[SEP]', '[PAD]']]) + f"""</div> | |
<div style="margin-top: 15px; padding: 10px; background: #3d2914; border-radius: 8px; border-left: 4px solid #ff9800;"> | |
<strong style="color: #ffcc02;">Debug info:</strong> <span style="color: #e0e0e0;">Found {len(tokens)} total tokens, {len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])} content tokens</span> | |
</div> | |
</div> | |
<div style="background: #3d2914; padding: 15px; border-radius: 10px; border-left: 4px solid #ff9800; margin: 15px 0;"> | |
<p style="margin: 0; color: #ffcc02; font-size: 14px;"> | |
<strong>Note:</strong> Detailed attention weights analysis is not available for the current model. | |
</p> | |
</div> | |
""" | |
# Build label->prob mapping for Gradio Label output | |
if len(LABELS) == len(probs): | |
prediction_result = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))} | |
else: | |
prediction_result = {f"class_{i}": float(p) for i, p in enumerate(probs)} | |
return prediction_result, detailed_analysis | |
# --------- BERT Model Functions ---------- | |
def predict_bert_single_text(text, text_type="text"): | |
"""Predict for a single text input using BERT.""" | |
# Tokenize | |
inputs = bert_tokenizer( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
padding=True, | |
max_length=512 | |
) | |
# Move to device | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = bert_model(**inputs, output_attentions=True) | |
logits = outputs.logits | |
probs = F.softmax(logits, dim=-1).squeeze(0).tolist() | |
# Get tokens for visualization | |
tokens = bert_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist()) | |
# Get attention weights (use last layer, first head as approximation) | |
attention_weights = None | |
has_attention = False | |
if hasattr(outputs, 'attentions') and outputs.attentions is not None: | |
# Average attention across all heads in the last layer | |
last_layer_attention = outputs.attentions[-1] # Last layer | |
attention_weights = last_layer_attention.mean(dim=1).squeeze(0) # Average across heads | |
# Use attention to [CLS] token as importance scores | |
attention_weights = attention_weights[0] # [CLS] token attention to all tokens | |
has_attention = True | |
return probs, tokens, has_attention, attention_weights | |
def predict_bert_interface_fn(text: str): | |
"""Gradio interface function for BERT model.""" | |
if not text or not text.strip(): | |
return {"error": "Please enter a URL or text."}, "" | |
# Check if input is URL | |
if is_url(text.strip()): | |
# Process URL | |
url = text.strip() | |
# Get prediction for URL itself | |
url_probs, url_tokens, url_has_attention, url_attention = predict_bert_single_text(url, "URL") | |
# Try to fetch HTML content | |
html_content, status = fetch_html_content(url) | |
if html_content: | |
# Get prediction for HTML content | |
html_probs, html_tokens, html_has_attention, html_attention = predict_bert_single_text(html_content, "HTML") | |
# Combine predictions | |
combined_probs = combine_predictions(url_probs, html_probs) | |
# Use combined probabilities but show analysis for both | |
probs = combined_probs | |
tokens = url_tokens + ["[SEP]"] + html_tokens[:50] # Limit HTML tokens for display | |
has_attention = url_has_attention or html_has_attention | |
attention_weights = url_attention if url_has_attention else html_attention | |
analysis_type = "Combined URL + HTML BERT Analysis" | |
fetch_status = f"✅ Successfully fetched HTML content (Status: {status})" | |
else: | |
# Fallback to URL-only analysis | |
probs = url_probs | |
tokens = url_tokens | |
has_attention = url_has_attention | |
attention_weights = url_attention | |
analysis_type = "URL-only BERT Analysis" | |
fetch_status = f"⚠️ Could not fetch HTML content: {status}" | |
else: | |
# Process as regular text | |
probs, tokens, has_attention, attention_weights = predict_bert_single_text(text, "text") | |
analysis_type = "BERT Text Analysis" | |
fetch_status = "" | |
# Create detailed analysis | |
predicted_class = "phishing" if probs[1] > probs[0] else "benign" | |
confidence = max(probs) | |
detailed_analysis = f""" | |
<div style="font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; background: #1e1e1e; padding: 20px; border-radius: 15px;"> | |
<div style="background: linear-gradient(135deg, {'#8b0000' if predicted_class == 'phishing' else '#006400'} 0%, {'#dc143c' if predicted_class == 'phishing' else '#228b22'} 100%); padding: 25px; border-radius: 20px; color: white; text-align: center; margin-bottom: 20px; box-shadow: 0 8px 32px rgba(0,0,0,0.5); border: 2px solid {'#ff4444' if predicted_class == 'phishing' else '#44ff44'};"> | |
<h2 style="margin: 0 0 10px 0; font-size: 28px; color: white;">🔍 {analysis_type}</h2> | |
<div style="font-size: 36px; font-weight: bold; margin: 10px 0; color: white;"> | |
{predicted_class.upper()} | |
</div> | |
<div style="font-size: 18px; color: #f0f0f0;"> | |
Confidence: {confidence:.1%} | |
</div> | |
<div style="margin-top: 15px; font-size: 14px; color: #e0e0e0;"> | |
{'This appears to be a phishing attempt!' if predicted_class == 'phishing' else '✅ This appears to be legitimate content.'} | |
</div> | |
</div> | |
""" | |
if fetch_status: | |
detailed_analysis += f""" | |
<div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;"> | |
<strong>Fetch Status:</strong> {fetch_status} | |
</div> | |
""" | |
if has_attention and attention_weights is not None: | |
attention_scores = attention_weights.squeeze(0).tolist() if attention_weights.dim() > 1 else attention_weights.tolist() | |
token_analysis = [] | |
for i, (token, score) in enumerate(zip(tokens, attention_scores)): | |
# More lenient filtering - include more tokens for text analysis | |
if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>'] and len(token.strip()) > 0 and score > 0.005: | |
clean_token = token.replace('▁', '').replace('Ġ', '').strip() # Handle different tokenizer prefixes | |
if clean_token: # Only add if token has content after cleaning | |
token_analysis.append({ | |
'token': clean_token, | |
'importance': score, | |
'position': i | |
}) | |
# Sort by importance | |
token_analysis.sort(key=lambda x: x['importance'], reverse=True) | |
detailed_analysis += f""" | |
## Top important tokens: | |
<div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;"> | |
<strong>Analysis Info:</strong> Found {len(token_analysis)} important tokens out of {len(tokens)} total tokens | |
</div> | |
<div style="font-family: Arial, sans-serif;"> | |
""" | |
for i, token_info in enumerate(token_analysis[:10]): # Top 10 tokens | |
bar_width = int(token_info['importance'] * 100) | |
color = "#ff4444" if predicted_class == "phishing" else "#44ff44" | |
detailed_analysis += f""" | |
<div style="margin: 8px 0; display: flex; align-items: center; background: #2d2d2d; padding: 8px; border-radius: 8px; border-left: 4px solid {color};"> | |
<div style="width: 30px; text-align: right; margin-right: 10px; font-weight: bold; color: #ffffff;"> | |
{i+1}. | |
</div> | |
<div style="width: 120px; margin-right: 10px; font-weight: bold; color: #e0e0e0; text-align: right;"> | |
{token_info['token']} | |
</div> | |
<div style="width: 300px; background-color: #404040; border-radius: 10px; overflow: hidden; margin-right: 10px; border: 1px solid #555;"> | |
<div style="width: {bar_width}%; background-color: {color}; height: 20px; border-radius: 10px; transition: width 0.3s ease;"></div> | |
</div> | |
<div style="color: #cccccc; font-size: 12px; font-weight: bold;"> | |
{token_info['importance']:.1%} | |
</div> | |
</div> | |
""" | |
detailed_analysis += "</div>\n" | |
detailed_analysis += f""" | |
## Detailed analysis: | |
<div style="font-family: Arial, sans-serif; background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;"> | |
<h3 style="margin: 0 0 15px 0; color: white;">Statistical Overview</h3> | |
<div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;"> | |
<div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);"> | |
<div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div> | |
<div style="font-size: 14px; color: #e0e0e0;">Total tokens</div> | |
</div> | |
<div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);"> | |
<div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in token_analysis if t['importance'] > 0.05])}</div> | |
<div style="font-size: 14px; color: #e0e0e0;">High impact tokens (>5%)</div> | |
</div> | |
</div> | |
</div> | |
<div style="font-family: Arial, sans-serif; margin: 15px 0; background: #2d2d2d; padding: 20px; border-radius: 15px; border: 1px solid #555;"> | |
<h3 style="color: #ffffff; margin-bottom: 15px;"> Prediction Confidence</h3> | |
<div style="display: flex; justify-content: space-between; margin-bottom: 10px;"> | |
<span style="font-weight: bold; color: #ff4444;">Phishing</span> | |
<span style="font-weight: bold; color: #44ff44;">Benign</span> | |
</div> | |
<div style="width: 100%; background-color: #404040; border-radius: 25px; overflow: hidden; height: 30px; border: 1px solid #666;"> | |
<div style="width: {probs[1]*100:.1f}%; background: linear-gradient(90deg, #ff4444 0%, #ff6666 100%); height: 100%; display: flex; align-items: center; justify-content: center; color: white; font-weight: bold; font-size: 14px;"> | |
{probs[1]:.1%} | |
</div> | |
</div> | |
<div style="margin-top: 10px; text-align: center; color: #cccccc; font-size: 14px;"> | |
Benign: {probs[0]:.1%} | |
</div> | |
</div> | |
""" | |
else: | |
# Fallback analysis without attention weights | |
detailed_analysis += f""" | |
<div style="background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;"> | |
<h3 style="margin: 0 0 15px 0; color: white;">Basic Analysis</h3> | |
<div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 15px;"> | |
<div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);"> | |
<div style="font-size: 24px; font-weight: bold; color: white;">{probs[1]:.1%}</div> | |
<div style="font-size: 14px; color: #e0e0e0;">Phishing</div> | |
</div> | |
<div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);"> | |
<div style="font-size: 24px; font-weight: bold; color: white;">{probs[0]:.1%}</div> | |
<div style="font-size: 14px; color: #e0e0e0;">Benign</div> | |
</div> | |
<div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);"> | |
<div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div> | |
<div style="font-size: 14px; color: #e0e0e0;">Tokens</div> | |
</div> | |
</div> | |
</div> | |
<div style="background: #2d2d2d; padding: 20px; border-radius: 15px; margin: 15px 0; border: 1px solid #555;"> | |
<h3 style="color: #ffffff; margin: 0 0 15px 0;">🔤 Tokens in text:</h3> | |
<div style="display: flex; flex-wrap: wrap; gap: 8px;">""" + ''.join([f'<span style="background: #404040; color: #64b5f6; padding: 4px 8px; border-radius: 15px; font-size: 12px; border: 1px solid #666;">{token.replace("▁", "")}</span>' for token in tokens if token not in ['[CLS]', '[SEP]', '[PAD]']]) + f"""</div> | |
<div style="margin-top: 15px; padding: 10px; background: #3d2914; border-radius: 8px; border-left: 4px solid #ff9800;"> | |
<strong style="color: #ffcc02;">Debug info:</strong> <span style="color: #e0e0e0;">Found {len(tokens)} total tokens, {len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])} content tokens</span> | |
</div> | |
</div> | |
<div style="background: #3d2914; padding: 15px; border-radius: 10px; border-left: 4px solid #ff9800; margin: 15px 0;"> | |
<p style="margin: 0; color: #ffcc02; font-size: 14px;"> | |
<strong>Note:</strong> Detailed attention weights analysis is not available for the current model. | |
</p> | |
</div> | |
""" | |
detailed_analysis += "</div>" | |
# Build label->prob mapping for Gradio Label output | |
if len(LABELS) == len(probs): | |
prediction_result = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))} | |
else: | |
prediction_result = {f"class_{i}": float(p) for i, p in enumerate(probs)} | |
return prediction_result, detailed_analysis | |
# --------- Gradio UI ---------- | |
deberta_interface = gr.Interface( | |
fn=predict_fn, | |
inputs=gr.Textbox(label="URL or text", placeholder="Example: http://suspicious-site.example or paste any text"), | |
outputs=[ | |
gr.Label(label="Prediction result"), | |
gr.Markdown(label="Detailed token analysis") | |
], | |
title="Phishing Detector (DeBERTa + LSTM)", | |
description=""" | |
Enter a URL or text for analysis. | |
**Features:** | |
- **URL Analysis**: For URLs, the system will fetch HTML content and combine both URL and content analysis | |
- **Combined Prediction**: Uses weighted combination of URL structure and webpage content analysis | |
- **Visual Analysis**: Predict phishing/benign probability with visual charts | |
- **Token Importance**: Display the most important tokens in classification | |
- **Detailed Insights**: Comprehensive analysis of the impact of each token | |
- **Dark Theme**: Beautiful interface with colorful charts optimized for dark themes | |
**How it works for URLs:** | |
1. Analyze the URL structure itself | |
2. Fetch the webpage HTML content | |
3. Analyze the webpage content | |
4. Combine both results for final prediction (30% URL + 70% content) | |
""", | |
examples=[ | |
["http://rendmoiunserviceeee.com"], | |
["https://www.google.com"], | |
["Dear customer, your account has been suspended. Click here to verify your identity immediately."], | |
["https://mail-secure-login-verify.example/path?token=suspicious"], | |
["http://paypaI-security-update.net/login"], | |
["Your package has been delivered successfully. Thank you for using our service."], | |
["https://github.com/user/repo"] | |
], | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
background-color: #1e1e1e !important; | |
color: #ffffff !important; | |
} | |
.dark .gradio-container { | |
background-color: #1e1e1e !important; | |
} | |
/* Dark theme for all components */ | |
.block { | |
background-color: #2d2d2d !important; | |
border: 1px solid #444 !important; | |
} | |
.gradio-textbox { | |
background-color: #3d3d3d !important; | |
color: #ffffff !important; | |
border: 1px solid #666 !important; | |
} | |
.gradio-button { | |
background-color: #4a4a4a !important; | |
color: #ffffff !important; | |
border: 1px solid #666 !important; | |
} | |
.gradio-button:hover { | |
background-color: #5a5a5a !important; | |
} | |
""" | |
) | |
bert_interface = gr.Interface( | |
fn=predict_bert_interface_fn, | |
inputs=gr.Textbox(label="URL or text", placeholder="Example: http://suspicious-site.example or paste any text"), | |
outputs=[ | |
gr.Label(label="Prediction result"), | |
gr.Markdown(label="Detailed token analysis") | |
], | |
title="Phishing Detector (BERT)", | |
description=""" | |
Enter a URL or text for analysis using the BERT model. | |
**Features:** | |
- **URL Analysis**: For URLs, the system will fetch HTML content and combine both URL and content analysis | |
- **Combined Prediction**: Uses weighted combination of URL structure and webpage content analysis | |
- **Visual Analysis**: Predict phishing/benign probability with visual charts | |
- **Token Importance**: Display the most important tokens in classification using attention weights | |
- **Detailed Insights**: Comprehensive analysis of the impact of each token | |
- **Dark Theme**: Beautiful interface with colorful charts optimized for dark themes | |
**How it works for URLs:** | |
1. Analyze the URL structure itself | |
2. Fetch the webpage HTML content | |
3. Analyze the webpage content | |
4. Combine both results for final prediction (30% URL + 70% content) | |
""", | |
examples=[ | |
["http://rendmoiunserviceeee.com"], | |
["https://www.google.com"], | |
["Dear customer, your account has been suspended. Click here to verify your identity immediately."], | |
["https://mail-secure-login-verify.example/path?token=suspicious"], | |
["http://paypaI-security-update.net/login"], | |
["Your package has been delivered successfully. Thank you for using our service."], | |
["https://github.com/user/repo"] | |
], | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
background-color: #1e1e1e !important; | |
color: #ffffff !important; | |
} | |
.dark .gradio-container { | |
background-color: #1e1e1e !important; | |
} | |
/* Dark theme for all components */ | |
.block { | |
background-color: #2d2d2d !important; | |
border: 1px solid #444 !important; | |
} | |
.gradio-textbox { | |
background-color: #3d3d3d !important; | |
color: #ffffff !important; | |
border: 1px solid #666 !important; | |
} | |
.gradio-button { | |
background-color: #4a4a4a !important; | |
color: #ffffff !important; | |
border: 1px solid #666 !important; | |
} | |
.gradio-button:hover { | |
background-color: #5a5a5a !important; | |
} | |
""" | |
) | |
demo = gr.TabbedInterface( | |
[deberta_interface, bert_interface], | |
["DeBERTa + LSTM", "BERT"] | |
) | |
if __name__ == "__main__": | |
demo.launch() |