import os import gradio as gr import torch from v2.usta_model import UstaModel from v2.usta_tokenizer import UstaTokenizer # Load the model and tokenizer def load_model(custom_model_path=None): try: u_tokenizer = UstaTokenizer("v2/tokenizer.json") print("✅ Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab)) # Model parameters - adjust these to match your trained model context_length = 32 vocab_size = len(u_tokenizer.vocab) embedding_dim = 12 num_heads = 4 num_layers = 8 device = "cpu" # Use CPU for compatibility # Load the model u_model = UstaModel( vocab_size=vocab_size, embedding_dim=embedding_dim, num_heads=num_heads, context_length=context_length, num_layers=num_layers, device=device ) # Determine which model file to use if custom_model_path and os.path.exists(custom_model_path): model_path = custom_model_path print(f"🎯 Using uploaded model: {model_path}") else: model_path = "v2/u_model_4000.pth" if not os.path.exists(model_path): print("❌ Model file not found at", model_path) # Download the model file from GitHub try: print("📥 Downloading model weights from GitHub...") import requests url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth" headers = { 'Accept': 'application/octet-stream', 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' } response = requests.get(url, headers=headers) response.raise_for_status() # Raise an exception for bad status codes # Check if we got a proper binary file (PyTorch files start with specific bytes) if response.content[:4] != b'PK\x03\x04' and b' 25: # Leave some room for generation tokens = tokens[-25:] # Generate response with torch.no_grad(): actual_max_tokens = min(max_tokens, 32 - len(tokens)) generated_tokens = model.generate( tokens, max_new_tokens=actual_max_tokens, temperature=temperature, top_k=top_k, top_p=top_p ) # Decode the generated tokens response = tokenizer.decode(generated_tokens) # Clean up the response (remove the original input) original_text = tokenizer.decode(tokens.tolist()) if response.startswith(original_text): response = response[len(original_text):] # Clean up any unwanted tokens response = response.replace("", "").replace("", "").strip() if not response: response = "I'm not sure how to respond to that with my geographical knowledge." # Add to history history.append([message, response]) return history except Exception as e: history.append([message, f"Sorry, I encountered an error: {str(e)}"]) return history # Create simple interface with gr.Blocks(title="🤖 Usta Model Chat") as demo: gr.Markdown("# 🤖 Usta Model Chat") gr.Markdown("Chat with a custom transformer language model built from scratch! This model specializes in geographical knowledge.") # Simple chat interface chatbot = gr.Chatbot(height=400) msg = gr.Textbox(label="Your message", placeholder="Ask about countries, capitals, or cities...") with gr.Row(): send_btn = gr.Button("Send", variant="primary") clear_btn = gr.Button("Clear") # Generation settings gr.Markdown("## ⚙️ Generation Settings") with gr.Row(): max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max tokens") temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature") with gr.Row(): top_k = gr.Slider(minimum=1, maximum=64, value=40, step=1, label="Top-k") top_p = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Top-p (nucleus sampling)") # Model loading (simplified) gr.Markdown("## 🔧 Load Custom Model (Optional)") with gr.Row(): model_url = gr.Textbox( label="Model URL", placeholder="https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth", scale=3 ) load_url_btn = gr.Button("Load from URL", scale=1) with gr.Row(): model_file = gr.File(label="Upload model file (.pth, .pt, .bin)") load_file_btn = gr.Button("Load File", scale=1) status = gr.Textbox(label="Status", value=model_status, interactive=False) # Event handlers def send_message(message, history, max_tok, temp, k, p): if not message.strip(): return history, "" return chat_with_usta(message, history, max_tok, temp, k, p), "" send_btn.click( send_message, inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p], outputs=[chatbot, msg] ) msg.submit( send_message, inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p], outputs=[chatbot, msg] ) clear_btn.click(lambda: [], outputs=[chatbot]) load_url_btn.click( load_model_from_url, inputs=[model_url], outputs=[status] ) load_file_btn.click( load_model_from_file, inputs=[model_file], outputs=[status] ) if __name__ == "__main__": demo.launch()