import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForSequenceClassification from huggingface_hub import hf_hub_download import gradio as gr import requests import re from urllib.parse import urlparse from bs4 import BeautifulSoup import time import joblib # --- import your architecture --- # Make sure this file is in the repo (e.g., models/deberta_lstm_classifier.py) # and update the import path accordingly. from model import DeBERTaLSTMClassifier # <-- your class # --------- Config ---------- REPO_ID = "khoa-done/phishing-detector" # HF repo that holds the checkpoint CKPT_NAME = "deberta_lstm_checkpoint.pt" # the .pt file name MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone LABELS = ["benign", "phishing"] # adjust to your classes # If your checkpoint contains hyperparams, you can fetch them like: # checkpoint.get("config") or checkpoint.get("model_args") # and pass into DeBERTaLSTMClassifier(**model_args) # --------- Load model/tokenizer once (global) ---------- device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME) checkpoint = torch.load(ckpt_path, map_location=device) # If you saved hyperparams in the checkpoint, use them: model_args = checkpoint.get("model_args", {}) # e.g., {"lstm_hidden":256, "num_labels":2, ...} model = DeBERTaLSTMClassifier(**model_args) # Load state dict and handle missing attention layer for older models try: model.load_state_dict(checkpoint["model_state_dict"]) except RuntimeError as e: if "attention" in str(e): # Old model without attention layer - initialize attention layer and load partial state state_dict = checkpoint["model_state_dict"] model_dict = model.state_dict() # Filter out attention layer parameters filtered_dict = {k: v for k, v in state_dict.items() if "attention" not in k} model_dict.update(filtered_dict) model.load_state_dict(model_dict) print("Loaded model without attention layer, using newly initialized attention weights") else: raise e model.to(device).eval() # --------- Load BERT model/tokenizer from Hugging Face Hub ---------- BERT_MODEL_PATH = "th1enq/bert_checkpoint" # Use Hugging Face Hub model bert_tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_PATH) bert_model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL_PATH) bert_model.to(device).eval() # --------- Helper functions ---------- def is_url(text): """Check if text is a URL""" url_pattern = re.compile( r'^https?://' # http:// or https:// r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain... r'localhost|' # localhost... r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip r'(?::\d+)?' # optional port r'(?:/?|[/?]\S+)$', re.IGNORECASE) return url_pattern.match(text) is not None def fetch_html_content(url, timeout=10): """Fetch HTML content from URL""" try: headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' } response = requests.get(url, headers=headers, timeout=timeout, verify=False) response.raise_for_status() return response.text, response.status_code except requests.exceptions.RequestException as e: return None, f"Request error: {str(e)}" except Exception as e: return None, f"General error: {str(e)}" def predict_single_text(text, text_type="text"): """Predict for a single text input""" # Tokenize inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=256 ) # DeBERTa typically doesn't use token_type_ids inputs.pop("token_type_ids", None) # Move to device inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): try: # Try to get predictions with attention weights result = model(**inputs, return_attention=True) if isinstance(result, tuple) and len(result) == 3: logits, attention_weights, deberta_attentions = result has_attention = True else: logits = result has_attention = False except TypeError: # Fallback for older model without return_attention parameter logits = model(**inputs) has_attention = False probs = F.softmax(logits, dim=-1).squeeze(0).tolist() # Get tokens for visualization tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist()) return probs, tokens, has_attention, attention_weights if has_attention else None def combine_predictions(url_probs, html_probs, url_weight=0.3, html_weight=0.7): """Combine URL and HTML content predictions""" combined_probs = [ url_weight * url_probs[0] + html_weight * html_probs[0], # benign url_weight * url_probs[1] + html_weight * html_probs[1] # phishing ] return combined_probs # --------- Inference function ---------- def predict_fn(text: str): if not text or not text.strip(): return {"error": "Please enter a URL or text."}, "" # Check if input is URL if is_url(text.strip()): # Process URL url = text.strip() # Get prediction for URL itself url_probs, url_tokens, url_has_attention, url_attention = predict_single_text(url, "URL") # Try to fetch HTML content html_content, status = fetch_html_content(url) if html_content: # Get prediction for HTML content html_probs, html_tokens, html_has_attention, html_attention = predict_single_text(html_content, "HTML") # Combine predictions combined_probs = combine_predictions(url_probs, html_probs) # Use combined probabilities but show analysis for both probs = combined_probs tokens = url_tokens + ["[SEP]"] + html_tokens[:50] # Limit HTML tokens for display has_attention = url_has_attention or html_has_attention attention_weights = url_attention if url_has_attention else html_attention analysis_type = "Combined URL + HTML Analysis" fetch_status = f"✅ Successfully fetched HTML content (Status: {status})" else: # Fallback to URL-only analysis probs = url_probs tokens = url_tokens has_attention = url_has_attention attention_weights = url_attention analysis_type = "URL-only Analysis" fetch_status = f"⚠️ Could not fetch HTML content: {status}" else: # Process as regular text probs, tokens, has_attention, attention_weights = predict_single_text(text, "text") analysis_type = "Text Analysis" fetch_status = "" # Get tokens for visualization # Create detailed analysis predicted_class = "phishing" if probs[1] > probs[0] else "benign" confidence = max(probs) detailed_analysis = f"""

