Spaces:
Running
Running
File size: 3,661 Bytes
24bd8e1 c1eda44 4082288 c1eda44 24bd8e1 26f293f 1348198 914ff43 a239ad1 57a96b8 e12152f 24bd8e1 e12152f 54b95bc 24bd8e1 e12152f 54b95bc 24bd8e1 54b95bc 24bd8e1 e12152f 26f293f e12152f 26f293f e12152f 26f293f e12152f 26f293f e12152f 26f293f e12152f 26f293f 54b95bc 26f293f e12152f 26f293f e12152f 54b95bc e12152f 54b95bc e12152f 54b95bc e12152f 54b95bc e12152f 24bd8e1 e12152f 26f293f 7429b83 54b95bc e12152f 26f293f 54b95bc 26f293f 54b95bc e12152f 54b95bc e12152f 26f293f 54b95bc e12152f 54b95bc e12152f 822e03e e95bc22 6aba93c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import io, numpy as np, matplotlib
matplotlib.use("Agg")
from PIL import Image
import pandas as pd, plotly.express as px, gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from wfgy_sdk import get_engine
from wfgy_sdk.evaluator import compare_logits, plot_histogram
# tiny model (CPU-friendly demo)
tok = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
mdl = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
eng = get_engine()
# history buffer
history = {"step": [0], "var": [0.0], "kl": [0.0]}
# paper table
paper_df = pd.DataFrame({
"Benchmark": ["MMLU","GSM8K","BBH","MathBench","TruthfulQA",
"XNLI","MLQA","LongBench","VQAv2","OK-VQA"],
"Baseline": [61.0,78.0,79.3,72.2,62.4,59.5,78.1,51.4,69.1,65.7],
"WFGY": [89.8,98.7,100.7,87.4,90.4,77.3,106.6,69.6,86.6,86.8]
})
paper_df["Abs_gain"] = (paper_df["WFGY"]-paper_df["Baseline"]).round(1)
paper_df["Rel_gain%"] = ((paper_df["Abs_gain"]/paper_df["Baseline"])*100).round(0)
styled_df = (
paper_df.style
.background_gradient(cmap="Greens", subset=["Abs_gain","Rel_gain%"])
.format({"Abs_gain":"{:.1f}","Rel_gain%":"{:.0f}"})
)
paper_bar = px.bar(
paper_df, x="Benchmark", y="Rel_gain%",
title="Relative gain (%)", color="Rel_gain%",
color_continuous_scale="Greens", height=300
)
# helpers
def top5(logits: np.ndarray):
p = torch.softmax(torch.tensor(logits), dim=0).numpy()
idx = p.argsort()[-5:][::-1]
return "\n".join([f"{tok.decode(int(i))!r}: {p[i]:.2e}" for i in idx])
def hist_plot():
df = pd.DataFrame(history)
return px.line(df, x="step", y=["var","kl"],
labels={"value":"metric","step":"call"},
title="history (var% ↓ & KL)").update_layout(height=260)
def clear_hist():
history["step"][:] = [0]; history["var"][:]=[0.0]; history["kl"][:]=[0.0]
return hist_plot()
def run(prompt: str):
p = prompt.strip()
if not p:
return "", "", "", "", None, hist_plot()
ids = tok(p, return_tensors="pt").input_ids
rawL = mdl(ids).logits[0,-1].detach().cpu().numpy()
G = np.random.randn(256).astype(np.float32)
I = G + np.random.normal(scale=0.05,size=256).astype(np.float32)
modL = eng.run(I,G,rawL)
m = compare_logits(rawL,modL)
n = len(history["step"])
history["step"].append(n); history["var"].append(m["var_drop"]*100); history["kl"].append(m["kl"])
fig = plot_histogram(rawL,modL); buf=io.BytesIO(); fig.savefig(buf,format="png"); buf.seek(0)
head = f"▼ var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f} | top-1 {'kept' if m['top1'] else 'changed'}"
return top5(rawL), top5(modL), head, Image.open(buf), hist_plot()
# UI
with gr.Blocks(title="WFGY variance gate demo") as demo:
gr.Markdown("# 🧠 WFGY simulation demo")
prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
run_b = gr.Button("🚀 Run")
with gr.Row():
raw_box = gr.Textbox(label="Raw top-5 tokens", lines=6)
mod_box = gr.Textbox(label="WFGY top-5 tokens", lines=6)
headline = gr.Markdown()
hist_img = gr.Image(type="pil", label="Logit histogram")
hist_p = gr.Plot()
clr_b = gr.Button("Clear history")
with gr.Accordion("Paper benchmarks", open=False):
gr.DataFrame(styled_df, interactive=False, wrap=True)
gr.Plot(paper_bar)
gr.Markdown("---\n⭐ **10 k GitHub stars before 2025-08-01 unlock WFGY 2.0**")
run_b.click(run, prompt, [raw_box,mod_box,headline,hist_img,hist_p])
clr_b.click(clear_hist, None, hist_p)
if __name__ == "__main__":
demo.queue().launch()
|