Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import torch | |
from v1.usta_model import UstaModel | |
from v1.usta_tokenizer import UstaTokenizer | |
# Load the model and tokenizer | |
def load_model(custom_model_path=None): | |
try: | |
u_tokenizer = UstaTokenizer("v1/tokenizer.json") | |
print("β Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab)) | |
# Model parameters - adjust these to match your trained model | |
context_length = 32 | |
vocab_size = len(u_tokenizer.vocab) | |
embedding_dim = 12 | |
num_heads = 4 | |
num_layers = 8 | |
# Load the model | |
u_model = UstaModel( | |
vocab_size=vocab_size, | |
embedding_dim=embedding_dim, | |
num_heads=num_heads, | |
context_length=context_length, | |
num_layers=num_layers | |
) | |
# Determine which model file to use | |
if custom_model_path and os.path.exists(custom_model_path): | |
model_path = custom_model_path | |
print(f"π― Using uploaded model: {model_path}") | |
else: | |
model_path = "v1/u_model.pth" | |
if not os.path.exists(model_path): | |
print("β Model file not found at", model_path) | |
# Download the model file from GitHub | |
try: | |
print("π₯ Downloading model weights from GitHub...") | |
import requests | |
url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth" | |
headers = { | |
'Accept': 'application/octet-stream', | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
} | |
response = requests.get(url, headers=headers) | |
response.raise_for_status() # Raise an exception for bad status codes | |
# Check if we got a proper binary file (PyTorch files start with specific bytes) | |
if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower(): | |
raise Exception("Downloaded HTML instead of binary file - check URL") | |
print(f"π¦ Downloaded {len(response.content)} bytes") | |
# Create v1 directory if it doesn't exist | |
os.makedirs("v1", exist_ok=True) | |
# Save the model weights to the local file system | |
with open(model_path, "wb") as f: | |
f.write(response.content) | |
print("β Model weights saved successfully!") | |
except Exception as e: | |
print(f"β Failed to download model weights: {e}") | |
print("Using random initialization.") | |
if os.path.exists(model_path): | |
try: | |
state_dict = torch.load(model_path, map_location="cpu", weights_only=False) | |
# Handle potential key mapping issues | |
if "embedding.weight" in state_dict and "embedding.embedding.weight" not in state_dict: | |
# Map old key names to new key names | |
new_state_dict = {} | |
for key, value in state_dict.items(): | |
if key == "embedding.weight": | |
new_state_dict["embedding.embedding.weight"] = value | |
elif key == "pos_embedding.weight": | |
# Skip positional embedding if not expected | |
continue | |
else: | |
new_state_dict[key] = value | |
state_dict = new_state_dict | |
u_model.load_state_dict(state_dict) | |
u_model.eval() | |
print("β Model weights loaded successfully!") | |
return u_model, u_tokenizer, f"β Model loaded from: {model_path}" | |
except Exception as e: | |
print(f"β οΈ Warning: Could not load trained weights: {e}") | |
print("Using random initialization.") | |
return u_model, u_tokenizer, f"β οΈ Failed to load weights: {e}" | |
else: | |
print(f"β οΈ Model file not found at {model_path}. Using random initialization.") | |
return u_model, u_tokenizer, "β οΈ Using random initialization" | |
except Exception as e: | |
print(f"β Error loading model: {e}") | |
raise e | |
# Global model variables | |
model, tokenizer, model_status = None, None, "Not loaded" | |
# Initialize model and tokenizer globally | |
try: | |
model, tokenizer, model_status = load_model() | |
print("π UstaModel and tokenizer initialized successfully!") | |
except Exception as e: | |
print(f"β Failed to initialize model: {e}") | |
model, tokenizer, model_status = None, None, f"β Error: {e}" | |
def load_model_from_url(url): | |
"""Load model from a URL""" | |
global model, tokenizer, model_status | |
if not url.strip(): | |
return "β Please provide a URL" | |
try: | |
print(f"π₯ Downloading model from URL: {url}") | |
import requests | |
headers = { | |
'Accept': 'application/octet-stream', | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
} | |
response = requests.get(url, headers=headers) | |
response.raise_for_status() | |
# Check if we got a proper binary file | |
if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower(): | |
return "β Downloaded HTML instead of binary file - check URL" | |
# Save temporary file | |
temp_path = "temp_model.pth" | |
with open(temp_path, "wb") as f: | |
f.write(response.content) | |
# Load the model | |
new_model, new_tokenizer, status = load_model(temp_path) | |
# Update global variables | |
model = new_model | |
tokenizer = new_tokenizer | |
model_status = status | |
# Clean up temp file | |
if os.path.exists(temp_path): | |
os.remove(temp_path) | |
return status | |
except Exception as e: | |
error_msg = f"β Failed to load model from URL: {e}" | |
model_status = error_msg | |
return error_msg | |
def load_model_from_file(uploaded_file): | |
"""Load model from uploaded file""" | |
global model, tokenizer, model_status | |
if uploaded_file is None: | |
return "β No file uploaded" | |
try: | |
# Load the new model | |
new_model, new_tokenizer, status = load_model(uploaded_file.name) | |
# Update global variables | |
model = new_model | |
tokenizer = new_tokenizer | |
model_status = status | |
return status | |
except Exception as e: | |
error_msg = f"β Failed to load uploaded model: {e}" | |
model_status = error_msg | |
return error_msg | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
""" | |
Generate a response using the UstaModel | |
""" | |
if model is None or tokenizer is None: | |
yield "Sorry, the UstaModel is not available. Please try again later." | |
return | |
try: | |
# For UstaModel, we'll use the message directly (ignoring system_message for now) | |
# since it's a simpler model focused on geographical knowledge | |
# Encode the input message | |
tokens = tokenizer.encode(message) | |
# Make sure we don't exceed context length | |
if len(tokens) > 25: # Leave some room for generation | |
tokens = tokens[-25:] | |
# Generate response | |
with torch.no_grad(): | |
# Use max_tokens parameter, but cap it at reasonable limit for this model | |
actual_max_tokens = min(max_tokens, 32 - len(tokens)) | |
generated_tokens = model.generate(tokens, actual_max_tokens) | |
# Decode the generated tokens | |
response = tokenizer.decode(generated_tokens) | |
# Clean up the response (remove the original input) | |
original_text = tokenizer.decode(tokens.tolist()) | |
if response.startswith(original_text): | |
response = response[len(original_text):] | |
# Clean up any unwanted tokens | |
response = response.replace("<unk>", "").replace("<pad>", "").strip() | |
if not response: | |
response = "I'm not sure how to respond to that with my geographical knowledge." | |
# Yield the response (to maintain compatibility with streaming interface) | |
yield response | |
except Exception as e: | |
yield f"Sorry, I encountered an error: {str(e)}" | |
# Create the simple ChatInterface with additional inputs for model loading | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox( | |
value="You are Usta, a geographical knowledge assistant trained from scratch.", | |
label="System message" | |
), | |
gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)" | |
), | |
gr.File(label="Upload Model File (.pth)", file_types=[".pth", ".pt"]), | |
gr.Textbox(label="Or Model URL", placeholder="https://github.com/user/repo/raw/main/model.pth"), | |
gr.Button("Load from File", variant="secondary"), | |
gr.Button("Load from URL", variant="secondary"), | |
gr.Textbox(label="Model Status", value=model_status, interactive=False) | |
], | |
title="π€ Usta Model Chat", | |
description="Chat with a custom transformer language model built from scratch! Upload your own model file or provide a URL to load a different model." | |
) | |
# Add event handlers after creating the interface | |
def setup_events(): | |
# Get the additional inputs | |
inputs = demo.additional_inputs | |
model_file = inputs[4] # File upload | |
model_url = inputs[5] # URL input | |
load_file_btn = inputs[6] # Load from file button | |
load_url_btn = inputs[7] # Load from URL button | |
status_display = inputs[8] # Status display | |
# Set up event handlers | |
load_file_btn.click( | |
load_model_from_file, | |
inputs=[model_file], | |
outputs=[status_display] | |
) | |
load_url_btn.click( | |
load_model_from_url, | |
inputs=[model_url], | |
outputs=[status_display] | |
) | |
# Set up events after interface creation | |
demo.load(setup_events) | |
if __name__ == "__main__": | |
demo.launch() | |