Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import torch | |
from v1.usta_model import UstaModel | |
from v1.usta_tokenizer import UstaTokenizer | |
# Load the model and tokenizer | |
def load_model(): | |
try: | |
u_tokenizer = UstaTokenizer("v1/tokenizer.json") | |
print("β Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab)) | |
# Model parameters - adjust these to match your trained model | |
context_length = 32 | |
vocab_size = len(u_tokenizer.vocab) | |
embedding_dim = 12 | |
num_heads = 4 | |
num_layers = 8 | |
# Load the model | |
u_model = UstaModel( | |
vocab_size=vocab_size, | |
embedding_dim=embedding_dim, | |
num_heads=num_heads, | |
context_length=context_length, | |
num_layers=num_layers | |
) | |
# Load the trained weights if available | |
model_path = "v1/u_model.pth" | |
if not os.path.exists(model_path): | |
print("β Model file not found at", model_path) | |
# Download the model file from GitHub | |
try: | |
print("π₯ Downloading model weights from GitHub...") | |
import requests | |
url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth" | |
headers = { | |
'Accept': 'application/octet-stream', | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
} | |
response = requests.get(url, headers=headers) | |
response.raise_for_status() # Raise an exception for bad status codes | |
# Check if we got a proper binary file (PyTorch files start with specific bytes) | |
if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower(): | |
raise Exception("Downloaded HTML instead of binary file - check URL") | |
print(f"π¦ Downloaded {len(response.content)} bytes") | |
# Create v1 directory if it doesn't exist | |
os.makedirs("v1", exist_ok=True) | |
# Save the model weights to the local file system | |
with open(model_path, "wb") as f: | |
f.write(response.content) | |
print("β Model weights saved successfully!") | |
except Exception as e: | |
print(f"β Failed to download model weights: {e}") | |
print("Using random initialization.") | |
if os.path.exists(model_path): | |
try: | |
u_model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=False)) | |
u_model.eval() | |
print("β Model weights loaded successfully!") | |
except Exception as e: | |
print(f"β οΈ Warning: Could not load trained weights: {e}") | |
print("Using random initialization.") | |
else: | |
print(f"β οΈ Model file not found at {model_path}. Using random initialization.") | |
return u_model, u_tokenizer | |
except Exception as e: | |
print(f"β Error loading model: {e}") | |
raise e | |
# Initialize model and tokenizer globally | |
try: | |
model, tokenizer = load_model() | |
print("π UstaModel and tokenizer initialized successfully!") | |
except Exception as e: | |
print(f"β Failed to initialize model: {e}") | |
model, tokenizer = None, None | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
""" | |
Generate a response using the UstaModel | |
""" | |
if model is None or tokenizer is None: | |
yield "Sorry, the UstaModel is not available. Please try again later." | |
return | |
try: | |
# For UstaModel, we'll use the message directly (ignoring system_message for now) | |
# since it's a simpler model focused on geographical knowledge | |
# Encode the input message | |
tokens = tokenizer.encode(message) | |
# Make sure we don't exceed context length | |
if len(tokens) > 25: # Leave some room for generation | |
tokens = tokens[-25:] | |
# Generate response | |
with torch.no_grad(): | |
# Use max_tokens parameter, but cap it at reasonable limit for this model | |
actual_max_tokens = min(max_tokens, 32 - len(tokens)) | |
generated_tokens = model.generate(tokens, actual_max_tokens) | |
# Decode the generated tokens | |
response = tokenizer.decode(generated_tokens) | |
# Clean up the response (remove the original input) | |
original_text = tokenizer.decode(tokens.tolist()) | |
if response.startswith(original_text): | |
response = response[len(original_text):] | |
# Clean up any unwanted tokens | |
response = response.replace("<unk>", "").replace("<pad>", "").strip() | |
if not response: | |
response = "I'm not sure how to respond to that with my geographical knowledge." | |
# Yield the response (to maintain compatibility with streaming interface) | |
yield response | |
except Exception as e: | |
yield f"Sorry, I encountered an error: {str(e)}" | |
""" | |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
""" | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox( | |
value="You are Usta, a geographical knowledge assistant trained from scratch.", | |
label="System message", | |
info="Note: This model focuses on geographical knowledge (countries, capitals, cities)" | |
), | |
gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
info="Note: This parameter is not used by UstaModel but kept for interface compatibility" | |
), | |
], | |
title="π€ Usta Model Chat", | |
description="Chat with a custom transformer language model built from scratch! This model specializes in geographical knowledge including countries, capitals, and cities." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |