File size: 6,528 Bytes
8d4b0c7
 
fa3f584
8d4b0c7
 
 
 
fa3f584
 
8d4b0c7
 
 
 
8d6020c
8d4b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fe46bf
8423f1f
8d4b0c7
 
 
 
8d6020c
8d4b0c7
 
 
 
db97ce9
8423f1f
 
 
 
 
 
 
8d4b0c7
8423f1f
 
 
 
 
 
8d4b0c7
 
 
8423f1f
8d6020c
8d4b0c7
 
8423f1f
8d4b0c7
 
 
 
 
 
8d6020c
8d4b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa3f584
 
 
 
 
 
 
 
 
8d4b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa3f584
8d4b0c7
 
 
fa3f584
 
 
 
 
 
 
8d4b0c7
 
 
 
 
 
 
fa3f584
 
 
 
 
 
8d4b0c7
fa3f584
 
8d4b0c7
8423f1f
fa3f584
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os

import gradio as gr
import torch

from v1.usta_model import UstaModel
from v1.usta_tokenizer import UstaTokenizer


# Load the model and tokenizer
def load_model():
    try:
        u_tokenizer = UstaTokenizer("v1/tokenizer.json")
        print("βœ… Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab))
        
        # Model parameters - adjust these to match your trained model
        context_length = 32
        vocab_size = len(u_tokenizer.vocab)
        embedding_dim = 12
        num_heads = 4
        num_layers = 8
        
        # Load the model
        u_model = UstaModel(
            vocab_size=vocab_size, 
            embedding_dim=embedding_dim,
            num_heads=num_heads, 
            context_length=context_length, 
            num_layers=num_layers
        )        
        
        # Load the trained weights if available
        model_path = "v1/u_model.pth"

        if not os.path.exists(model_path):
            print("❌ Model file not found at", model_path)
            # Download the model file from GitHub
            try:
                print("πŸ“₯ Downloading model weights from GitHub...")
                import requests
                url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth"
                
                headers = {
                    'Accept': 'application/octet-stream',
                    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
                }
                
                response = requests.get(url, headers=headers)
                response.raise_for_status()  # Raise an exception for bad status codes
                
                # Check if we got a proper binary file (PyTorch files start with specific bytes)
                if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower():
                    raise Exception("Downloaded HTML instead of binary file - check URL")
                
                print(f"πŸ“¦ Downloaded {len(response.content)} bytes")
                
                # Create v1 directory if it doesn't exist
                os.makedirs("v1", exist_ok=True)
                
                # Save the model weights to the local file system
                with open(model_path, "wb") as f:
                    f.write(response.content)
                print("βœ… Model weights saved successfully!")
            except Exception as e:
                print(f"❌ Failed to download model weights: {e}")
                print("Using random initialization.")

        if os.path.exists(model_path):
            try:
                u_model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=False))
                u_model.eval()
                print("βœ… Model weights loaded successfully!")
            except Exception as e:
                print(f"⚠️ Warning: Could not load trained weights: {e}")
                print("Using random initialization.")
        else:
            print(f"⚠️ Model file not found at {model_path}. Using random initialization.")
        
        return u_model, u_tokenizer
        
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        raise e

# Initialize model and tokenizer globally
try:
    model, tokenizer = load_model()
    print("πŸš€ UstaModel and tokenizer initialized successfully!")
except Exception as e:
    print(f"❌ Failed to initialize model: {e}")
    model, tokenizer = None, None

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    """
    Generate a response using the UstaModel
    """
    if model is None or tokenizer is None:
        yield "Sorry, the UstaModel is not available. Please try again later."
        return
        
    try:
        # For UstaModel, we'll use the message directly (ignoring system_message for now)
        # since it's a simpler model focused on geographical knowledge
        
        # Encode the input message
        tokens = tokenizer.encode(message)
        
        # Make sure we don't exceed context length
        if len(tokens) > 25:  # Leave some room for generation
            tokens = tokens[-25:]
        
        # Generate response
        with torch.no_grad():
            # Use max_tokens parameter, but cap it at reasonable limit for this model
            actual_max_tokens = min(max_tokens, 32 - len(tokens))
            generated_tokens = model.generate(tokens, actual_max_tokens)
        
        # Decode the generated tokens
        response = tokenizer.decode(generated_tokens)
        
        # Clean up the response (remove the original input)
        original_text = tokenizer.decode(tokens.tolist())
        if response.startswith(original_text):
            response = response[len(original_text):]
        
        # Clean up any unwanted tokens
        response = response.replace("<unk>", "").replace("<pad>", "").strip()
        
        if not response:
            response = "I'm not sure how to respond to that with my geographical knowledge."
            
        # Yield the response (to maintain compatibility with streaming interface)
        yield response
        
    except Exception as e:
        yield f"Sorry, I encountered an error: {str(e)}"

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are Usta, a geographical knowledge assistant trained from scratch.", 
            label="System message",
            info="Note: This model focuses on geographical knowledge (countries, capitals, cities)"
        ),
        gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
            info="Note: This parameter is not used by UstaModel but kept for interface compatibility"
        ),
    ],
    title="πŸ€– Usta Model Chat",
    description="Chat with a custom transformer language model built from scratch! This model specializes in geographical knowledge including countries, capitals, and cities."
)

if __name__ == "__main__":
    demo.launch()