alibayram commited on
Commit
c700703
Β·
1 Parent(s): db97ce9

space update

Browse files
Files changed (1) hide show
  1. app.py +129 -62
app.py CHANGED
@@ -8,7 +8,7 @@ from v1.usta_tokenizer import UstaTokenizer
8
 
9
 
10
  # Load the model and tokenizer
11
- def load_model():
12
  try:
13
  u_tokenizer = UstaTokenizer("v1/tokenizer.json")
14
  print("βœ… Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab))
@@ -29,66 +29,96 @@ def load_model():
29
  num_layers=num_layers
30
  )
31
 
32
- # Load the trained weights if available
33
- model_path = "v1/u_model.pth"
34
-
35
- if not os.path.exists(model_path):
36
- print("❌ Model file not found at", model_path)
37
- # Download the model file from GitHub
38
- try:
39
- print("πŸ“₯ Downloading model weights from GitHub...")
40
- import requests
41
- url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth"
42
-
43
- headers = {
44
- 'Accept': 'application/octet-stream',
45
- 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
46
- }
47
-
48
- response = requests.get(url, headers=headers)
49
- response.raise_for_status() # Raise an exception for bad status codes
50
-
51
- # Check if we got a proper binary file (PyTorch files start with specific bytes)
52
- if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower():
53
- raise Exception("Downloaded HTML instead of binary file - check URL")
54
-
55
- print(f"πŸ“¦ Downloaded {len(response.content)} bytes")
56
-
57
- # Create v1 directory if it doesn't exist
58
- os.makedirs("v1", exist_ok=True)
59
-
60
- # Save the model weights to the local file system
61
- with open(model_path, "wb") as f:
62
- f.write(response.content)
63
- print("βœ… Model weights saved successfully!")
64
- except Exception as e:
65
- print(f"❌ Failed to download model weights: {e}")
66
- print("Using random initialization.")
 
 
 
 
67
 
68
  if os.path.exists(model_path):
69
  try:
70
  u_model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=False))
71
  u_model.eval()
72
  print("βœ… Model weights loaded successfully!")
 
73
  except Exception as e:
74
  print(f"⚠️ Warning: Could not load trained weights: {e}")
75
  print("Using random initialization.")
 
76
  else:
77
  print(f"⚠️ Model file not found at {model_path}. Using random initialization.")
78
-
79
- return u_model, u_tokenizer
80
 
81
  except Exception as e:
82
  print(f"❌ Error loading model: {e}")
83
  raise e
84
 
 
 
 
85
  # Initialize model and tokenizer globally
86
  try:
87
- model, tokenizer = load_model()
88
  print("πŸš€ UstaModel and tokenizer initialized successfully!")
89
  except Exception as e:
90
  print(f"❌ Failed to initialize model: {e}")
