File size: 9,881 Bytes
8d4b0c7
 
fa3f584
8d4b0c7
 
 
 
fa3f584
 
8d4b0c7
c700703
8d4b0c7
 
8d6020c
8d4b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fe46bf
8423f1f
c700703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d4b0c7
 
 
ea11d44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d4b0c7
 
c700703
8d4b0c7
 
 
c700703
8d4b0c7
 
c700703
8d4b0c7
 
 
 
 
c700703
 
 
8d4b0c7
 
c700703
8d4b0c7
 
 
c700703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa3f584
 
 
 
 
 
 
 
 
8d4b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa3f584
8d4b0c7
 
 
fa3f584
 
 
 
c700703
 
 
 
 
 
 
 
 
 
 
 
 
ea11d44
c700703
 
 
 
 
ea11d44
c700703
 
 
 
 
 
 
 
ea11d44
c700703
 
 
 
 
 
 
 
ea11d44
c700703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa3f584
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import os

import gradio as gr
import torch

from v1.usta_model import UstaModel
from v1.usta_tokenizer import UstaTokenizer


# Load the model and tokenizer
def load_model(custom_model_path=None):
    try:
        u_tokenizer = UstaTokenizer("v1/tokenizer.json")
        print("βœ… Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab))
        
        # Model parameters - adjust these to match your trained model
        context_length = 32
        vocab_size = len(u_tokenizer.vocab)
        embedding_dim = 12
        num_heads = 4
        num_layers = 8
        
        # Load the model
        u_model = UstaModel(
            vocab_size=vocab_size, 
            embedding_dim=embedding_dim,
            num_heads=num_heads, 
            context_length=context_length, 
            num_layers=num_layers
        )        
        
        # Determine which model file to use
        if custom_model_path and os.path.exists(custom_model_path):
            model_path = custom_model_path
            print(f"🎯 Using uploaded model: {model_path}")
        else:
            model_path = "v1/u_model.pth"
            
            if not os.path.exists(model_path):
                print("❌ Model file not found at", model_path)
                # Download the model file from GitHub
                try:
                    print("πŸ“₯ Downloading model weights from GitHub...")
                    import requests
                    url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth"
                    
                    headers = {
                        'Accept': 'application/octet-stream',
                        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
                    }
                    
                    response = requests.get(url, headers=headers)
                    response.raise_for_status()  # Raise an exception for bad status codes
                    
                    # Check if we got a proper binary file (PyTorch files start with specific bytes)
                    if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower():
                        raise Exception("Downloaded HTML instead of binary file - check URL")
                    
                    print(f"πŸ“¦ Downloaded {len(response.content)} bytes")
                    
                    # Create v1 directory if it doesn't exist
                    os.makedirs("v1", exist_ok=True)
                    
                    # Save the model weights to the local file system
                    with open(model_path, "wb") as f:
                        f.write(response.content)
                    print("βœ… Model weights saved successfully!")
                except Exception as e:
                    print(f"❌ Failed to download model weights: {e}")
                    print("Using random initialization.")

        if os.path.exists(model_path):
            try:
                state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
                
                # Handle potential key mapping issues
                if "embedding.weight" in state_dict and "embedding.embedding.weight" not in state_dict:
                    # Map old key names to new key names
                    new_state_dict = {}
                    for key, value in state_dict.items():
                        if key == "embedding.weight":
                            new_state_dict["embedding.embedding.weight"] = value
                        elif key == "pos_embedding.weight":
                            # Skip positional embedding if not expected
                            continue
                        else:
                            new_state_dict[key] = value
                    state_dict = new_state_dict
                
                u_model.load_state_dict(state_dict)
                u_model.eval()
                print("βœ… Model weights loaded successfully!")
                return u_model, u_tokenizer, f"βœ… Model loaded from: {model_path}"
            except Exception as e:
                print(f"⚠️ Warning: Could not load trained weights: {e}")
                print("Using random initialization.")
                return u_model, u_tokenizer, f"⚠️ Failed to load weights: {e}"
        else:
            print(f"⚠️ Model file not found at {model_path}. Using random initialization.")
            return u_model, u_tokenizer, "⚠️ Using random initialization"
        
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        raise e

# Global model variables
model, tokenizer, model_status = None, None, "Not loaded"

# Initialize model and tokenizer globally
try:
    model, tokenizer, model_status = load_model()
    print("πŸš€ UstaModel and tokenizer initialized successfully!")
except Exception as e:
    print(f"❌ Failed to initialize model: {e}")
    model, tokenizer, model_status = None, None, f"❌ Error: {e}"

def update_model(uploaded_file):
    """Update the model when a new file is uploaded"""
    global model, tokenizer, model_status
    
    if uploaded_file is None:
        return "❌ No file uploaded"
    
    try:
        # Load the new model
        new_model, new_tokenizer, status = load_model(uploaded_file.name)
        
        # Update global variables
        model = new_model
        tokenizer = new_tokenizer
        model_status = status
        
        return status
    except Exception as e:
        error_msg = f"❌ Failed to load uploaded model: {e}"
        model_status = error_msg
        return error_msg

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    """
    Generate a response using the UstaModel
    """
    if model is None or tokenizer is None:
        yield "Sorry, the UstaModel is not available. Please try again later."
        return
        
    try:
        # For UstaModel, we'll use the message directly (ignoring system_message for now)
        # since it's a simpler model focused on geographical knowledge
        
        # Encode the input message
        tokens = tokenizer.encode(message)
        
        # Make sure we don't exceed context length
        if len(tokens) > 25:  # Leave some room for generation
            tokens = tokens[-25:]
        
        # Generate response
        with torch.no_grad():
            # Use max_tokens parameter, but cap it at reasonable limit for this model
            actual_max_tokens = min(max_tokens, 32 - len(tokens))
            generated_tokens = model.generate(tokens, actual_max_tokens)
        
        # Decode the generated tokens
        response = tokenizer.decode(generated_tokens)
        
        # Clean up the response (remove the original input)
        original_text = tokenizer.decode(tokens.tolist())
        if response.startswith(original_text):
            response = response[len(original_text):]
        
        # Clean up any unwanted tokens
        response = response.replace("<unk>", "").replace("<pad>", "").strip()
        
        if not response:
            response = "I'm not sure how to respond to that with my geographical knowledge."
            
        # Yield the response (to maintain compatibility with streaming interface)
        yield response
        
    except Exception as e:
        yield f"Sorry, I encountered an error: {str(e)}"

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""

# Create the interface with file upload
with gr.Blocks(title="πŸ€– Usta Model Chat", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# πŸ€– Usta Model Chat")
    gr.Markdown("Chat with a custom transformer language model built from scratch! This model specializes in geographical knowledge including countries, capitals, and cities.")
    
    with gr.Row():
        with gr.Column(scale=2):
            # Model upload section
            with gr.Group():
                gr.Markdown("### πŸ“ Model Upload (Optional)")
                model_file = gr.File(
                    label="Upload your own model.pth file",
                    file_types=[".pth", ".pt"]
                )
                upload_btn = gr.Button("Load Model", variant="primary")
                model_status_display = gr.Textbox(
                    label="Model Status",
                    value=model_status,
                    interactive=False
                )
        
        with gr.Column(scale=1):
            # Settings
            with gr.Group():
                gr.Markdown("### βš™οΈ Generation Settings")
                system_msg = gr.Textbox(
                    value="You are Usta, a geographical knowledge assistant trained from scratch.", 
                    label="System message"
                )
                max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens")
                temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
                top_p = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.95,
                    step=0.05,
                    label="Top-p (nucleus sampling)"
                )
    
    # Chat interface
    chatbot = gr.ChatInterface(
        respond,
        additional_inputs=[system_msg, max_tokens, temperature, top_p],
        chatbot=gr.Chatbot(height=400),
        title=None,  # We already have title above
        description=None  # We already have description above
    )
    
    # Event handlers
    upload_btn.click(
        update_model,
        inputs=[model_file],
        outputs=[model_status_display]
    )

if __name__ == "__main__":
    demo.launch()