usta-llm-demo / app.py
alibayram's picture
space update
db97ce9
raw
history blame
6.53 kB
import os
import gradio as gr
import torch
from v1.usta_model import UstaModel
from v1.usta_tokenizer import UstaTokenizer
# Load the model and tokenizer
def load_model():
try:
u_tokenizer = UstaTokenizer("v1/tokenizer.json")
print("βœ… Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab))
# Model parameters - adjust these to match your trained model
context_length = 32
vocab_size = len(u_tokenizer.vocab)
embedding_dim = 12
num_heads = 4
num_layers = 8
# Load the model
u_model = UstaModel(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
num_heads=num_heads,
context_length=context_length,
num_layers=num_layers
)
# Load the trained weights if available
model_path = "v1/u_model.pth"
if not os.path.exists(model_path):
print("❌ Model file not found at", model_path)
# Download the model file from GitHub
try:
print("πŸ“₯ Downloading model weights from GitHub...")
import requests
url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth"
headers = {
'Accept': 'application/octet-stream',
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
}
response = requests.get(url, headers=headers)
response.raise_for_status() # Raise an exception for bad status codes
# Check if we got a proper binary file (PyTorch files start with specific bytes)
if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower():
raise Exception("Downloaded HTML instead of binary file - check URL")
print(f"πŸ“¦ Downloaded {len(response.content)} bytes")
# Create v1 directory if it doesn't exist
os.makedirs("v1", exist_ok=True)
# Save the model weights to the local file system
with open(model_path, "wb") as f:
f.write(response.content)
print("βœ… Model weights saved successfully!")
except Exception as e:
print(f"❌ Failed to download model weights: {e}")
print("Using random initialization.")
if os.path.exists(model_path):
try:
u_model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=False))
u_model.eval()
print("βœ… Model weights loaded successfully!")
except Exception as e:
print(f"⚠️ Warning: Could not load trained weights: {e}")
print("Using random initialization.")
else:
print(f"⚠️ Model file not found at {model_path}. Using random initialization.")
return u_model, u_tokenizer
except Exception as e:
print(f"❌ Error loading model: {e}")
raise e
# Initialize model and tokenizer globally
try:
model, tokenizer = load_model()
print("πŸš€ UstaModel and tokenizer initialized successfully!")
except Exception as e:
print(f"❌ Failed to initialize model: {e}")
model, tokenizer = None, None
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
"""
Generate a response using the UstaModel
"""
if model is None or tokenizer is None:
yield "Sorry, the UstaModel is not available. Please try again later."
return
try:
# For UstaModel, we'll use the message directly (ignoring system_message for now)
# since it's a simpler model focused on geographical knowledge
# Encode the input message
tokens = tokenizer.encode(message)
# Make sure we don't exceed context length
if len(tokens) > 25: # Leave some room for generation
tokens = tokens[-25:]
# Generate response
with torch.no_grad():
# Use max_tokens parameter, but cap it at reasonable limit for this model
actual_max_tokens = min(max_tokens, 32 - len(tokens))
generated_tokens = model.generate(tokens, actual_max_tokens)
# Decode the generated tokens
response = tokenizer.decode(generated_tokens)
# Clean up the response (remove the original input)
original_text = tokenizer.decode(tokens.tolist())
if response.startswith(original_text):
response = response[len(original_text):]
# Clean up any unwanted tokens
response = response.replace("<unk>", "").replace("<pad>", "").strip()
if not response:
response = "I'm not sure how to respond to that with my geographical knowledge."
# Yield the response (to maintain compatibility with streaming interface)
yield response
except Exception as e:
yield f"Sorry, I encountered an error: {str(e)}"
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are Usta, a geographical knowledge assistant trained from scratch.",
label="System message",
info="Note: This model focuses on geographical knowledge (countries, capitals, cities)"
),
gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
info="Note: This parameter is not used by UstaModel but kept for interface compatibility"
),
],
title="πŸ€– Usta Model Chat",
description="Chat with a custom transformer language model built from scratch! This model specializes in geographical knowledge including countries, capitals, and cities."
)
if __name__ == "__main__":
demo.launch()