ai-assist-sh commited on
Commit
85540cc
·
verified ·
1 Parent(s): d05f393

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +151 -47
main.py CHANGED
@@ -1,20 +1,32 @@
1
- import os, re, time, json
2
  import gradio as gr
3
  import torch
4
  import torch.nn.functional as F
 
5
 
6
- # Be quiet + CPU friendly
7
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
8
 
 
9
  URL_MODEL_ID = "CrabInHoney/urlbert-tiny-v4-malicious-url-classifier"
10
  LABEL_MAP = {0: "benign", 1: "defacement", 2: "malware", 3: "phishing"}
11
- #URL_RE = re.compile(r"""(?xi)\b(?:https?://|www\.)[a-z0-9\-._~%]+(?:/[^\s<>"']*)?""")
12
- URL_RE = re.compile(r"""(?xi)\bhttps?://[^\s<>"']+""")
13
 
 
 
 
 
 
 
 
 
 
14
 
 
15
  _tok = None
16
  _mdl = None
17
 
 
18
  def _extract_urls(text: str):
19
  return sorted(set(m.group(0) for m in URL_RE.finditer(text or "")))
20
 
@@ -42,42 +54,125 @@ def _format_scores_md(scores_sorted):
42
  lines.append(f"| **{s['label']}** | {s['prob']*100:.2f} | {s['logit']:.3f} |")
43
  return "\n".join(lines)
44
 
45
- def _markdown_table(rows):
46
- lines = ["| URL | Prediction | Confidence (%) |", "|---|---|---|"]
47
- for u, lbl, conf in rows:
48
- lines.append(f"| `{u}` | **{lbl}** | {conf:.2f} |")
 
 
 
 
 
 
49
  return "\n".join(lines)
50
 
51
  def _forensic_block(url, token_ids, tokens, scores_sorted, cls_vec, elapsed_s, truncated):
52
- toks_preview = ", ".join(tokens[:64]) + (" …" if len(tokens) > 64 else "")
53
- ids_preview = ", ".join(map(str, token_ids[:64])) + (" …" if len(token_ids) > 64 else "")
54
  cls_dim = len(cls_vec)
55
- cls_preview = ", ".join(f"{v:.4f}" for v in cls_vec[:16]) + (" …" if cls_dim > 16 else "")
56
  l2 = (sum(v*v for v in cls_vec)) ** 0.5
57
-
58
  md = []
59
- md.append(f"### 🔍 Forensics for `{url}`")
60
- md.append("")
61
  md.append(f"- tokens: **{len(tokens)}** • truncated: **{'yes' if truncated else 'no'}**")
62
- md.append(f"- inference time: **{elapsed_s:.2f}s**")
63
- md.append("")
64
  md.append("**Top-k scores**")
65
  md.append(_format_scores_md(scores_sorted))
66
- md.append("")
67
- md.append("**Token IDs (preview)**")
68
- md.append("```txt\n" + ids_preview + "\n```")
69
  md.append("**Tokens (preview)**")
70
- md.append("```txt\n" + toks_preview + "\n```")
71
  md.append("**[CLS] embedding (preview)**")
72
  md.append(f"`dim={cls_dim}`, `L2={l2:.4f}`")
73
- md.append("```txt\n" + cls_preview + "\n```")
74
  return "\n".join(md)
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def analyze(text: str, forensic: bool, show_json: bool):
77
  """
78
- Returns a single Markdown block:
79
- - verdict + compact table
80
- - optional forensic blocks (tokens, logits, CLS)
 
81
  - optional raw JSON (copy/paste)
82
  """
83
  text = (text or "").strip()
@@ -99,21 +194,15 @@ def analyze(text: str, forensic: bool, show_json: bool):
99
  idx = int(k.split("_")[-1])
100
  id2label[idx] = v
101
 
102
- rows = []
103
- unsafe = False
104
  forensic_blocks = []
105
  export_data = {"model_id": URL_MODEL_ID, "items": []}
 
106
 
107
  for u in urls:
 
108
  max_len = min(512, getattr(mdl.config, "max_position_embeddings", 512) or 512)
