ai-assist-sh's picture
Update main.py
85540cc verified
raw
history blame
12.2 kB
import os, re, time, json, urllib.parse
import gradio as gr
import torch
import torch.nn.functional as F
import tldextract # for robust registered-domain parsing
# Quiet + CPU friendly
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
# -------- Models / Regex --------
URL_MODEL_ID = "CrabInHoney/urlbert-tiny-v4-malicious-url-classifier"
LABEL_MAP = {0: "benign", 1: "defacement", 2: "malware", 3: "phishing"}
URL_RE = re.compile(r"""(?xi)\b(?:https?://|www\.)[a-z0-9\-._~%]+(?:/[^\s<>"']*)?""")
# Heuristic config
KEYWORDS = {
"login","verify","account","secure","update","bank","wallet",
"password","invoice","pay","reset","support","unlock","confirm"
}
SUSPICIOUS_TLDS = {
"zip","mov","lol","xyz","top","country","link","click","cam",
"help","gq","cf","tk","work","rest","monster","quest","live"
}
# Lazy globals for tokenizer & model
_tok = None
_mdl = None
# -------- Utilities --------
def _extract_urls(text: str):
return sorted(set(m.group(0) for m in URL_RE.finditer(text or "")))
def _load_model():
global _tok, _mdl
if _tok is not None and _mdl is not None:
return _tok, _mdl
from transformers import AutoTokenizer, AutoModelForSequenceClassification
_tok = AutoTokenizer.from_pretrained(URL_MODEL_ID)
_mdl = AutoModelForSequenceClassification.from_pretrained(URL_MODEL_ID)
_mdl.eval()
return _tok, _mdl
def _softmax(logits: torch.Tensor):
return F.softmax(logits, dim=-1).tolist()
def _lbl_name(idx: int, id2label: dict):
if id2label and idx in id2label:
return id2label[idx]
return LABEL_MAP.get(idx, str(idx))
def _format_scores_md(scores_sorted):
lines = ["| Class | Prob (%) | Logit |", "|---|---:|---:|"]
for s in scores_sorted:
lines.append(f"| **{s['label']}** | {s['prob']*100:.2f} | {s['logit']:.3f} |")
return "\n".join(lines)
def _markdown_results_header(rows):
# rows: [ [url, model_label, model_pct, heur, fused, reason_txt], ... ]
lines = [
"| URL | Model | Model Prob (%) | Heuristic | Fused Risk | Reasons |",
"|---|---|---:|---:|---:|---|",
]
for u, lbl, pct, h, fused, reasons in rows:
lines.append(
f"| `{u}` | **{lbl}** | {pct:.2f} | {h:.2f} | {fused:.2f} | {reasons} |"
)
return "\n".join(lines)
def _forensic_block(url, token_ids, tokens, scores_sorted, cls_vec, elapsed_s, truncated):
toks_prev = ", ".join(tokens[:64]) + (" …" if len(tokens) > 64 else "")
ids_prev = ", ".join(map(str, token_ids[:64])) + (" …" if len(token_ids) > 64 else "")
cls_dim = len(cls_vec)
cls_prev = ", ".join(f"{v:.4f}" for v in cls_vec[:16]) + (" …" if cls_dim > 16 else "")
l2 = (sum(v*v for v in cls_vec)) ** 0.5
md = []
md.append(f"### 🔍 Forensics for `{url}`\n")
md.append(f"- tokens: **{len(tokens)}** • truncated: **{'yes' if truncated else 'no'}**")
md.append(f"- inference time: **{elapsed_s:.2f}s**\n")
md.append("**Top-k scores**")
md.append(_format_scores_md(scores_sorted))
md.append("\n**Token IDs (preview)**")
md.append("```txt\n" + ids_prev + "\n```")
md.append("**Tokens (preview)**")
md.append("```txt\n" + toks_prev + "\n```")
md.append("**[CLS] embedding (preview)**")
md.append(f"`dim={cls_dim}`, `L2={l2:.4f}`")
md.append("```txt\n" + cls_prev + "\n```")
return "\n".join(md)
# -------- Heuristics --------
def _safe_parse(url: str):
# add scheme if missing so urlparse sees netloc
if not re.match(r"^https?://", url, re.I):
url = "http://" + url
return urllib.parse.urlparse(url)
def heuristic_features(u: str):
feats = {}
try:
p = _safe_parse(u)
feats["scheme_https"] = 1 if p.scheme.lower() == "https" else 0
feats["host"] = p.hostname or ""
feats["path"] = p.path or "/"
feats["query"] = p.query or ""
ext = tldextract.extract(feats["host"]) # subdomain, domain, suffix
feats["registered_domain"] = f"{ext.domain}.{ext.suffix}" if ext.domain and ext.suffix else feats["host"]
feats["subdomain"] = ext.subdomain or ""
feats["tld"] = ext.suffix or ""
feats["labels"] = feats["host"].count(".") + (1 if feats["host"] else 0)
feats["has_at"] = "@" in u
feats["has_port"] = bool(p.netloc and ":" in p.netloc.split("@")[-1])
feats["has_punycode"] = "xn--" in feats["host"]
feats["len_url"] = len(u)
feats["hyphen_in_regdom"] = "-" in (ext.domain or "")
low_host = feats["host"].lower()
low_path = feats["path"].lower()
feats["kw_in_path"] = int(any(k in low_path for k in KEYWORDS))
feats["kw_in_host"] = int(any(k in low_host for k in KEYWORDS))
# keyword appears in subdomain but not in registered brand
feats["kw_in_subdomain_only"] = int(
feats["kw_in_host"] and (ext.domain and not any(k in ext.domain.lower() for k in KEYWORDS))
)
feats["suspicious_tld"] = int((feats["tld"].split(".")[-1] or "") in SUSPICIOUS_TLDS)
# crude “entropy-like” signal for long alnum query blobs
alnum = sum(c.isalnum() for c in feats["query"])
feats["query_ratio_alnum"] = (alnum / max(1, len(feats["query"]))) if feats["query"] else 0.0
feats["parse_error"] = False
except Exception:
feats = {"parse_error": True}
return feats
def heuristic_score(feats: dict) -> float:
"""0..1 suspicious score."""
if feats.get("parse_error"):
return 0.70 # unparsable => suspicious
score = 0.0
score += 0.25 * feats["kw_in_path"]
score += 0.20 * feats["kw_in_subdomain_only"]
score += 0.10 * feats["kw_in_host"]
score += 0.10 * feats["hyphen_in_regdom"]
score += 0.10 * (feats["labels"] >= 4)
score += 0.10 * feats["has_punycode"]
score += 0.10 * feats["suspicious_tld"]
score += 0.05 * feats["has_at"]
score += 0.05 * feats["has_port"]
score += 0.10 * (feats["len_url"] >= 100)
if feats["query"] and len(feats["query"]) >= 40 and feats["query_ratio_alnum"] > 0.9:
score += 0.10
return max(0.0, min(1.0, score))
def heuristic_reasons(feats: dict) -> str:
if feats.get("parse_error"):
return "parse error"
rs = []
if feats.get("kw_in_path"): rs.append("keyword in path")
if feats.get("kw_in_subdomain_only"): rs.append("keyword in subdomain")
if feats.get("kw_in_host") and not feats.get("kw_in_subdomain_only"): rs.append("keyword in host")
if feats.get("hyphen_in_regdom"): rs.append("hyphen in registered domain")
if feats.get("labels", 0) >= 4: rs.append("deep subdomain nesting")
if feats.get("has_punycode"): rs.append("punycode host")
if feats.get("suspicious_tld"): rs.append(f"suspicious TLD: {feats.get('tld')}")
if feats.get("has_at"): rs.append("@ in URL")
if feats.get("has_port"): rs.append("explicit port")
if feats.get("len_url", 0) >= 100: rs.append("very long URL")
if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9:
rs.append("long query blob")
return ", ".join(rs) if rs else "no heuristic triggers"
# -------- Core --------
def analyze(text: str, forensic: bool, show_json: bool):
"""
One output: Markdown with
- verdict
- table (model, heuristic, fused + reasons)
- optional forensic blocks (tokens, logits, [CLS])
- optional raw JSON (copy/paste)
"""
text = (text or "").strip()
if not text:
return "Paste an email body or a URL."
urls = [text] if (text.lower().startswith(("http://","https://","www.")) and " " not in text) else _extract_urls(text)
if not urls:
return "No URLs detected in the text."
tok, mdl = _load_model()
id2label_raw = getattr(mdl.config, "id2label", None) or {}
id2label = {}
for k, v in id2label_raw.items():
try:
id2label[int(k)] = v
except Exception:
if isinstance(k, str) and k.startswith("LABEL_"):
idx = int(k.split("_")[-1])
id2label[idx] = v
header_rows = []
forensic_blocks = []
export_data = {"model_id": URL_MODEL_ID, "items": []}
any_unsafe = False
for u in urls:
# --- Encode & forward for logits / CLS ---
max_len = min(512, getattr(mdl.config, "max_position_embeddings", 512) or 512)
enc = tok(u, truncation=True, max_length=max_len, return_tensors="pt", return_attention_mask=True)
token_ids = enc["input_ids"][0].tolist()
tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])
truncated = enc["input_ids"].shape[1] >= max_len and len(tokens) >= max_len
t0 = time.time()
with torch.no_grad():
out = mdl(**enc, output_hidden_states=True)
elapsed = time.time() - t0
logits = out.logits.squeeze(0) # (num_labels,)
probs = _softmax(logits) # list[float]
hidden_states = out.hidden_states
cls_vec = hidden_states[-1][0, 0, :].cpu().tolist()
per_class = [
{"label": _lbl_name(i, id2label), "prob": float(probs[i]), "logit": float(logits[i])}
for i in range(len(probs))
]
per_class_sorted = sorted(per_class, key=lambda x: x["prob"], reverse=True)
top = per_class_sorted[0]
# --- Heuristics & fusion ---
feats = heuristic_features(u)
h_score = heuristic_score(feats)
mdl_phish_like = sum(s["prob"] for s in per_class_sorted if s["label"].lower() in {"phishing","malware","defacement"})
fused = 0.65 * mdl_phish_like + 0.35 * h_score
reasons = heuristic_reasons(feats)
header_rows.append([u, top["label"], top["prob"] * 100.0, h_score, fused, reasons])
if fused >= 0.50:
any_unsafe = True
# collect full details for optional JSON dump
export_data["items"].append({
"url": u,
"token_ids": token_ids,
"tokens": tokens,
"truncated": truncated,
"logits": [float(x) for x in logits.cpu().tolist()],
"probs": [float(p) for p in probs],
"scores_sorted": per_class_sorted,
"cls_vector": cls_vec,
"cls_dim": len(cls_vec),
"elapsed_sec": elapsed,
"heuristic": feats,
"heuristic_score": h_score,
"fused_risk": fused,
})
if forensic:
forensic_blocks.append(
_forensic_block(
url=u,
token_ids=token_ids,
tokens=tokens,
scores_sorted=per_class_sorted,
cls_vec=cls_vec,
elapsed_s=elapsed,
truncated=truncated,
)
)
verdict = "🔴 **UNSAFE (links flagged)**" if any_unsafe else "🟢 **SAFE (no fused risk ≥ 0.50)**"
body = verdict + "\n\n" + _markdown_results_header(header_rows)
if forensic and forensic_blocks:
body += "\n\n---\n\n" + "\n\n---\n\n".join(forensic_blocks)
if show_json:
pretty = json.dumps(export_data, ensure_ascii=False, indent=2)
body += "\n\n---\n\n**Raw forensics JSON (copy & save):**\n"
body += "```json\n" + pretty + "\n```"
return body
# -------- UI --------
demo = gr.Interface(
fn=analyze,
inputs=[
gr.Textbox(lines=6, label="Email or URL", placeholder="Paste a URL or a full email…"),
gr.Checkbox(label="Forensic mode (tokens, logits, [CLS])", value=True),
gr.Checkbox(label="Show raw JSON at the end (copy/paste)", value=False),
],
outputs=gr.Markdown(label="Results"),
title="🛡️ PhishingMail — Model + Heuristics (HF Free CPU)",
description=(
"We extract links and classify each with a compact malicious-URL model, then fuse with transparent heuristics. "
"Table shows Model Prob, Heuristic Score, and Fused Risk with reasons. "
"Toggle Forensic mode for tokens/logits/[CLS]."
),
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)