wfgy-demo / app.py
OneStarDao's picture
Update app.py
e12152f verified
raw
history blame
3.66 kB
import io, numpy as np, matplotlib
matplotlib.use("Agg")
from PIL import Image
import pandas as pd, plotly.express as px, gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from wfgy_sdk import get_engine
from wfgy_sdk.evaluator import compare_logits, plot_histogram
# tiny model (CPU-friendly demo)
tok = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
mdl = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
eng = get_engine()
# history buffer
history = {"step": [0], "var": [0.0], "kl": [0.0]}
# paper table
paper_df = pd.DataFrame({
"Benchmark": ["MMLU","GSM8K","BBH","MathBench","TruthfulQA",
"XNLI","MLQA","LongBench","VQAv2","OK-VQA"],
"Baseline": [61.0,78.0,79.3,72.2,62.4,59.5,78.1,51.4,69.1,65.7],
"WFGY": [89.8,98.7,100.7,87.4,90.4,77.3,106.6,69.6,86.6,86.8]
})
paper_df["Abs_gain"] = (paper_df["WFGY"]-paper_df["Baseline"]).round(1)
paper_df["Rel_gain%"] = ((paper_df["Abs_gain"]/paper_df["Baseline"])*100).round(0)
styled_df = (
paper_df.style
.background_gradient(cmap="Greens", subset=["Abs_gain","Rel_gain%"])
.format({"Abs_gain":"{:.1f}","Rel_gain%":"{:.0f}"})
)
paper_bar = px.bar(
paper_df, x="Benchmark", y="Rel_gain%",
title="Relative gain (%)", color="Rel_gain%",
color_continuous_scale="Greens", height=300
)
# helpers
def top5(logits: np.ndarray):
p = torch.softmax(torch.tensor(logits), dim=0).numpy()
idx = p.argsort()[-5:][::-1]
return "\n".join([f"{tok.decode(int(i))!r}: {p[i]:.2e}" for i in idx])
def hist_plot():
df = pd.DataFrame(history)
return px.line(df, x="step", y=["var","kl"],
labels={"value":"metric","step":"call"},
title="history (var% ↓ & KL)").update_layout(height=260)
def clear_hist():
history["step"][:] = [0]; history["var"][:]=[0.0]; history["kl"][:]=[0.0]
return hist_plot()
def run(prompt: str):
p = prompt.strip()
if not p:
return "", "", "", "", None, hist_plot()
ids = tok(p, return_tensors="pt").input_ids
rawL = mdl(ids).logits[0,-1].detach().cpu().numpy()
G = np.random.randn(256).astype(np.float32)
I = G + np.random.normal(scale=0.05,size=256).astype(np.float32)
modL = eng.run(I,G,rawL)
m = compare_logits(rawL,modL)
n = len(history["step"])
history["step"].append(n); history["var"].append(m["var_drop"]*100); history["kl"].append(m["kl"])
fig = plot_histogram(rawL,modL); buf=io.BytesIO(); fig.savefig(buf,format="png"); buf.seek(0)
head = f"▼ var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f} | top-1 {'kept' if m['top1'] else 'changed'}"
return top5(rawL), top5(modL), head, Image.open(buf), hist_plot()
# UI
with gr.Blocks(title="WFGY variance gate demo") as demo:
gr.Markdown("# 🧠 WFGY simulation demo")
prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
run_b = gr.Button("🚀 Run")
with gr.Row():
raw_box = gr.Textbox(label="Raw top-5 tokens", lines=6)
mod_box = gr.Textbox(label="WFGY top-5 tokens", lines=6)
headline = gr.Markdown()
hist_img = gr.Image(type="pil", label="Logit histogram")
hist_p = gr.Plot()
clr_b = gr.Button("Clear history")
with gr.Accordion("Paper benchmarks", open=False):
gr.DataFrame(styled_df, interactive=False, wrap=True)
gr.Plot(paper_bar)
gr.Markdown("---\n⭐ **10 k GitHub stars before 2025-08-01 unlock WFGY 2.0**")
run_b.click(run, prompt, [raw_box,mod_box,headline,hist_img,hist_p])
clr_b.click(clear_hist, None, hist_p)
if __name__ == "__main__":
demo.queue().launch()