meet-beeper / app.py
AbstractPhil's picture
Update app.py
f3fa540 verified
raw
history blame
7.42 kB
import gradio as gr
import torch
from beeper_model import BeeperRoseGPT, generate
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_safetensors
# ----------------------------
# 🔧 Model versions configuration
# ----------------------------
MODEL_VERSIONS = {
"Beeper v1 (Original)": {
"repo_id": "AbstractPhil/beeper-rose-tinystories-6l-512d-ctx512",
"model_file": "beeper_rose_final.safetensors",
"description": "Original Beeper trained on TinyStories"
},
"Beeper v2 (Extended)": {
"repo_id": "AbstractPhil/beeper-rose-v2",
"model_file": "beeper_rose_final.safetensors",
"description": "Beeper v2 with extended training (~15 epochs) on a good starting corpus of general knowledge."
}
}
# Base configuration
config = {
"context": 512,
"vocab_size": 8192,
"dim": 512,
"n_heads": 8,
"n_layers": 6,
"mlp_ratio": 4.0,
"temperature": 0.9,
"top_k": 40,
"top_p": 0.9,
"repetition_penalty": 1.1,
"presence_penalty": 0.6,
"frequency_penalty": 0.0,
"resid_dropout": 0.1,
"dropout": 0.0,
"grad_checkpoint": False,
"tokenizer_path": "beeper.tokenizer.json"
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Global model and tokenizer variables
infer = None
tok = None
current_version = None
def load_model_version(version_name):
"""Load the selected model version"""
global infer, tok, current_version
if current_version == version_name and infer is not None:
return f"Already loaded: {version_name}"
version_info = MODEL_VERSIONS[version_name]
try:
# Download model and tokenizer files
model_file = hf_hub_download(
repo_id=version_info["repo_id"],
filename=version_info["model_file"]
)
tokenizer_file = hf_hub_download(
repo_id=version_info["repo_id"],
filename="tokenizer.json"
)
# Initialize model
infer = BeeperRoseGPT(config).to(device)
# Load safetensors
state_dict = load_safetensors(model_file, device=str(device))
infer.load_state_dict(state_dict)
infer.eval()
# Load tokenizer
tok = Tokenizer.from_file(tokenizer_file)
current_version = version_name
return f"Successfully loaded: {version_name}"
except Exception as e:
return f"Error loading {version_name}: {str(e)}"
# Load default model on startup
load_status = load_model_version("Beeper v1 (Original)")
print(load_status)
# ----------------------------
# 💬 Gradio Chat Wrapper
# ----------------------------
def beeper_reply(message, history, model_version, temperature=None, top_k=None, top_p=None):
global infer, tok, current_version
# Load model if version changed
if model_version != current_version:
status = load_model_version(model_version)
if "Error" in status:
return f"⚠️ {status}"
# Check if model is loaded
if infer is None or tok is None:
return "⚠️ Model not loaded. Please select a version and try again."
# Use defaults if not provided (for examples caching)
if temperature is None:
temperature = 0.9
if top_k is None:
top_k = 40
if top_p is None:
top_p = 0.9
# Build conversation context
prompt_parts = []
if history:
for h in history:
if h[0]: # User message exists
prompt_parts.append(f"User: {h[0]}")
if h[1]: # Assistant response exists
prompt_parts.append(f"Beeper: {h[1]}")
# Add current message
prompt_parts.append(f"User: {message}")
prompt_parts.append("Beeper:")
prompt = "\n".join(prompt_parts)
# Generate response
response = generate(
model=infer,
tok=tok,
cfg=config,
prompt=prompt,
max_new_tokens=128,
temperature=float(temperature),
top_k=int(top_k),
top_p=float(top_p),
repetition_penalty=config["repetition_penalty"],
presence_penalty=config["presence_penalty"],
frequency_penalty=config["frequency_penalty"],
device=device,
detokenize=True
)
# Clean up response - remove the prompt part if it's included
if response.startswith(prompt):
response = response[len(prompt):].strip()
return response
# ----------------------------
# 🖼️ Interface
# ----------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🤖 Beeper - A Rose-based Tiny Language Model
Hello! I'm Beeper, a small language model trained with love and care. Please be patient with me - I'm still learning! 💕
"""
)
with gr.Row():
with gr.Column(scale=3):
model_dropdown = gr.Dropdown(
choices=list(MODEL_VERSIONS.keys()),
value="Beeper v1 (Original)",
label="Select Beeper Version",
info="Choose which version of Beeper to chat with"
)
with gr.Column(scale=7):
version_info = gr.Markdown("**Current:** Beeper v1 - Original training on TinyStories")
# Update version info when dropdown changes
def update_version_info(version_name):
info = MODEL_VERSIONS[version_name]["description"]
return f"**Current:** {info}"
model_dropdown.change(
fn=update_version_info,
inputs=[model_dropdown],
outputs=[version_info]
)
# Chat interface
chatbot = gr.Chatbot(label="Chat with Beeper", type="messages", height=400)
msg = gr.Textbox(label="Message", placeholder="Type your message here...")
with gr.Row():
with gr.Column(scale=2):
temperature_slider = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature")
with gr.Column(scale=2):
top_k_slider = gr.Slider(1, 100, value=40, step=1, label="Top-k")
with gr.Column(scale=2):
top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
with gr.Row():
submit = gr.Button("Send", variant="primary")
clear = gr.Button("Clear")
# Examples
gr.Examples(
examples=[
["Hello Beeper! How are you today?"],
["Can you tell me a story about a robot?"],
["What do you like to do for fun?"],
["What makes you happy?"],
["Tell me about your dreams"],
],
inputs=msg
)
# Handle chat
def respond(message, chat_history, model_version, temperature, top_k, top_p):
response = beeper_reply(message, chat_history, model_version, temperature, top_k, top_p)
chat_history.append([message, response])
return "", chat_history
msg.submit(
respond,
[msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider],
[msg, chatbot]
)
submit.click(
respond,
[msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider],
[msg, chatbot]
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch()