Spaces:
Running
Running
File size: 2,397 Bytes
14903f8 c1eda44 1348198 914ff43 a239ad1 048a5a0 57a96b8 7429b83 048a5a0 c1eda44 ef37700 048a5a0 7429b83 048a5a0 c1eda44 048a5a0 c1eda44 7429b83 048a5a0 c1eda44 6aba93c 7429b83 a239ad1 048a5a0 7429b83 4ea805f ef37700 048a5a0 ef37700 048a5a0 7429b83 048a5a0 7429b83 048a5a0 822e03e e95bc22 6aba93c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import io, traceback, numpy as np, gradio as gr, matplotlib
matplotlib.use("Agg")
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM
from wfgy_sdk import get_engine
from wfgy_sdk.evaluator import compare_logits, plot_histogram
from tabulate import tabulate
MODEL = "sshleifer/tiny-gpt2"
tok = AutoTokenizer.from_pretrained(MODEL)
mdl = AutoModelForCausalLM.from_pretrained(MODEL)
eng = get_engine()
def run(prompt: str):
prompt = prompt.strip()
if not prompt:
return "", "", "", None, None
try:
ids = tok(prompt, return_tensors="pt").input_ids
rawL = mdl(ids).logits[0, -1].detach().cpu().numpy()
G = np.random.randn(256).astype(np.float32)
I = G + np.random.normal(scale=0.05, size=256).astype(np.float32)
modL = eng.run(I, G, rawL)
m = compare_logits(rawL, modL)
tbl = tabulate(
[[f"{m['std_ratio']:.3f}",
f"{m['var_drop']*100:4.1f} %",
f"{m['kl']:.3f}",
"✔" if m['top1'] else "✘"]],
headers=["std_ratio", "▼ var", "KL", "top-1"],
tablefmt="github")
headline = f"▼ var {m['var_drop']*100:4.1f} % | KL {m['kl']:.3f}"
fig = plot_histogram(rawL, modL)
buf = io.BytesIO(); fig.savefig(buf, format="png"); buf.seek(0)
img = Image.open(buf)
raw_txt = prompt + tok.decode(int(rawL.argmax()))
mod_txt = prompt + tok.decode(int(modL.argmax()))
return raw_txt, mod_txt, headline, tbl, img
except Exception:
tb = traceback.format_exc()
return "runtime error", tb, "runtime error", "", None
with gr.Blocks(title="WFGY variance gate") as demo:
gr.Markdown("# 🧠 WFGY simulation demo")
prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
btn = gr.Button("🚀 Run")
with gr.Row():
raw_box = gr.Textbox(label="Raw GPT-2")
mod_box = gr.Textbox(label="After WFGY")
headline = gr.Markdown()
metrics = gr.Markdown() # ← 新增數值表
img = gr.Image(label="Logit histogram", type="pil")
btn.click(run, prompt,
[raw_box, mod_box, headline, metrics, img])
gr.Markdown("---\n"
"### ⭐ 10 000 stars → unlock **WFGY 2.0** by 2025-08-01")
if __name__ == "__main__":
demo.queue().launch()
|