usta-llm-demo / app.py
alibayram's picture
space update
ff7d616
raw
history blame
10.9 kB
import os
import gradio as gr
import torch
from v1.usta_model import UstaModel
from v1.usta_tokenizer import UstaTokenizer
# Load the model and tokenizer
def load_model(custom_model_path=None):
try:
u_tokenizer = UstaTokenizer("v1/tokenizer.json")
print("βœ… Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab))
# Model parameters - adjust these to match your trained model
context_length = 32
vocab_size = len(u_tokenizer.vocab)
embedding_dim = 12
num_heads = 4
num_layers = 8
# Load the model
u_model = UstaModel(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
num_heads=num_heads,
context_length=context_length,
num_layers=num_layers
)
# Determine which model file to use
if custom_model_path and os.path.exists(custom_model_path):
model_path = custom_model_path
print(f"🎯 Using uploaded model: {model_path}")
else:
model_path = "v1/u_model.pth"
if not os.path.exists(model_path):
print("❌ Model file not found at", model_path)
# Download the model file from GitHub
try:
print("πŸ“₯ Downloading model weights from GitHub...")
import requests
url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth"
headers = {
'Accept': 'application/octet-stream',
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
}
response = requests.get(url, headers=headers)
response.raise_for_status() # Raise an exception for bad status codes
# Check if we got a proper binary file (PyTorch files start with specific bytes)
if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower():
raise Exception("Downloaded HTML instead of binary file - check URL")
print(f"πŸ“¦ Downloaded {len(response.content)} bytes")
# Create v1 directory if it doesn't exist
os.makedirs("v1", exist_ok=True)
# Save the model weights to the local file system
with open(model_path, "wb") as f:
f.write(response.content)
print("βœ… Model weights saved successfully!")
except Exception as e:
print(f"❌ Failed to download model weights: {e}")
print("Using random initialization.")
if os.path.exists(model_path):
try:
state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
# Handle potential key mapping issues
if "embedding.weight" in state_dict and "embedding.embedding.weight" not in state_dict:
# Map old key names to new key names
new_state_dict = {}
for key, value in state_dict.items():
if key == "embedding.weight":
new_state_dict["embedding.embedding.weight"] = value
elif key == "pos_embedding.weight":
# Skip positional embedding if not expected
continue
else:
new_state_dict[key] = value
state_dict = new_state_dict
u_model.load_state_dict(state_dict)
u_model.eval()
print("βœ… Model weights loaded successfully!")
return u_model, u_tokenizer, f"βœ… Model loaded from: {model_path}"
except Exception as e:
print(f"⚠️ Warning: Could not load trained weights: {e}")
print("Using random initialization.")
return u_model, u_tokenizer, f"⚠️ Failed to load weights: {e}"
else:
print(f"⚠️ Model file not found at {model_path}. Using random initialization.")
return u_model, u_tokenizer, "⚠️ Using random initialization"
except Exception as e:
print(f"❌ Error loading model: {e}")
raise e
# Global model variables
model, tokenizer, model_status = None, None, "Not loaded"
# Initialize model and tokenizer globally
try:
model, tokenizer, model_status = load_model()
print("πŸš€ UstaModel and tokenizer initialized successfully!")
except Exception as e:
print(f"❌ Failed to initialize model: {e}")
model, tokenizer, model_status = None, None, f"❌ Error: {e}"
def load_model_from_url(url):
"""Load model from a URL"""
global model, tokenizer, model_status
if not url.strip():
return "❌ Please provide a URL"
try:
print(f"πŸ“₯ Downloading model from URL: {url}")
import requests
headers = {
'Accept': 'application/octet-stream',
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
}
response = requests.get(url, headers=headers)
response.raise_for_status()
# Check if we got a proper binary file
if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower():
return "❌ Downloaded HTML instead of binary file - check URL"
# Save temporary file
temp_path = "temp_model.pth"
with open(temp_path, "wb") as f:
f.write(response.content)
# Load the model
new_model, new_tokenizer, status = load_model(temp_path)
# Update global variables
model = new_model
tokenizer = new_tokenizer
model_status = status
# Clean up temp file
if os.path.exists(temp_path):
os.remove(temp_path)
return status
except Exception as e:
error_msg = f"❌ Failed to load model from URL: {e}"
model_status = error_msg
return error_msg
def load_model_from_file(uploaded_file):
"""Load model from uploaded file"""
global model, tokenizer, model_status
if uploaded_file is None:
return "❌ No file uploaded"
try:
# Load the new model
new_model, new_tokenizer, status = load_model(uploaded_file.name)
# Update global variables
model = new_model
tokenizer = new_tokenizer
model_status = status
return status
except Exception as e:
error_msg = f"❌ Failed to load uploaded model: {e}"
model_status = error_msg
return error_msg
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
"""
Generate a response using the UstaModel
"""
if model is None or tokenizer is None:
yield "Sorry, the UstaModel is not available. Please try again later."
return
try:
# For UstaModel, we'll use the message directly (ignoring system_message for now)
# since it's a simpler model focused on geographical knowledge
# Encode the input message
tokens = tokenizer.encode(message)
# Make sure we don't exceed context length
if len(tokens) > 25: # Leave some room for generation
tokens = tokens[-25:]
# Generate response
with torch.no_grad():
# Use max_tokens parameter, but cap it at reasonable limit for this model
actual_max_tokens = min(max_tokens, 32 - len(tokens))
generated_tokens = model.generate(tokens, actual_max_tokens)
# Decode the generated tokens
response = tokenizer.decode(generated_tokens)
# Clean up the response (remove the original input)
original_text = tokenizer.decode(tokens.tolist())
if response.startswith(original_text):
response = response[len(original_text):]
# Clean up any unwanted tokens
response = response.replace("<unk>", "").replace("<pad>", "").strip()
if not response:
response = "I'm not sure how to respond to that with my geographical knowledge."
# Yield the response (to maintain compatibility with streaming interface)
yield response
except Exception as e:
yield f"Sorry, I encountered an error: {str(e)}"
# Create the simple ChatInterface with additional inputs for model loading
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are Usta, a geographical knowledge assistant trained from scratch.",
label="System message"
),
gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
gr.File(label="Upload Model File (.pth)", file_types=[".pth", ".pt"]),
gr.Textbox(label="Or Model URL", placeholder="https://github.com/user/repo/raw/main/model.pth"),
gr.Button("Load from File", variant="secondary"),
gr.Button("Load from URL", variant="secondary"),
gr.Textbox(label="Model Status", value=model_status, interactive=False)
],
title="πŸ€– Usta Model Chat",
description="Chat with a custom transformer language model built from scratch! Upload your own model file or provide a URL to load a different model."
)
# Add event handlers after creating the interface
def setup_events():
# Get the additional inputs
inputs = demo.additional_inputs
model_file = inputs[4] # File upload
model_url = inputs[5] # URL input
load_file_btn = inputs[6] # Load from file button
load_url_btn = inputs[7] # Load from URL button
status_display = inputs[8] # Status display
# Set up event handlers
load_file_btn.click(
load_model_from_file,
inputs=[model_file],
outputs=[status_display]
)
load_url_btn.click(
load_model_from_url,
inputs=[model_url],
outputs=[status_display]
)
# Set up events after interface creation
demo.load(setup_events)
if __name__ == "__main__":
demo.launch()