109
- enc = tok(
110
- u,
111
- truncation=True,
112
- max_length=max_len,
113
- padding=False,
114
- return_tensors="pt",
115
- return_attention_mask=True,
116
- )
117
  token_ids = enc["input_ids"][0].tolist()
118
  tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])
119
  truncated = enc["input_ids"].shape[1] >= max_len and len(tokens) >= max_len
@@ -123,9 +212,9 @@ def analyze(text: str, forensic: bool, show_json: bool):
123
  out = mdl(**enc, output_hidden_states=True)
124
  elapsed = time.time() - t0
125
 
126
- logits = out.logits.squeeze(0) # (num_labels,)
127
- probs = _softmax(logits) # list[float]
128
- hidden_states = out.hidden_states # tuple of layers
129
  cls_vec = hidden_states[-1][0, 0, :].cpu().tolist()
130
 
131
  per_class = [
@@ -134,10 +223,19 @@ def analyze(text: str, forensic: bool, show_json: bool):
134
  ]
135
  per_class_sorted = sorted(per_class, key=lambda x: x["prob"], reverse=True)
136
  top = per_class_sorted[0]
137
- rows.append([u, top["label"], top["prob"] * 100.0])
138
- if top["label"].lower() in {"phishing", "malware", "defacement"}:
139
- unsafe = True
140
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  export_data["items"].append({
142
  "url": u,
143
  "token_ids": token_ids,
@@ -145,10 +243,13 @@ def analyze(text: str, forensic: bool, show_json: bool):
145
  "truncated": truncated,
146
  "logits": [float(x) for x in logits.cpu().tolist()],
147
  "probs": [float(p) for p in probs],
148
- "scores_sorted": per_class_sorted, # label+prob+logit
149
  "cls_vector": cls_vec,
150
  "cls_dim": len(cls_vec),
151
  "elapsed_sec": elapsed,
 
 
 
152
  })
153
 
154
  if forensic:
@@ -164,20 +265,20 @@ def analyze(text: str, forensic: bool, show_json: bool):
164
  )
165
  )
166
 
167
- verdict = "🔴 **UNSAFE (links flagged)**" if unsafe else "🟢 **SAFE (all links benign)**"
168
- body = verdict + "\n\n" + _markdown_table(rows)
169
 
170
  if forensic and forensic_blocks:
171
  body += "\n\n---\n\n" + "\n\n---\n\n".join(forensic_blocks)
172
 
173
  if show_json:
174
- # raw JSON for copy-paste (no File component needed)
175
  pretty = json.dumps(export_data, ensure_ascii=False, indent=2)
176
  body += "\n\n---\n\n**Raw forensics JSON (copy & save):**\n"
177
  body += "```json\n" + pretty + "\n```"
178
 
179
  return body
