Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -1,20 +1,32 @@
|
|
1 |
-
import os, re, time, json
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
|
|
5 |
|
6 |
-
#
|
7 |
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
8 |
|
|
|
9 |
URL_MODEL_ID = "CrabInHoney/urlbert-tiny-v4-malicious-url-classifier"
|
10 |
LABEL_MAP = {0: "benign", 1: "defacement", 2: "malware", 3: "phishing"}
|
11 |
-
|
12 |
-
URL_RE = re.compile(r"""(?xi)\bhttps?://[^\s<>"']+""")
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
|
|
15 |
_tok = None
|
16 |
_mdl = None
|
17 |
|
|
|
18 |
def _extract_urls(text: str):
|
19 |
return sorted(set(m.group(0) for m in URL_RE.finditer(text or "")))
|
20 |
|
@@ -42,42 +54,125 @@ def _format_scores_md(scores_sorted):
|
|
42 |
lines.append(f"| **{s['label']}** | {s['prob']*100:.2f} | {s['logit']:.3f} |")
|
43 |
return "\n".join(lines)
|
44 |
|
45 |
-
def
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
return "\n".join(lines)
|
50 |
|
51 |
def _forensic_block(url, token_ids, tokens, scores_sorted, cls_vec, elapsed_s, truncated):
|
52 |
-
|
53 |
-
|
54 |
cls_dim = len(cls_vec)
|
55 |
-
|
56 |
l2 = (sum(v*v for v in cls_vec)) ** 0.5
|
57 |
-
|
58 |
md = []
|
59 |
-
md.append(f"### 🔍 Forensics for `{url}
|
60 |
-
md.append("")
|
61 |
md.append(f"- tokens: **{len(tokens)}** • truncated: **{'yes' if truncated else 'no'}**")
|
62 |
-
md.append(f"- inference time: **{elapsed_s:.2f}s
|
63 |
-
md.append("")
|
64 |
md.append("**Top-k scores**")
|
65 |
md.append(_format_scores_md(scores_sorted))
|
66 |
-
md.append("")
|
67 |
-
md.append("
|
68 |
-
md.append("```txt\n" + ids_preview + "\n```")
|
69 |
md.append("**Tokens (preview)**")
|
70 |
-
md.append("```txt\n" +
|
71 |
md.append("**[CLS] embedding (preview)**")
|
72 |
md.append(f"`dim={cls_dim}`, `L2={l2:.4f}`")
|
73 |
-
md.append("```txt\n" +
|
74 |
return "\n".join(md)
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
def analyze(text: str, forensic: bool, show_json: bool):
|
77 |
"""
|
78 |
-
|
79 |
-
- verdict
|
80 |
-
-
|
|
|
81 |
- optional raw JSON (copy/paste)
|
82 |
"""
|
83 |
text = (text or "").strip()
|
@@ -99,21 +194,15 @@ def analyze(text: str, forensic: bool, show_json: bool):
|
|
99 |
idx = int(k.split("_")[-1])
|
100 |
id2label[idx] = v
|
101 |
|
102 |
-
|
103 |
-
unsafe = False
|
104 |
forensic_blocks = []
|
105 |
export_data = {"model_id": URL_MODEL_ID, "items": []}
|
|
|
106 |
|
107 |
for u in urls:
|
|
|
108 |
max_len = min(512, getattr(mdl.config, "max_position_embeddings", 512) or 512)
|
109 |
-
enc = tok(
|
110 |
-
u,
|
111 |
-
truncation=True,
|
112 |
-
max_length=max_len,
|
113 |
-
padding=False,
|
114 |
-
return_tensors="pt",
|
115 |
-
return_attention_mask=True,
|
116 |
-
)
|
117 |
token_ids = enc["input_ids"][0].tolist()
|
118 |
tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])
|
119 |
truncated = enc["input_ids"].shape[1] >= max_len and len(tokens) >= max_len
|
@@ -123,9 +212,9 @@ def analyze(text: str, forensic: bool, show_json: bool):
|
|
123 |
out = mdl(**enc, output_hidden_states=True)
|
124 |
elapsed = time.time() - t0
|
125 |
|
126 |
-
logits = out.logits.squeeze(0)
|
127 |
-
probs
|
128 |
-
hidden_states = out.hidden_states
|
129 |
cls_vec = hidden_states[-1][0, 0, :].cpu().tolist()
|
130 |
|
131 |
per_class = [
|
@@ -134,10 +223,19 @@ def analyze(text: str, forensic: bool, show_json: bool):
|
|
134 |
]
|
135 |
per_class_sorted = sorted(per_class, key=lambda x: x["prob"], reverse=True)
|
136 |
top = per_class_sorted[0]
|
137 |
-
rows.append([u, top["label"], top["prob"] * 100.0])
|
138 |
-
if top["label"].lower() in {"phishing", "malware", "defacement"}:
|
139 |
-
unsafe = True
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
export_data["items"].append({
|
142 |
"url": u,
|
143 |
"token_ids": token_ids,
|
@@ -145,10 +243,13 @@ def analyze(text: str, forensic: bool, show_json: bool):
|
|
145 |
"truncated": truncated,
|
146 |
"logits": [float(x) for x in logits.cpu().tolist()],
|
147 |
"probs": [float(p) for p in probs],
|
148 |
-
"scores_sorted": per_class_sorted,
|
149 |
"cls_vector": cls_vec,
|
150 |
"cls_dim": len(cls_vec),
|
151 |
"elapsed_sec": elapsed,
|
|
|
|
|
|
|
152 |
})
|
153 |
|
154 |
if forensic:
|
@@ -164,20 +265,20 @@ def analyze(text: str, forensic: bool, show_json: bool):
|
|
164 |
)
|
165 |
)
|
166 |
|
167 |
-
verdict = "🔴 **UNSAFE (links flagged)**" if
|
168 |
-
body = verdict + "\n\n" +
|
169 |
|
170 |
if forensic and forensic_blocks:
|
171 |
body += "\n\n---\n\n" + "\n\n---\n\n".join(forensic_blocks)
|
172 |
|
173 |
if show_json:
|
174 |
-
# raw JSON for copy-paste (no File component needed)
|
175 |
pretty = json.dumps(export_data, ensure_ascii=False, indent=2)
|
176 |
body += "\n\n---\n\n**Raw forensics JSON (copy & save):**\n"
|
177 |
body += "```json\n" + pretty + "\n```"
|
178 |
|
179 |
return body
|
180 |
|
|
|
181 |
demo = gr.Interface(
|
182 |
fn=analyze,
|
183 |
inputs=[
|
@@ -186,10 +287,13 @@ demo = gr.Interface(
|
|
186 |
gr.Checkbox(label="Show raw JSON at the end (copy/paste)", value=False),
|
187 |
],
|
188 |
outputs=gr.Markdown(label="Results"),
|
189 |
-
title="🛡️ PhishingMail —
|
190 |
-
description=
|
|
|
|
|
|
|
|
|
191 |
)
|
192 |
|
193 |
if __name__ == "__main__":
|
194 |
-
# Safe defaults for HF Spaces (no share=True needed)
|
195 |
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
|
|
1 |
+
import os, re, time, json, urllib.parse
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
+
import tldextract # for robust registered-domain parsing
|
6 |
|
7 |
+
# Quiet + CPU friendly
|
8 |
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
9 |
|
10 |
+
# -------- Models / Regex --------
|
11 |
URL_MODEL_ID = "CrabInHoney/urlbert-tiny-v4-malicious-url-classifier"
|
12 |
LABEL_MAP = {0: "benign", 1: "defacement", 2: "malware", 3: "phishing"}
|
13 |
+
URL_RE = re.compile(r"""(?xi)\b(?:https?://|www\.)[a-z0-9\-._~%]+(?:/[^\s<>"']*)?""")
|
|
|
14 |
|
15 |
+
# Heuristic config
|
16 |
+
KEYWORDS = {
|
17 |
+
"login","verify","account","secure","update","bank","wallet",
|
18 |
+
"password","invoice","pay","reset","support","unlock","confirm"
|
19 |
+
}
|
20 |
+
SUSPICIOUS_TLDS = {
|
21 |
+
"zip","mov","lol","xyz","top","country","link","click","cam",
|
22 |
+
"help","gq","cf","tk","work","rest","monster","quest","live"
|
23 |
+
}
|
24 |
|
25 |
+
# Lazy globals for tokenizer & model
|
26 |
_tok = None
|
27 |
_mdl = None
|
28 |
|
29 |
+
# -------- Utilities --------
|
30 |
def _extract_urls(text: str):
|
31 |
return sorted(set(m.group(0) for m in URL_RE.finditer(text or "")))
|
32 |
|
|
|
54 |
lines.append(f"| **{s['label']}** | {s['prob']*100:.2f} | {s['logit']:.3f} |")
|
55 |
return "\n".join(lines)
|
56 |
|
57 |
+
def _markdown_results_header(rows):
|
58 |
+
# rows: [ [url, model_label, model_pct, heur, fused, reason_txt], ... ]
|
59 |
+
lines = [
|
60 |
+
"| URL | Model | Model Prob (%) | Heuristic | Fused Risk | Reasons |",
|
61 |
+
"|---|---|---:|---:|---:|---|",
|
62 |
+
]
|
63 |
+
for u, lbl, pct, h, fused, reasons in rows:
|
64 |
+
lines.append(
|
65 |
+
f"| `{u}` | **{lbl}** | {pct:.2f} | {h:.2f} | {fused:.2f} | {reasons} |"
|
66 |
+
)
|
67 |
return "\n".join(lines)
|
68 |
|
69 |
def _forensic_block(url, token_ids, tokens, scores_sorted, cls_vec, elapsed_s, truncated):
|
70 |
+
toks_prev = ", ".join(tokens[:64]) + (" …" if len(tokens) > 64 else "")
|
71 |
+
ids_prev = ", ".join(map(str, token_ids[:64])) + (" …" if len(token_ids) > 64 else "")
|
72 |
cls_dim = len(cls_vec)
|
73 |
+
cls_prev = ", ".join(f"{v:.4f}" for v in cls_vec[:16]) + (" …" if cls_dim > 16 else "")
|
74 |
l2 = (sum(v*v for v in cls_vec)) ** 0.5
|
|
|
75 |
md = []
|
76 |
+
md.append(f"### 🔍 Forensics for `{url}`\n")
|
|
|
77 |
md.append(f"- tokens: **{len(tokens)}** • truncated: **{'yes' if truncated else 'no'}**")
|
78 |
+
md.append(f"- inference time: **{elapsed_s:.2f}s**\n")
|
|
|
79 |
md.append("**Top-k scores**")
|
80 |
md.append(_format_scores_md(scores_sorted))
|
81 |
+
md.append("\n**Token IDs (preview)**")
|
82 |
+
md.append("```txt\n" + ids_prev + "\n```")
|
|
|
83 |
md.append("**Tokens (preview)**")
|
84 |
+
md.append("```txt\n" + toks_prev + "\n```")
|
85 |
md.append("**[CLS] embedding (preview)**")
|
86 |
md.append(f"`dim={cls_dim}`, `L2={l2:.4f}`")
|
87 |
+
md.append("```txt\n" + cls_prev + "\n```")
|
88 |
return "\n".join(md)
|
89 |
|
90 |
+
# -------- Heuristics --------
|
91 |
+
def _safe_parse(url: str):
|
92 |
+
# add scheme if missing so urlparse sees netloc
|
93 |
+
if not re.match(r"^https?://", url, re.I):
|
94 |
+
url = "http://" + url
|
95 |
+
return urllib.parse.urlparse(url)
|
96 |
+
|
97 |
+
def heuristic_features(u: str):
|
98 |
+
feats = {}
|
99 |
+
try:
|
100 |
+
p = _safe_parse(u)
|
101 |
+
feats["scheme_https"] = 1 if p.scheme.lower() == "https" else 0
|
102 |
+
feats["host"] = p.hostname or ""
|
103 |
+
feats["path"] = p.path or "/"
|
104 |
+
feats["query"] = p.query or ""
|
105 |
+
ext = tldextract.extract(feats["host"]) # subdomain, domain, suffix
|
106 |
+
feats["registered_domain"] = f"{ext.domain}.{ext.suffix}" if ext.domain and ext.suffix else feats["host"]
|
107 |
+
feats["subdomain"] = ext.subdomain or ""
|
108 |
+
feats["tld"] = ext.suffix or ""
|
109 |
+
feats["labels"] = feats["host"].count(".") + (1 if feats["host"] else 0)
|
110 |
+
feats["has_at"] = "@" in u
|
111 |
+
feats["has_port"] = bool(p.netloc and ":" in p.netloc.split("@")[-1])
|
112 |
+
feats["has_punycode"] = "xn--" in feats["host"]
|
113 |
+
feats["len_url"] = len(u)
|
114 |
+
feats["hyphen_in_regdom"] = "-" in (ext.domain or "")
|
115 |
+
low_host = feats["host"].lower()
|
116 |
+
low_path = feats["path"].lower()
|
117 |
+
feats["kw_in_path"] = int(any(k in low_path for k in KEYWORDS))
|
118 |
+
feats["kw_in_host"] = int(any(k in low_host for k in KEYWORDS))
|
119 |
+
# keyword appears in subdomain but not in registered brand
|
120 |
+
feats["kw_in_subdomain_only"] = int(
|
121 |
+
feats["kw_in_host"] and (ext.domain and not any(k in ext.domain.lower() for k in KEYWORDS))
|
122 |
+
)
|
123 |
+
feats["suspicious_tld"] = int((feats["tld"].split(".")[-1] or "") in SUSPICIOUS_TLDS)
|
124 |
+
# crude “entropy-like” signal for long alnum query blobs
|
125 |
+
alnum = sum(c.isalnum() for c in feats["query"])
|
126 |
+
feats["query_ratio_alnum"] = (alnum / max(1, len(feats["query"]))) if feats["query"] else 0.0
|
127 |
+
feats["parse_error"] = False
|
128 |
+
except Exception:
|
129 |
+
feats = {"parse_error": True}
|
130 |
+
return feats
|
131 |
+
|
132 |
+
def heuristic_score(feats: dict) -> float:
|
133 |
+
"""0..1 suspicious score."""
|
134 |
+
if feats.get("parse_error"):
|
135 |
+
return 0.70 # unparsable => suspicious
|
136 |
+
score = 0.0
|
137 |
+
score += 0.25 * feats["kw_in_path"]
|
138 |
+
score += 0.20 * feats["kw_in_subdomain_only"]
|
139 |
+
score += 0.10 * feats["kw_in_host"]
|
140 |
+
score += 0.10 * feats["hyphen_in_regdom"]
|
141 |
+
score += 0.10 * (feats["labels"] >= 4)
|
142 |
+
score += 0.10 * feats["has_punycode"]
|
143 |
+
score += 0.10 * feats["suspicious_tld"]
|
144 |
+
score += 0.05 * feats["has_at"]
|
145 |
+
score += 0.05 * feats["has_port"]
|
146 |
+
score += 0.10 * (feats["len_url"] >= 100)
|
147 |
+
if feats["query"] and len(feats["query"]) >= 40 and feats["query_ratio_alnum"] > 0.9:
|
148 |
+
score += 0.10
|
149 |
+
return max(0.0, min(1.0, score))
|
150 |
+
|
151 |
+
def heuristic_reasons(feats: dict) -> str:
|
152 |
+
if feats.get("parse_error"):
|
153 |
+
return "parse error"
|
154 |
+
rs = []
|
155 |
+
if feats.get("kw_in_path"): rs.append("keyword in path")
|
156 |
+
if feats.get("kw_in_subdomain_only"): rs.append("keyword in subdomain")
|
157 |
+
if feats.get("kw_in_host") and not feats.get("kw_in_subdomain_only"): rs.append("keyword in host")
|
158 |
+
if feats.get("hyphen_in_regdom"): rs.append("hyphen in registered domain")
|
159 |
+
if feats.get("labels", 0) >= 4: rs.append("deep subdomain nesting")
|
160 |
+
if feats.get("has_punycode"): rs.append("punycode host")
|
161 |
+
if feats.get("suspicious_tld"): rs.append(f"suspicious TLD: {feats.get('tld')}")
|
162 |
+
if feats.get("has_at"): rs.append("@ in URL")
|
163 |
+
if feats.get("has_port"): rs.append("explicit port")
|
164 |
+
if feats.get("len_url", 0) >= 100: rs.append("very long URL")
|
165 |
+
if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9:
|
166 |
+
rs.append("long query blob")
|
167 |
+
return ", ".join(rs) if rs else "no heuristic triggers"
|
168 |
+
|
169 |
+
# -------- Core --------
|
170 |
def analyze(text: str, forensic: bool, show_json: bool):
|
171 |
"""
|
172 |
+
One output: Markdown with
|
173 |
+
- verdict
|
174 |
+
- table (model, heuristic, fused + reasons)
|
175 |
+
- optional forensic blocks (tokens, logits, [CLS])
|
176 |
- optional raw JSON (copy/paste)
|
177 |
"""
|
178 |
text = (text or "").strip()
|
|
|
194 |
idx = int(k.split("_")[-1])
|
195 |
id2label[idx] = v
|
196 |
|
197 |
+
header_rows = []
|
|
|
198 |
forensic_blocks = []
|
199 |
export_data = {"model_id": URL_MODEL_ID, "items": []}
|
200 |
+
any_unsafe = False
|
201 |
|
202 |
for u in urls:
|
203 |
+
# --- Encode & forward for logits / CLS ---
|
204 |
max_len = min(512, getattr(mdl.config, "max_position_embeddings", 512) or 512)
|
205 |
+
enc = tok(u, truncation=True, max_length=max_len, return_tensors="pt", return_attention_mask=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
token_ids = enc["input_ids"][0].tolist()
|
207 |
tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])
|
208 |
truncated = enc["input_ids"].shape[1] >= max_len and len(tokens) >= max_len
|
|
|
212 |
out = mdl(**enc, output_hidden_states=True)
|
213 |
elapsed = time.time() - t0
|
214 |
|
215 |
+
logits = out.logits.squeeze(0) # (num_labels,)
|
216 |
+
probs = _softmax(logits) # list[float]
|
217 |
+
hidden_states = out.hidden_states
|
218 |
cls_vec = hidden_states[-1][0, 0, :].cpu().tolist()
|
219 |
|
220 |
per_class = [
|
|
|
223 |
]
|
224 |
per_class_sorted = sorted(per_class, key=lambda x: x["prob"], reverse=True)
|
225 |
top = per_class_sorted[0]
|
|
|
|
|
|
|
226 |
|
227 |
+
# --- Heuristics & fusion ---
|
228 |
+
feats = heuristic_features(u)
|
229 |
+
h_score = heuristic_score(feats)
|
230 |
+
mdl_phish_like = sum(s["prob"] for s in per_class_sorted if s["label"].lower() in {"phishing","malware","defacement"})
|
231 |
+
fused = 0.65 * mdl_phish_like + 0.35 * h_score
|
232 |
+
reasons = heuristic_reasons(feats)
|
233 |
+
|
234 |
+
header_rows.append([u, top["label"], top["prob"] * 100.0, h_score, fused, reasons])
|
235 |
+
if fused >= 0.50:
|
236 |
+
any_unsafe = True
|
237 |
+
|
238 |
+
# collect full details for optional JSON dump
|
239 |
export_data["items"].append({
|
240 |
"url": u,
|
241 |
"token_ids": token_ids,
|
|
|
243 |
"truncated": truncated,
|
244 |
"logits": [float(x) for x in logits.cpu().tolist()],
|
245 |
"probs": [float(p) for p in probs],
|
246 |
+
"scores_sorted": per_class_sorted,
|
247 |
"cls_vector": cls_vec,
|
248 |
"cls_dim": len(cls_vec),
|
249 |
"elapsed_sec": elapsed,
|
250 |
+
"heuristic": feats,
|
251 |
+
"heuristic_score": h_score,
|
252 |
+
"fused_risk": fused,
|
253 |
})
|
254 |
|
255 |
if forensic:
|
|
|
265 |
)
|
266 |
)
|
267 |
|
268 |
+
verdict = "🔴 **UNSAFE (links flagged)**" if any_unsafe else "🟢 **SAFE (no fused risk ≥ 0.50)**"
|
269 |
+
body = verdict + "\n\n" + _markdown_results_header(header_rows)
|
270 |
|
271 |
if forensic and forensic_blocks:
|
272 |
body += "\n\n---\n\n" + "\n\n---\n\n".join(forensic_blocks)
|
273 |
|
274 |
if show_json:
|
|
|
275 |
pretty = json.dumps(export_data, ensure_ascii=False, indent=2)
|
276 |
body += "\n\n---\n\n**Raw forensics JSON (copy & save):**\n"
|
277 |
body += "```json\n" + pretty + "\n```"
|
278 |
|
279 |
return body
|
280 |
|
281 |
+
# -------- UI --------
|
282 |
demo = gr.Interface(
|
283 |
fn=analyze,
|
284 |
inputs=[
|
|
|
287 |
gr.Checkbox(label="Show raw JSON at the end (copy/paste)", value=False),
|
288 |
],
|
289 |
outputs=gr.Markdown(label="Results"),
|
290 |
+
title="🛡️ PhishingMail — Model + Heuristics (HF Free CPU)",
|
291 |
+
description=(
|
292 |
+
"We extract links and classify each with a compact malicious-URL model, then fuse with transparent heuristics. "
|
293 |
+
"Table shows Model Prob, Heuristic Score, and Fused Risk with reasons. "
|
294 |
+
"Toggle Forensic mode for tokens/logits/[CLS]."
|
295 |
+
),
|
296 |
)
|
297 |
|
298 |
if __name__ == "__main__":
|
|
|
299 |
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|