meet-beeper / app.py
AbstractPhil's picture
Update app.py
5d8d3ef verified
raw
history blame
10.2 kB
# app.py
# --------------------------------------------------------------------------------------------------
# Gradio app for Beeper
# - Loads released safetensors + tokenizer from Hugging Face
# - Auto-sizes pentachora banks to match checkpoints (across Beeper v1..v4)
# - Generation uses same knobs & penalties as training script
# --------------------------------------------------------------------------------------------------
import gradio as gr
import torch
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_safetensors
from beeper_model import BeeperRoseGPT, generate, prepare_model_for_state_dict
# ----------------------------
# πŸ”§ Model versions configuration
# ----------------------------
MODEL_VERSIONS = {
"Beeper v4 (Advanced)": {
"repo_id": "AbstractPhil/beeper-rose-v4",
"model_file": "beeper_final.safetensors",
"description": "Beeper v4 with nearly 40% the full corpus training - the most capable version currently."
},
"Beeper v3 (Multi-Concept)": {
"repo_id": "AbstractPhil/beeper-rose-v3",
"model_file": "beeper_final.safetensors",
"description": "Beeper v3 with 30+ epochs including reasoning, math, and ethics"
},
"Beeper v2 (Extended)": {
"repo_id": "AbstractPhil/beeper-rose-v2",
"model_file": "beeper_final.safetensors",
"description": "Beeper v2 with extended training (~15 epochs)"
},
"Beeper v1 (Original)": {
"repo_id": "AbstractPhil/beeper-rose-tinystories-6l-512d-ctx512",
"model_file": "beeper_rose.safetensors",
"description": "Original Beeper trained on TinyStories"
},
}
# Base configuration (matches training defaults)
CONFIG = {
"context": 512,
"vocab_size": 8192,
"dim": 512,
"n_heads": 8,
"n_layers": 6,
"mlp_ratio": 4.0,
"temperature": 0.9,
"top_k": 40,
"top_p": 0.9,
"repetition_penalty": 1.10,
"presence_penalty": 0.6,
"frequency_penalty": 0.0,
"resid_dropout": 0.1,
"dropout": 0.0,
"grad_checkpoint": False,
# tokenizer_path not needed here; we load tokenizer.json from the HF repo
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Globals (kept simple for a single process Gradio app)
infer: BeeperRoseGPT | None = None
tok: Tokenizer | None = None
current_version: str | None = None
def load_model_version(version_name: str) -> str:
"""
Download the checkpoint and tokenizer, build model, ensure pentachora sizes match,
then strictly load weights. Robust to v1/v2 (no pentas) and v3/v4 (with pentas).
"""
global infer, tok, current_version
if current_version == version_name and infer is not None and tok is not None:
return f"Already loaded: {version_name}"
version_info = MODEL_VERSIONS[version_name]
try:
# Download artifacts
model_file = hf_hub_download(
repo_id=version_info["repo_id"],
filename=version_info["model_file"]
)
tokenizer_file = hf_hub_download(
repo_id=version_info["repo_id"],
filename="tokenizer.json"
)
# Load state dict on CPU, inspect pentachora shapes if present
state_dict = load_safetensors(model_file, device="cpu")
# Build model & pre-create pentachora if needed
m = BeeperRoseGPT(CONFIG).to(device)
prepare_model_for_state_dict(m, state_dict, device=device)
# Try strict load first; if shapes drift (rare), fallback to non-strict
try:
missing, unexpected = m.load_state_dict(state_dict, strict=True)
# PyTorch returns NamedTuple; report counts
_msg = f"strict load ok | missing={len(missing)} unexpected={len(unexpected)}"
except Exception as e:
_msg = f"strict load failed ({e}); trying non-strict…"
# Non-strict load for very old snapshots
m.load_state_dict(state_dict, strict=False)
m.eval()
# Tokenizer
t = Tokenizer.from_file(tokenizer_file)
# Swap globals
infer, tok = m, t
current_version = version_name
return f"Successfully loaded: {version_name} ({_msg})"
except Exception as e:
infer = None
tok = None
current_version = None
return f"Error loading {version_name}: {str(e)}"
# Load default on startup β€” prefer v4, fallback to v3
try:
load_status = load_model_version("Beeper v4 (Advanced)")
if "Error" in load_status:
print(f"v4 not ready yet: {load_status}")
load_status = load_model_version("Beeper v3 (Multi-Concept)")
except Exception as _:
load_status = load_model_version("Beeper v3 (Multi-Concept)")
print(load_status)
# ----------------------------
# πŸ’¬ Chat wrapper
# ----------------------------
def beeper_reply(
message: str,
history: list[tuple[str, str]] | None,
model_version: str,
temperature: float | None,
top_k: int | None,
top_p: float | None,
max_new_tokens: int = 80
) -> str:
global infer, tok, current_version
# Hot-swap versions if the dropdown changed
if model_version != current_version:
status = load_model_version(model_version)
if "Error" in status:
return f"⚠️ {status}"
if infer is None or tok is None:
return "⚠️ Model not loaded. Please select a version and try again."
# Light prompting heuristics (consistent with your example)
m = message.strip()
if "?" in m:
prompt = f"Q: {m}\nA:"
elif m.lower() in {"hi", "hello", "hey"}:
prompt = 'The little robot said hello. She said, "'
elif "story" in m.lower():
prompt = "Once upon a time, there was a robot. "
else:
prompt = m + ". "
# Generate
text = generate(
model=infer,
tok=tok,
cfg=CONFIG,
prompt=prompt,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature) if temperature is not None else None,
top_k=int(top_k) if top_k is not None else None,
top_p=float(top_p) if top_p is not None else None,
repetition_penalty=1.10,
presence_penalty=0.8,
frequency_penalty=0.1,
device=device,
detokenize=True,
)
# Strip prompt echoes & artifacts
if text.startswith(prompt):
text = text[len(prompt):]
text = text.replace("Q:", "").replace("A:", "")
lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
if lines:
text = lines[0]
# If user message echoed at head, trim after first occurrence
head = m[:20].lower()
if text.lower().startswith(head):
idx = text.lower().find(head)
text = text[idx + len(head):].strip() or text
for artifact in ("User:", "Beeper:", "U ser:", "Beep er:", "User ", "Beeper "):
text = text.replace(artifact, "")
text = text.strip()
if not text or len(text) < 3:
text = "I like robots and stories!"
if text[-1:] not in ".!?”\"'":
text += "."
return text[:200]
# ----------------------------
# πŸ–ΌοΈ Interface
# ----------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸ€– Beeper β€” A Rose-based Tiny Language Model
Hello! I'm Beeper, a small language model trained with love and care. Please be patient with me β€” I'm still learning! πŸ’•
"""
)
with gr.Row():
with gr.Column(scale=3):
model_dropdown = gr.Dropdown(
choices=list(MODEL_VERSIONS.keys()),
value="Beeper v3 (Multi-Concept)", # safer default
label="Select Beeper Version",
info="Choose which version of Beeper to chat with",
)
with gr.Column(scale=7):
version_info = gr.Markdown("**Current:** " + MODEL_VERSIONS["Beeper v3 (Multi-Concept)"]["description"])
def update_version_info(version_name: str):
return f"**Current:** {MODEL_VERSIONS[version_name]['description']}"
model_dropdown.change(
fn=update_version_info,
inputs=[model_dropdown],
outputs=[version_info],
)
chatbot = gr.Chatbot(label="Chat with Beeper", height=400)
msg = gr.Textbox(label="Message", placeholder="Type your message here...")
with gr.Row():
with gr.Column(scale=2):
temperature_slider = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature")
with gr.Column(scale=2):
top_k_slider = gr.Slider(1, 100, value=40, step=1, label="Top-k")
with gr.Column(scale=2):
top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
with gr.Column(scale=2):
max_new_tokens_slider = gr.Slider(20, 512, value=128, step=1, label="Max new tokens")
with gr.Row():
submit = gr.Button("Send", variant="primary")
clear = gr.Button("Clear")
gr.Examples(
examples=[
["Hello Beeper! How are you today?"],
["Can you tell me a story about a robot?"],
["What do you like to do for fun?"],
["What makes you happy?"],
["Tell me about your dreams"],
],
inputs=msg,
)
def respond(message, chat_history, model_version, temperature, top_k, top_p, max_new_tokens):
if chat_history is None:
chat_history = []
response = beeper_reply(message, chat_history, model_version, temperature, top_k, top_p, max_new_tokens)
chat_history.append((message, response))
return "", chat_history
msg.submit(
respond,
[msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider, max_new_tokens_slider],
[msg, chatbot],
)
submit.click(
respond,
[msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider, max_new_tokens_slider],
[msg, chatbot],
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch()