File size: 2,981 Bytes
57a96b8
 
 
 
 
 
 
 
903544e
57a96b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903544e
57a96b8
 
 
 
 
 
 
 
 
903544e
 
57a96b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903544e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
WFGY HuggingFace Space – deluxe demo
* Generates text before/after WFGY
* Shows variance, KL, top-1 shift
* Renders overlay histogram
"""

import base64, io, numpy as np, gradio as gr, wfgy_sdk as w
from wfgy_sdk.evaluator import compare_logits
from wfgy_sdk.visual import plot_histogram

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed

MODEL = "sshleifer/tiny-gpt2"  # 124-MB, runs on CPU in ~2 s
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(MODEL)
set_seed(42)

ENGINE = w.get_engine()  # singleton


def gen_text(prompt, max_new_tokens=40):
    ids = tokenizer(prompt, return_tensors="pt").input_ids
    with torch.no_grad():
        out = model.generate(ids, max_new_tokens=max_new_tokens, do_sample=False)
    return tokenizer.decode(out[0, ids.shape[1]:], skip_special_tokens=True)


def wfgy_demo(prompt, enable_wfgy):
    # ---- generate raw text & logits ----
    ids = tokenizer(prompt, return_tensors="pt").input_ids
    with torch.no_grad():
        output = model(ids)
    raw_logits = output.logits[0, -1].cpu().numpy()

    # dummy semantic vectors for demo
    G = np.random.randn(256); G /= np.linalg.norm(G)
    I = G + np.random.normal(scale=0.05, size=256)

    # run WFGY
    if enable_wfgy:
        mod_logits = ENGINE.run(input_vec=I, ground_vec=G, logits=raw_logits)
    else:
        mod_logits = raw_logits.copy()

    # decode next-token text for both versions
    next_raw = tokenizer.decode(int(raw_logits.argmax()))
    next_mod = tokenizer.decode(int(mod_logits.argmax()))
    raw_txt  = prompt + next_raw
    mod_txt  = prompt + next_mod

    # metrics
    m = compare_logits(raw_logits, mod_logits)
    badge = f"variance ↓ {(1-m['std_ratio'])*100:.0f}% | KL {m['kl_divergence']:.2f}"
    top1  = "✔" if m["top1_shift"] else "✘"
    badge += f" | top-1 changed {top1}"

    # histogram
    fig = plot_histogram(raw_logits, mod_logits, show=False)
    buf = io.BytesIO(); fig.savefig(buf, format="png"); fig.clf()
    img_b64 = "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()

    return raw_txt, mod_txt, badge, img_b64


with gr.Blocks(title="WFGY variance gate") as demo:
    gr.Markdown("## WFGY Live Demo  — variance drop in real-time")

    prompt = gr.Textbox(label="Prompt", placeholder="Ask anything…", lines=2)
    enable = gr.Checkbox(label="Enable WFGY", value=True)
    run_btn = gr.Button("Run")

    with gr.Row():
        raw_out = gr.Textbox(label="Raw GPT-2")
        mod_out = gr.Textbox(label="After WFGY")

    metrics = gr.HTML(label="Metrics")
    hist    = gr.Image(label="Logit distribution", elem_id="hist", width=450)

    run_btn.click(wfgy_demo, [prompt, enable],
                  [raw_out, mod_out, metrics, hist])

    gr.Markdown(
        "⭐ If the variance drop looks magic, [**star the repo**]"
        "(https://github.com/onestardao/WFGY) and help unlock WFGY 2.0!"
    )

demo.launch()