180
 
 
181
  demo = gr.Interface(
182
  fn=analyze,
183
  inputs=[
@@ -186,10 +287,13 @@ demo = gr.Interface(
186
  gr.Checkbox(label="Show raw JSON at the end (copy/paste)", value=False),
187
  ],
188
  outputs=gr.Markdown(label="Results"),
189
- title="🛡️ PhishingMail — Forensics (HF Free CPU)",
190
- description="Extract links, classify with a tiny URL model, and (optionally) view tokens, logits, and [CLS] embedding.",
 
 
 
 
191
  )
192
 
193
  if __name__ == "__main__":
194
- # Safe defaults for HF Spaces (no share=True needed)
195
  demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
1
+ import os, re, time, json, urllib.parse
2
  import gradio as gr
3
  import torch
4
  import torch.nn.functional as F
5
+ import tldextract # for robust registered-domain parsing
6
 
7
+ # Quiet + CPU friendly
8
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
9
 
10
+ # -------- Models / Regex --------
11
  URL_MODEL_ID = "CrabInHoney/urlbert-tiny-v4-malicious-url-classifier"
12
  LABEL_MAP = {0: "benign", 1: "defacement", 2: "malware", 3: "phishing"}
13
+ URL_RE = re.compile(r"""(?xi)\b(?:https?://|www\.)[a-z0-9\-._~%]+(?:/[^\s<>"']*)?""")
 
14
 
15
+ # Heuristic config
16
+ KEYWORDS = {
17
+ "login","verify","account","secure","update","bank","wallet",
18
+ "password","invoice","pay","reset","support","unlock","confirm"
19
+ }
20
+ SUSPICIOUS_TLDS = {
21
+ "zip","mov","lol","xyz","top","country","link","click","cam",
22
+ "help","gq","cf","tk","work","rest","monster","quest","live"
23
+ }
24
 
25
+ # Lazy globals for tokenizer & model
26
  _tok = None
27
  _mdl = None
28
 
29
+ # -------- Utilities --------
30
  def _extract_urls(text: str):
31
  return sorted(set(m.group(0) for m in URL_RE.finditer(text or "")))
32
 
 
54
  lines.append(f"| **{s['label']}** | {s['prob']*100:.2f} | {s['logit']:.3f} |")
55
  return "\n".join(lines)
56
 
57
+ def _markdown_results_header(rows):
58
+ # rows: [ [url, model_label, model_pct, heur, fused, reason_txt], ... ]
59
+ lines = [
60
+ "| URL | Model | Model Prob (%) | Heuristic | Fused Risk | Reasons |",
61
+ "|---|---|---:|---:|---:|---|",
62
+ ]
63
+ for u, lbl, pct, h, fused, reasons in rows:
64
+ lines.append(
65
+ f"| `{u}` | **{lbl}** | {pct:.2f} | {h:.2f} | {fused:.2f} | {reasons} |"
66
+ )
67
  return "\n".join(lines)
68
 
69
  def _forensic_block(url, token_ids, tokens, scores_sorted, cls_vec, elapsed_s, truncated):
70
+ toks_prev = ", ".join(tokens[:64]) + (" …" if len(tokens) > 64 else "")
71
+ ids_prev = ", ".join(map(str, token_ids[:64])) + (" …" if len(token_ids) > 64 else "")
72
  cls_dim = len(cls_vec)
73
+ cls_prev = ", ".join(f"{v:.4f}" for v in cls_vec[:16]) + (" …" if cls_dim > 16 else "")
74
  l2 = (sum(v*v for v in cls_vec)) ** 0.5
 
75
  md = []
76
+ md.append(f"### 🔍 Forensics for `{url}`\n")
 
77
  md.append(f"- tokens: **{len(tokens)}** • truncated: **{'yes' if truncated else 'no'}**")
78
+ md.append(f"- inference time: **{elapsed_s:.2f}s**\n")
 
79
  md.append("**Top-k scores**")
80
  md.append(_format_scores_md(scores_sorted))
81
+ md.append("\n**Token IDs (preview)**")
82
+ md.append("```txt\n" + ids_prev + "\n```")
 
83
  md.append("**Tokens (preview)**")
84
+ md.append("```txt\n" + toks_prev + "\n```")
85
  md.append("**[CLS] embedding (preview)**")
86
  md.append(f"`dim={cls_dim}`, `L2={l2:.4f}`")
87
+ md.append("```txt\n" + cls_prev + "\n```")
88
  return "\n".join(md)
89
 
90
+ # -------- Heuristics --------
91
+ def _safe_parse(url: str):
92
+ # add scheme if missing so urlparse sees netloc
93
+ if not re.match(r"^https?://", url, re.I):
94
+ url = "http://" + url
95
+ return urllib.parse.urlparse(url)
96
+
97
+ def heuristic_features(u: str):
98
+ feats = {}
99
+ try:
100
+ p = _safe_parse(u)
101
+ feats["scheme_https"] = 1 if p.scheme.lower() == "https" else 0
102
+ feats["host"] = p.hostname or ""
103
+ feats["path"] = p.path or "/"
104
+ feats["query"] = p.query or ""
105
+ ext = tldextract.extract(feats["host"]) # subdomain, domain, suffix
106
+ feats["registered_domain"] = f"{ext.domain}.{ext.suffix}" if ext.domain and ext.suffix else feats["host"]
107
+ feats["subdomain"] = ext.subdomain or ""
108
+ feats["tld"] = ext.suffix or ""
109
+ feats["labels"] = feats["host"].count(".") + (1 if feats["host"] else 0)
110
+ feats["has_at"] = "@" in u
111
+ feats["has_port"] = bool(p.netloc and ":" in p.netloc.split("@")[-1])
112
+ feats["has_punycode"] = "xn--" in feats["host"]
113
+ feats["len_url"] = len(u)
114
+ feats["hyphen_in_regdom"] = "-" in (ext.domain or "")
115
+ low_host = feats["host"].lower()
116
+ low_path = feats["path"].lower()
117
+ feats["kw_in_path"] = int(any(k in low_path for k in KEYWORDS))
118
+ feats["kw_in_host"] = int(any(k in low_host for k in KEYWORDS))
119
+ # keyword appears in subdomain but not in registered brand
120
+ feats["kw_in_subdomain_only"] = int(
121
+ feats["kw_in_host"] and (ext.domain and not any(k in ext.domain.lower() for k in KEYWORDS))
122
+ )
123
+ feats["suspicious_tld"] = int((feats["tld"].split(".")[-1] or "") in SUSPICIOUS_TLDS)
124
+ # crude “entropy-like” signal for long alnum query blobs
125
+ alnum = sum(c.isalnum() for c in feats["query"])
126
+ feats["query_ratio_alnum"] = (alnum / max(1, len(feats["query"]))) if feats["query"] else 0.0
127
+ feats["parse_error"] = False
128
+ except Exception:
129
+ feats = {"parse_error": True}
130
+ return feats
131
+
132
+ def heuristic_score(feats: dict) -> float:
133
+ """0..1 suspicious score."""
134
+ if feats.get("parse_error"):
135
+ return 0.70 # unparsable => suspicious
136
+ score = 0.0
137
+ score += 0.25 * feats["kw_in_path"]
138
+ score += 0.20 * feats["kw_in_subdomain_only"]
139
+ score += 0.10 * feats["kw_in_host"]
140
+ score += 0.10 * feats["hyphen_in_regdom"]
141
+ score += 0.10 * (feats["labels"] >= 4)
142
+ score += 0.10 * feats["has_punycode"]
143
+ score += 0.10 * feats["suspicious_tld"]
144
+ score += 0.05 * feats["has_at"]
145
+ score += 0.05 * feats["has_port"]
146
+ score += 0.10 * (feats["len_url"] >= 100)
147
+ if feats["query"] and len(feats["query"]) >= 40 and feats["query_ratio_alnum"] > 0.9:
148
+ score += 0.10
149
+ return max(0.0, min(1.0, score))
150
+
151
+ def heuristic_reasons(feats: dict) -> str:
152
+ if feats.get("parse_error"):
153
+ return "parse error"
154
+ rs = []
155
+ if feats.get("kw_in_path"): rs.append("keyword in path")
156
+ if feats.get("kw_in_subdomain_only"): rs.append("keyword in subdomain")
157
+ if feats.get("kw_in_host") and not feats.get("kw_in_subdomain_only"): rs.append("keyword in host")
158
+ if feats.get("hyphen_in_regdom"): rs.append("hyphen in registered domain")
159
+ if feats.get("labels", 0) >= 4: rs.append("deep subdomain nesting")
160
+ if feats.get("has_punycode"): rs.append("punycode host")
161
+ if feats.get("suspicious_tld"): rs.append(f"suspicious TLD: {feats.get('tld')}")
162
+ if feats.get("has_at"): rs.append("@ in URL")
163
+ if feats.get("has_port"): rs.append("explicit port")
164
+ if feats.get("len_url", 0) >= 100: rs.append("very long URL")
165
+ if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9:
166
+ rs.append("long query blob")
167
+ return ", ".join(rs) if rs else "no heuristic triggers"
168
+
169
+ # -------- Core --------
170
  def analyze(text: str, forensic: bool, show_json: bool):
171
  """
172
+ One output: Markdown with
173
+ - verdict
174
+ - table (model, heuristic, fused + reasons)
175
+ - optional forensic blocks (tokens, logits, [CLS])
176
  - optional raw JSON (copy/paste)
177
  """
178
  text = (text or "").strip()
 
194
  idx = int(k.split("_")[-1])
195
  id2label[idx] = v
196
 
197
+ header_rows = []
 
198
  forensic_blocks = []
199
  export_data = {"model_id": URL_MODEL_ID, "items": []}
200
+ any_unsafe = False
201
 
202
  for u in urls:
203
+ # --- Encode & forward for logits / CLS ---
204
  max_len = min(512, getattr(mdl.config, "max_position_embeddings", 512) or 512)
205
+ enc = tok(u, truncation=True, max_length=max_len, return_tensors="pt", return_attention_mask=True)
 
 
 
 
 
 
 
206
  token_ids = enc["input_ids"][0].tolist()
207
  tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])