🔍 {analysis_type}

{predicted_class.upper()}
Confidence: {confidence:.1%}
{'This appears to be a phishing attempt!' if predicted_class == 'phishing' else '✅ This appears to be legitimate content.'}
""" if fetch_status: detailed_analysis += f"""
Fetch Status: {fetch_status}
""" if has_attention and attention_weights is not None: attention_scores = attention_weights.squeeze(0).tolist() token_analysis = [] for i, (token, score) in enumerate(zip(tokens, attention_scores)): # More lenient filtering - include more tokens for text analysis if token not in ['[CLS]', '[SEP]', '[PAD]', '', ''] and len(token.strip()) > 0 and score > 0.005: clean_token = token.replace('▁', '').replace('Ġ', '').strip() # Handle different tokenizer prefixes if clean_token: # Only add if token has content after cleaning token_analysis.append({ 'token': clean_token, 'importance': score, 'position': i }) # Sort by importance token_analysis.sort(key=lambda x: x['importance'], reverse=True) detailed_analysis += f""" ## Top important tokens:
Analysis Info: Found {len(token_analysis)} important tokens out of {len(tokens)} total tokens
""" for i, token_info in enumerate(token_analysis[:10]): # Top 10 tokens bar_width = int(token_info['importance'] * 100) color = "#ff4444" if predicted_class == "phishing" else "#44ff44" detailed_analysis += f"""
{i+1}.
{token_info['token']}
{token_info['importance']:.1%}
""" detailed_analysis += "
\n" detailed_analysis += f""" ## Detailed analysis:

Statistical Overview

{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}
Total tokens
{len([t for t in token_analysis if t['importance'] > 0.05])}
High impact tokens (>5%)

Prediction Confidence

Phishing Benign
{probs[1]:.1%}
Benign: {probs[0]:.1%}
""" else: # Fallback analysis without attention weights detailed_analysis += f"""

Basic Analysis

{probs[1]:.1%}
Phishing
{probs[0]:.1%}
Benign
{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}
Tokens

🔤 Tokens in text:

""" + ''.join([f'{token.replace("▁", "")}' for token in tokens if token not in ['[CLS]', '[SEP]', '[PAD]']]) + f"""
Debug info: Found {len(tokens)} total tokens, {len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])} content tokens

Note: Detailed attention weights analysis is not available for the current model.

