File size: 2,397 Bytes
14903f8
c1eda44
 
1348198
914ff43
a239ad1
048a5a0
57a96b8
7429b83
048a5a0
 
 
c1eda44
ef37700
 
 
048a5a0
7429b83
048a5a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1eda44
 
048a5a0
 
 
c1eda44
7429b83
048a5a0
c1eda44
6aba93c
7429b83
a239ad1
048a5a0
7429b83
4ea805f
ef37700
 
048a5a0
ef37700
048a5a0
 
7429b83
048a5a0
 
7429b83
048a5a0
 
822e03e
e95bc22
6aba93c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import io, traceback, numpy as np, gradio as gr, matplotlib
matplotlib.use("Agg")
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM
from wfgy_sdk import get_engine
from wfgy_sdk.evaluator import compare_logits, plot_histogram
from tabulate import tabulate

MODEL = "sshleifer/tiny-gpt2"
tok   = AutoTokenizer.from_pretrained(MODEL)
mdl   = AutoModelForCausalLM.from_pretrained(MODEL)
eng   = get_engine()

def run(prompt: str):
    prompt = prompt.strip()
    if not prompt:
        return "", "", "", None, None
    try:
        ids  = tok(prompt, return_tensors="pt").input_ids
        rawL = mdl(ids).logits[0, -1].detach().cpu().numpy()
        G    = np.random.randn(256).astype(np.float32)
        I    = G + np.random.normal(scale=0.05, size=256).astype(np.float32)
        modL = eng.run(I, G, rawL)

        m = compare_logits(rawL, modL)
        tbl = tabulate(
            [[f"{m['std_ratio']:.3f}",
              f"{m['var_drop']*100:4.1f} %",
              f"{m['kl']:.3f}",
              "✔" if m['top1'] else "✘"]],
            headers=["std_ratio", "▼ var", "KL", "top-1"],
            tablefmt="github")
        headline = f"▼ var {m['var_drop']*100:4.1f} % | KL {m['kl']:.3f}"

        fig = plot_histogram(rawL, modL)
        buf = io.BytesIO(); fig.savefig(buf, format="png"); buf.seek(0)
        img = Image.open(buf)

        raw_txt = prompt + tok.decode(int(rawL.argmax()))
        mod_txt = prompt + tok.decode(int(modL.argmax()))
        return raw_txt, mod_txt, headline, tbl, img
    except Exception:
        tb = traceback.format_exc()
        return "runtime error", tb, "runtime error", "", None

with gr.Blocks(title="WFGY variance gate") as demo:
    gr.Markdown("# 🧠 WFGY simulation demo")
    prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
    btn    = gr.Button("🚀 Run")

    with gr.Row():
        raw_box = gr.Textbox(label="Raw GPT-2")
        mod_box = gr.Textbox(label="After WFGY")

    headline = gr.Markdown()
    metrics  = gr.Markdown()           #  ← 新增數值表
    img      = gr.Image(label="Logit histogram", type="pil")

    btn.click(run, prompt,
              [raw_box, mod_box, headline, metrics, img])

    gr.Markdown("---\n"
                "### ⭐ 10 000 stars → unlock **WFGY 2.0** by 2025-08-01")

if __name__ == "__main__":
    demo.queue().launch()