import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import hf_hub_download
import gradio as gr
import requests
import re
from urllib.parse import urlparse
from bs4 import BeautifulSoup
import time
import joblib
# --- import your architecture ---
# Make sure this file is in the repo (e.g., models/deberta_lstm_classifier.py)
# and update the import path accordingly.
from model import DeBERTaLSTMClassifier # <-- your class
# --------- Config ----------
REPO_ID = "khoa-done/phishing-detector" # HF repo that holds the checkpoint
CKPT_NAME = "deberta_lstm_checkpoint.pt" # the .pt file name
MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone
LABELS = ["benign", "phishing"] # adjust to your classes
# If your checkpoint contains hyperparams, you can fetch them like:
# checkpoint.get("config") or checkpoint.get("model_args")
# and pass into DeBERTaLSTMClassifier(**model_args)
# --------- Load model/tokenizer once (global) ----------
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME)
checkpoint = torch.load(ckpt_path, map_location=device)
# If you saved hyperparams in the checkpoint, use them:
model_args = checkpoint.get("model_args", {}) # e.g., {"lstm_hidden":256, "num_labels":2, ...}
model = DeBERTaLSTMClassifier(**model_args)
# Load state dict and handle missing attention layer for older models
try:
model.load_state_dict(checkpoint["model_state_dict"])
except RuntimeError as e:
if "attention" in str(e):
# Old model without attention layer - initialize attention layer and load partial state
state_dict = checkpoint["model_state_dict"]
model_dict = model.state_dict()
# Filter out attention layer parameters
filtered_dict = {k: v for k, v in state_dict.items() if "attention" not in k}
model_dict.update(filtered_dict)
model.load_state_dict(model_dict)
print("Loaded model without attention layer, using newly initialized attention weights")
else:
raise e
model.to(device).eval()
# --------- Load BERT model/tokenizer from Hugging Face Hub ----------
BERT_MODEL_PATH = "th1enq/bert_checkpoint" # Use Hugging Face Hub model
bert_tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_PATH)
bert_model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL_PATH)
bert_model.to(device).eval()
# --------- Helper functions ----------
def is_url(text):
"""Check if text is a URL"""
url_pattern = re.compile(
r'^https?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain...
r'localhost|' # localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
return url_pattern.match(text) is not None
def fetch_html_content(url, timeout=10):
"""Fetch HTML content from URL"""
try:
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
response = requests.get(url, headers=headers, timeout=timeout, verify=False)
response.raise_for_status()
return response.text, response.status_code
except requests.exceptions.RequestException as e:
return None, f"Request error: {str(e)}"
except Exception as e:
return None, f"General error: {str(e)}"
def predict_single_text(text, text_type="text"):
"""Predict for a single text input"""
# Tokenize
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256
)
# DeBERTa typically doesn't use token_type_ids
inputs.pop("token_type_ids", None)
# Move to device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
try:
# Try to get predictions with attention weights
result = model(**inputs, return_attention=True)
if isinstance(result, tuple) and len(result) == 3:
logits, attention_weights, deberta_attentions = result
has_attention = True
else:
logits = result
has_attention = False
except TypeError:
# Fallback for older model without return_attention parameter
logits = model(**inputs)
has_attention = False
probs = F.softmax(logits, dim=-1).squeeze(0).tolist()
# Get tokens for visualization
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist())
return probs, tokens, has_attention, attention_weights if has_attention else None
def combine_predictions(url_probs, html_probs, url_weight=0.3, html_weight=0.7):
"""Combine URL and HTML content predictions"""
combined_probs = [
url_weight * url_probs[0] + html_weight * html_probs[0], # benign
url_weight * url_probs[1] + html_weight * html_probs[1] # phishing
]
return combined_probs
# --------- Inference function ----------
def predict_fn(text: str):
if not text or not text.strip():
return {"error": "Please enter a URL or text."}, ""
# Check if input is URL
if is_url(text.strip()):
# Process URL
url = text.strip()
# Get prediction for URL itself
url_probs, url_tokens, url_has_attention, url_attention = predict_single_text(url, "URL")
# Try to fetch HTML content
html_content, status = fetch_html_content(url)
if html_content:
# Get prediction for HTML content
html_probs, html_tokens, html_has_attention, html_attention = predict_single_text(html_content, "HTML")
# Combine predictions
combined_probs = combine_predictions(url_probs, html_probs)
# Use combined probabilities but show analysis for both
probs = combined_probs
tokens = url_tokens + ["[SEP]"] + html_tokens[:50] # Limit HTML tokens for display
has_attention = url_has_attention or html_has_attention
attention_weights = url_attention if url_has_attention else html_attention
analysis_type = "Combined URL + HTML Analysis"
fetch_status = f"✅ Successfully fetched HTML content (Status: {status})"
else:
# Fallback to URL-only analysis
probs = url_probs
tokens = url_tokens
has_attention = url_has_attention
attention_weights = url_attention
analysis_type = "URL-only Analysis"
fetch_status = f"⚠️ Could not fetch HTML content: {status}"
else:
# Process as regular text
probs, tokens, has_attention, attention_weights = predict_single_text(text, "text")
analysis_type = "Text Analysis"
fetch_status = ""
# Get tokens for visualization
# Create detailed analysis
predicted_class = "phishing" if probs[1] > probs[0] else "benign"
confidence = max(probs)
detailed_analysis = f"""
🔍 {analysis_type}
{predicted_class.upper()}
Confidence: {confidence:.1%}
{'This appears to be a phishing attempt!' if predicted_class == 'phishing' else '✅ This appears to be legitimate content.'}
"""
if fetch_status:
detailed_analysis += f"""
Fetch Status: {fetch_status}
"""
if has_attention and attention_weights is not None:
attention_scores = attention_weights.squeeze(0).tolist()
token_analysis = []
for i, (token, score) in enumerate(zip(tokens, attention_scores)):
# More lenient filtering - include more tokens for text analysis
if token not in ['[CLS]', '[SEP]', '[PAD]', '
', ''] and len(token.strip()) > 0 and score > 0.005:
clean_token = token.replace('▁', '').replace('Ġ', '').strip() # Handle different tokenizer prefixes
if clean_token: # Only add if token has content after cleaning
token_analysis.append({
'token': clean_token,
'importance': score,
'position': i
})
# Sort by importance
token_analysis.sort(key=lambda x: x['importance'], reverse=True)
detailed_analysis += f"""
## Top important tokens:
Analysis Info: Found {len(token_analysis)} important tokens out of {len(tokens)} total tokens
"""
for i, token_info in enumerate(token_analysis[:10]): # Top 10 tokens
bar_width = int(token_info['importance'] * 100)
color = "#ff4444" if predicted_class == "phishing" else "#44ff44"
detailed_analysis += f"""
{i+1}.
{token_info['token']}
{token_info['importance']:.1%}
"""
detailed_analysis += "
\n"
detailed_analysis += f"""
## Detailed analysis:
Statistical Overview
{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}
Total tokens
{len([t for t in token_analysis if t['importance'] > 0.05])}
High impact tokens (>5%)
Prediction Confidence
Phishing
Benign
Benign: {probs[0]:.1%}
"""
else:
# Fallback analysis without attention weights
detailed_analysis += f"""
Basic Analysis
{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}
Tokens
🔤 Tokens in text:
""" + ''.join([f'{token.replace("▁", "")}' for token in tokens if token not in ['[CLS]', '[SEP]', '[PAD]']]) + f"""
Debug info: Found {len(tokens)} total tokens, {len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])} content tokens
Note: Detailed attention weights analysis is not available for the current model.
"""
# Build label->prob mapping for Gradio Label output
if len(LABELS) == len(probs):
prediction_result = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
else:
prediction_result = {f"class_{i}": float(p) for i, p in enumerate(probs)}
return prediction_result, detailed_analysis
# --------- BERT Model Functions ----------
def predict_bert_single_text(text, text_type="text"):
"""Predict for a single text input using BERT."""
# Tokenize
inputs = bert_tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512
)
# Move to device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = bert_model(**inputs, output_attentions=True)
logits = outputs.logits
probs = F.softmax(logits, dim=-1).squeeze(0).tolist()
# Get tokens for visualization
tokens = bert_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist())
# Get attention weights (use last layer, first head as approximation)
attention_weights = None
has_attention = False
if hasattr(outputs, 'attentions') and outputs.attentions is not None:
# Average attention across all heads in the last layer
last_layer_attention = outputs.attentions[-1] # Last layer
attention_weights = last_layer_attention.mean(dim=1).squeeze(0) # Average across heads
# Use attention to [CLS] token as importance scores
attention_weights = attention_weights[0] # [CLS] token attention to all tokens
has_attention = True
return probs, tokens, has_attention, attention_weights
def predict_bert_interface_fn(text: str):
"""Gradio interface function for BERT model."""
if not text or not text.strip():
return {"error": "Please enter a URL or text."}, ""
# Check if input is URL
if is_url(text.strip()):
# Process URL
url = text.strip()
# Get prediction for URL itself
url_probs, url_tokens, url_has_attention, url_attention = predict_bert_single_text(url, "URL")
# Try to fetch HTML content
html_content, status = fetch_html_content(url)
if html_content:
# Get prediction for HTML content
html_probs, html_tokens, html_has_attention, html_attention = predict_bert_single_text(html_content, "HTML")
# Combine predictions
combined_probs = combine_predictions(url_probs, html_probs)
# Use combined probabilities but show analysis for both
probs = combined_probs
tokens = url_tokens + ["[SEP]"] + html_tokens[:50] # Limit HTML tokens for display
has_attention = url_has_attention or html_has_attention
attention_weights = url_attention if url_has_attention else html_attention
analysis_type = "Combined URL + HTML BERT Analysis"
fetch_status = f"✅ Successfully fetched HTML content (Status: {status})"
else:
# Fallback to URL-only analysis
probs = url_probs
tokens = url_tokens
has_attention = url_has_attention
attention_weights = url_attention
analysis_type = "URL-only BERT Analysis"
fetch_status = f"⚠️ Could not fetch HTML content: {status}"
else:
# Process as regular text
probs, tokens, has_attention, attention_weights = predict_bert_single_text(text, "text")
analysis_type = "BERT Text Analysis"
fetch_status = ""
# Create detailed analysis
predicted_class = "phishing" if probs[1] > probs[0] else "benign"
confidence = max(probs)
detailed_analysis = f"""
🔍 {analysis_type}
{predicted_class.upper()}
Confidence: {confidence:.1%}
{'This appears to be a phishing attempt!' if predicted_class == 'phishing' else '✅ This appears to be legitimate content.'}
"""
if fetch_status:
detailed_analysis += f"""
Fetch Status: {fetch_status}
"""
if has_attention and attention_weights is not None:
attention_scores = attention_weights.squeeze(0).tolist() if attention_weights.dim() > 1 else attention_weights.tolist()
token_analysis = []
for i, (token, score) in enumerate(zip(tokens, attention_scores)):
# More lenient filtering - include more tokens for text analysis
if token not in ['[CLS]', '[SEP]', '[PAD]', '
', ''] and len(token.strip()) > 0 and score > 0.005:
clean_token = token.replace('▁', '').replace('Ġ', '').strip() # Handle different tokenizer prefixes
if clean_token: # Only add if token has content after cleaning
token_analysis.append({
'token': clean_token,
'importance': score,
'position': i
})
# Sort by importance
token_analysis.sort(key=lambda x: x['importance'], reverse=True)
detailed_analysis += f"""
## Top important tokens:
Analysis Info: Found {len(token_analysis)} important tokens out of {len(tokens)} total tokens
"""
for i, token_info in enumerate(token_analysis[:10]): # Top 10 tokens
bar_width = int(token_info['importance'] * 100)
color = "#ff4444" if predicted_class == "phishing" else "#44ff44"
detailed_analysis += f"""
{i+1}.
{token_info['token']}
{token_info['importance']:.1%}
"""
detailed_analysis += "
\n"
detailed_analysis += f"""
## Detailed analysis:
Statistical Overview
{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}
Total tokens
{len([t for t in token_analysis if t['importance'] > 0.05])}
High impact tokens (>5%)
Prediction Confidence
Phishing
Benign
Benign: {probs[0]:.1%}
"""
else:
# Fallback analysis without attention weights
detailed_analysis += f"""
Basic Analysis
{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}
Tokens
🔤 Tokens in text:
""" + ''.join([f'{token.replace("▁", "")}' for token in tokens if token not in ['[CLS]', '[SEP]', '[PAD]']]) + f"""
Debug info: Found {len(tokens)} total tokens, {len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])} content tokens
Note: Detailed attention weights analysis is not available for the current model.
"""
detailed_analysis += "
"
# Build label->prob mapping for Gradio Label output
if len(LABELS) == len(probs):
prediction_result = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
else:
prediction_result = {f"class_{i}": float(p) for i, p in enumerate(probs)}
return prediction_result, detailed_analysis
# --------- Gradio UI ----------
deberta_interface = gr.Interface(
fn=predict_fn,
inputs=gr.Textbox(label="URL or text", placeholder="Example: http://suspicious-site.example or paste any text"),
outputs=[
gr.Label(label="Prediction result"),
gr.Markdown(label="Detailed token analysis")
],
title="Phishing Detector (DeBERTa + LSTM)",
description="""
Enter a URL or text for analysis.
**Features:**
- **URL Analysis**: For URLs, the system will fetch HTML content and combine both URL and content analysis
- **Combined Prediction**: Uses weighted combination of URL structure and webpage content analysis
- **Visual Analysis**: Predict phishing/benign probability with visual charts
- **Token Importance**: Display the most important tokens in classification
- **Detailed Insights**: Comprehensive analysis of the impact of each token
- **Dark Theme**: Beautiful interface with colorful charts optimized for dark themes
**How it works for URLs:**
1. Analyze the URL structure itself
2. Fetch the webpage HTML content
3. Analyze the webpage content
4. Combine both results for final prediction (30% URL + 70% content)
""",
examples=[
["http://rendmoiunserviceeee.com"],
["https://www.google.com"],
["Dear customer, your account has been suspended. Click here to verify your identity immediately."],
["https://mail-secure-login-verify.example/path?token=suspicious"],
["http://paypaI-security-update.net/login"],
["Your package has been delivered successfully. Thank you for using our service."],
["https://github.com/user/repo"]
],
theme=gr.themes.Soft(),
css="""
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background-color: #1e1e1e !important;
color: #ffffff !important;
}
.dark .gradio-container {
background-color: #1e1e1e !important;
}
/* Dark theme for all components */
.block {
background-color: #2d2d2d !important;
border: 1px solid #444 !important;
}
.gradio-textbox {
background-color: #3d3d3d !important;
color: #ffffff !important;
border: 1px solid #666 !important;
}
.gradio-button {
background-color: #4a4a4a !important;
color: #ffffff !important;
border: 1px solid #666 !important;
}
.gradio-button:hover {
background-color: #5a5a5a !important;
}
"""
)
bert_interface = gr.Interface(
fn=predict_bert_interface_fn,
inputs=gr.Textbox(label="URL or text", placeholder="Example: http://suspicious-site.example or paste any text"),
outputs=[
gr.Label(label="Prediction result"),
gr.Markdown(label="Detailed token analysis")
],
title="Phishing Detector (BERT)",
description="""
Enter a URL or text for analysis using the BERT model.
**Features:**
- **URL Analysis**: For URLs, the system will fetch HTML content and combine both URL and content analysis
- **Combined Prediction**: Uses weighted combination of URL structure and webpage content analysis
- **Visual Analysis**: Predict phishing/benign probability with visual charts
- **Token Importance**: Display the most important tokens in classification using attention weights
- **Detailed Insights**: Comprehensive analysis of the impact of each token
- **Dark Theme**: Beautiful interface with colorful charts optimized for dark themes
**How it works for URLs:**
1. Analyze the URL structure itself
2. Fetch the webpage HTML content
3. Analyze the webpage content
4. Combine both results for final prediction (30% URL + 70% content)
""",
examples=[
["http://rendmoiunserviceeee.com"],
["https://www.google.com"],
["Dear customer, your account has been suspended. Click here to verify your identity immediately."],
["https://mail-secure-login-verify.example/path?token=suspicious"],
["http://paypaI-security-update.net/login"],
["Your package has been delivered successfully. Thank you for using our service."],
["https://github.com/user/repo"]
],
theme=gr.themes.Soft(),
css="""
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background-color: #1e1e1e !important;
color: #ffffff !important;
}
.dark .gradio-container {
background-color: #1e1e1e !important;
}
/* Dark theme for all components */
.block {
background-color: #2d2d2d !important;
border: 1px solid #444 !important;
}
.gradio-textbox {
background-color: #3d3d3d !important;
color: #ffffff !important;
border: 1px solid #666 !important;
}
.gradio-button {
background-color: #4a4a4a !important;
color: #ffffff !important;
border: 1px solid #666 !important;
}
.gradio-button:hover {
background-color: #5a5a5a !important;
}
"""
)
demo = gr.TabbedInterface(
[deberta_interface, bert_interface],
["DeBERTa + LSTM", "BERT"]
)
if __name__ == "__main__":
demo.launch()