208
  truncated = enc["input_ids"].shape[1] >= max_len and len(tokens) >= max_len
 
212
  out = mdl(**enc, output_hidden_states=True)
213
  elapsed = time.time() - t0
214
 
215
+ logits = out.logits.squeeze(0) # (num_labels,)
216
+ probs = _softmax(logits) # list[float]
217
+ hidden_states = out.hidden_states
218
  cls_vec = hidden_states[-1][0, 0, :].cpu().tolist()
219
 
220
  per_class = [
 
223
  ]
224
  per_class_sorted = sorted(per_class, key=lambda x: x["prob"], reverse=True)
225
  top = per_class_sorted[0]
 
 
 
226
 
227
+ # --- Heuristics & fusion ---
228
+ feats = heuristic_features(u)
229
+ h_score = heuristic_score(feats)
230
+ mdl_phish_like = sum(s["prob"] for s in per_class_sorted if s["label"].lower() in {"phishing","malware","defacement"})
231
+ fused = 0.65 * mdl_phish_like + 0.35 * h_score
232
+ reasons = heuristic_reasons(feats)
233
+
234
+ header_rows.append([u, top["label"], top["prob"] * 100.0, h_score, fused, reasons])
235
+ if fused >= 0.50:
236
+ any_unsafe = True
237
+
238
+ # collect full details for optional JSON dump
239
  export_data["items"].append({
240
  "url": u,
241
  "token_ids": token_ids,
 
243
  "truncated": truncated,
244
  "logits": [float(x) for x in logits.cpu().tolist()],
245
  "probs": [float(p) for p in probs],
246
+ "scores_sorted": per_class_sorted,
247
  "cls_vector": cls_vec,
248
  "cls_dim": len(cls_vec),
249
  "elapsed_sec": elapsed,
250
+ "heuristic": feats,
251
+ "heuristic_score": h_score,
252
+ "fused_risk": fused,
253
  })
