Spaces:
Running
Running
import io, numpy as np, matplotlib | |
matplotlib.use("Agg") | |
from PIL import Image | |
import pandas as pd, plotly.express as px, gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from wfgy_sdk import get_engine | |
from wfgy_sdk.evaluator import compare_logits, plot_histogram | |
# tiny model (CPU-friendly demo) | |
tok = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2") | |
mdl = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2") | |
eng = get_engine() | |
# history buffer | |
history = {"step": [0], "var": [0.0], "kl": [0.0]} | |
# paper table | |
paper_df = pd.DataFrame({ | |
"Benchmark": ["MMLU","GSM8K","BBH","MathBench","TruthfulQA", | |
"XNLI","MLQA","LongBench","VQAv2","OK-VQA"], | |
"Baseline": [61.0,78.0,79.3,72.2,62.4,59.5,78.1,51.4,69.1,65.7], | |
"WFGY": [89.8,98.7,100.7,87.4,90.4,77.3,106.6,69.6,86.6,86.8] | |
}) | |
paper_df["Abs_gain"] = (paper_df["WFGY"]-paper_df["Baseline"]).round(1) | |
paper_df["Rel_gain%"] = ((paper_df["Abs_gain"]/paper_df["Baseline"])*100).round(0) | |
styled_df = ( | |
paper_df.style | |
.background_gradient(cmap="Greens", subset=["Abs_gain","Rel_gain%"]) | |
.format({"Abs_gain":"{:.1f}","Rel_gain%":"{:.0f}"}) | |
) | |
paper_bar = px.bar( | |
paper_df, x="Benchmark", y="Rel_gain%", | |
title="Relative gain (%)", color="Rel_gain%", | |
color_continuous_scale="Greens", height=300 | |
) | |
# helpers | |
def top5(logits: np.ndarray): | |
p = torch.softmax(torch.tensor(logits), dim=0).numpy() | |
idx = p.argsort()[-5:][::-1] | |
return "\n".join([f"{tok.decode(int(i))!r}: {p[i]:.2e}" for i in idx]) | |
def hist_plot(): | |
df = pd.DataFrame(history) | |
return px.line(df, x="step", y=["var","kl"], | |
labels={"value":"metric","step":"call"}, | |
title="history (var% ↓ & KL)").update_layout(height=260) | |
def clear_hist(): | |
history["step"][:] = [0]; history["var"][:]=[0.0]; history["kl"][:]=[0.0] | |
return hist_plot() | |
def run(prompt: str): | |
p = prompt.strip() | |
if not p: | |
return "", "", "", "", None, hist_plot() | |
ids = tok(p, return_tensors="pt").input_ids | |
rawL = mdl(ids).logits[0,-1].detach().cpu().numpy() | |
G = np.random.randn(256).astype(np.float32) | |
I = G + np.random.normal(scale=0.05,size=256).astype(np.float32) | |
modL = eng.run(I,G,rawL) | |
m = compare_logits(rawL,modL) | |
n = len(history["step"]) | |
history["step"].append(n); history["var"].append(m["var_drop"]*100); history["kl"].append(m["kl"]) | |
fig = plot_histogram(rawL,modL); buf=io.BytesIO(); fig.savefig(buf,format="png"); buf.seek(0) | |
head = f"▼ var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f} | top-1 {'kept' if m['top1'] else 'changed'}" | |
return top5(rawL), top5(modL), head, Image.open(buf), hist_plot() | |
# UI | |
with gr.Blocks(title="WFGY variance gate demo") as demo: | |
gr.Markdown("# 🧠 WFGY simulation demo") | |
prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat") | |
run_b = gr.Button("🚀 Run") | |
with gr.Row(): | |
raw_box = gr.Textbox(label="Raw top-5 tokens", lines=6) | |
mod_box = gr.Textbox(label="WFGY top-5 tokens", lines=6) | |
headline = gr.Markdown() | |
hist_img = gr.Image(type="pil", label="Logit histogram") | |
hist_p = gr.Plot() | |
clr_b = gr.Button("Clear history") | |
with gr.Accordion("Paper benchmarks", open=False): | |
gr.DataFrame(styled_df, interactive=False, wrap=True) | |
gr.Plot(paper_bar) | |
gr.Markdown("---\n⭐ **10 k GitHub stars before 2025-08-01 unlock WFGY 2.0**") | |
run_b.click(run, prompt, [raw_box,mod_box,headline,hist_img,hist_p]) | |
clr_b.click(clear_hist, None, hist_p) | |
if __name__ == "__main__": | |
demo.queue().launch() | |