File size: 12,211 Bytes
85540cc
755ffe2
 
 
85540cc
755ffe2
85540cc
755ffe2
 
85540cc
755ffe2
 
85540cc
d05f393
85540cc
 
 
 
 
 
 
 
 
755ffe2
85540cc
755ffe2
 
 
85540cc
755ffe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85540cc
 
 
 
 
 
 
 
 
 
755ffe2
 
 
85540cc
 
755ffe2
85540cc
755ffe2
 
85540cc
755ffe2
85540cc
755ffe2
 
85540cc
 
755ffe2
85540cc
755ffe2
 
85540cc
755ffe2
 
85540cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9467b7
755ffe2
85540cc
 
 
 
f9467b7
755ffe2
 
 
f9467b7
755ffe2
 
 
f9467b7
755ffe2
 
 
 
 
 
 
 
 
 
 
 
85540cc
755ffe2
 
85540cc
755ffe2
 
85540cc
755ffe2
85540cc
755ffe2
 
 
 
 
 
 
 
 
85540cc
 
 
755ffe2
 
 
 
 
 
 
 
 
85540cc
 
 
 
 
 
 
 
 
 
 
 
755ffe2
 
 
 
 
 
 
85540cc
755ffe2
 
 
85540cc
 
 
755ffe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85540cc
 
f9467b7
755ffe2
 
 
f9467b7
 
 
 
 
 
 
85540cc
f9467b7
 
 
 
 
 
 
 
85540cc
 
 
 
 
 
f9467b7
755ffe2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import os, re, time, json, urllib.parse
import gradio as gr
import torch
import torch.nn.functional as F
import tldextract  # for robust registered-domain parsing

# Quiet + CPU friendly
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

# -------- Models / Regex --------
URL_MODEL_ID = "CrabInHoney/urlbert-tiny-v4-malicious-url-classifier"
LABEL_MAP = {0: "benign", 1: "defacement", 2: "malware", 3: "phishing"}
URL_RE = re.compile(r"""(?xi)\b(?:https?://|www\.)[a-z0-9\-._~%]+(?:/[^\s<>"']*)?""")

# Heuristic config
KEYWORDS = {
    "login","verify","account","secure","update","bank","wallet",
    "password","invoice","pay","reset","support","unlock","confirm"
}
SUSPICIOUS_TLDS = {
    "zip","mov","lol","xyz","top","country","link","click","cam",
    "help","gq","cf","tk","work","rest","monster","quest","live"
}

# Lazy globals for tokenizer & model
_tok = None
_mdl = None

# -------- Utilities --------
def _extract_urls(text: str):
    return sorted(set(m.group(0) for m in URL_RE.finditer(text or "")))

def _load_model():
    global _tok, _mdl
    if _tok is not None and _mdl is not None:
        return _tok, _mdl
    from transformers import AutoTokenizer, AutoModelForSequenceClassification
    _tok = AutoTokenizer.from_pretrained(URL_MODEL_ID)
    _mdl = AutoModelForSequenceClassification.from_pretrained(URL_MODEL_ID)
    _mdl.eval()
    return _tok, _mdl

def _softmax(logits: torch.Tensor):
    return F.softmax(logits, dim=-1).tolist()

def _lbl_name(idx: int, id2label: dict):
    if id2label and idx in id2label:
        return id2label[idx]
    return LABEL_MAP.get(idx, str(idx))

def _format_scores_md(scores_sorted):
    lines = ["| Class | Prob (%) | Logit |", "|---|---:|---:|"]
    for s in scores_sorted:
        lines.append(f"| **{s['label']}** | {s['prob']*100:.2f} | {s['logit']:.3f} |")
    return "\n".join(lines)

def _markdown_results_header(rows):
    # rows: [ [url, model_label, model_pct, heur, fused, reason_txt], ... ]
    lines = [
        "| URL | Model | Model Prob (%) | Heuristic | Fused Risk | Reasons |",
        "|---|---|---:|---:|---:|---|",
    ]
    for u, lbl, pct, h, fused, reasons in rows:
        lines.append(
            f"| `{u}` | **{lbl}** | {pct:.2f} | {h:.2f} | {fused:.2f} | {reasons} |"
        )
    return "\n".join(lines)

