Spaces:
Sleeping
Sleeping
File size: 6,528 Bytes
8d4b0c7 fa3f584 8d4b0c7 fa3f584 8d4b0c7 8d6020c 8d4b0c7 1fe46bf 8423f1f 8d4b0c7 8d6020c 8d4b0c7 db97ce9 8423f1f 8d4b0c7 8423f1f 8d4b0c7 8423f1f 8d6020c 8d4b0c7 8423f1f 8d4b0c7 8d6020c 8d4b0c7 fa3f584 8d4b0c7 fa3f584 8d4b0c7 fa3f584 8d4b0c7 fa3f584 8d4b0c7 fa3f584 8d4b0c7 8423f1f fa3f584 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import os
import gradio as gr
import torch
from v1.usta_model import UstaModel
from v1.usta_tokenizer import UstaTokenizer
# Load the model and tokenizer
def load_model():
try:
u_tokenizer = UstaTokenizer("v1/tokenizer.json")
print("β
Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab))
# Model parameters - adjust these to match your trained model
context_length = 32
vocab_size = len(u_tokenizer.vocab)
embedding_dim = 12
num_heads = 4
num_layers = 8
# Load the model
u_model = UstaModel(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
num_heads=num_heads,
context_length=context_length,
num_layers=num_layers
)
# Load the trained weights if available
model_path = "v1/u_model.pth"
if not os.path.exists(model_path):
print("β Model file not found at", model_path)
# Download the model file from GitHub
try:
print("π₯ Downloading model weights from GitHub...")
import requests
url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth"
headers = {
'Accept': 'application/octet-stream',
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
}
response = requests.get(url, headers=headers)
response.raise_for_status() # Raise an exception for bad status codes
# Check if we got a proper binary file (PyTorch files start with specific bytes)
if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower():
raise Exception("Downloaded HTML instead of binary file - check URL")
print(f"π¦ Downloaded {len(response.content)} bytes")
# Create v1 directory if it doesn't exist
os.makedirs("v1", exist_ok=True)
# Save the model weights to the local file system
with open(model_path, "wb") as f:
f.write(response.content)
print("β
Model weights saved successfully!")
except Exception as e:
print(f"β Failed to download model weights: {e}")
print("Using random initialization.")
if os.path.exists(model_path):
try:
u_model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=False))
u_model.eval()
print("β
Model weights loaded successfully!")
except Exception as e:
print(f"β οΈ Warning: Could not load trained weights: {e}")
print("Using random initialization.")
else:
print(f"β οΈ Model file not found at {model_path}. Using random initialization.")
return u_model, u_tokenizer
except Exception as e:
print(f"β Error loading model: {e}")
raise e
# Initialize model and tokenizer globally
try:
model, tokenizer = load_model()
print("π UstaModel and tokenizer initialized successfully!")
except Exception as e:
print(f"β Failed to initialize model: {e}")
model, tokenizer = None, None
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
"""
Generate a response using the UstaModel
"""
if model is None or tokenizer is None:
yield "Sorry, the UstaModel is not available. Please try again later."
return
try:
# For UstaModel, we'll use the message directly (ignoring system_message for now)
# since it's a simpler model focused on geographical knowledge
# Encode the input message
tokens = tokenizer.encode(message)
# Make sure we don't exceed context length
if len(tokens) > 25: # Leave some room for generation
tokens = tokens[-25:]
# Generate response
with torch.no_grad():
# Use max_tokens parameter, but cap it at reasonable limit for this model
actual_max_tokens = min(max_tokens, 32 - len(tokens))
generated_tokens = model.generate(tokens, actual_max_tokens)
# Decode the generated tokens
response = tokenizer.decode(generated_tokens)
# Clean up the response (remove the original input)
original_text = tokenizer.decode(tokens.tolist())
if response.startswith(original_text):
response = response[len(original_text):]
# Clean up any unwanted tokens
response = response.replace("<unk>", "").replace("<pad>", "").strip()
if not response:
response = "I'm not sure how to respond to that with my geographical knowledge."
# Yield the response (to maintain compatibility with streaming interface)
yield response
except Exception as e:
yield f"Sorry, I encountered an error: {str(e)}"
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are Usta, a geographical knowledge assistant trained from scratch.",
label="System message",
info="Note: This model focuses on geographical knowledge (countries, capitals, cities)"
),
gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
info="Note: This parameter is not used by UstaModel but kept for interface compatibility"
),
],
title="π€ Usta Model Chat",
description="Chat with a custom transformer language model built from scratch! This model specializes in geographical knowledge including countries, capitals, and cities."
)
if __name__ == "__main__":
demo.launch()
|