""" # Build label->prob mapping for Gradio Label output if len(LABELS) == len(probs): prediction_result = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))} else: prediction_result = {f"class_{i}": float(p) for i, p in enumerate(probs)} return prediction_result, detailed_analysis # --------- BERT Model Functions ---------- def predict_bert_single_text(text, text_type="text"): """Predict for a single text input using BERT.""" # Tokenize inputs = bert_tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=512 ) # Move to device inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = bert_model(**inputs, output_attentions=True) logits = outputs.logits probs = F.softmax(logits, dim=-1).squeeze(0).tolist() # Get tokens for visualization tokens = bert_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist()) # Get attention weights (use last layer, first head as approximation) attention_weights = None has_attention = False if hasattr(outputs, 'attentions') and outputs.attentions is not None: # Average attention across all heads in the last layer last_layer_attention = outputs.attentions[-1] # Last layer attention_weights = last_layer_attention.mean(dim=1).squeeze(0) # Average across heads # Use attention to [CLS] token as importance scores attention_weights = attention_weights[0] # [CLS] token attention to all tokens has_attention = True return probs, tokens, has_attention, attention_weights def predict_bert_interface_fn(text: str): """Gradio interface function for BERT model.""" if not text or not text.strip(): return {"error": "Please enter a URL or text."}, "" # Check if input is URL if is_url(text.strip()): # Process URL url = text.strip() # Get prediction for URL itself url_probs, url_tokens, url_has_attention, url_attention = predict_bert_single_text(url, "URL") # Try to fetch HTML content html_content, status = fetch_html_content(url) if html_content: # Get prediction for HTML content html_probs, html_tokens, html_has_attention, html_attention = predict_bert_single_text(html_content, "HTML") # Combine predictions combined_probs = combine_predictions(url_probs, html_probs) # Use combined probabilities but show analysis for both probs = combined_probs tokens = url_tokens + ["[SEP]"] + html_tokens[:50] # Limit HTML tokens for display has_attention = url_has_attention or html_has_attention attention_weights = url_attention if url_has_attention else html_attention analysis_type = "Combined URL + HTML BERT Analysis" fetch_status = f"✅ Successfully fetched HTML content (Status: {status})" else: # Fallback to URL-only analysis probs = url_probs tokens = url_tokens has_attention = url_has_attention attention_weights = url_attention analysis_type = "URL-only BERT Analysis" fetch_status = f"⚠️ Could not fetch HTML content: {status}" else: # Process as regular text probs, tokens, has_attention, attention_weights = predict_bert_single_text(text, "text") analysis_type = "BERT Text Analysis" fetch_status = "" # Create detailed analysis predicted_class = "phishing" if probs[1] > probs[0] else "benign" confidence = max(probs) detailed_analysis = f"""

🔍 {analysis_type}

{predicted_class.upper()}
Confidence: {confidence:.1%}
{'This appears to be a phishing attempt!' if predicted_class == 'phishing' else '✅ This appears to be legitimate content.'}
""" if fetch_status: detailed_analysis += f"""
Fetch Status: {fetch_status}
""" if has_attention and attention_weights is not None: attention_scores = attention_weights.squeeze(0).tolist() if attention_weights.dim() > 1 else attention_weights.tolist() token_analysis = [] for i, (token, score) in enumerate(zip(tokens, attention_scores)): # More lenient filtering - include more tokens for text analysis if token not in ['[CLS]', '[SEP]', '[PAD]', '', ''] and len(token.strip()) > 0 and score > 0.005: clean_token = token.replace('▁', '').replace('Ġ', '').strip() # Handle different tokenizer prefixes if clean_token: # Only add if token has content after cleaning token_analysis.append({ 'token': clean_token, 'importance': score, 'position': i }) # Sort by importance token_analysis.sort(key=lambda x: x['importance'], reverse=True) detailed_analysis += f""" ## Top important tokens:
Analysis Info: Found {len(token_analysis)} important tokens out of {len(tokens)} total tokens
""" for i, token_info in enumerate(token_analysis[:10]): # Top 10 tokens bar_width = int(token_info['importance'] * 100) color = "#ff4444" if predicted_class == "phishing" else "#44ff44" detailed_analysis += f"""
{i+1}.
{token_info['token']}
{token_info['importance']:.1%}
""" detailed_analysis += "
\n" detailed_analysis += f""" ## Detailed analysis:

Statistical Overview

{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}
Total tokens
{len([t for t in token_analysis if t['importance'] > 0.05])}
High impact tokens (>5%)

Prediction Confidence

Phishing Benign
{probs[1]:.1%}
Benign: {probs[0]:.1%}
""" else: # Fallback analysis without attention weights detailed_analysis += f"""

Basic Analysis

{probs[1]:.1%}
Phishing
{probs[0]:.1%}
Benign
{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}
Tokens

🔤 Tokens in text:

""" + ''.join([f'{token.replace("▁", "")}' for token in tokens if token not in ['[CLS]', '[SEP]', '[PAD]']]) + f"""
Debug info: Found {len(tokens)} total tokens, {len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])} content tokens

Note: Detailed attention weights analysis is not available for the current model.

