Spaces:
Running
Running
""" | |
WFGY HuggingFace Space – deluxe demo | |
* Generates text before/after WFGY | |
* Shows variance, KL, top-1 shift | |
* Renders overlay histogram | |
""" | |
import base64, io, numpy as np, gradio as gr, wfgy_sdk as w | |
from wfgy_sdk.evaluator import compare_logits | |
from wfgy_sdk.visual import plot_histogram | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed | |
MODEL = "sshleifer/tiny-gpt2" # 124-MB, runs on CPU in ~2 s | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
model = AutoModelForCausalLM.from_pretrained(MODEL) | |
set_seed(42) | |
ENGINE = w.get_engine() # singleton | |
def gen_text(prompt, max_new_tokens=40): | |
ids = tokenizer(prompt, return_tensors="pt").input_ids | |
with torch.no_grad(): | |
out = model.generate(ids, max_new_tokens=max_new_tokens, do_sample=False) | |
return tokenizer.decode(out[0, ids.shape[1]:], skip_special_tokens=True) | |
def wfgy_demo(prompt, enable_wfgy): | |
# ---- generate raw text & logits ---- | |
ids = tokenizer(prompt, return_tensors="pt").input_ids | |
with torch.no_grad(): | |
output = model(ids) | |
raw_logits = output.logits[0, -1].cpu().numpy() | |
# dummy semantic vectors for demo | |
G = np.random.randn(256); G /= np.linalg.norm(G) | |
I = G + np.random.normal(scale=0.05, size=256) | |
# run WFGY | |
if enable_wfgy: | |
mod_logits = ENGINE.run(input_vec=I, ground_vec=G, logits=raw_logits) | |
else: | |
mod_logits = raw_logits.copy() | |
# decode next-token text for both versions | |
next_raw = tokenizer.decode(int(raw_logits.argmax())) | |
next_mod = tokenizer.decode(int(mod_logits.argmax())) | |
raw_txt = prompt + next_raw | |
mod_txt = prompt + next_mod | |
# metrics | |
m = compare_logits(raw_logits, mod_logits) | |
badge = f"variance ↓ {(1-m['std_ratio'])*100:.0f}% | KL {m['kl_divergence']:.2f}" | |
top1 = "✔" if m["top1_shift"] else "✘" | |
badge += f" | top-1 changed {top1}" | |
# histogram | |
fig = plot_histogram(raw_logits, mod_logits, show=False) | |
buf = io.BytesIO(); fig.savefig(buf, format="png"); fig.clf() | |
img_b64 = "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode() | |
return raw_txt, mod_txt, badge, img_b64 | |
with gr.Blocks(title="WFGY variance gate") as demo: | |
gr.Markdown("## WFGY Live Demo — variance drop in real-time") | |
prompt = gr.Textbox(label="Prompt", placeholder="Ask anything…", lines=2) | |
enable = gr.Checkbox(label="Enable WFGY", value=True) | |
run_btn = gr.Button("Run") | |
with gr.Row(): | |
raw_out = gr.Textbox(label="Raw GPT-2") | |
mod_out = gr.Textbox(label="After WFGY") | |
metrics = gr.HTML(label="Metrics") | |
hist = gr.Image(label="Logit distribution", elem_id="hist", width=450) | |
run_btn.click(wfgy_demo, [prompt, enable], | |
[raw_out, mod_out, metrics, hist]) | |
gr.Markdown( | |
"⭐ If the variance drop looks magic, [**star the repo**]" | |
"(https://github.com/onestardao/WFGY) and help unlock WFGY 2.0!" | |
) | |
demo.launch() | |