wfgy-demo / app.py
OneStarDao's picture
Update app.py
6aba93c verified
raw
history blame
2.1 kB
# HF Space · WFGY variance gate demo (Gradio 4.31+)
import io
import numpy as np
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from wfgy_sdk import get_engine
from wfgy_sdk.evaluator import compare_logits, plot_histogram
MODEL_ID = "sshleifer/tiny-gpt2"
tok = AutoTokenizer.from_pretrained(MODEL_ID)
mdl = AutoModelForCausalLM.from_pretrained(MODEL_ID)
eng = get_engine()
def run(prompt: str):
prompt = prompt.strip()
if not prompt:
return "", "", "no prompt – nothing to show", None
ids = tok(prompt, return_tensors="pt").input_ids
logits_raw = mdl(ids).logits[0, -1].detach().cpu().numpy()
# toy fingerprints
G = np.random.randn(256).astype(np.float32)
I = G + np.random.normal(scale=0.05, size=256).astype(np.float32)
logits_mod = eng.run(I, G, logits_raw)
m = compare_logits(logits_raw, logits_mod)
headline = f"▼ var {m['var_drop']*100:4.1f} % | KL {m['kl']:.3f}"
fig = plot_histogram(logits_raw, logits_mod)
buf = io.BytesIO()
fig.savefig(buf, format="png")
buf.seek(0)
raw_txt = prompt + tok.decode(int(logits_raw.argmax()))
mod_txt = prompt + tok.decode(int(logits_mod.argmax()))
return raw_txt, mod_txt, headline, buf
with gr.Blocks(title="WFGY variance gate") as demo:
gr.Markdown(
"# 🧠 WFGY simulation demo \n"
"Type any prompt and watch the logit variance collapse in real time."
)
prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
btn = gr.Button("🚀 Run")
with gr.Row():
raw_box = gr.Textbox(label="Raw GPT-2")
mod_box = gr.Textbox(label="After WFGY")
headline = gr.Markdown()
img = gr.Image(label="Logit histogram")
btn.click(run, prompt, [raw_box, mod_box, headline, img])
gr.Markdown(
"---\n"
"### ⭐ Help unlock **WFGY 2.0** \n"
"10 000 GitHub stars by **2025-08-01** → next-gen release."
)
if __name__ == "__main__":
# Gradio ≥4.31: queue() has no arg; use default queue size (=2)
demo.queue().launch()