91
- model, tokenizer = None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def respond(
94
  message,
@@ -145,28 +175,65 @@ def respond(
145
  """
146
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
147
  """
148
- demo = gr.ChatInterface(
149
- respond,
150
- additional_inputs=[
151
- gr.Textbox(
152
- value="You are Usta, a geographical knowledge assistant trained from scratch.",
153
- label="System message",
154
- info="Note: This model focuses on geographical knowledge (countries, capitals, cities)"
155
- ),
156
- gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens"),
157
- gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature"),
158
- gr.Slider(
159
- minimum=0.1,
160
- maximum=1.0,
161
- value=0.95,
162
- step=0.05,
163
- label="Top-p (nucleus sampling)",
164
- info="Note: This parameter is not used by UstaModel but kept for interface compatibility"
165
- ),
166
- ],
167
- title="πŸ€– Usta Model Chat",
168
- description="Chat with a custom transformer language model built from scratch! This model specializes in geographical knowledge including countries, capitals, and cities."
169
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  if __name__ == "__main__":
172
  demo.launch()
 
8
 
9
 
10
  # Load the model and tokenizer
11
+ def load_model(custom_model_path=None):
12
  try:
13
  u_tokenizer = UstaTokenizer("v1/tokenizer.json")
14
  print("βœ… Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab))
 
29
  num_layers=num_layers
30
  )
31
 
32
+ # Determine which model file to use
33
+ if custom_model_path and os.path.exists(custom_model_path):
34
+ model_path = custom_model_path
35
+ print(f"🎯 Using uploaded model: {model_path}")
36
+ else:
37
+ model_path = "v1/u_model.pth"
38
+
39
+ if not os.path.exists(model_path):
40
+ print("❌ Model file not found at", model_path)
41
+ # Download the model file from GitHub
42
+ try:
43
+ print("πŸ“₯ Downloading model weights from GitHub...")
44
+ import requests
45
+ url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth"
46
+
47
+ headers = {
48
+ 'Accept': 'application/octet-stream',
49
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
50
+ }
51
+
52
+ response = requests.get(url, headers=headers)
53
+ response.raise_for_status() # Raise an exception for bad status codes
54
+
55
+ # Check if we got a proper binary file (PyTorch files start with specific bytes)
56
+ if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower():
57
+ raise Exception("Downloaded HTML instead of binary file - check URL")
58
+
59
+ print(f"πŸ“¦ Downloaded {len(response.content)} bytes")
60
+
61
+ # Create v1 directory if it doesn't exist
62
+ os.makedirs("v1", exist_ok=True)
63
+
64
+ # Save the model weights to the local file system
65
+ with open(model_path, "wb") as f:
66
+ f.write(response.content)
67
+ print("βœ… Model weights saved successfully!")
68
+ except Exception as e:
69
+ print(f"❌ Failed to download model weights: {e}")
70
+ print("Using random initialization.")
71
 
72
  if os.path.exists(model_path):
73
  try:
74
  u_model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=False))
75
  u_model.eval()
76
  print("βœ… Model weights loaded successfully!")
77
+ return u_model, u_tokenizer, f"βœ… Model loaded from: {model_path}"
78
  except Exception as e:
79
  print(f"⚠️ Warning: Could not load trained weights: {e}")
80
  print("Using random initialization.")
81
+ return u_model, u_tokenizer, f"⚠️ Failed to load weights: {e}"
82
  else:
83
  print(f"⚠️ Model file not found at {model_path}. Using random initialization.")
84
+ return u_model, u_tokenizer, "⚠️ Using random initialization"
 
85
 
86
  except Exception as e:
87
  print(f"❌ Error loading model: {e}")
88
  raise e
89
 
90
+ # Global model variables
91
+ model, tokenizer, model_status = None, None, "Not loaded"
92
+
93
  # Initialize model and tokenizer globally
94
  try:
95
+ model, tokenizer, model_status = load_model()
96
  print("πŸš€ UstaModel and tokenizer initialized successfully!")
97
  except Exception as e:
98
  print(f"❌ Failed to initialize model: {e}")
99
+ model, tokenizer, model_status = None, None, f"❌ Error: {e}"
100
+
101
+ def update_model(uploaded_file):
102
+ """Update the model when a new file is uploaded"""
103
+ global model, tokenizer, model_status
104
+
105
+ if uploaded_file is None:
106
+ return "❌ No file uploaded"
107
+
108
+ try:
109
+ # Load the new model
110
+ new_model, new_tokenizer, status = load_model(uploaded_file.name)
111
+
112
+ # Update global variables
113
+ model = new_model
114
+ tokenizer = new_tokenizer
115
+ model_status = status
116
+
117
+ return status
118
+ except Exception as e:
119
+ error_msg = f"❌ Failed to load uploaded model: {e}"
120
+ model_status = error_msg
121
+ return error_msg
122
 
123
  def respond(
124
  message,
 
175
  """
176
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
177
  """
178
+
179
+ # Create the interface with file upload
180
+ with gr.Blocks(title="πŸ€– Usta Model Chat", theme=gr.themes.Soft()) as demo:
181
+ gr.Markdown("# πŸ€– Usta Model Chat")
182
+ gr.Markdown("Chat with a custom transformer language model built from scratch! This model specializes in geographical knowledge including countries, capitals, and cities.")
183
+
184
+ with gr.Row():
185
+ with gr.Column(scale=2):
186
+ # Model upload section
187
+ with gr.Group():
188
+ gr.Markdown("### πŸ“ Model Upload (Optional)")
189
+ model_file = gr.File(
190
+ label="Upload your own model.pth file",
191
+ file_types=[".pth", ".pt"],
192
+ info="Upload a custom UstaModel checkpoint to use instead of the default model"
193
+ )
194
+ upload_btn = gr.Button("Load Model", variant="primary")
195
+ model_status_display = gr.Textbox(
196
+ label="Model Status",
197
+ value=model_status,
198
+ interactive=False,
199
+ info="Shows the current model loading status"
200
+ )
201
+
202
+ with gr.Column(scale=1):
203
+ # Settings
204
+ with gr.Group():
205
+ gr.Markdown("### βš™οΈ Generation Settings")
206
+ system_msg = gr.Textbox(
207
+ value="You are Usta, a geographical knowledge assistant trained from scratch.",
208
+ label="System message",
209
+ info="Note: This model focuses on geographical knowledge"
210
+ )
211
+ max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens")
212
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
213
+ top_p = gr.Slider(
214
+ minimum=0.1,
215
+ maximum=1.0,
216
+ value=0.95,
217
+ step=0.05,
218
+ label="Top-p (nucleus sampling)",
219
+ info="Note: This parameter is not used by UstaModel"
220
+ )
221
+
222
+ # Chat interface
223
+ chatbot = gr.ChatInterface(
224
+ respond,
225
+ additional_inputs=[system_msg, max_tokens, temperature, top_p],
226
+ chatbot=gr.Chatbot(height=400),
227
+ title=None, # We already have title above
228
+ description=None # We already have description above
229
+ )
230
+
231
+ # Event handlers
232
+ upload_btn.click(
233
+ update_model,
234
+ inputs=[model_file],
235
+ outputs=[model_status_display]
236
+ )
237
 
238
  if __name__ == "__main__":
239
  demo.launch()