def _forensic_block(url, token_ids, tokens, scores_sorted, cls_vec, elapsed_s, truncated):
    toks_prev = ", ".join(tokens[:64]) + (" …" if len(tokens) > 64 else "")
    ids_prev  = ", ".join(map(str, token_ids[:64])) + (" …" if len(token_ids) > 64 else "")
    cls_dim = len(cls_vec)
    cls_prev = ", ".join(f"{v:.4f}" for v in cls_vec[:16]) + (" …" if cls_dim > 16 else "")
    l2 = (sum(v*v for v in cls_vec)) ** 0.5
    md = []
    md.append(f"### 🔍 Forensics for `{url}`\n")
    md.append(f"- tokens: **{len(tokens)}** • truncated: **{'yes' if truncated else 'no'}**")
    md.append(f"- inference time: **{elapsed_s:.2f}s**\n")
    md.append("**Top-k scores**")
    md.append(_format_scores_md(scores_sorted))
    md.append("\n**Token IDs (preview)**")
    md.append("```txt\n" + ids_prev + "\n```")
    md.append("**Tokens (preview)**")
    md.append("```txt\n" + toks_prev + "\n```")
    md.append("**[CLS] embedding (preview)**")
    md.append(f"`dim={cls_dim}`, `L2={l2:.4f}`")
    md.append("```txt\n" + cls_prev + "\n```")
    return "\n".join(md)

# -------- Heuristics --------
def _safe_parse(url: str):
    # add scheme if missing so urlparse sees netloc
    if not re.match(r"^https?://", url, re.I):
        url = "http://" + url
    return urllib.parse.urlparse(url)

def heuristic_features(u: str):
    feats = {}
    try:
        p = _safe_parse(u)
        feats["scheme_https"] = 1 if p.scheme.lower() == "https" else 0
        feats["host"] = p.hostname or ""
        feats["path"] = p.path or "/"
        feats["query"] = p.query or ""
        ext = tldextract.extract(feats["host"])  # subdomain, domain, suffix
        feats["registered_domain"] = f"{ext.domain}.{ext.suffix}" if ext.domain and ext.suffix else feats["host"]
        feats["subdomain"] = ext.subdomain or ""
        feats["tld"] = ext.suffix or ""
        feats["labels"] = feats["host"].count(".") + (1 if feats["host"] else 0)
        feats["has_at"] = "@" in u
        feats["has_port"] = bool(p.netloc and ":" in p.netloc.split("@")[-1])
        feats["has_punycode"] = "xn--" in feats["host"]
        feats["len_url"] = len(u)
        feats["hyphen_in_regdom"] = "-" in (ext.domain or "")
        low_host = feats["host"].lower()
        low_path = feats["path"].lower()
        feats["kw_in_path"] = int(any(k in low_path for k in KEYWORDS))
        feats["kw_in_host"] = int(any(k in low_host for k in KEYWORDS))
        # keyword appears in subdomain but not in registered brand
        feats["kw_in_subdomain_only"] = int(
            feats["kw_in_host"] and (ext.domain and not any(k in ext.domain.lower() for k in KEYWORDS))
        )
        feats["suspicious_tld"] = int((feats["tld"].split(".")[-1] or "") in SUSPICIOUS_TLDS)
        # crude “entropy-like” signal for long alnum query blobs
        alnum = sum(c.isalnum() for c in feats["query"])
        feats["query_ratio_alnum"] = (alnum / max(1, len(feats["query"]))) if feats["query"] else 0.0
        feats["parse_error"] = False
    except Exception:
        feats = {"parse_error": True}
    return feats

def heuristic_score(feats: dict) -> float:
    """0..1 suspicious score."""
    if feats.get("parse_error"):
        return 0.70  # unparsable => suspicious
    score = 0.0
    score += 0.25 * feats["kw_in_path"]
    score += 0.20 * feats["kw_in_subdomain_only"]
    score += 0.10 * feats["kw_in_host"]
    score += 0.10 * feats["hyphen_in_regdom"]
    score += 0.10 * (feats["labels"] >= 4)
    score += 0.10 * feats["has_punycode"]
    score += 0.10 * feats["suspicious_tld"]
    score += 0.05 * feats["has_at"]
    score += 0.05 * feats["has_port"]
    score += 0.10 * (feats["len_url"] >= 100)
    if feats["query"] and len(feats["query"]) >= 40 and feats["query_ratio_alnum"] > 0.9:
        score += 0.10
    return max(0.0, min(1.0, score))

def heuristic_reasons(feats: dict) -> str:
    if feats.get("parse_error"):
        return "parse error"
    rs = []
    if feats.get("kw_in_path"): rs.append("keyword in path")
    if feats.get("kw_in_subdomain_only"): rs.append("keyword in subdomain")
    if feats.get("kw_in_host") and not feats.get("kw_in_subdomain_only"): rs.append("keyword in host")
    if feats.get("hyphen_in_regdom"): rs.append("hyphen in registered domain")
    if feats.get("labels", 0) >= 4: rs.append("deep subdomain nesting")
    if feats.get("has_punycode"): rs.append("punycode host")
    if feats.get("suspicious_tld"): rs.append(f"suspicious TLD: {feats.get('tld')}")
    if feats.get("has_at"): rs.append("@ in URL")
    if feats.get("has_port"): rs.append("explicit port")
    if feats.get("len_url", 0) >= 100: rs.append("very long URL")
    if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9:
        rs.append("long query blob")
    return ", ".join(rs) if rs else "no heuristic triggers"

