import io, numpy as np, matplotlib matplotlib.use("Agg") from PIL import Image import pandas as pd, plotly.express as px, gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM from wfgy_sdk import get_engine from wfgy_sdk.evaluator import compare_logits, plot_histogram # tiny model (CPU-friendly demo) tok = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2") mdl = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2") eng = get_engine() # history buffer history = {"step": [0], "var": [0.0], "kl": [0.0]} # paper table paper_df = pd.DataFrame({ "Benchmark": ["MMLU","GSM8K","BBH","MathBench","TruthfulQA", "XNLI","MLQA","LongBench","VQAv2","OK-VQA"], "Baseline": [61.0,78.0,79.3,72.2,62.4,59.5,78.1,51.4,69.1,65.7], "WFGY": [89.8,98.7,100.7,87.4,90.4,77.3,106.6,69.6,86.6,86.8] }) paper_df["Abs_gain"] = (paper_df["WFGY"]-paper_df["Baseline"]).round(1) paper_df["Rel_gain%"] = ((paper_df["Abs_gain"]/paper_df["Baseline"])*100).round(0) styled_df = ( paper_df.style .background_gradient(cmap="Greens", subset=["Abs_gain","Rel_gain%"]) .format({"Abs_gain":"{:.1f}","Rel_gain%":"{:.0f}"}) ) paper_bar = px.bar( paper_df, x="Benchmark", y="Rel_gain%", title="Relative gain (%)", color="Rel_gain%", color_continuous_scale="Greens", height=300 ) # helpers def top5(logits: np.ndarray): p = torch.softmax(torch.tensor(logits), dim=0).numpy() idx = p.argsort()[-5:][::-1] return "\n".join([f"{tok.decode(int(i))!r}: {p[i]:.2e}" for i in idx]) def hist_plot(): df = pd.DataFrame(history) return px.line(df, x="step", y=["var","kl"], labels={"value":"metric","step":"call"}, title="history (var% ↓ & KL)").update_layout(height=260) def clear_hist(): history["step"][:] = [0]; history["var"][:]=[0.0]; history["kl"][:]=[0.0] return hist_plot() def run(prompt: str): p = prompt.strip() if not p: return "", "", "", "", None, hist_plot() ids = tok(p, return_tensors="pt").input_ids rawL = mdl(ids).logits[0,-1].detach().cpu().numpy() G = np.random.randn(256).astype(np.float32) I = G + np.random.normal(scale=0.05,size=256).astype(np.float32) modL = eng.run(I,G,rawL) m = compare_logits(rawL,modL) n = len(history["step"]) history["step"].append(n); history["var"].append(m["var_drop"]*100); history["kl"].append(m["kl"]) fig = plot_histogram(rawL,modL); buf=io.BytesIO(); fig.savefig(buf,format="png"); buf.seek(0) head = f"▼ var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f} | top-1 {'kept' if m['top1'] else 'changed'}" return top5(rawL), top5(modL), head, Image.open(buf), hist_plot() # UI with gr.Blocks(title="WFGY variance gate demo") as demo: gr.Markdown("# 🧠 WFGY simulation demo") prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat") run_b = gr.Button("🚀 Run") with gr.Row(): raw_box = gr.Textbox(label="Raw top-5 tokens", lines=6) mod_box = gr.Textbox(label="WFGY top-5 tokens", lines=6) headline = gr.Markdown() hist_img = gr.Image(type="pil", label="Logit histogram") hist_p = gr.Plot() clr_b = gr.Button("Clear history") with gr.Accordion("Paper benchmarks", open=False): gr.DataFrame(styled_df, interactive=False, wrap=True) gr.Plot(paper_bar) gr.Markdown("---\n⭐ **10 k GitHub stars before 2025-08-01 unlock WFGY 2.0**") run_b.click(run, prompt, [raw_box,mod_box,headline,hist_img,hist_p]) clr_b.click(clear_hist, None, hist_p) if __name__ == "__main__": demo.queue().launch()