File size: 3,661 Bytes
24bd8e1
c1eda44
4082288
c1eda44
24bd8e1
26f293f
1348198
914ff43
a239ad1
57a96b8
e12152f
24bd8e1
 
 
 
e12152f
54b95bc
24bd8e1
e12152f
54b95bc
24bd8e1
 
54b95bc
24bd8e1
 
e12152f
 
26f293f
 
 
e12152f
 
26f293f
 
 
 
 
 
 
 
 
e12152f
 
 
 
 
 
26f293f
e12152f
26f293f
 
 
e12152f
26f293f
e12152f
26f293f
54b95bc
26f293f
 
e12152f
26f293f
e12152f
 
54b95bc
e12152f
 
54b95bc
e12152f
 
 
54b95bc
e12152f
54b95bc
e12152f
 
24bd8e1
e12152f
26f293f
7429b83
54b95bc
e12152f
26f293f
54b95bc
26f293f
 
 
54b95bc
 
e12152f
 
54b95bc
 
e12152f
26f293f
54b95bc
e12152f
54b95bc
e12152f
 
822e03e
e95bc22
6aba93c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import io, numpy as np, matplotlib
matplotlib.use("Agg")

from PIL import Image
import pandas as pd, plotly.express as px, gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from wfgy_sdk import get_engine
from wfgy_sdk.evaluator import compare_logits, plot_histogram

# tiny model (CPU-friendly demo)
tok = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
mdl = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
eng = get_engine()

# history buffer
history = {"step": [0], "var": [0.0], "kl": [0.0]}

# paper table
paper_df = pd.DataFrame({
    "Benchmark": ["MMLU","GSM8K","BBH","MathBench","TruthfulQA",
                  "XNLI","MLQA","LongBench","VQAv2","OK-VQA"],
    "Baseline":  [61.0,78.0,79.3,72.2,62.4,59.5,78.1,51.4,69.1,65.7],
    "WFGY":      [89.8,98.7,100.7,87.4,90.4,77.3,106.6,69.6,86.6,86.8]
})
paper_df["Abs_gain"]  = (paper_df["WFGY"]-paper_df["Baseline"]).round(1)
paper_df["Rel_gain%"] = ((paper_df["Abs_gain"]/paper_df["Baseline"])*100).round(0)

styled_df = (
    paper_df.style
    .background_gradient(cmap="Greens", subset=["Abs_gain","Rel_gain%"])
    .format({"Abs_gain":"{:.1f}","Rel_gain%":"{:.0f}"})
)

paper_bar = px.bar(
    paper_df, x="Benchmark", y="Rel_gain%",
    title="Relative gain (%)", color="Rel_gain%",
    color_continuous_scale="Greens", height=300
)

# helpers
def top5(logits: np.ndarray):
    p = torch.softmax(torch.tensor(logits), dim=0).numpy()
    idx = p.argsort()[-5:][::-1]
    return "\n".join([f"{tok.decode(int(i))!r}: {p[i]:.2e}" for i in idx])

def hist_plot():
    df = pd.DataFrame(history)
    return px.line(df, x="step", y=["var","kl"],
                   labels={"value":"metric","step":"call"},
                   title="history (var% ↓  &  KL)").update_layout(height=260)

def clear_hist():
    history["step"][:] = [0]; history["var"][:]=[0.0]; history["kl"][:]=[0.0]
    return hist_plot()

def run(prompt: str):
    p = prompt.strip()
    if not p:
        return "", "", "", "", None, hist_plot()

    ids  = tok(p, return_tensors="pt").input_ids
    rawL = mdl(ids).logits[0,-1].detach().cpu().numpy()
    G    = np.random.randn(256).astype(np.float32)
    I    = G + np.random.normal(scale=0.05,size=256).astype(np.float32)
    modL = eng.run(I,G,rawL)

    m = compare_logits(rawL,modL)
    n  = len(history["step"])
    history["step"].append(n); history["var"].append(m["var_drop"]*100); history["kl"].append(m["kl"])

    fig = plot_histogram(rawL,modL); buf=io.BytesIO(); fig.savefig(buf,format="png"); buf.seek(0)

    head = f"▼ var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f} | top-1 {'kept' if m['top1'] else 'changed'}"
    return top5(rawL), top5(modL), head, Image.open(buf), hist_plot()

# UI
with gr.Blocks(title="WFGY variance gate demo") as demo:
    gr.Markdown("# 🧠 WFGY simulation demo")
    prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
    run_b  = gr.Button("🚀 Run")

    with gr.Row():
        raw_box = gr.Textbox(label="Raw top-5 tokens", lines=6)
        mod_box = gr.Textbox(label="WFGY top-5 tokens", lines=6)

    headline = gr.Markdown()
    hist_img = gr.Image(type="pil", label="Logit histogram")
    hist_p   = gr.Plot()
    clr_b    = gr.Button("Clear history")

    with gr.Accordion("Paper benchmarks", open=False):
        gr.DataFrame(styled_df, interactive=False, wrap=True)
        gr.Plot(paper_bar)

    gr.Markdown("---\n⭐ **10 k GitHub stars before 2025-08-01 unlock WFGY 2.0**")

    run_b.click(run, prompt, [raw_box,mod_box,headline,hist_img,hist_p])
    clr_b.click(clear_hist, None, hist_p)

if __name__ == "__main__":
    demo.queue().launch()