File size: 12,074 Bytes
8d4b0c7
 
fa3f584
8d4b0c7
 
6563ff2
 
fa3f584
 
8d4b0c7
c700703
8d4b0c7
6563ff2
8d6020c
8d4b0c7
 
 
 
 
 
 
6563ff2
8d4b0c7
 
 
 
 
 
 
6563ff2
 
1fe46bf
8423f1f
c700703
 
 
 
 
6563ff2
c700703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6563ff2
 
c700703
 
 
 
 
 
 
 
8d4b0c7
 
 
ea11d44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d4b0c7
 
c700703
8d4b0c7
 
 
c700703
8d4b0c7
 
c700703
8d4b0c7
 
 
 
 
c700703
 
 
8d4b0c7
 
c700703
8d4b0c7
 
 
c700703
 
ff7d616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c700703
 
 
 
 
 
67856b9
 
 
 
 
 
 
 
 
 
 
 
c700703
67856b9
c700703
 
 
 
 
 
 
 
 
67856b9
c700703
 
fa3f584
6563ff2
39dfa2d
8d4b0c7
39dfa2d
8d4b0c7
 
 
 
 
 
 
 
 
 
 
 
6563ff2
 
 
 
 
 
 
8d4b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39dfa2d
 
 
8d4b0c7
 
39dfa2d
 
fa3f584
39dfa2d
 
0eefbc1
39dfa2d
0eefbc1
39dfa2d
 
 
0eefbc1
39dfa2d
 
 
0eefbc1
 
6563ff2
 
 
 
 
 
 
 
39dfa2d
 
 
 
 
 
 
 
0eefbc1
39dfa2d
0eefbc1
 
67856b9
39dfa2d
 
 
0eefbc1
 
6563ff2
0eefbc1
 
6563ff2
c700703
39dfa2d
 
6563ff2
0eefbc1
 
 
 
39dfa2d
6563ff2
0eefbc1
 
 
39dfa2d
 
 
 
 
 
0eefbc1
 
ff7d616
 
 
39dfa2d
c700703
fa3f584
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import os

import gradio as gr
import torch

from v2.usta_model import UstaModel
from v2.usta_tokenizer import UstaTokenizer


