Spaces:
Sleeping
Sleeping
File size: 12,211 Bytes
85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc d05f393 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc f9467b7 755ffe2 85540cc f9467b7 755ffe2 f9467b7 755ffe2 f9467b7 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc 755ffe2 85540cc f9467b7 755ffe2 f9467b7 85540cc f9467b7 85540cc f9467b7 755ffe2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 |
import os, re, time, json, urllib.parse
import gradio as gr
import torch
import torch.nn.functional as F
import tldextract # for robust registered-domain parsing
# Quiet + CPU friendly
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
# -------- Models / Regex --------
URL_MODEL_ID = "CrabInHoney/urlbert-tiny-v4-malicious-url-classifier"
LABEL_MAP = {0: "benign", 1: "defacement", 2: "malware", 3: "phishing"}
URL_RE = re.compile(r"""(?xi)\b(?:https?://|www\.)[a-z0-9\-._~%]+(?:/[^\s<>"']*)?""")
# Heuristic config
KEYWORDS = {
"login","verify","account","secure","update","bank","wallet",
"password","invoice","pay","reset","support","unlock","confirm"
}
SUSPICIOUS_TLDS = {
"zip","mov","lol","xyz","top","country","link","click","cam",
"help","gq","cf","tk","work","rest","monster","quest","live"
}
# Lazy globals for tokenizer & model
_tok = None
_mdl = None
# -------- Utilities --------
def _extract_urls(text: str):
return sorted(set(m.group(0) for m in URL_RE.finditer(text or "")))
def _load_model():
global _tok, _mdl
if _tok is not None and _mdl is not None:
return _tok, _mdl
from transformers import AutoTokenizer, AutoModelForSequenceClassification
_tok = AutoTokenizer.from_pretrained(URL_MODEL_ID)
_mdl = AutoModelForSequenceClassification.from_pretrained(URL_MODEL_ID)
_mdl.eval()
return _tok, _mdl
def _softmax(logits: torch.Tensor):
return F.softmax(logits, dim=-1).tolist()
def _lbl_name(idx: int, id2label: dict):
if id2label and idx in id2label:
return id2label[idx]
return LABEL_MAP.get(idx, str(idx))
def _format_scores_md(scores_sorted):
lines = ["| Class | Prob (%) | Logit |", "|---|---:|---:|"]
for s in scores_sorted:
lines.append(f"| **{s['label']}** | {s['prob']*100:.2f} | {s['logit']:.3f} |")
return "\n".join(lines)
def _markdown_results_header(rows):
# rows: [ [url, model_label, model_pct, heur, fused, reason_txt], ... ]
lines = [
"| URL | Model | Model Prob (%) | Heuristic | Fused Risk | Reasons |",
"|---|---|---:|---:|---:|---|",
]
for u, lbl, pct, h, fused, reasons in rows:
lines.append(
f"| `{u}` | **{lbl}** | {pct:.2f} | {h:.2f} | {fused:.2f} | {reasons} |"
)
return "\n".join(lines)
def _forensic_block(url, token_ids, tokens, scores_sorted, cls_vec, elapsed_s, truncated):
toks_prev = ", ".join(tokens[:64]) + (" …" if len(tokens) > 64 else "")
ids_prev = ", ".join(map(str, token_ids[:64])) + (" …" if len(token_ids) > 64 else "")
cls_dim = len(cls_vec)
cls_prev = ", ".join(f"{v:.4f}" for v in cls_vec[:16]) + (" …" if cls_dim > 16 else "")
l2 = (sum(v*v for v in cls_vec)) ** 0.5
md = []
md.append(f"### 🔍 Forensics for `{url}`\n")
md.append(f"- tokens: **{len(tokens)}** • truncated: **{'yes' if truncated else 'no'}**")
md.append(f"- inference time: **{elapsed_s:.2f}s**\n")
md.append("**Top-k scores**")
md.append(_format_scores_md(scores_sorted))
md.append("\n**Token IDs (preview)**")
md.append("```txt\n" + ids_prev + "\n```")
md.append("**Tokens (preview)**")
md.append("```txt\n" + toks_prev + "\n```")
md.append("**[CLS] embedding (preview)**")
md.append(f"`dim={cls_dim}`, `L2={l2:.4f}`")
md.append("```txt\n" + cls_prev + "\n```")
return "\n".join(md)
# -------- Heuristics --------
def _safe_parse(url: str):
# add scheme if missing so urlparse sees netloc
if not re.match(r"^https?://", url, re.I):
url = "http://" + url
return urllib.parse.urlparse(url)
def heuristic_features(u: str):
feats = {}
try:
p = _safe_parse(u)
feats["scheme_https"] = 1 if p.scheme.lower() == "https" else 0
feats["host"] = p.hostname or ""
feats["path"] = p.path or "/"
feats["query"] = p.query or ""
ext = tldextract.extract(feats["host"]) # subdomain, domain, suffix
feats["registered_domain"] = f"{ext.domain}.{ext.suffix}" if ext.domain and ext.suffix else feats["host"]
feats["subdomain"] = ext.subdomain or ""
feats["tld"] = ext.suffix or ""
feats["labels"] = feats["host"].count(".") + (1 if feats["host"] else 0)
feats["has_at"] = "@" in u
feats["has_port"] = bool(p.netloc and ":" in p.netloc.split("@")[-1])
feats["has_punycode"] = "xn--" in feats["host"]
feats["len_url"] = len(u)
feats["hyphen_in_regdom"] = "-" in (ext.domain or "")
low_host = feats["host"].lower()
low_path = feats["path"].lower()
feats["kw_in_path"] = int(any(k in low_path for k in KEYWORDS))
feats["kw_in_host"] = int(any(k in low_host for k in KEYWORDS))
# keyword appears in subdomain but not in registered brand
feats["kw_in_subdomain_only"] = int(
feats["kw_in_host"] and (ext.domain and not any(k in ext.domain.lower() for k in KEYWORDS))
)
feats["suspicious_tld"] = int((feats["tld"].split(".")[-1] or "") in SUSPICIOUS_TLDS)
# crude “entropy-like” signal for long alnum query blobs
alnum = sum(c.isalnum() for c in feats["query"])
feats["query_ratio_alnum"] = (alnum / max(1, len(feats["query"]))) if feats["query"] else 0.0
feats["parse_error"] = False
except Exception:
feats = {"parse_error": True}
return feats
def heuristic_score(feats: dict) -> float:
"""0..1 suspicious score."""
if feats.get("parse_error"):
return 0.70 # unparsable => suspicious
score = 0.0
score += 0.25 * feats["kw_in_path"]
score += 0.20 * feats["kw_in_subdomain_only"]
score += 0.10 * feats["kw_in_host"]
score += 0.10 * feats["hyphen_in_regdom"]
score += 0.10 * (feats["labels"] >= 4)
score += 0.10 * feats["has_punycode"]
score += 0.10 * feats["suspicious_tld"]
score += 0.05 * feats["has_at"]
score += 0.05 * feats["has_port"]
score += 0.10 * (feats["len_url"] >= 100)
if feats["query"] and len(feats["query"]) >= 40 and feats["query_ratio_alnum"] > 0.9:
score += 0.10
return max(0.0, min(1.0, score))
def heuristic_reasons(feats: dict) -> str:
if feats.get("parse_error"):
return "parse error"
rs = []
if feats.get("kw_in_path"): rs.append("keyword in path")
if feats.get("kw_in_subdomain_only"): rs.append("keyword in subdomain")
if feats.get("kw_in_host") and not feats.get("kw_in_subdomain_only"): rs.append("keyword in host")
if feats.get("hyphen_in_regdom"): rs.append("hyphen in registered domain")
if feats.get("labels", 0) >= 4: rs.append("deep subdomain nesting")
if feats.get("has_punycode"): rs.append("punycode host")
if feats.get("suspicious_tld"): rs.append(f"suspicious TLD: {feats.get('tld')}")
if feats.get("has_at"): rs.append("@ in URL")
if feats.get("has_port"): rs.append("explicit port")
if feats.get("len_url", 0) >= 100: rs.append("very long URL")
if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9:
rs.append("long query blob")
return ", ".join(rs) if rs else "no heuristic triggers"
# -------- Core --------
def analyze(text: str, forensic: bool, show_json: bool):
"""
One output: Markdown with
- verdict
- table (model, heuristic, fused + reasons)
- optional forensic blocks (tokens, logits, [CLS])
- optional raw JSON (copy/paste)
"""
text = (text or "").strip()
if not text:
return "Paste an email body or a URL."
urls = [text] if (text.lower().startswith(("http://","https://","www.")) and " " not in text) else _extract_urls(text)
if not urls:
return "No URLs detected in the text."
tok, mdl = _load_model()
id2label_raw = getattr(mdl.config, "id2label", None) or {}
id2label = {}
for k, v in id2label_raw.items():
try:
id2label[int(k)] = v
except Exception:
if isinstance(k, str) and k.startswith("LABEL_"):
idx = int(k.split("_")[-1])
id2label[idx] = v
header_rows = []
forensic_blocks = []
export_data = {"model_id": URL_MODEL_ID, "items": []}
any_unsafe = False
for u in urls:
# --- Encode & forward for logits / CLS ---
max_len = min(512, getattr(mdl.config, "max_position_embeddings", 512) or 512)
enc = tok(u, truncation=True, max_length=max_len, return_tensors="pt", return_attention_mask=True)
token_ids = enc["input_ids"][0].tolist()
tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])
truncated = enc["input_ids"].shape[1] >= max_len and len(tokens) >= max_len
t0 = time.time()
with torch.no_grad():
out = mdl(**enc, output_hidden_states=True)
elapsed = time.time() - t0
logits = out.logits.squeeze(0) # (num_labels,)
probs = _softmax(logits) # list[float]
hidden_states = out.hidden_states
cls_vec = hidden_states[-1][0, 0, :].cpu().tolist()
per_class = [
{"label": _lbl_name(i, id2label), "prob": float(probs[i]), "logit": float(logits[i])}
for i in range(len(probs))
]
per_class_sorted = sorted(per_class, key=lambda x: x["prob"], reverse=True)
top = per_class_sorted[0]
# --- Heuristics & fusion ---
feats = heuristic_features(u)
h_score = heuristic_score(feats)
mdl_phish_like = sum(s["prob"] for s in per_class_sorted if s["label"].lower() in {"phishing","malware","defacement"})
fused = 0.65 * mdl_phish_like + 0.35 * h_score
reasons = heuristic_reasons(feats)
header_rows.append([u, top["label"], top["prob"] * 100.0, h_score, fused, reasons])
if fused >= 0.50:
any_unsafe = True
# collect full details for optional JSON dump
export_data["items"].append({
"url": u,
"token_ids": token_ids,
"tokens": tokens,
"truncated": truncated,
"logits": [float(x) for x in logits.cpu().tolist()],
"probs": [float(p) for p in probs],
"scores_sorted": per_class_sorted,
"cls_vector": cls_vec,
"cls_dim": len(cls_vec),
"elapsed_sec": elapsed,
"heuristic": feats,
"heuristic_score": h_score,
"fused_risk": fused,
})
if forensic:
forensic_blocks.append(
_forensic_block(
url=u,
token_ids=token_ids,
tokens=tokens,
scores_sorted=per_class_sorted,
cls_vec=cls_vec,
elapsed_s=elapsed,
truncated=truncated,
)
)
verdict = "🔴 **UNSAFE (links flagged)**" if any_unsafe else "🟢 **SAFE (no fused risk ≥ 0.50)**"
body = verdict + "\n\n" + _markdown_results_header(header_rows)
if forensic and forensic_blocks:
body += "\n\n---\n\n" + "\n\n---\n\n".join(forensic_blocks)
if show_json:
pretty = json.dumps(export_data, ensure_ascii=False, indent=2)
body += "\n\n---\n\n**Raw forensics JSON (copy & save):**\n"
body += "```json\n" + pretty + "\n```"
return body
# -------- UI --------
demo = gr.Interface(
fn=analyze,
inputs=[
gr.Textbox(lines=6, label="Email or URL", placeholder="Paste a URL or a full email…"),
gr.Checkbox(label="Forensic mode (tokens, logits, [CLS])", value=True),
gr.Checkbox(label="Show raw JSON at the end (copy/paste)", value=False),
],
outputs=gr.Markdown(label="Results"),
title="🛡️ PhishingMail — Model + Heuristics (HF Free CPU)",
description=(
"We extract links and classify each with a compact malicious-URL model, then fuse with transparent heuristics. "
"Table shows Model Prob, Heuristic Score, and Fused Risk with reasons. "
"Toggle Forensic mode for tokens/logits/[CLS]."
),
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|