# -------- Core --------
def analyze(text: str, forensic: bool, show_json: bool):
    """
    One output: Markdown with
      - verdict
      - table (model, heuristic, fused + reasons)
      - optional forensic blocks (tokens, logits, [CLS])
      - optional raw JSON (copy/paste)
    """
    text = (text or "").strip()
    if not text:
        return "Paste an email body or a URL."

    urls = [text] if (text.lower().startswith(("http://","https://","www.")) and " " not in text) else _extract_urls(text)
    if not urls:
        return "No URLs detected in the text."

    tok, mdl = _load_model()
    id2label_raw = getattr(mdl.config, "id2label", None) or {}
    id2label = {}
    for k, v in id2label_raw.items():
        try:
            id2label[int(k)] = v
        except Exception:
            if isinstance(k, str) and k.startswith("LABEL_"):
                idx = int(k.split("_")[-1])
                id2label[idx] = v

    header_rows = []
    forensic_blocks = []
    export_data = {"model_id": URL_MODEL_ID, "items": []}
    any_unsafe = False

    for u in urls:
        # --- Encode & forward for logits / CLS ---
        max_len = min(512, getattr(mdl.config, "max_position_embeddings", 512) or 512)
        enc = tok(u, truncation=True, max_length=max_len, return_tensors="pt", return_attention_mask=True)
        token_ids = enc["input_ids"][0].tolist()
        tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])
        truncated = enc["input_ids"].shape[1] >= max_len and len(tokens) >= max_len

        t0 = time.time()
        with torch.no_grad():
            out = mdl(**enc, output_hidden_states=True)
        elapsed = time.time() - t0

        logits = out.logits.squeeze(0)              # (num_labels,)
        probs  = _softmax(logits)                   # list[float]
        hidden_states = out.hidden_states
        cls_vec = hidden_states[-1][0, 0, :].cpu().tolist()

        per_class = [
            {"label": _lbl_name(i, id2label), "prob": float(probs[i]), "logit": float(logits[i])}
            for i in range(len(probs))
        ]
        per_class_sorted = sorted(per_class, key=lambda x: x["prob"], reverse=True)
        top = per_class_sorted[0]

        # --- Heuristics & fusion ---
        feats = heuristic_features(u)
        h_score = heuristic_score(feats)
        mdl_phish_like = sum(s["prob"] for s in per_class_sorted if s["label"].lower() in {"phishing","malware","defacement"})
        fused = 0.65 * mdl_phish_like + 0.35 * h_score
        reasons = heuristic_reasons(feats)

        header_rows.append([u, top["label"], top["prob"] * 100.0, h_score, fused, reasons])
        if fused >= 0.50:
            any_unsafe = True

        # collect full details for optional JSON dump
        export_data["items"].append({
            "url": u,
            "token_ids": token_ids,
            "tokens": tokens,
            "truncated": truncated,
            "logits": [float(x) for x in logits.cpu().tolist()],
            "probs":  [float(p) for p in probs],
            "scores_sorted": per_class_sorted,
            "cls_vector": cls_vec,
            "cls_dim": len(cls_vec),
            "elapsed_sec": elapsed,
            "heuristic": feats,
            "heuristic_score": h_score,
            "fused_risk": fused,
        })

        if forensic:
            forensic_blocks.append(
                _forensic_block(
                    url=u,
                    token_ids=token_ids,
                    tokens=tokens,
                    scores_sorted=per_class_sorted,
                    cls_vec=cls_vec,
                    elapsed_s=elapsed,
                    truncated=truncated,
                )
            )

    verdict = "🔴 **UNSAFE (links flagged)**" if any_unsafe else "🟢 **SAFE (no fused risk ≥ 0.50)**"
    body = verdict + "\n\n" + _markdown_results_header(header_rows)

    if forensic and forensic_blocks:
        body += "\n\n---\n\n" + "\n\n---\n\n".join(forensic_blocks)

    if show_json:
        pretty = json.dumps(export_data, ensure_ascii=False, indent=2)
        body += "\n\n---\n\n**Raw forensics JSON (copy & save):**\n"
        body += "```json\n" + pretty + "\n```"

    return body

# -------- UI --------
demo = gr.Interface(
    fn=analyze,
    inputs=[
        gr.Textbox(lines=6, label="Email or URL", placeholder="Paste a URL or a full email…"),
        gr.Checkbox(label="Forensic mode (tokens, logits, [CLS])", value=True),
        gr.Checkbox(label="Show raw JSON at the end (copy/paste)", value=False),
    ],
    outputs=gr.Markdown(label="Results"),
    title="🛡️ PhishingMail — Model + Heuristics (HF Free CPU)",
    description=(
        "We extract links and classify each with a compact malicious-URL model, then fuse with transparent heuristics. "
        "Table shows Model Prob, Heuristic Score, and Fused Risk with reasons. "
        "Toggle Forensic mode for tokens/logits/[CLS]."
    ),
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)