ai-assist-sh commited on
Commit
9256d25
·
verified ·
1 Parent(s): 546fc56

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +319 -299
  2. requirements.txt +8 -6
main.py CHANGED
@@ -1,299 +1,319 @@
1
- import os, re, time, json, urllib.parse
2
- import gradio as gr
3
- import torch
4
- import torch.nn.functional as F
5
- import tldextract # for robust registered-domain parsing
6
-
7
- # Quiet + CPU friendly
8
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
9
-
10
- # -------- Models / Regex --------
11
- URL_MODEL_ID = "CrabInHoney/urlbert-tiny-v4-malicious-url-classifier"
12
- LABEL_MAP = {0: "benign", 1: "defacement", 2: "malware", 3: "phishing"}
13
- URL_RE = re.compile(r"""(?xi)\b(?:https?://|www\.)[a-z0-9\-._~%]+(?:/[^\s<>"']*)?""")
14
-
15
- # Heuristic config
16
- KEYWORDS = {
17
- "login","verify","account","secure","update","bank","wallet",
18
- "password","invoice","pay","reset","support","unlock","confirm"
19
- }
20
- SUSPICIOUS_TLDS = {
21
- "zip","mov","lol","xyz","top","country","link","click","cam",
22
- "help","gq","cf","tk","work","rest","monster","quest","live"
23
- }
24
-
25
- # Lazy globals for tokenizer & model
26
- _tok = None
27
- _mdl = None
28
-
29
- # -------- Utilities --------
30
- def _extract_urls(text: str):
31
- return sorted(set(m.group(0) for m in URL_RE.finditer(text or "")))
32
-
33
- def _load_model():
34
- global _tok, _mdl
35
- if _tok is not None and _mdl is not None:
36
- return _tok, _mdl
37
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
38
- _tok = AutoTokenizer.from_pretrained(URL_MODEL_ID)
39
- _mdl = AutoModelForSequenceClassification.from_pretrained(URL_MODEL_ID)
40
- _mdl.eval()
41
- return _tok, _mdl
42
-
43
- def _softmax(logits: torch.Tensor):
44
- return F.softmax(logits, dim=-1).tolist()
45
-
46
- def _lbl_name(idx: int, id2label: dict):
47
- if id2label and idx in id2label:
48
- return id2label[idx]
49
- return LABEL_MAP.get(idx, str(idx))
50
-
51
- def _format_scores_md(scores_sorted):
52
- lines = ["| Class | Prob (%) | Logit |", "|---|---:|---:|"]
53
- for s in scores_sorted:
54
- lines.append(f"| **{s['label']}** | {s['prob']*100:.2f} | {s['logit']:.3f} |")
55
- return "\n".join(lines)
56
-
57
- def _markdown_results_header(rows):
58
- # rows: [ [url, model_label, model_pct, heur, fused, reason_txt], ... ]
59
- lines = [
60
- "| URL | Model | Model Prob (%) | Heuristic | Fused Risk | Reasons |",
61
- "|---|---|---:|---:|---:|---|",
62
- ]
63
- for u, lbl, pct, h, fused, reasons in rows:
64
- lines.append(
65
- f"| `{u}` | **{lbl}** | {pct:.2f} | {h:.2f} | {fused:.2f} | {reasons} |"
66
- )
67
- return "\n".join(lines)
68
-
69
- def _forensic_block(url, token_ids, tokens, scores_sorted, cls_vec, elapsed_s, truncated):
70
- toks_prev = ", ".join(tokens[:64]) + (" …" if len(tokens) > 64 else "")
71
- ids_prev = ", ".join(map(str, token_ids[:64])) + (" …" if len(token_ids) > 64 else "")
72
- cls_dim = len(cls_vec)
73
- cls_prev = ", ".join(f"{v:.4f}" for v in cls_vec[:16]) + (" …" if cls_dim > 16 else "")
74
- l2 = (sum(v*v for v in cls_vec)) ** 0.5
75
- md = []
76
- md.append(f"### 🔍 Forensics for `{url}`\n")
77
- md.append(f"- tokens: **{len(tokens)}** • truncated: **{'yes' if truncated else 'no'}**")
78
- md.append(f"- inference time: **{elapsed_s:.2f}s**\n")
79
- md.append("**Top-k scores**")
80
- md.append(_format_scores_md(scores_sorted))
81
- md.append("\n**Token IDs (preview)**")
82
- md.append("```txt\n" + ids_prev + "\n```")
83
- md.append("**Tokens (preview)**")
84
- md.append("```txt\n" + toks_prev + "\n```")
85
- md.append("**[CLS] embedding (preview)**")
86
- md.append(f"`dim={cls_dim}`, `L2={l2:.4f}`")
87
- md.append("```txt\n" + cls_prev + "\n```")
88
- return "\n".join(md)
89
-
90
- # -------- Heuristics --------
91
- def _safe_parse(url: str):
92
- # add scheme if missing so urlparse sees netloc
93
- if not re.match(r"^https?://", url, re.I):
94
- url = "http://" + url
95
- return urllib.parse.urlparse(url)
96
-
97
- def heuristic_features(u: str):
98
- feats = {}
99
- try:
100
- p = _safe_parse(u)
101
- feats["scheme_https"] = 1 if p.scheme.lower() == "https" else 0
102
- feats["host"] = p.hostname or ""
103
- feats["path"] = p.path or "/"
104
- feats["query"] = p.query or ""
105
- ext = tldextract.extract(feats["host"]) # subdomain, domain, suffix
106
- feats["registered_domain"] = f"{ext.domain}.{ext.suffix}" if ext.domain and ext.suffix else feats["host"]
107
- feats["subdomain"] = ext.subdomain or ""
108
- feats["tld"] = ext.suffix or ""
109
- feats["labels"] = feats["host"].count(".") + (1 if feats["host"] else 0)
110
- feats["has_at"] = "@" in u
111
- feats["has_port"] = bool(p.netloc and ":" in p.netloc.split("@")[-1])
112
- feats["has_punycode"] = "xn--" in feats["host"]
113
- feats["len_url"] = len(u)
114
- feats["hyphen_in_regdom"] = "-" in (ext.domain or "")
115
- low_host = feats["host"].lower()
116
- low_path = feats["path"].lower()
117
- feats["kw_in_path"] = int(any(k in low_path for k in KEYWORDS))
118
- feats["kw_in_host"] = int(any(k in low_host for k in KEYWORDS))
119
- # keyword appears in subdomain but not in registered brand
120
- feats["kw_in_subdomain_only"] = int(
121
- feats["kw_in_host"] and (ext.domain and not any(k in ext.domain.lower() for k in KEYWORDS))
122
- )
123
- feats["suspicious_tld"] = int((feats["tld"].split(".")[-1] or "") in SUSPICIOUS_TLDS)
124
- # crude “entropy-like” signal for long alnum query blobs
125
- alnum = sum(c.isalnum() for c in feats["query"])
126
- feats["query_ratio_alnum"] = (alnum / max(1, len(feats["query"]))) if feats["query"] else 0.0
127
- feats["parse_error"] = False
128
- except Exception:
129
- feats = {"parse_error": True}
130
- return feats
131
-
132
- def heuristic_score(feats: dict) -> float:
133
- """0..1 suspicious score."""
134
- if feats.get("parse_error"):
135
- return 0.70 # unparsable => suspicious
136
- score = 0.0
137
- score += 0.25 * feats["kw_in_path"]
138
- score += 0.20 * feats["kw_in_subdomain_only"]
139
- score += 0.10 * feats["kw_in_host"]
140
- score += 0.10 * feats["hyphen_in_regdom"]
141
- score += 0.10 * (feats["labels"] >= 4)
142
- score += 0.10 * feats["has_punycode"]
143
- score += 0.10 * feats["suspicious_tld"]
144
- score += 0.05 * feats["has_at"]
145
- score += 0.05 * feats["has_port"]
146
- score += 0.10 * (feats["len_url"] >= 100)
147
- if feats["query"] and len(feats["query"]) >= 40 and feats["query_ratio_alnum"] > 0.9:
148
- score += 0.10
149
- return max(0.0, min(1.0, score))
150
-
151
- def heuristic_reasons(feats: dict) -> str:
152
- if feats.get("parse_error"):
153
- return "parse error"
154
- rs = []
155
- if feats.get("kw_in_path"): rs.append("keyword in path")
156
- if feats.get("kw_in_subdomain_only"): rs.append("keyword in subdomain")
157
- if feats.get("kw_in_host") and not feats.get("kw_in_subdomain_only"): rs.append("keyword in host")
158
- if feats.get("hyphen_in_regdom"): rs.append("hyphen in registered domain")
159
- if feats.get("labels", 0) >= 4: rs.append("deep subdomain nesting")
160
- if feats.get("has_punycode"): rs.append("punycode host")
161
- if feats.get("suspicious_tld"): rs.append(f"suspicious TLD: {feats.get('tld')}")
162
- if feats.get("has_at"): rs.append("@ in URL")
163
- if feats.get("has_port"): rs.append("explicit port")
164
- if feats.get("len_url", 0) >= 100: rs.append("very long URL")
165
- if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9:
166
- rs.append("long query blob")
167
- return ", ".join(rs) if rs else "no heuristic triggers"
168
-
169
- # -------- Core --------
170
- def analyze(text: str, forensic: bool, show_json: bool):
171
- """
172
- One output: Markdown with
173
- - verdict
174
- - table (model, heuristic, fused + reasons)
175
- - optional forensic blocks (tokens, logits, [CLS])
176
- - optional raw JSON (copy/paste)
177
- """
178
- text = (text or "").strip()
179
- if not text:
180
- return "Paste an email body or a URL."
181
-
182
- urls = [text] if (text.lower().startswith(("http://","https://","www.")) and " " not in text) else _extract_urls(text)
183
- if not urls:
184
- return "No URLs detected in the text."
185
-
186
- tok, mdl = _load_model()
187
- id2label_raw = getattr(mdl.config, "id2label", None) or {}
188
- id2label = {}
189
- for k, v in id2label_raw.items():
190
- try:
191
- id2label[int(k)] = v
192
- except Exception:
193
- if isinstance(k, str) and k.startswith("LABEL_"):
194
- idx = int(k.split("_")[-1])
195
- id2label[idx] = v
196
-
197
- header_rows = []
198
- forensic_blocks = []
199
- export_data = {"model_id": URL_MODEL_ID, "items": []}
200
- any_unsafe = False
201
-
202
- for u in urls:
203
- # --- Encode & forward for logits / CLS ---
204
- max_len = min(512, getattr(mdl.config, "max_position_embeddings", 512) or 512)
205
- enc = tok(u, truncation=True, max_length=max_len, return_tensors="pt", return_attention_mask=True)
206
- token_ids = enc["input_ids"][0].tolist()
207
- tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])
208
- truncated = enc["input_ids"].shape[1] >= max_len and len(tokens) >= max_len
209
-
210
- t0 = time.time()
211
- with torch.no_grad():
212
- out = mdl(**enc, output_hidden_states=True)
213
- elapsed = time.time() - t0
214
-
215
- logits = out.logits.squeeze(0) # (num_labels,)
216
- probs = _softmax(logits) # list[float]
217
- hidden_states = out.hidden_states
218
- cls_vec = hidden_states[-1][0, 0, :].cpu().tolist()
219
-
220
- per_class = [
221
- {"label": _lbl_name(i, id2label), "prob": float(probs[i]), "logit": float(logits[i])}
222
- for i in range(len(probs))
223
- ]
224
- per_class_sorted = sorted(per_class, key=lambda x: x["prob"], reverse=True)
225
- top = per_class_sorted[0]
226
-
227
- # --- Heuristics & fusion ---
228
- feats = heuristic_features(u)
229
- h_score = heuristic_score(feats)
230
- mdl_phish_like = sum(s["prob"] for s in per_class_sorted if s["label"].lower() in {"phishing","malware","defacement"})
231
- fused = 0.65 * mdl_phish_like + 0.35 * h_score
232
- reasons = heuristic_reasons(feats)
233
-
234
- header_rows.append([u, top["label"], top["prob"] * 100.0, h_score, fused, reasons])
235
- if fused >= 0.50:
236
- any_unsafe = True
237
-
238
- # collect full details for optional JSON dump
239
- export_data["items"].append({
240
- "url": u,
241
- "token_ids": token_ids,
242
- "tokens": tokens,
243
- "truncated": truncated,
244
- "logits": [float(x) for x in logits.cpu().tolist()],
245
- "probs": [float(p) for p in probs],
246
- "scores_sorted": per_class_sorted,
247
- "cls_vector": cls_vec,
248
- "cls_dim": len(cls_vec),
249
- "elapsed_sec": elapsed,
250
- "heuristic": feats,
251
- "heuristic_score": h_score,
252
- "fused_risk": fused,
253
- })
254
-
255
- if forensic:
256
- forensic_blocks.append(
257
- _forensic_block(
258
- url=u,
259
- token_ids=token_ids,
260
- tokens=tokens,
261
- scores_sorted=per_class_sorted,
262
- cls_vec=cls_vec,
263
- elapsed_s=elapsed,
264
- truncated=truncated,
265
- )
266
- )
267
-
268
- verdict = "🔴 **UNSAFE (links flagged)**" if any_unsafe else "🟢 **SAFE (no fused risk ≥ 0.50)**"
269
- body = verdict + "\n\n" + _markdown_results_header(header_rows)
270
-
271
- if forensic and forensic_blocks:
272
- body += "\n\n---\n\n" + "\n\n---\n\n".join(forensic_blocks)
273
-
274
- if show_json:
275
- pretty = json.dumps(export_data, ensure_ascii=False, indent=2)
276
- body += "\n\n---\n\n**Raw forensics JSON (copy & save):**\n"
277
- body += "```json\n" + pretty + "\n```"
278
-
279
- return body
280
-
281
- # -------- UI --------
282
- demo = gr.Interface(
283
- fn=analyze,
284
- inputs=[
285
- gr.Textbox(lines=6, label="Email or URL", placeholder="Paste a URL or a full email…"),
286
- gr.Checkbox(label="Forensic mode (tokens, logits, [CLS])", value=True),
287
- gr.Checkbox(label="Show raw JSON at the end (copy/paste)", value=False),
288
- ],
289
- outputs=gr.Markdown(label="Results"),
290
- title="🛡️ PhishingMail — Model + Heuristics (HF Free CPU)",
291
- description=(
292
- "We extract links and classify each with a compact malicious-URL model, then fuse with transparent heuristics. "
293
- "Table shows Model Prob, Heuristic Score, and Fused Risk with reasons. "
294
- "Toggle Forensic mode for tokens/logits/[CLS]."
295
- ),
296
- )
297
-
298
- if __name__ == "__main__":
299
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, time, json, urllib.parse
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ # Optional (robust domain parsing). If not installed, code falls back gracefully.
7
+ try:
8
+ import tldextract
9
+ except Exception:
10
+ tldextract = None
11
+
12
+ # Quiet + CPU friendly
13
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
14
+
15
+ # -------- Models / Regex --------
16
+ URL_MODEL_ID = "CrabInHoney/urlbert-tiny-v4-malicious-url-classifier"
17
+ LABEL_MAP = {0: "benign", 1: "defacement", 2: "malware", 3: "phishing"}
18
+ URL_RE = re.compile(r"""(?xi)\b(?:https?://|www\.)[a-z0-9\-._~%]+(?:/[^\s<>"']*)?""")
19
+
20
+ # Heuristic config (covers your cases)
21
+ KEYWORDS = {
22
+ "phish","login","verify","account","secure","update","bank","wallet",
23
+ "password","invoice","pay","reset","support","unlock","confirm"
24
+ }
25
+ SUSPICIOUS_TLDS = {
26
+ "zip","mov","lol","xyz","top","country","link","click","cam","help",
27
+ "gq","cf","tk","work","rest","monster","quest","live","io","ly"
28
+ }
29
+
30
+ # Lazy globals for tokenizer & model
31
+ _tok = None
32
+ _mdl = None
33
+
34
+ # -------- Utilities --------
35
+ def _extract_urls(text: str):
36
+ return sorted(set(m.group(0) for m in URL_RE.finditer(text or "")))
37
+
38
+ def _load_model():
39
+ global _tok, _mdl
40
+ if _tok is not None and _mdl is not None:
41
+ return _tok, _mdl
42
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
43
+ _tok = AutoTokenizer.from_pretrained(URL_MODEL_ID)
44
+ _mdl = AutoModelForSequenceClassification.from_pretrained(URL_MODEL_ID)
45
+ _mdl.eval()
46
+ return _tok, _mdl
47
+
48
+ def _softmax(logits: torch.Tensor):
49
+ return F.softmax(logits, dim=-1).tolist()
50
+
51
+ def _lbl_name(idx: int, id2label: dict):
52
+ if id2label and idx in id2label:
53
+ return id2label[idx]
54
+ return LABEL_MAP.get(idx, str(idx))
55
+
56
+ def _format_scores_md(scores_sorted):
57
+ lines = ["| Class | Prob (%) | Logit |", "|---|---:|---:|"]
58
+ for s in scores_sorted:
59
+ lines.append(f"| **{s['label']}** | {s['prob']*100:.2f} | {s['logit']:.3f} |")
60
+ return "\n".join(lines)
61
+
62
+ def _markdown_results_header(rows):
63
+ # rows: [ [url, model_label, model_pct, heur, fused, reasons], ... ]
64
+ lines = [
65
+ "| URL | Model | Model Prob (%) | Heuristic | Fused Risk | Reasons |",
66
+ "|---|---|---:|---:|---:|---|",
67
+ ]
68
+ for u, lbl, pct, h, fused, reasons in rows:
69
+ lines.append(
70
+ f"| `{u}` | **{lbl}** | {pct:.2f} | {h:.2f} | {fused:.2f} | {reasons} |"
71
+ )
72
+ return "\n".join(lines)
73
+
74
+ def _forensic_block(url, token_ids, tokens, scores_sorted, cls_vec, elapsed_s, truncated):
75
+ toks_prev = ", ".join(tokens[:64]) + (" …" if len(tokens) > 64 else "")
76
+ ids_prev = ", ".join(map(str, token_ids[:64])) + (" …" if len(token_ids) > 64 else "")
77
+ cls_dim = len(cls_vec)
78
+ cls_prev = ", ".join(f"{v:.4f}" for v in cls_vec[:16]) + (" …" if cls_dim > 16 else "")
79
+ l2 = (sum(v*v for v in cls_vec)) ** 0.5
80
+ md = []
81
+ md.append(f"### 🔍 Forensics for `{url}`\n")
82
+ md.append(f"- tokens: **{len(tokens)}** truncated: **{'yes' if truncated else 'no'}**")
83
+ md.append(f"- inference time: **{elapsed_s:.2f}s**\n")
84
+ md.append("**Top-k scores**")
85
+ md.append(_format_scores_md(scores_sorted))
86
+ md.append("\n**Token IDs (preview)**")
87
+ md.append("```txt\n" + ids_prev + "\n```")
88
+ md.append("**Tokens (preview)**")
89
+ md.append("```txt\n" + toks_prev + "\n```")
90
+ md.append("**[CLS] embedding (preview)**")
91
+ md.append(f"`dim={cls_dim}`, `L2={l2:.4f}`")
92
+ md.append("```txt\n" + cls_prev + "\n```")
93
+ return "\n".join(md)
94
+
95
+ # -------- Heuristics --------
96
+ def _safe_parse(url: str):
97
+ if not re.match(r"^https?://", url, re.I):
98
+ url = "http://" + url
99
+ return urllib.parse.urlparse(url)
100
+
101
+ def _split_reg_domain(host: str):
102
+ """Fallback registered-domain guess if tldextract isn't available."""
103
+ parts = host.split(".")
104
+ if len(parts) >= 2:
105
+ return parts[-2] + "." + parts[-1]
106
+ return host
107
+
108
+ def heuristic_features(u: str):
109
+ feats = {}
110
+ try:
111
+ p = _safe_parse(u)
112
+ feats["scheme_https"] = 1 if p.scheme.lower() == "https" else 0
113
+ feats["host"] = p.hostname or ""
114
+ feats["path"] = p.path or "/"
115
+ feats["query"] = p.query or ""
116
+ if tldextract:
117
+ ext = tldextract.extract(feats["host"]) # subdomain, domain, suffix
118
+ regdom = f"{ext.domain}.{ext.suffix}" if ext.domain and ext.suffix else feats["host"]
119
+ subdomain = ext.subdomain or ""
120
+ tld = ext.suffix or ""
121
+ domain_core = ext.domain or ""
122
+ else:
123
+ # lightweight fallback
124
+ regdom = _split_reg_domain(feats["host"])
125
+ tld = regdom.split(".")[-1] if "." in regdom else ""
126
+ subdomain = feats["host"][:-len(regdom)].rstrip(".") if feats["host"].endswith(regdom) else ""
127
+ domain_core = regdom.split(".")[0] if "." in regdom else regdom
128
+
129
+ feats["registered_domain"] = regdom
130
+ feats["subdomain"] = subdomain
131
+ feats["tld"] = tld
132
+ feats["labels"] = feats["host"].count(".") + (1 if feats["host"] else 0)
133
+ feats["has_at"] = "@" in u
134
+ feats["has_port"] = bool(p.netloc and ":" in p.netloc.split("@")[-1])
135
+ feats["has_punycode"] = "xn--" in feats["host"]
136
+ feats["len_url"] = len(u)
137
+ feats["hyphen_in_regdom"] = "-" in (domain_core or "")
138
+ low_host = feats["host"].lower()
139
+ low_path = feats["path"].lower()
140
+ feats["kw_in_path"] = int(any(k in low_path for k in KEYWORDS))
141
+ feats["kw_in_host"] = int(any(k in low_host for k in KEYWORDS))
142
+ # keyword appears in subdomain but not in registered brand core
143
+ feats["kw_in_subdomain_only"] = int(
144
+ feats["kw_in_host"] and (domain_core and not any(k in (domain_core.lower()) for k in KEYWORDS))
145
+ )
146
+ feats["suspicious_tld"] = int((feats["tld"].split(".")[-1] or "") in SUSPICIOUS_TLDS)
147
+ alnum = sum(c.isalnum() for c in feats["query"])
148
+ feats["query_ratio_alnum"] = (alnum / max(1, len(feats["query"]))) if feats["query"] else 0.0
149
+ feats["parse_error"] = False
150
+ except Exception:
151
+ feats = {"parse_error": True}
152
+ return feats
153
+
154
+ def heuristic_score(feats: dict) -> float:
155
+ """0..1 suspicious score."""
156
+ if feats.get("parse_error"):
157
+ return 0.70 # unparsable => suspicious
158
+ score = 0.0
159
+ score += 0.25 * feats["kw_in_path"]
160
+ score += 0.20 * feats["kw_in_subdomain_only"]
161
+ score += 0.10 * feats["kw_in_host"]
162
+ score += 0.10 * feats["hyphen_in_regdom"]
163
+ score += 0.10 * (feats["labels"] >= 4)
164
+ score += 0.10 * feats["has_punycode"]
165
+ score += 0.10 * feats["suspicious_tld"]
166
+ score += 0.05 * feats["has_at"]
167
+ score += 0.05 * feats["has_port"]
168
+ score += 0.10 * (feats["len_url"] >= 100)
169
+ if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9:
170
+ score += 0.10
171
+ return max(0.0, min(1.0, score))
172
+
173
+ def heuristic_reasons(feats: dict) -> str:
174
+ if feats.get("parse_error"):
175
+ return "parse error"
176
+ rs = []
177
+ if feats.get("kw_in_path"): rs.append("keyword in path")
178
+ if feats.get("kw_in_subdomain_only"): rs.append("keyword in subdomain")
179
+ if feats.get("kw_in_host") and not feats.get("kw_in_subdomain_only"): rs.append("keyword in host")
180
+ if feats.get("hyphen_in_regdom"): rs.append("hyphen in registered domain")
181
+ if feats.get("labels", 0) >= 4: rs.append("deep subdomain nesting")
182
+ if feats.get("has_punycode"): rs.append("punycode host")
183
+ if feats.get("suspicious_tld"): rs.append(f"suspicious TLD: {feats.get('tld')}")
184
+ if feats.get("has_at"): rs.append("@ in URL")
185
+ if feats.get("has_port"): rs.append("explicit port")
186
+ if feats.get("len_url", 0) >= 100: rs.append("very long URL")
187
+ if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9:
188
+ rs.append("long query blob")
189
+ return ", ".join(rs) if rs else "no heuristic triggers"
190
+
191
+ # -------- Core --------
192
+ def analyze(text: str, forensic: bool, show_json: bool):
193
+ """
194
+ One Markdown output:
195
+ - verdict
196
+ - table (model, heuristic, fused + reasons)
197
+ - optional forensic blocks (tokens, logits, [CLS])
198
+ - optional raw JSON (copy/paste)
199
+ """
200
+ text = (text or "").strip()
201
+ if not text:
202
+ return "Paste an email body or a URL."
203
+
204
+ urls = [text] if (text.lower().startswith(("http://","https://","www.")) and " " not in text) else _extract_urls(text)
205
+ if not urls:
206
+ return "No URLs detected in the text."
207
+
208
+ tok, mdl = _load_model()
209
+ id2label_raw = getattr(mdl.config, "id2label", None) or {}
210
+ id2label = {}
211
+ for k, v in id2label_raw.items():
212
+ try:
213
+ id2label[int(k)] = v
214
+ except Exception:
215
+ if isinstance(k, str) and k.startswith("LABEL_"):
216
+ idx = int(k.split("_")[-1])
217
+ id2label[idx] = v
218
+
219
+ header_rows = []
220
+ forensic_blocks = []
221
+ export_data = {"model_id": URL_MODEL_ID, "items": []}
222
+ any_unsafe = False
223
+
224
+ for u in urls:
225
+ max_len = min(512, getattr(mdl.config, "max_position_embeddings", 512) or 512)
226
+ enc = tok(u, truncation=True, max_length=max_len, return_tensors="pt", return_attention_mask=True)
227
+ token_ids = enc["input_ids"][0].tolist()
228
+ tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])
229
+ truncated = enc["input_ids"].shape[1] >= max_len and len(tokens) >= max_len
230
+
231
+ t0 = time.time()
232
+ with torch.no_grad():
233
+ out = mdl(**enc, output_hidden_states=True)
234
+ elapsed = time.time() - t0
235
+
236
+ logits = out.logits.squeeze(0)
237
+ probs = _softmax(logits)
238
+ hidden_states = out.hidden_states
239
+ cls_vec = hidden_states[-1][0, 0, :].cpu().tolist()
240
+
241
+ per_class = [
242
+ {"label": _lbl_name(i, id2label), "prob": float(probs[i]), "logit": float(logits[i])}
243
+ for i in range(len(probs))
244
+ ]
245
+ per_class_sorted = sorted(per_class, key=lambda x: x["prob"], reverse=True)
246
+ top = per_class_sorted[0]
247
+
248
+ # Heuristics & fusion
249
+ feats = heuristic_features(u)
250
+ h_score = heuristic_score(feats)
251
+ mdl_phish_like = sum(s["prob"] for s in per_class_sorted if s["label"].lower() in {"phishing","malware","defacement"})
252
+ fused = 0.65 * mdl_phish_like + 0.35 * h_score
253
+ reasons = heuristic_reasons(feats)
254
+
255
+ header_rows.append([u, top["label"], top["prob"] * 100.0, h_score, fused, reasons])
256
+ if fused >= 0.50:
257
+ any_unsafe = True
258
+
259
+ export_data["items"].append({
260
+ "url": u,
261
+ "token_ids": token_ids,
262
+ "tokens": tokens,
263
+ "truncated": truncated,
264
+ "logits": [float(x) for x in logits.cpu().tolist()],
265
+ "probs": [float(p) for p in probs],
266
+ "scores_sorted": per_class_sorted,
267
+ "cls_vector": cls_vec,
268
+ "cls_dim": len(cls_vec),
269
+ "elapsed_sec": elapsed,
270
+ "heuristic": feats,
271
+ "heuristic_score": h_score,
272
+ "fused_risk": fused,
273
+ })
274
+
275
+ if forensic:
276
+ forensic_blocks.append(
277
+ _forensic_block(
278
+ url=u,
279
+ token_ids=token_ids,
280
+ tokens=tokens,
281
+ scores_sorted=per_class_sorted,
282
+ cls_vec=cls_vec,
283
+ elapsed_s=elapsed,
284
+ truncated=truncated,
285
+ )
286
+ )
287
+
288
+ verdict = "🔴 **UNSAFE (links flagged)**" if any_unsafe else "🟢 **SAFE (no fused risk ≥ 0.50)**"
289
+ body = verdict + "\n\n" + _markdown_results_header(header_rows)
290
+
291
+ if forensic and forensic_blocks:
292
+ body += "\n\n---\n\n" + "\n\n---\n\n".join(forensic_blocks)
293
+
294
+ if show_json:
295
+ pretty = json.dumps(export_data, ensure_ascii=False, indent=2)
296
+ body += "\n\n---\n\n**Raw forensics JSON (copy & save):**\n"
297
+ body += "```json\n" + pretty + "\n```"
298
+
299
+ return body
300
+
301
+ # -------- UI --------
302
+ demo = gr.Interface(
303
+ fn=analyze,
304
+ inputs=[
305
+ gr.Textbox(lines=8, label="Email or URL", placeholder="Paste a URL or a full email…"),
306
+ gr.Checkbox(label="Forensic mode (tokens, logits, [CLS])", value=True),
307
+ gr.Checkbox(label="Show raw JSON at the end (copy/paste)", value=False),
308
+ ],
309
+ outputs=gr.Markdown(label="Results"),
310
+ title="🛡️ PhishingMail — Model + Heuristics (HF Free CPU)",
311
+ description=(
312
+ "We extract links and classify each with a compact malicious-URL model, then fuse with transparent heuristics. "
313
+ "Table shows Model Prob, Heuristic Score, and Fused Risk with reasons. "
314
+ "Toggle Forensic mode for tokens/logits/[CLS]."
315
+ ),
316
+ )
317
+
318
+ if __name__ == "__main__":
319
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
requirements.txt CHANGED
@@ -1,6 +1,8 @@
1
- gradio==4.44.1
2
- transformers==4.55.2
3
- tldextract==5.1.2 # <--- add this line
4
-
5
- --extra-index-url https://download.pytorch.org/whl/cpu
6
- torch==2.4.0+cpu
 
 
 
1
+ gradio==4.44.1
2
+ transformers==4.55.2
3
+
4
+ # optional but recommended; code falls back if missing
5
+ tldextract==5.1.2
6
+
7
+ --extra-index-url https://download.pytorch.org/whl/cpu
8
+ torch==2.4.0+cpu