File size: 10,935 Bytes
8d4b0c7
 
fa3f584
8d4b0c7
 
 
 
fa3f584
 
8d4b0c7
c700703
8d4b0c7
 
8d6020c
8d4b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fe46bf
8423f1f
c700703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d4b0c7
 
 
ea11d44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d4b0c7
 
c700703
8d4b0c7
 
 
c700703
8d4b0c7
 
c700703
8d4b0c7
 
 
 
 
c700703
 
 
8d4b0c7
 
c700703
8d4b0c7
 
 
c700703
 
ff7d616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c700703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa3f584
 
 
 
 
 
 
 
 
8d4b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa3f584
8d4b0c7
 
 
fa3f584
ff7d616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c700703
ff7d616
 
 
 
 
 
 
 
 
c700703
ff7d616
 
 
 
 
c700703
 
ff7d616
 
 
 
c700703
fa3f584
ff7d616
 
 
fa3f584
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import os

import gradio as gr
import torch

from v1.usta_model import UstaModel
from v1.usta_tokenizer import UstaTokenizer


# Load the model and tokenizer
def load_model(custom_model_path=None):
    try:
        u_tokenizer = UstaTokenizer("v1/tokenizer.json")
        print("βœ… Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab))
        
        # Model parameters - adjust these to match your trained model
        context_length = 32
        vocab_size = len(u_tokenizer.vocab)
        embedding_dim = 12
        num_heads = 4
        num_layers = 8
        
        # Load the model
        u_model = UstaModel(
            vocab_size=vocab_size, 
            embedding_dim=embedding_dim,
            num_heads=num_heads, 
            context_length=context_length, 
            num_layers=num_layers
        )        
        
        # Determine which model file to use
        if custom_model_path and os.path.exists(custom_model_path):
            model_path = custom_model_path
            print(f"🎯 Using uploaded model: {model_path}")
        else:
            model_path = "v1/u_model.pth"
            
            if not os.path.exists(model_path):
                print("❌ Model file not found at", model_path)
                # Download the model file from GitHub
                try:
                    print("πŸ“₯ Downloading model weights from GitHub...")
                    import requests
                    url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth"
                    
                    headers = {
                        'Accept': 'application/octet-stream',
                        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
                    }
                    
                    response = requests.get(url, headers=headers)
                    response.raise_for_status()  # Raise an exception for bad status codes
                    
                    # Check if we got a proper binary file (PyTorch files start with specific bytes)
                    if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower():
                        raise Exception("Downloaded HTML instead of binary file - check URL")
                    
                    print(f"πŸ“¦ Downloaded {len(response.content)} bytes")
                    
                    # Create v1 directory if it doesn't exist
                    os.makedirs("v1", exist_ok=True)
                    
                    # Save the model weights to the local file system
                    with open(model_path, "wb") as f:
                        f.write(response.content)
                    print("βœ… Model weights saved successfully!")
                except Exception as e:
                    print(f"❌ Failed to download model weights: {e}")
                    print("Using random initialization.")

        if os.path.exists(model_path):
            try:
                state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
                
                # Handle potential key mapping issues
                if "embedding.weight" in state_dict and "embedding.embedding.weight" not in state_dict:
                    # Map old key names to new key names
                    new_state_dict = {}
                    for key, value in state_dict.items():
                        if key == "embedding.weight":
                            new_state_dict["embedding.embedding.weight"] = value
                        elif key == "pos_embedding.weight":
                            # Skip positional embedding if not expected
                            continue
                        else:
                            new_state_dict[key] = value
                    state_dict = new_state_dict
                
                u_model.load_state_dict(state_dict)
                u_model.eval()
                print("βœ… Model weights loaded successfully!")
                return u_model, u_tokenizer, f"βœ… Model loaded from: {model_path}"
            except Exception as e:
                print(f"⚠️ Warning: Could not load trained weights: {e}")
                print("Using random initialization.")
                return u_model, u_tokenizer, f"⚠️ Failed to load weights: {e}"
        else:
            print(f"⚠️ Model file not found at {model_path}. Using random initialization.")
            return u_model, u_tokenizer, "⚠️ Using random initialization"
        
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        raise e

# Global model variables
model, tokenizer, model_status = None, None, "Not loaded"

# Initialize model and tokenizer globally
try:
    model, tokenizer, model_status = load_model()
    print("πŸš€ UstaModel and tokenizer initialized successfully!")
except Exception as e:
    print(f"❌ Failed to initialize model: {e}")
    model, tokenizer, model_status = None, None, f"❌ Error: {e}"

def load_model_from_url(url):
    """Load model from a URL"""
    global model, tokenizer, model_status
    
    if not url.strip():
        return "❌ Please provide a URL"
    
    try:
        print(f"πŸ“₯ Downloading model from URL: {url}")
        import requests
        
        headers = {
            'Accept': 'application/octet-stream',
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
        }
        
        response = requests.get(url, headers=headers)
        response.raise_for_status()
        
        # Check if we got a proper binary file
        if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower():
            return "❌ Downloaded HTML instead of binary file - check URL"
        
        # Save temporary file
        temp_path = "temp_model.pth"
        with open(temp_path, "wb") as f:
            f.write(response.content)
        
        # Load the model
        new_model, new_tokenizer, status = load_model(temp_path)
        
        # Update global variables
        model = new_model
        tokenizer = new_tokenizer
        model_status = status
        
        # Clean up temp file
        if os.path.exists(temp_path):
            os.remove(temp_path)
        
        return status
    except Exception as e:
        error_msg = f"❌ Failed to load model from URL: {e}"
        model_status = error_msg
        return error_msg

def load_model_from_file(uploaded_file):
    """Load model from uploaded file"""
    global model, tokenizer, model_status
    
    if uploaded_file is None:
        return "❌ No file uploaded"
    
    try:
        # Load the new model
        new_model, new_tokenizer, status = load_model(uploaded_file.name)
        
        # Update global variables
        model = new_model
        tokenizer = new_tokenizer
        model_status = status
        
        return status
    except Exception as e:
        error_msg = f"❌ Failed to load uploaded model: {e}"
        model_status = error_msg
        return error_msg

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    """
    Generate a response using the UstaModel
    """
    if model is None or tokenizer is None:
        yield "Sorry, the UstaModel is not available. Please try again later."
        return
        
    try:
        # For UstaModel, we'll use the message directly (ignoring system_message for now)
        # since it's a simpler model focused on geographical knowledge
        
        # Encode the input message
        tokens = tokenizer.encode(message)
        
        # Make sure we don't exceed context length
        if len(tokens) > 25:  # Leave some room for generation
            tokens = tokens[-25:]
        
        # Generate response
        with torch.no_grad():
            # Use max_tokens parameter, but cap it at reasonable limit for this model
            actual_max_tokens = min(max_tokens, 32 - len(tokens))
            generated_tokens = model.generate(tokens, actual_max_tokens)
        
        # Decode the generated tokens
        response = tokenizer.decode(generated_tokens)
        
        # Clean up the response (remove the original input)
        original_text = tokenizer.decode(tokens.tolist())
        if response.startswith(original_text):
            response = response[len(original_text):]
        
        # Clean up any unwanted tokens
        response = response.replace("<unk>", "").replace("<pad>", "").strip()
        
        if not response:
            response = "I'm not sure how to respond to that with my geographical knowledge."
            
        # Yield the response (to maintain compatibility with streaming interface)
        yield response
        
    except Exception as e:
        yield f"Sorry, I encountered an error: {str(e)}"

# Create the simple ChatInterface with additional inputs for model loading
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are Usta, a geographical knowledge assistant trained from scratch.", 
            label="System message"
        ),
        gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)"
        ),
        gr.File(label="Upload Model File (.pth)", file_types=[".pth", ".pt"]),
        gr.Textbox(label="Or Model URL", placeholder="https://github.com/user/repo/raw/main/model.pth"),
        gr.Button("Load from File", variant="secondary"),
        gr.Button("Load from URL", variant="secondary"),
        gr.Textbox(label="Model Status", value=model_status, interactive=False)
    ],
    title="πŸ€– Usta Model Chat",
    description="Chat with a custom transformer language model built from scratch! Upload your own model file or provide a URL to load a different model."
)

# Add event handlers after creating the interface
def setup_events():
    # Get the additional inputs
    inputs = demo.additional_inputs
    model_file = inputs[4]  # File upload
    model_url = inputs[5]   # URL input
    load_file_btn = inputs[6]  # Load from file button
    load_url_btn = inputs[7]   # Load from URL button
    status_display = inputs[8]  # Status display
    
    # Set up event handlers
    load_file_btn.click(
        load_model_from_file,
        inputs=[model_file],
        outputs=[status_display]
    )
    
    load_url_btn.click(
        load_model_from_url,
        inputs=[model_url],
        outputs=[status_display]
    )

# Set up events after interface creation
demo.load(setup_events)

if __name__ == "__main__":
    demo.launch()