Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from beeper_model import BeeperRoseGPT, generate | |
from tokenizers import Tokenizer | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file as load_safetensors | |
# ---------------------------- | |
# 🔧 Model versions configuration | |
# ---------------------------- | |
MODEL_VERSIONS = { | |
"Beeper v3 (Multi-Concept)": { | |
"repo_id": "AbstractPhil/beeper-rose-v3", | |
"model_file": "beeper_rose_final.safetensors", | |
"description": "Beeper v3 with 30+ epochs including reasoning, math, coding, and more." | |
}, | |
"Beeper v2 (Extended)": { | |
"repo_id": "AbstractPhil/beeper-rose-v2", | |
"model_file": "beeper_rose_final.safetensors", | |
"description": "Beeper v2 with extended training (~15 epochs)" | |
}, | |
"Beeper v1 (Original)": { | |
"repo_id": "AbstractPhil/beeper-rose-tinystories-6l-512d-ctx512", | |
"model_file": "beeper_rose_final.safetensors", | |
"description": "Original Beeper trained on TinyStories" | |
}, | |
} | |
# Base configuration | |
config = { | |
"context": 512, | |
"vocab_size": 8192, | |
"dim": 512, | |
"n_heads": 8, | |
"n_layers": 6, | |
"mlp_ratio": 4.0, | |
"temperature": 0.9, | |
"top_k": 40, | |
"top_p": 0.9, | |
"repetition_penalty": 1.1, | |
"presence_penalty": 0.6, | |
"frequency_penalty": 0.0, | |
"resid_dropout": 0.1, | |
"dropout": 0.0, | |
"grad_checkpoint": False, | |
"tokenizer_path": "beeper.tokenizer.json" | |
} | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Global model and tokenizer variables | |
infer = None | |
tok = None | |
current_version = None | |
def load_model_version(version_name): | |
"""Load the selected model version""" | |
global infer, tok, current_version | |
if current_version == version_name and infer is not None: | |
return f"Already loaded: {version_name}" | |
version_info = MODEL_VERSIONS[version_name] | |
try: | |
# Download model and tokenizer files | |
model_file = hf_hub_download( | |
repo_id=version_info["repo_id"], | |
filename=version_info["model_file"] | |
) | |
tokenizer_file = hf_hub_download( | |
repo_id=version_info["repo_id"], | |
filename="tokenizer.json" | |
) | |
# Initialize model | |
infer = BeeperRoseGPT(config).to(device) | |
# Load safetensors | |
state_dict = load_safetensors(model_file, device=str(device)) | |
infer.load_state_dict(state_dict) | |
infer.eval() | |
# Load tokenizer | |
tok = Tokenizer.from_file(tokenizer_file) | |
current_version = version_name | |
return f"Successfully loaded: {version_name}" | |
except Exception as e: | |
return f"Error loading {version_name}: {str(e)}" | |
# Load default model on startup | |
load_status = load_model_version("Beeper v3 (Multi-Concept)") | |
print(load_status) | |
# ---------------------------- | |
# 💬 Gradio Chat Wrapper | |
# ---------------------------- | |
def beeper_reply(message, history, model_version, temperature=None, top_k=None, top_p=None): | |
global infer, tok, current_version | |
# Load model if version changed | |
if model_version != current_version: | |
status = load_model_version(model_version) | |
if "Error" in status: | |
return f"⚠️ {status}" | |
# Check if model is loaded | |
if infer is None or tok is None: | |
return "⚠️ Model not loaded. Please select a version and try again." | |
# Use defaults if not provided | |
if temperature is None: | |
temperature = 0.9 | |
if top_k is None: | |
top_k = 40 | |
if top_p is None: | |
top_p = 0.9 | |
# Try Q&A format since she has some in corpus | |
if "?" in message: | |
prompt = f"Q: {message}\nA:" | |
elif message.lower().strip() in ["hi", "hello", "hey"]: | |
prompt = "The little robot said hello. She said, \"" | |
elif "story" in message.lower(): | |
prompt = "Once upon a time, there was a robot. " | |
else: | |
# Simple continuation | |
prompt = message + ". " | |
# Generate response with lower temperature for less repetition | |
response = generate( | |
model=infer, | |
tok=tok, | |
cfg=config, | |
prompt=prompt, | |
max_new_tokens=80, # Shorter to avoid rambling | |
temperature=float(temperature) * 0.8, # Slightly lower temp | |
top_k=int(top_k), | |
top_p=float(top_p), | |
repetition_penalty=1.2, # Higher penalty for repetition | |
presence_penalty=0.8, # Higher presence penalty | |
frequency_penalty=0.1, # Add frequency penalty | |
device=device, | |
detokenize=True | |
) | |
# Aggressive cleanup | |
# Remove the prompt completely | |
if response.startswith(prompt): | |
response = response[len(prompt):] | |
# Remove Q&A format artifacts | |
response = response.replace("Q:", "").replace("A:", "") | |
# Split on newlines and take first non-empty line | |
lines = response.split('\n') | |
for line in lines: | |
clean_line = line.strip() | |
if clean_line and not clean_line.startswith(message[:10]): | |
response = clean_line | |
break | |
# If response still contains the user message, try to extract after it | |
if message.lower()[:20] in response.lower()[:50]: | |
# Find where the echo ends | |
words_in_message = message.split() | |
for i in range(min(5, len(words_in_message)), 0, -1): | |
pattern = ' '.join(words_in_message[:i]) | |
if pattern.lower() in response.lower(): | |
idx = response.lower().find(pattern.lower()) + len(pattern) | |
response = response[idx:].strip() | |
break | |
# Remove any remaining "User" or "Beeper" artifacts | |
for artifact in ["User:", "Beeper:", "U ser:", "Beep er:", "User ", "Beeper "]: | |
response = response.replace(artifact, "") | |
# Ensure we have something | |
if not response or len(response) < 3: | |
responses = [ | |
"I like robots and stories!", | |
"That's interesting!", | |
"I want to play in the park.", | |
"The robot was happy.", | |
"Yes, I think so too!" | |
] | |
import random | |
response = random.choice(responses) | |
# Clean ending | |
response = response.strip() | |
if response and response[-1] not in '.!?"': | |
response = response.rsplit('.', 1)[0] + '.' if '.' in response else response + '.' | |
return response[:200] # Cap length | |
# ---------------------------- | |
# 🖼️ Interface | |
# ---------------------------- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# 🤖 Beeper - A Rose-based Tiny Language Model | |
Hello! I'm Beeper, a small language model trained with love and care. Please be patient with me - I'm still learning! 💕 | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
model_dropdown = gr.Dropdown( | |
choices=list(MODEL_VERSIONS.keys()), | |
value="Beeper v3 (Multi-Concept)", | |
label="Select Beeper Version", | |
info="Choose which version of Beeper to chat with" | |
) | |
with gr.Column(scale=7): | |
version_info = gr.Markdown("**Current:** Beeper v3 with 30+ epochs including reasoning, math, coding, and more.") | |
# Update version info when dropdown changes | |
def update_version_info(version_name): | |
info = MODEL_VERSIONS[version_name]["description"] | |
return f"**Current:** {info}" | |
model_dropdown.change( | |
fn=update_version_info, | |
inputs=[model_dropdown], | |
outputs=[version_info] | |
) | |
# Chat interface | |
chatbot = gr.Chatbot(label="Chat with Beeper", type="tuples", height=400) | |
msg = gr.Textbox(label="Message", placeholder="Type your message here... She will probably complete it for now, but maybe she'll answer.") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
temperature_slider = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature") | |
with gr.Column(scale=2): | |
top_k_slider = gr.Slider(1, 100, value=40, step=1, label="Top-k") | |
with gr.Column(scale=2): | |
top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
with gr.Row(): | |
submit = gr.Button("Send", variant="primary") | |
clear = gr.Button("Clear") | |
# Examples | |
gr.Examples( | |
examples=[ | |
["Hello Beeper! How are you today?"], | |
["Can you tell me a story about a robot?"], | |
["What do you like to do for fun?"], | |
["What makes you happy?"], | |
["Tell me about your dreams"], | |
], | |
inputs=msg | |
) | |
# Handle chat | |
def respond(message, chat_history, model_version, temperature, top_k, top_p): | |
if not chat_history: | |
chat_history = [] | |
response = beeper_reply(message, chat_history, model_version, temperature, top_k, top_p) | |
chat_history.append([message, response]) | |
return "", chat_history | |
msg.submit( | |
respond, | |
[msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider], | |
[msg, chatbot] | |
) | |
submit.click( | |
respond, | |
[msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider], | |
[msg, chatbot] | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
if __name__ == "__main__": | |
demo.launch() |