Spaces:
Running
on
Zero
Running
on
Zero
| # app.py | |
| # -------------------------------------------------------------------------------------------------- | |
| # Gradio app for Beeper | |
| # - Loads released safetensors + tokenizer from Hugging Face | |
| # - Auto-sizes pentachora banks to match checkpoints (across Beeper v1..v4) | |
| # - Generation uses same knobs & penalties as training script | |
| # -------------------------------------------------------------------------------------------------- | |
| import gradio as gr | |
| import torch | |
| from tokenizers import Tokenizer | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file as load_safetensors | |
| from beeper_model import BeeperRoseGPT, generate, prepare_model_for_state_dict | |
| # ---------------------------- | |
| # 🔧 Model versions configuration | |
| # ---------------------------- | |
| MODEL_VERSIONS = { | |
| "Beeper v4 (Advanced)": { | |
| "repo_id": "AbstractPhil/beeper-rose-v4", | |
| "model_file": "beeper_final.safetensors", | |
| "description": "Beeper v4 with nearly 40% the full corpus training - the most capable version currently." | |
| }, | |
| "Beeper v3 (Multi-Concept)": { | |
| "repo_id": "AbstractPhil/beeper-rose-v3", | |
| "model_file": "beeper_final.safetensors", | |
| "description": "Beeper v3 with 30+ epochs including reasoning, math, and ethics" | |
| }, | |
| "Beeper v2 (Extended)": { | |
| "repo_id": "AbstractPhil/beeper-rose-v2", | |
| "model_file": "beeper_final.safetensors", | |
| "description": "Beeper v2 with extended training (~15 epochs)" | |
| }, | |
| "Beeper v1 (Original)": { | |
| "repo_id": "AbstractPhil/beeper-rose-tinystories-6l-512d-ctx512", | |
| "model_file": "beeper_rose.safetensors", | |
| "description": "Original Beeper trained on TinyStories" | |
| }, | |
| } | |
| # Base configuration (matches training defaults) | |
| CONFIG = { | |
| "context": 512, | |
| "vocab_size": 8192, | |
| "dim": 512, | |
| "n_heads": 8, | |
| "n_layers": 6, | |
| "mlp_ratio": 4.0, | |
| "temperature": 0.9, | |
| "top_k": 40, | |
| "top_p": 0.9, | |
| "repetition_penalty": 1.10, | |
| "presence_penalty": 0.6, | |
| "frequency_penalty": 0.0, | |
| "resid_dropout": 0.1, | |
| "dropout": 0.0, | |
| "grad_checkpoint": False, | |
| # tokenizer_path not needed here; we load tokenizer.json from the HF repo | |
| } | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Globals (kept simple for a single process Gradio app) | |
| infer: BeeperRoseGPT | None = None | |
| tok: Tokenizer | None = None | |
| current_version: str | None = None | |
| def load_model_version(version_name: str) -> str: | |
| """ | |
| Download the checkpoint and tokenizer, build model, ensure pentachora sizes match, | |
| then strictly load weights. Robust to v1/v2 (no pentas) and v3/v4 (with pentas). | |
| """ | |
| global infer, tok, current_version | |
| if current_version == version_name and infer is not None and tok is not None: | |
| return f"Already loaded: {version_name}" | |
| version_info = MODEL_VERSIONS[version_name] | |
| try: | |
| # Download artifacts | |
| model_file = hf_hub_download( | |
| repo_id=version_info["repo_id"], | |
| filename=version_info["model_file"] | |
| ) | |
| tokenizer_file = hf_hub_download( | |
| repo_id=version_info["repo_id"], | |
| filename="tokenizer.json" | |
| ) | |
| # Load state dict on CPU, inspect pentachora shapes if present | |
| state_dict = load_safetensors(model_file, device="cpu") | |
| # Build model & pre-create pentachora if needed | |
| m = BeeperRoseGPT(CONFIG).to(device) | |
| prepare_model_for_state_dict(m, state_dict, device=device) | |
| # Try strict load first; if shapes drift (rare), fallback to non-strict | |
| try: | |
| missing, unexpected = m.load_state_dict(state_dict, strict=True) | |
| # PyTorch returns NamedTuple; report counts | |
| _msg = f"strict load ok | missing={len(missing)} unexpected={len(unexpected)}" | |
| except Exception as e: | |
| _msg = f"strict load failed ({e}); trying non-strict…" | |
| # Non-strict load for very old snapshots | |
| m.load_state_dict(state_dict, strict=False) | |
| m.eval() | |
| # Tokenizer | |
| t = Tokenizer.from_file(tokenizer_file) | |
| # Swap globals | |
| infer, tok = m, t | |
| current_version = version_name | |
| return f"Successfully loaded: {version_name} ({_msg})" | |
| except Exception as e: | |
| infer = None | |
| tok = None | |
| current_version = None | |
| return f"Error loading {version_name}: {str(e)}" | |
| # Load default on startup — prefer v4, fallback to v3 | |
| try: | |
| load_status = load_model_version("Beeper v4 (Advanced)") | |
| if "Error" in load_status: | |
| print(f"v4 not ready yet: {load_status}") | |
| load_status = load_model_version("Beeper v3 (Multi-Concept)") | |
| except Exception as _: | |
| load_status = load_model_version("Beeper v3 (Multi-Concept)") | |
| print(load_status) | |
| # ---------------------------- | |
| # 💬 Chat wrapper | |
| # ---------------------------- | |
| def beeper_reply( | |
| message: str, | |
| history: list[tuple[str, str]] | None, | |
| model_version: str, | |
| temperature: float | None, | |
| top_k: int | None, | |
| top_p: float | None, | |
| max_new_tokens: int = 80 | |
| ) -> str: | |
| global infer, tok, current_version | |
| # Hot-swap versions if the dropdown changed | |
| if model_version != current_version: | |
| status = load_model_version(model_version) | |
| if "Error" in status: | |
| return f"⚠️ {status}" | |
| if infer is None or tok is None: | |
| return "⚠️ Model not loaded. Please select a version and try again." | |
| # Light prompting heuristics (consistent with your example) | |
| m = message.strip() | |
| if "?" in m: | |
| prompt = f"Q: {m}\nA:" | |
| elif m.lower() in {"hi", "hello", "hey"}: | |
| prompt = 'The little robot said hello. She said, "' | |
| elif "story" in m.lower(): | |
| prompt = "Once upon a time, there was a robot. " | |
| else: | |
| prompt = m + ". " | |
| # Generate | |
| text = generate( | |
| model=infer, | |
| tok=tok, | |
| cfg=CONFIG, | |
| prompt=prompt, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature) if temperature is not None else None, | |
| top_k=int(top_k) if top_k is not None else None, | |
| top_p=float(top_p) if top_p is not None else None, | |
| repetition_penalty=1.10, | |
| presence_penalty=0.8, | |
| frequency_penalty=0.1, | |
| device=device, | |
| detokenize=True, | |
| ) | |
| # Strip prompt echoes & artifacts | |
| if text.startswith(prompt): | |
| text = text[len(prompt):] | |
| text = text.replace("Q:", "").replace("A:", "") | |
| lines = [ln.strip() for ln in text.splitlines() if ln.strip()] | |
| if lines: | |
| text = lines[0] | |
| # If user message echoed at head, trim after first occurrence | |
| head = m[:20].lower() | |
| if text.lower().startswith(head): | |
| idx = text.lower().find(head) | |
| text = text[idx + len(head):].strip() or text | |
| for artifact in ("User:", "Beeper:", "U ser:", "Beep er:", "User ", "Beeper "): | |
| text = text.replace(artifact, "") | |
| text = text.strip() | |
| if not text or len(text) < 3: | |
| text = "I like robots and stories!" | |
| if text[-1:] not in ".!?”\"'": | |
| text += "." | |
| return text[:200] | |
| # ---------------------------- | |
| # 🖼️ Interface | |
| # ---------------------------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🤖 Beeper — A Rose-based Tiny Language Model | |
| Hello! I'm Beeper, a small language model trained with love and care. Please be patient with me — I'm still learning! 💕 | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODEL_VERSIONS.keys()), | |
| value="Beeper v3 (Multi-Concept)", # safer default | |
| label="Select Beeper Version", | |
| info="Choose which version of Beeper to chat with", | |
| ) | |
| with gr.Column(scale=7): | |
| version_info = gr.Markdown("**Current:** " + MODEL_VERSIONS["Beeper v3 (Multi-Concept)"]["description"]) | |
| def update_version_info(version_name: str): | |
| return f"**Current:** {MODEL_VERSIONS[version_name]['description']}" | |
| model_dropdown.change( | |
| fn=update_version_info, | |
| inputs=[model_dropdown], | |
| outputs=[version_info], | |
| ) | |
| chatbot = gr.Chatbot(label="Chat with Beeper", height=400) | |
| msg = gr.Textbox(label="Message", placeholder="Type your message here...") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| temperature_slider = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature") | |
| with gr.Column(scale=2): | |
| top_k_slider = gr.Slider(1, 100, value=40, step=1, label="Top-k") | |
| with gr.Column(scale=2): | |
| top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| with gr.Column(scale=2): | |
| max_new_tokens_slider = gr.Slider(20, 512, value=128, step=1, label="Max new tokens") | |
| with gr.Row(): | |
| submit = gr.Button("Send", variant="primary") | |
| clear = gr.Button("Clear") | |
| gr.Examples( | |
| examples=[ | |
| ["Hello Beeper! How are you today?"], | |
| ["Can you tell me a story about a robot?"], | |
| ["What do you like to do for fun?"], | |
| ["What makes you happy?"], | |
| ["Tell me about your dreams"], | |
| ], | |
| inputs=msg, | |
| ) | |
| def respond(message, chat_history, model_version, temperature, top_k, top_p, max_new_tokens): | |
| if chat_history is None: | |
| chat_history = [] | |
| response = beeper_reply(message, chat_history, model_version, temperature, top_k, top_p, max_new_tokens) | |
| chat_history.append((message, response)) | |
| return "", chat_history | |
| msg.submit( | |
| respond, | |
| [msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider, max_new_tokens_slider], | |
| [msg, chatbot], | |
| ) | |
| submit.click( | |
| respond, | |
| [msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider, max_new_tokens_slider], | |
| [msg, chatbot], | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| if __name__ == "__main__": | |
| demo.launch() | |