Spaces:
Running
Running
import io, traceback, numpy as np, gradio as gr, matplotlib | |
matplotlib.use("Agg") | |
from PIL import Image | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from wfgy_sdk import get_engine | |
from wfgy_sdk.evaluator import compare_logits, plot_histogram | |
from tabulate import tabulate | |
MODEL = "sshleifer/tiny-gpt2" | |
tok = AutoTokenizer.from_pretrained(MODEL) | |
mdl = AutoModelForCausalLM.from_pretrained(MODEL) | |
eng = get_engine() | |
def run(prompt: str): | |
prompt = prompt.strip() | |
if not prompt: | |
return "", "", "", None, None | |
try: | |
ids = tok(prompt, return_tensors="pt").input_ids | |
rawL = mdl(ids).logits[0, -1].detach().cpu().numpy() | |
G = np.random.randn(256).astype(np.float32) | |
I = G + np.random.normal(scale=0.05, size=256).astype(np.float32) | |
modL = eng.run(I, G, rawL) | |
m = compare_logits(rawL, modL) | |
tbl = tabulate( | |
[[f"{m['std_ratio']:.3f}", | |
f"{m['var_drop']*100:4.1f} %", | |
f"{m['kl']:.3f}", | |
"✔" if m['top1'] else "✘"]], | |
headers=["std_ratio", "▼ var", "KL", "top-1"], | |
tablefmt="github") | |
headline = f"▼ var {m['var_drop']*100:4.1f} % | KL {m['kl']:.3f}" | |
fig = plot_histogram(rawL, modL) | |
buf = io.BytesIO(); fig.savefig(buf, format="png"); buf.seek(0) | |
img = Image.open(buf) | |
raw_txt = prompt + tok.decode(int(rawL.argmax())) | |
mod_txt = prompt + tok.decode(int(modL.argmax())) | |
return raw_txt, mod_txt, headline, tbl, img | |
except Exception: | |
tb = traceback.format_exc() | |
return "runtime error", tb, "runtime error", "", None | |
with gr.Blocks(title="WFGY variance gate") as demo: | |
gr.Markdown("# 🧠 WFGY simulation demo") | |
prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat") | |
btn = gr.Button("🚀 Run") | |
with gr.Row(): | |
raw_box = gr.Textbox(label="Raw GPT-2") | |
mod_box = gr.Textbox(label="After WFGY") | |
headline = gr.Markdown() | |
metrics = gr.Markdown() # ← 新增數值表 | |
img = gr.Image(label="Logit histogram", type="pil") | |
btn.click(run, prompt, | |
[raw_box, mod_box, headline, metrics, img]) | |
gr.Markdown("---\n" | |
"### ⭐ 10 000 stars → unlock **WFGY 2.0** by 2025-08-01") | |
if __name__ == "__main__": | |
demo.queue().launch() | |