alibayram commited on
Commit
39dfa2d
Β·
1 Parent(s): 0eefbc1

space update

Browse files
Files changed (1) hide show
  1. app.py +45 -92
app.py CHANGED
@@ -182,30 +182,12 @@ def load_model_from_file(uploaded_file):
182
  model_status = error_msg
183
  return error_msg
184
 
185
- def respond(
186
- message,
187
- history: list[tuple[str, str]],
188
- system_message,
189
- max_tokens,
190
- temperature,
191
- top_p,
192
- model_file,
193
- model_url,
194
- load_file_btn,
195
- load_url_btn,
196
- status_display
197
- ):
198
- """
199
- Generate a response using the UstaModel
200
- """
201
  if model is None or tokenizer is None:
202
- yield "Sorry, the UstaModel is not available. Please try again later."
203
- return
204
 
205
  try:
206
- # For UstaModel, we'll use the message directly (ignoring system_message for now)
207
- # since it's a simpler model focused on geographical knowledge
208
-
209
  # Encode the input message
210
  tokens = tokenizer.encode(message)
211
 
@@ -215,7 +197,6 @@ def respond(
215
 
216
  # Generate response
217
  with torch.no_grad():
218
- # Use max_tokens parameter, but cap it at reasonable limit for this model
219
  actual_max_tokens = min(max_tokens, 32 - len(tokens))
220
  generated_tokens = model.generate(tokens, actual_max_tokens)
221
 
@@ -233,104 +214,76 @@ def respond(
233
  if not response:
234
  response = "I'm not sure how to respond to that with my geographical knowledge."
235
 
236
- # Yield the response (to maintain compatibility with streaming interface)
237
- yield response
 
238
 
239
  except Exception as e:
240
- yield f"Sorry, I encountered an error: {str(e)}"
 
241
 
242
- # Create a Blocks interface to properly handle events
243
- with gr.Blocks(title="πŸ€– Usta Model Chat", theme=gr.themes.Soft()) as demo:
244
  gr.Markdown("# πŸ€– Usta Model Chat")
245
- gr.Markdown("Chat with a custom transformer language model built from scratch! Upload your own model file or provide a URL to load a different model.")
246
 
247
- # Model loading section
248
- with gr.Accordion("πŸ”§ Model Loading Options", open=False):
249
- with gr.Row():
250
- with gr.Column():
251
- gr.Markdown("### πŸ“ Upload Model File")
252
- model_file = gr.File(label="Upload Model File (.pth)", file_types=[".pth", ".pt"])
253
- load_file_btn = gr.Button("Load from File", variant="primary")
254
-
255
- with gr.Column():
256
- gr.Markdown("### πŸ”— Load from URL")
257
- model_url = gr.Textbox(label="Model URL", placeholder="https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth")
258
- load_url_btn = gr.Button("Load from URL", variant="primary")
259
-
260
- status_display = gr.Textbox(label="Model Status", value=model_status, interactive=False)
261
 
262
- # Chat interface (simpler version)
263
- chatbot = gr.Chatbot(label="Chat", type="messages")
264
- msg = gr.Textbox(label="Message", placeholder="Type your message here...")
265
 
266
  # Generation settings
267
- with gr.Accordion("βš™οΈ Generation Settings", open=False):
268
- system_msg = gr.Textbox(
269
- value="You are Usta, a geographical knowledge assistant trained from scratch.",
270
- label="System message"
 
 
 
 
 
271
  )
272
- max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens")
273
- temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
274
- top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
275
 
276
- # Button controls
277
  with gr.Row():
278
- submit_btn = gr.Button("Send", variant="primary")
279
- clear_btn = gr.Button("Clear Chat", variant="secondary")
 
 
280
 
281
  # Event handlers
282
- def chat_respond(message, history, sys_msg, max_tok, temp, top_p_val):
283
  if not message.strip():
284
  return history, ""
285
-
286
- # Convert messages format for our respond function
287
- tuple_history = [(h["role"], h["content"]) for h in history if h["role"] != "system"]
288
-
289
- # Generate response using our existing function
290
- response_gen = respond(
291
- message, tuple_history, sys_msg, max_tok, temp, top_p_val,
292
- None, None, None, None, None # Dummy values for unused params
293
- )
294
-
295
- # Get the response
296
- response = ""
297
- for r in response_gen:
298
- response = r
299
-
300
- # Add to history in messages format
301
- history.append({"role": "user", "content": message})
302
- history.append({"role": "assistant", "content": response})
303
-
304
- return history, ""
305
 
306
- # Set up event handlers
307
- submit_btn.click(
308
- chat_respond,
309
- inputs=[msg, chatbot, system_msg, max_tokens, temperature, top_p],
310
  outputs=[chatbot, msg]
311
  )
312
 
313
  msg.submit(
314
- chat_respond,
315
- inputs=[msg, chatbot, system_msg, max_tokens, temperature, top_p],
316
  outputs=[chatbot, msg]
317
  )
318
 
319
- clear_btn.click(
320
- lambda: ([], ""),
321
- outputs=[chatbot, msg]
 
 
 
322
  )
323
 
324
  load_file_btn.click(
325
  load_model_from_file,
326
  inputs=[model_file],
327
- outputs=[status_display]
328
- )
329
-
330
- load_url_btn.click(
331
- load_model_from_url,
332
- inputs=[model_url],
333
- outputs=[status_display]
334
  )
335
 
336
  if __name__ == "__main__":
 
182
  model_status = error_msg
183
  return error_msg
184
 
185
+ def chat_with_usta(message, history, max_tokens=20):
186
+ """Simple chat function"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  if model is None or tokenizer is None:
188
+ return history + [["Error", "UstaModel is not available. Please try again later."]]
 
189
 
190
  try:
 
 
 
191
  # Encode the input message
192
  tokens = tokenizer.encode(message)
193
 
 
197
 
198
  # Generate response
199
  with torch.no_grad():
 
200
  actual_max_tokens = min(max_tokens, 32 - len(tokens))
201
  generated_tokens = model.generate(tokens, actual_max_tokens)
202
 
 
214
  if not response:
215
  response = "I'm not sure how to respond to that with my geographical knowledge."
216
 
217
+ # Add to history
218
+ history.append([message, response])
219
+ return history
220
 
221
  except Exception as e:
222
+ history.append([message, f"Sorry, I encountered an error: {str(e)}"])
223
+ return history
224
 
225
+ # Create simple interface
226
+ with gr.Blocks(title="πŸ€– Usta Model Chat") as demo:
227
  gr.Markdown("# πŸ€– Usta Model Chat")
228
+ gr.Markdown("Chat with a custom transformer language model built from scratch! This model specializes in geographical knowledge.")
229
 
230
+ # Simple chat interface
231
+ chatbot = gr.Chatbot(height=400)
232
+ msg = gr.Textbox(label="Your message", placeholder="Ask about countries, capitals, or cities...")
 
 
 
 
 
 
 
 
 
 
 
233
 
234
+ with gr.Row():
235
+ send_btn = gr.Button("Send", variant="primary")
236
+ clear_btn = gr.Button("Clear")
237
 
238
  # Generation settings
239
+ max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max tokens")
240
+
241
+ # Model loading (simplified)
242
+ gr.Markdown("## πŸ”§ Load Custom Model (Optional)")
243
+ with gr.Row():
244
+ model_url = gr.Textbox(
245
+ label="Model URL",
246
+ placeholder="https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth",
247
+ scale=3
248
  )
249
+ load_url_btn = gr.Button("Load from URL", scale=1)
 
 
250
 
 
251
  with gr.Row():
252
+ model_file = gr.File(label="Upload .pth file", file_types=[".pth"])
253
+ load_file_btn = gr.Button("Load File", scale=1)
254
+
255
+ status = gr.Textbox(label="Status", value=model_status, interactive=False)
256
 
257
  # Event handlers
258
+ def send_message(message, history, max_tok):
259
  if not message.strip():
260
  return history, ""
261
+ return chat_with_usta(message, history, max_tok), ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
+ send_btn.click(
264
+ send_message,
265
+ inputs=[msg, chatbot, max_tokens],
 
266
  outputs=[chatbot, msg]
267
  )
268
 
269
  msg.submit(
270
+ send_message,
271
+ inputs=[msg, chatbot, max_tokens],
272
  outputs=[chatbot, msg]
273
  )
274
 
275
+ clear_btn.click(lambda: [], outputs=[chatbot])
276
+
277
+ load_url_btn.click(
278
+ load_model_from_url,
279
+ inputs=[model_url],
280
+ outputs=[status]
281
  )
282
 
283
  load_file_btn.click(
284
  load_model_from_file,
285
  inputs=[model_file],
286
+ outputs=[status]
 
 
 
 
 
 
287
  )
288
 
289
  if __name__ == "__main__":