""" detailed_analysis += "
" # Build label->prob mapping for Gradio Label output if len(LABELS) == len(probs): prediction_result = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))} else: prediction_result = {f"class_{i}": float(p) for i, p in enumerate(probs)} return prediction_result, detailed_analysis # --------- Gradio UI ---------- deberta_interface = gr.Interface( fn=predict_fn, inputs=gr.Textbox(label="URL or text", placeholder="Example: http://suspicious-site.example or paste any text"), outputs=[ gr.Label(label="Prediction result"), gr.Markdown(label="Detailed token analysis") ], title="Phishing Detector (DeBERTa + LSTM)", description=""" Enter a URL or text for analysis. **Features:** - **URL Analysis**: For URLs, the system will fetch HTML content and combine both URL and content analysis - **Combined Prediction**: Uses weighted combination of URL structure and webpage content analysis - **Visual Analysis**: Predict phishing/benign probability with visual charts - **Token Importance**: Display the most important tokens in classification - **Detailed Insights**: Comprehensive analysis of the impact of each token - **Dark Theme**: Beautiful interface with colorful charts optimized for dark themes **How it works for URLs:** 1. Analyze the URL structure itself 2. Fetch the webpage HTML content 3. Analyze the webpage content 4. Combine both results for final prediction (30% URL + 70% content) """, examples=[ ["http://rendmoiunserviceeee.com"], ["https://www.google.com"], ["Dear customer, your account has been suspended. Click here to verify your identity immediately."], ["https://mail-secure-login-verify.example/path?token=suspicious"], ["http://paypaI-security-update.net/login"], ["Your package has been delivered successfully. Thank you for using our service."], ["https://github.com/user/repo"] ], theme=gr.themes.Soft(), css=""" .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background-color: #1e1e1e !important; color: #ffffff !important; } .dark .gradio-container { background-color: #1e1e1e !important; } /* Dark theme for all components */ .block { background-color: #2d2d2d !important; border: 1px solid #444 !important; } .gradio-textbox { background-color: #3d3d3d !important; color: #ffffff !important; border: 1px solid #666 !important; } .gradio-button { background-color: #4a4a4a !important; color: #ffffff !important; border: 1px solid #666 !important; } .gradio-button:hover { background-color: #5a5a5a !important; } """ ) bert_interface = gr.Interface( fn=predict_bert_interface_fn, inputs=gr.Textbox(label="URL or text", placeholder="Example: http://suspicious-site.example or paste any text"), outputs=[ gr.Label(label="Prediction result"), gr.Markdown(label="Detailed token analysis") ], title="Phishing Detector (BERT)", description=""" Enter a URL or text for analysis using the BERT model. **Features:** - **URL Analysis**: For URLs, the system will fetch HTML content and combine both URL and content analysis - **Combined Prediction**: Uses weighted combination of URL structure and webpage content analysis - **Visual Analysis**: Predict phishing/benign probability with visual charts - **Token Importance**: Display the most important tokens in classification using attention weights - **Detailed Insights**: Comprehensive analysis of the impact of each token - **Dark Theme**: Beautiful interface with colorful charts optimized for dark themes **How it works for URLs:** 1. Analyze the URL structure itself 2. Fetch the webpage HTML content 3. Analyze the webpage content 4. Combine both results for final prediction (30% URL + 70% content) """, examples=[ ["http://rendmoiunserviceeee.com"], ["https://www.google.com"], ["Dear customer, your account has been suspended. Click here to verify your identity immediately."], ["https://mail-secure-login-verify.example/path?token=suspicious"], ["http://paypaI-security-update.net/login"], ["Your package has been delivered successfully. Thank you for using our service."], ["https://github.com/user/repo"] ], theme=gr.themes.Soft(), css=""" .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background-color: #1e1e1e !important; color: #ffffff !important; } .dark .gradio-container { background-color: #1e1e1e !important; } /* Dark theme for all components */ .block { background-color: #2d2d2d !important; border: 1px solid #444 !important; } .gradio-textbox { background-color: #3d3d3d !important; color: #ffffff !important; border: 1px solid #666 !important; } .gradio-button { background-color: #4a4a4a !important; color: #ffffff !important; border: 1px solid #666 !important; } .gradio-button:hover { background-color: #5a5a5a !important; } """ ) demo = gr.TabbedInterface( [deberta_interface, bert_interface], ["DeBERTa + LSTM", "BERT"] ) if __name__ == "__main__": demo.launch()