File size: 1,719 Bytes
914ff43
a239ad1
914ff43
18538cc
a239ad1
1348198
914ff43
a239ad1
57a96b8
18538cc
 
 
1348198
914ff43
a239ad1
914ff43
a239ad1
c92d178
a239ad1
 
 
 
243324b
a239ad1
 
243324b
a239ad1
 
 
1348198
a239ad1
 
 
1348198
a239ad1
 
 
 
1348198
4ea805f
a239ad1
 
1348198
a239ad1
 
914ff43
a239ad1
822e03e
e95bc22
a239ad1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
HF Space · WFGY 1-click Variance Gate  (貼上就能部署)
"""

import io, numpy as np, gradio as gr, matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
from wfgy_sdk import get_engine
from wfgy_sdk.evaluator import compare_logits, plot_histogram

MODEL  = "sshleifer/tiny-gpt2"
tok    = AutoTokenizer.from_pretrained(MODEL)
mdl    = AutoModelForCausalLM.from_pretrained(MODEL)
ENG    = get_engine()

def run(prompt:str):
    if not prompt.strip():
        return "-", "-", "-", None

    inp   = tok(prompt, return_tensors="pt")
    rawL  = mdl(**inp).logits[0, -1].detach().cpu().numpy()
    I, G  = np.random.randn(2, 256).astype(np.float32)
    modL  = ENG.run(I, G, rawL)

    mets  = compare_logits(rawL, modL)
    head  = f"▼ Var {mets['var_drop']*100:.1f}% | KL {mets['kl']:.2f}"

    # ── 圖表轉成 PNG buffer ──
    fig   = plot_histogram(rawL, modL)
    buf   = io.BytesIO(); fig.savefig(buf, format="png"); buf.seek(0)

    raw_txt = prompt + tok.decode(int(rawL.argmax()))
    mod_txt = prompt + tok.decode(int(modL.argmax()))
    return raw_txt, mod_txt, head, buf

with gr.Blocks(title="WFGY 1-Click Variance Gate") as demo:
    gr.Markdown("# 🧠 WFGY 模擬實驗\n*輸入任意 Prompt,立刻觀看 Logit 直方圖*")
    prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
    run_b  = gr.Button("🚀 Run")

    with gr.Row():
        raw  = gr.Textbox(label="Raw GPT-2")
        mod  = gr.Textbox(label="After WFGY")

    head = gr.Markdown()
    img  = gr.Image(label="Logit Histogram")

    run_b.click(run, prompt, [raw, mod, head, img])

if __name__ == "__main__":
    demo.queue(default_concurrency_limit=2).launch()