Spaces:
Running
Running
File size: 12,074 Bytes
8d4b0c7 fa3f584 8d4b0c7 6563ff2 fa3f584 8d4b0c7 c700703 8d4b0c7 6563ff2 8d6020c 8d4b0c7 6563ff2 8d4b0c7 6563ff2 1fe46bf 8423f1f c700703 6563ff2 c700703 6563ff2 c700703 8d4b0c7 ea11d44 8d4b0c7 c700703 8d4b0c7 c700703 8d4b0c7 c700703 8d4b0c7 c700703 8d4b0c7 c700703 8d4b0c7 c700703 ff7d616 c700703 67856b9 c700703 67856b9 c700703 67856b9 c700703 fa3f584 6563ff2 39dfa2d 8d4b0c7 39dfa2d 8d4b0c7 6563ff2 8d4b0c7 39dfa2d 8d4b0c7 39dfa2d fa3f584 39dfa2d 0eefbc1 39dfa2d 0eefbc1 39dfa2d 0eefbc1 39dfa2d 0eefbc1 6563ff2 39dfa2d 0eefbc1 39dfa2d 0eefbc1 67856b9 39dfa2d 0eefbc1 6563ff2 0eefbc1 6563ff2 c700703 39dfa2d 6563ff2 0eefbc1 39dfa2d 6563ff2 0eefbc1 39dfa2d 0eefbc1 ff7d616 39dfa2d c700703 fa3f584 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 |
import os
import gradio as gr
import torch
from v2.usta_model import UstaModel
from v2.usta_tokenizer import UstaTokenizer
# Load the model and tokenizer
def load_model(custom_model_path=None):
try:
u_tokenizer = UstaTokenizer("v2/tokenizer.json")
print("β
Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab))
# Model parameters - adjust these to match your trained model
context_length = 32
vocab_size = len(u_tokenizer.vocab)
embedding_dim = 12
num_heads = 4
num_layers = 8
device = "cpu" # Use CPU for compatibility
# Load the model
u_model = UstaModel(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
num_heads=num_heads,
context_length=context_length,
num_layers=num_layers,
device=device
)
# Determine which model file to use
if custom_model_path and os.path.exists(custom_model_path):
model_path = custom_model_path
print(f"π― Using uploaded model: {model_path}")
else:
model_path = "v2/u_model_4000.pth"
if not os.path.exists(model_path):
print("β Model file not found at", model_path)
# Download the model file from GitHub
try:
print("π₯ Downloading model weights from GitHub...")
import requests
url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth"
headers = {
'Accept': 'application/octet-stream',
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
}
response = requests.get(url, headers=headers)
response.raise_for_status() # Raise an exception for bad status codes
# Check if we got a proper binary file (PyTorch files start with specific bytes)
if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower():
raise Exception("Downloaded HTML instead of binary file - check URL")
print(f"π¦ Downloaded {len(response.content)} bytes")
# Create v2 directory if it doesn't exist
os.makedirs("v2", exist_ok=True)
# Save the model weights to the local file system
with open(model_path, "wb") as f:
f.write(response.content)
print("β
Model weights saved successfully!")
except Exception as e:
print(f"β Failed to download model weights: {e}")
print("Using random initialization.")
if os.path.exists(model_path):
try:
state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
# Handle potential key mapping issues
if "embedding.weight" in state_dict and "embedding.embedding.weight" not in state_dict:
# Map old key names to new key names
new_state_dict = {}
for key, value in state_dict.items():
if key == "embedding.weight":
new_state_dict["embedding.embedding.weight"] = value
elif key == "pos_embedding.weight":
# Skip positional embedding if not expected
continue
else:
new_state_dict[key] = value
state_dict = new_state_dict
u_model.load_state_dict(state_dict)
u_model.eval()
print("β
Model weights loaded successfully!")
return u_model, u_tokenizer, f"β
Model loaded from: {model_path}"
except Exception as e:
print(f"β οΈ Warning: Could not load trained weights: {e}")
print("Using random initialization.")
return u_model, u_tokenizer, f"β οΈ Failed to load weights: {e}"
else:
print(f"β οΈ Model file not found at {model_path}. Using random initialization.")
return u_model, u_tokenizer, "β οΈ Using random initialization"
except Exception as e:
print(f"β Error loading model: {e}")
raise e
# Global model variables
model, tokenizer, model_status = None, None, "Not loaded"
# Initialize model and tokenizer globally
try:
model, tokenizer, model_status = load_model()
print("π UstaModel and tokenizer initialized successfully!")
except Exception as e:
print(f"β Failed to initialize model: {e}")
model, tokenizer, model_status = None, None, f"β Error: {e}"
def load_model_from_url(url):
"""Load model from a URL"""
global model, tokenizer, model_status
if not url.strip():
return "β Please provide a URL"
try:
print(f"π₯ Downloading model from URL: {url}")
import requests
headers = {
'Accept': 'application/octet-stream',
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
}
response = requests.get(url, headers=headers)
response.raise_for_status()
# Check if we got a proper binary file
if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower():
return "β Downloaded HTML instead of binary file - check URL"
# Save temporary file
temp_path = "temp_model.pth"
with open(temp_path, "wb") as f:
f.write(response.content)
# Load the model
new_model, new_tokenizer, status = load_model(temp_path)
# Update global variables
model = new_model
tokenizer = new_tokenizer
model_status = status
# Clean up temp file
if os.path.exists(temp_path):
os.remove(temp_path)
return status
except Exception as e:
error_msg = f"β Failed to load model from URL: {e}"
model_status = error_msg
return error_msg
def load_model_from_file(uploaded_file):
"""Load model from uploaded file"""
global model, tokenizer, model_status
if uploaded_file is None:
return "β No file uploaded"
try:
# Check if the file path exists and is valid
file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else str(uploaded_file)
# For HF Spaces compatibility, also try the upload path
if not os.path.exists(file_path) and hasattr(uploaded_file, 'orig_name'):
# Sometimes HF Spaces provides different paths
print(f"Original path not found: {file_path}")
print(f"Trying original name: {uploaded_file.orig_name}")
file_path = uploaded_file.orig_name
print(f"π Attempting to load model from: {file_path}")
# Load the new model
new_model, new_tokenizer, status = load_model(file_path)
# Update global variables
model = new_model
tokenizer = new_tokenizer
model_status = status
return status
except Exception as e:
error_msg = f"β Failed to load uploaded model: {e}"
print(f"Error details: {e}")
model_status = error_msg
return error_msg
def chat_with_usta(message, history, max_tokens=20, temperature=1.0, top_k=64, top_p=1.0):
"""Simple chat function"""
if model is None or tokenizer is None:
return history + [["Error", "UstaModel is not available. Please try again later."]]
try:
# Encode the input message
tokens = tokenizer.encode(message)
# Make sure we don't exceed context length
if len(tokens) > 25: # Leave some room for generation
tokens = tokens[-25:]
# Generate response
with torch.no_grad():
actual_max_tokens = min(max_tokens, 32 - len(tokens))
generated_tokens = model.generate(
tokens,
max_new_tokens=actual_max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p
)
# Decode the generated tokens
response = tokenizer.decode(generated_tokens)
# Clean up the response (remove the original input)
original_text = tokenizer.decode(tokens.tolist())
if response.startswith(original_text):
response = response[len(original_text):]
# Clean up any unwanted tokens
response = response.replace("<unk>", "").replace("<pad>", "").strip()
if not response:
response = "I'm not sure how to respond to that with my geographical knowledge."
# Add to history
history.append([message, response])
return history
except Exception as e:
history.append([message, f"Sorry, I encountered an error: {str(e)}"])
return history
# Create simple interface
with gr.Blocks(title="π€ Usta Model Chat") as demo:
gr.Markdown("# π€ Usta Model Chat")
gr.Markdown("Chat with a custom transformer language model built from scratch! This model specializes in geographical knowledge.")
# Simple chat interface
chatbot = gr.Chatbot(height=400)
msg = gr.Textbox(label="Your message", placeholder="Ask about countries, capitals, or cities...")
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear")
# Generation settings
gr.Markdown("## βοΈ Generation Settings")
with gr.Row():
max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max tokens")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
with gr.Row():
top_k = gr.Slider(minimum=1, maximum=64, value=40, step=1, label="Top-k")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Top-p (nucleus sampling)")
# Model loading (simplified)
gr.Markdown("## π§ Load Custom Model (Optional)")
with gr.Row():
model_url = gr.Textbox(
label="Model URL",
placeholder="https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth",
scale=3
)
load_url_btn = gr.Button("Load from URL", scale=1)
with gr.Row():
model_file = gr.File(label="Upload model file (.pth, .pt, .bin)")
load_file_btn = gr.Button("Load File", scale=1)
status = gr.Textbox(label="Status", value=model_status, interactive=False)
# Event handlers
def send_message(message, history, max_tok, temp, k, p):
if not message.strip():
return history, ""
return chat_with_usta(message, history, max_tok, temp, k, p), ""
send_btn.click(
send_message,
inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p],
outputs=[chatbot, msg]
)
msg.submit(
send_message,
inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p],
outputs=[chatbot, msg]
)
clear_btn.click(lambda: [], outputs=[chatbot])
load_url_btn.click(
load_model_from_url,
inputs=[model_url],
outputs=[status]
)
load_file_btn.click(
load_model_from_file,
inputs=[model_file],
outputs=[status]
)
if __name__ == "__main__":
demo.launch()
|