# Load the model and tokenizer
def load_model(custom_model_path=None):
    try:
        u_tokenizer = UstaTokenizer("v2/tokenizer.json")
        print("βœ… Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab))
        
        # Model parameters - adjust these to match your trained model
        context_length = 32
        vocab_size = len(u_tokenizer.vocab)
        embedding_dim = 12
        num_heads = 4
        num_layers = 8
        device = "cpu"  # Use CPU for compatibility
        
        # Load the model
        u_model = UstaModel(
            vocab_size=vocab_size, 
            embedding_dim=embedding_dim,
            num_heads=num_heads, 
            context_length=context_length, 
            num_layers=num_layers,
            device=device
        )        
        
        # Determine which model file to use
        if custom_model_path and os.path.exists(custom_model_path):
            model_path = custom_model_path
            print(f"🎯 Using uploaded model: {model_path}")
        else:
            model_path = "v2/u_model_4000.pth"
            
            if not os.path.exists(model_path):
                print("❌ Model file not found at", model_path)
                # Download the model file from GitHub
                try:
                    print("πŸ“₯ Downloading model weights from GitHub...")
                    import requests
                    url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth"
                    
                    headers = {
                        'Accept': 'application/octet-stream',
                        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
                    }
                    
                    response = requests.get(url, headers=headers)
                    response.raise_for_status()  # Raise an exception for bad status codes
                    
                    # Check if we got a proper binary file (PyTorch files start with specific bytes)
                    if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower():
                        raise Exception("Downloaded HTML instead of binary file - check URL")
                    
                    print(f"πŸ“¦ Downloaded {len(response.content)} bytes")
                    
                    # Create v2 directory if it doesn't exist
                    os.makedirs("v2", exist_ok=True)
                    
                    # Save the model weights to the local file system
                    with open(model_path, "wb") as f:
                        f.write(response.content)
                    print("βœ… Model weights saved successfully!")
                except Exception as e:
                    print(f"❌ Failed to download model weights: {e}")
                    print("Using random initialization.")

        if os.path.exists(model_path):
            try:
                state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
                
                # Handle potential key mapping issues
                if "embedding.weight" in state_dict and "embedding.embedding.weight" not in state_dict:
                    # Map old key names to new key names
                    new_state_dict = {}
                    for key, value in state_dict.items():
                        if key == "embedding.weight":
                            new_state_dict["embedding.embedding.weight"] = value
                        elif key == "pos_embedding.weight":
                            # Skip positional embedding if not expected
                            continue
                        else:
                            new_state_dict[key] = value
                    state_dict = new_state_dict
                
                u_model.load_state_dict(state_dict)
                u_model.eval()
                print("βœ… Model weights loaded successfully!")
                return u_model, u_tokenizer, f"βœ… Model loaded from: {model_path}"
            except Exception as e:
                print(f"⚠️ Warning: Could not load trained weights: {e}")
                print("Using random initialization.")
                return u_model, u_tokenizer, f"⚠️ Failed to load weights: {e}"
        else:
            print(f"⚠️ Model file not found at {model_path}. Using random initialization.")
            return u_model, u_tokenizer, "⚠️ Using random initialization"
        
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        raise e

# Global model variables
model, tokenizer, model_status = None, None, "Not loaded"

# Initialize model and tokenizer globally
try:
    model, tokenizer, model_status = load_model()
    print("πŸš€ UstaModel and tokenizer initialized successfully!")
except Exception as e:
    print(f"❌ Failed to initialize model: {e}")
    model, tokenizer, model_status = None, None, f"❌ Error: {e}"

def load_model_from_url(url):
    """Load model from a URL"""
    global model, tokenizer, model_status
    
    if not url.strip():
        return "❌ Please provide a URL"
    
    try:
        print(f"πŸ“₯ Downloading model from URL: {url}")
        import requests
        
        headers = {
            'Accept': 'application/octet-stream',
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
        }
        
        response = requests.get(url, headers=headers)
        response.raise_for_status()
        
        # Check if we got a proper binary file
        if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower():
            return "❌ Downloaded HTML instead of binary file - check URL"
        
        # Save temporary file
        temp_path = "temp_model.pth"
        with open(temp_path, "wb") as f:
            f.write(response.content)
        
        # Load the model
        new_model, new_tokenizer, status = load_model(temp_path)
        
        # Update global variables
        model = new_model
        tokenizer = new_tokenizer
        model_status = status
        
        # Clean up temp file
        if os.path.exists(temp_path):
            os.remove(temp_path)
        
        return status
    except Exception as e:
        error_msg = f"❌ Failed to load model from URL: {e}"
        model_status = error_msg
        return error_msg

def load_model_from_file(uploaded_file):
    """Load model from uploaded file"""
    global model, tokenizer, model_status
    
    if uploaded_file is None:
        return "❌ No file uploaded"
    
    try:
        # Check if the file path exists and is valid
        file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else str(uploaded_file)
        
        # For HF Spaces compatibility, also try the upload path
        if not os.path.exists(file_path) and hasattr(uploaded_file, 'orig_name'):
            # Sometimes HF Spaces provides different paths
            print(f"Original path not found: {file_path}")
            print(f"Trying original name: {uploaded_file.orig_name}")
            file_path = uploaded_file.orig_name
        
        print(f"πŸ“ Attempting to load model from: {file_path}")
        
        # Load the new model
        new_model, new_tokenizer, status = load_model(file_path)
        
        # Update global variables
        model = new_model
        tokenizer = new_tokenizer
        model_status = status
        
        return status
    except Exception as e:
        error_msg = f"❌ Failed to load uploaded model: {e}"
        print(f"Error details: {e}")
        model_status = error_msg
        return error_msg

def chat_with_usta(message, history, max_tokens=20, temperature=1.0, top_k=64, top_p=1.0):
    """Simple chat function"""
    if model is None or tokenizer is None:
        return history + [["Error", "UstaModel is not available. Please try again later."]]
        
    try:
        # Encode the input message
        tokens = tokenizer.encode(message)
        
        # Make sure we don't exceed context length
        if len(tokens) > 25:  # Leave some room for generation
            tokens = tokens[-25:]
        
        # Generate response
        with torch.no_grad():
            actual_max_tokens = min(max_tokens, 32 - len(tokens))
            generated_tokens = model.generate(
                tokens, 
                max_new_tokens=actual_max_tokens,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p
            )
        
        # Decode the generated tokens
        response = tokenizer.decode(generated_tokens)
        
        # Clean up the response (remove the original input)
        original_text = tokenizer.decode(tokens.tolist())
        if response.startswith(original_text):
            response = response[len(original_text):]
        
        # Clean up any unwanted tokens
        response = response.replace("<unk>", "").replace("<pad>", "").strip()
        
        if not response:
            response = "I'm not sure how to respond to that with my geographical knowledge."
            
        # Add to history
        history.append([message, response])
        return history
        
    except Exception as e:
        history.append([message, f"Sorry, I encountered an error: {str(e)}"])
        return history

# Create simple interface
with gr.Blocks(title="πŸ€– Usta Model Chat") as demo:
    gr.Markdown("# πŸ€– Usta Model Chat")
    gr.Markdown("Chat with a custom transformer language model built from scratch! This model specializes in geographical knowledge.")
    
    # Simple chat interface
    chatbot = gr.Chatbot(height=400)
    msg = gr.Textbox(label="Your message", placeholder="Ask about countries, capitals, or cities...")
    
    with gr.Row():
        send_btn = gr.Button("Send", variant="primary")
        clear_btn = gr.Button("Clear")
    
    # Generation settings
    gr.Markdown("## βš™οΈ Generation Settings")
    with gr.Row():
        max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max tokens")
        temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
    
    with gr.Row():
        top_k = gr.Slider(minimum=1, maximum=64, value=40, step=1, label="Top-k")
        top_p = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Top-p (nucleus sampling)")
    
    # Model loading (simplified)
    gr.Markdown("## πŸ”§ Load Custom Model (Optional)")
    with gr.Row():
        model_url = gr.Textbox(
            label="Model URL", 
            placeholder="https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth",
            scale=3
        )
        load_url_btn = gr.Button("Load from URL", scale=1)
    
    with gr.Row():
        model_file = gr.File(label="Upload model file (.pth, .pt, .bin)")
        load_file_btn = gr.Button("Load File", scale=1)
    
    status = gr.Textbox(label="Status", value=model_status, interactive=False)
    
    # Event handlers
    def send_message(message, history, max_tok, temp, k, p):
        if not message.strip():
            return history, ""
        return chat_with_usta(message, history, max_tok, temp, k, p), ""
    
    send_btn.click(
        send_message,
        inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p],
        outputs=[chatbot, msg]
    )
    
    msg.submit(
        send_message,
        inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p],
        outputs=[chatbot, msg]
    )
    
    clear_btn.click(lambda: [], outputs=[chatbot])
    
    load_url_btn.click(
        load_model_from_url,
        inputs=[model_url],
        outputs=[status]
    )
    
    load_file_btn.click(
        load_model_from_file,
        inputs=[model_file],
        outputs=[status]
    )

if __name__ == "__main__":
    demo.launch()