254
 
255
  if forensic:
 
265
  )
266
  )
267
 
268
+ verdict = "🔴 **UNSAFE (links flagged)**" if any_unsafe else "🟢 **SAFE (no fused risk ≥ 0.50)**"
269
+ body = verdict + "\n\n" + _markdown_results_header(header_rows)
270
 
271
  if forensic and forensic_blocks:
272
  body += "\n\n---\n\n" + "\n\n---\n\n".join(forensic_blocks)
273
 
274
  if show_json:
 
275
  pretty = json.dumps(export_data, ensure_ascii=False, indent=2)
276
  body += "\n\n---\n\n**Raw forensics JSON (copy & save):**\n"
277
  body += "```json\n" + pretty + "\n```"
278
 
279
  return body
280
 
281
+ # -------- UI --------
282
  demo = gr.Interface(
283
  fn=analyze,
284
  inputs=[
 
287
  gr.Checkbox(label="Show raw JSON at the end (copy/paste)", value=False),
288
  ],
289
  outputs=gr.Markdown(label="Results"),
290
+ title="🛡️ PhishingMail — Model + Heuristics (HF Free CPU)",
291
+ description=(
292
+ "We extract links and classify each with a compact malicious-URL model, then fuse with transparent heuristics. "
293
+ "Table shows Model Prob, Heuristic Score, and Fused Risk with reasons. "
294
+ "Toggle Forensic mode for tokens/logits/[CLS]."
295
+ ),
296
  )
297
 
298
  if __name__ == "__main__":
 
299
  demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)