wfgy-demo / app.py
OneStarDao's picture
Update app.py
048a5a0 verified
raw
history blame
2.4 kB
import io, traceback, numpy as np, gradio as gr, matplotlib
matplotlib.use("Agg")
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM
from wfgy_sdk import get_engine
from wfgy_sdk.evaluator import compare_logits, plot_histogram
from tabulate import tabulate
MODEL = "sshleifer/tiny-gpt2"
tok = AutoTokenizer.from_pretrained(MODEL)
mdl = AutoModelForCausalLM.from_pretrained(MODEL)
eng = get_engine()
def run(prompt: str):
prompt = prompt.strip()
if not prompt:
return "", "", "", None, None
try:
ids = tok(prompt, return_tensors="pt").input_ids
rawL = mdl(ids).logits[0, -1].detach().cpu().numpy()
G = np.random.randn(256).astype(np.float32)
I = G + np.random.normal(scale=0.05, size=256).astype(np.float32)
modL = eng.run(I, G, rawL)
m = compare_logits(rawL, modL)
tbl = tabulate(
[[f"{m['std_ratio']:.3f}",
f"{m['var_drop']*100:4.1f} %",
f"{m['kl']:.3f}",
"✔" if m['top1'] else "✘"]],
headers=["std_ratio", "▼ var", "KL", "top-1"],
tablefmt="github")
headline = f"▼ var {m['var_drop']*100:4.1f} % | KL {m['kl']:.3f}"
fig = plot_histogram(rawL, modL)
buf = io.BytesIO(); fig.savefig(buf, format="png"); buf.seek(0)
img = Image.open(buf)
raw_txt = prompt + tok.decode(int(rawL.argmax()))
mod_txt = prompt + tok.decode(int(modL.argmax()))
return raw_txt, mod_txt, headline, tbl, img
except Exception:
tb = traceback.format_exc()
return "runtime error", tb, "runtime error", "", None
with gr.Blocks(title="WFGY variance gate") as demo:
gr.Markdown("# 🧠 WFGY simulation demo")
prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
btn = gr.Button("🚀 Run")
with gr.Row():
raw_box = gr.Textbox(label="Raw GPT-2")
mod_box = gr.Textbox(label="After WFGY")
headline = gr.Markdown()
metrics = gr.Markdown() # ← 新增數值表
img = gr.Image(label="Logit histogram", type="pil")
btn.click(run, prompt,
[raw_box, mod_box, headline, metrics, img])
gr.Markdown("---\n"
"### ⭐ 10 000 stars → unlock **WFGY 2.0** by 2025-08-01")
if __name__ == "__main__":
demo.queue().launch()