Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import torch | |
from v1.usta_model import UstaModel | |
from v1.usta_tokenizer import UstaTokenizer | |
# Load the model and tokenizer | |
def load_model(custom_model_path=None): | |
try: | |
u_tokenizer = UstaTokenizer("v1/tokenizer.json") | |
print("β Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab)) | |
# Model parameters - adjust these to match your trained model | |
context_length = 32 | |
vocab_size = len(u_tokenizer.vocab) | |
embedding_dim = 12 | |
num_heads = 4 | |
num_layers = 8 | |
# Load the model | |
u_model = UstaModel( | |
vocab_size=vocab_size, | |
embedding_dim=embedding_dim, | |
num_heads=num_heads, | |
context_length=context_length, | |
num_layers=num_layers | |
) | |
# Determine which model file to use | |
if custom_model_path and os.path.exists(custom_model_path): | |
model_path = custom_model_path | |
print(f"π― Using uploaded model: {model_path}") | |
else: | |
model_path = "v1/u_model.pth" | |
if not os.path.exists(model_path): | |
print("β Model file not found at", model_path) | |
# Download the model file from GitHub | |
try: | |
print("π₯ Downloading model weights from GitHub...") | |
import requests | |
url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth" | |
headers = { | |
'Accept': 'application/octet-stream', | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
} | |
response = requests.get(url, headers=headers) | |
response.raise_for_status() # Raise an exception for bad status codes | |
# Check if we got a proper binary file (PyTorch files start with specific bytes) | |
if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower(): | |
raise Exception("Downloaded HTML instead of binary file - check URL") | |
print(f"π¦ Downloaded {len(response.content)} bytes") | |
# Create v1 directory if it doesn't exist | |
os.makedirs("v1", exist_ok=True) | |
# Save the model weights to the local file system | |
with open(model_path, "wb") as f: | |
f.write(response.content) | |
print("β Model weights saved successfully!") | |
except Exception as e: | |
print(f"β Failed to download model weights: {e}") | |
print("Using random initialization.") | |
if os.path.exists(model_path): | |
try: | |
state_dict = torch.load(model_path, map_location="cpu", weights_only=False) | |
# Handle potential key mapping issues | |
if "embedding.weight" in state_dict and "embedding.embedding.weight" not in state_dict: | |
# Map old key names to new key names | |
new_state_dict = {} | |
for key, value in state_dict.items(): | |
if key == "embedding.weight": | |
new_state_dict["embedding.embedding.weight"] = value | |
elif key == "pos_embedding.weight": | |
# Skip positional embedding if not expected | |
continue | |
else: | |
new_state_dict[key] = value | |
state_dict = new_state_dict | |
u_model.load_state_dict(state_dict) | |
u_model.eval() | |
print("β Model weights loaded successfully!") | |
return u_model, u_tokenizer, f"β Model loaded from: {model_path}" | |
except Exception as e: | |
print(f"β οΈ Warning: Could not load trained weights: {e}") | |
print("Using random initialization.") | |
return u_model, u_tokenizer, f"β οΈ Failed to load weights: {e}" | |
else: | |
print(f"β οΈ Model file not found at {model_path}. Using random initialization.") | |
return u_model, u_tokenizer, "β οΈ Using random initialization" | |
except Exception as e: | |
print(f"β Error loading model: {e}") | |
raise e | |
# Global model variables | |
model, tokenizer, model_status = None, None, "Not loaded" | |
# Initialize model and tokenizer globally | |
try: | |
model, tokenizer, model_status = load_model() | |
print("π UstaModel and tokenizer initialized successfully!") | |
except Exception as e: | |
print(f"β Failed to initialize model: {e}") | |
model, tokenizer, model_status = None, None, f"β Error: {e}" | |
def update_model(uploaded_file): | |
"""Update the model when a new file is uploaded""" | |
global model, tokenizer, model_status | |
if uploaded_file is None: | |
return "β No file uploaded" | |
try: | |
# Load the new model | |
new_model, new_tokenizer, status = load_model(uploaded_file.name) | |
# Update global variables | |
model = new_model | |
tokenizer = new_tokenizer | |
model_status = status | |
return status | |
except Exception as e: | |
error_msg = f"β Failed to load uploaded model: {e}" | |
model_status = error_msg | |
return error_msg | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
""" | |
Generate a response using the UstaModel | |
""" | |
if model is None or tokenizer is None: | |
yield "Sorry, the UstaModel is not available. Please try again later." | |
return | |
try: | |
# For UstaModel, we'll use the message directly (ignoring system_message for now) | |
# since it's a simpler model focused on geographical knowledge | |
# Encode the input message | |
tokens = tokenizer.encode(message) | |
# Make sure we don't exceed context length | |
if len(tokens) > 25: # Leave some room for generation | |
tokens = tokens[-25:] | |
# Generate response | |
with torch.no_grad(): | |
# Use max_tokens parameter, but cap it at reasonable limit for this model | |
actual_max_tokens = min(max_tokens, 32 - len(tokens)) | |
generated_tokens = model.generate(tokens, actual_max_tokens) | |
# Decode the generated tokens | |
response = tokenizer.decode(generated_tokens) | |
# Clean up the response (remove the original input) | |
original_text = tokenizer.decode(tokens.tolist()) | |
if response.startswith(original_text): | |
response = response[len(original_text):] | |
# Clean up any unwanted tokens | |
response = response.replace("<unk>", "").replace("<pad>", "").strip() | |
if not response: | |
response = "I'm not sure how to respond to that with my geographical knowledge." | |
# Yield the response (to maintain compatibility with streaming interface) | |
yield response | |
except Exception as e: | |
yield f"Sorry, I encountered an error: {str(e)}" | |
""" | |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
""" | |
# Create the interface with file upload | |
with gr.Blocks(title="π€ Usta Model Chat", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π€ Usta Model Chat") | |
gr.Markdown("Chat with a custom transformer language model built from scratch! This model specializes in geographical knowledge including countries, capitals, and cities.") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Model upload section | |
with gr.Group(): | |
gr.Markdown("### π Model Upload (Optional)") | |
model_file = gr.File( | |
label="Upload your own model.pth file", | |
file_types=[".pth", ".pt"] | |
) | |
upload_btn = gr.Button("Load Model", variant="primary") | |
model_status_display = gr.Textbox( | |
label="Model Status", | |
value=model_status, | |
interactive=False | |
) | |
with gr.Column(scale=1): | |
# Settings | |
with gr.Group(): | |
gr.Markdown("### βοΈ Generation Settings") | |
system_msg = gr.Textbox( | |
value="You are Usta, a geographical knowledge assistant trained from scratch.", | |
label="System message" | |
) | |
max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens") | |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature") | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)" | |
) | |
# Chat interface | |
chatbot = gr.ChatInterface( | |
respond, | |
additional_inputs=[system_msg, max_tokens, temperature, top_p], | |
chatbot=gr.Chatbot(height=400), | |
title=None, # We already have title above | |
description=None # We already have description above | |
) | |
# Event handlers | |
upload_btn.click( | |
update_model, | |
inputs=[model_file], | |
outputs=[model_status_display] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |