alibayram commited on
Commit
ea11d44
Β·
1 Parent(s): c700703

space update

Browse files
Files changed (1) hide show
  1. app.py +21 -9
app.py CHANGED
@@ -71,7 +71,23 @@ def load_model(custom_model_path=None):
71
 
72
  if os.path.exists(model_path):
73
  try:
74
- u_model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=False))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  u_model.eval()
76
  print("βœ… Model weights loaded successfully!")
77
  return u_model, u_tokenizer, f"βœ… Model loaded from: {model_path}"
@@ -188,15 +204,13 @@ with gr.Blocks(title="πŸ€– Usta Model Chat", theme=gr.themes.Soft()) as demo:
188
  gr.Markdown("### πŸ“ Model Upload (Optional)")
189
  model_file = gr.File(
190
  label="Upload your own model.pth file",
191
- file_types=[".pth", ".pt"],
192
- info="Upload a custom UstaModel checkpoint to use instead of the default model"
193
  )
194
  upload_btn = gr.Button("Load Model", variant="primary")
195
  model_status_display = gr.Textbox(
196
  label="Model Status",
197
  value=model_status,
198
- interactive=False,
199
- info="Shows the current model loading status"
200
  )
201
 
202
  with gr.Column(scale=1):
@@ -205,8 +219,7 @@ with gr.Blocks(title="πŸ€– Usta Model Chat", theme=gr.themes.Soft()) as demo:
205
  gr.Markdown("### βš™οΈ Generation Settings")
206
  system_msg = gr.Textbox(
207
  value="You are Usta, a geographical knowledge assistant trained from scratch.",
208
- label="System message",
209
- info="Note: This model focuses on geographical knowledge"
210
  )
211
  max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens")
212
  temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
@@ -215,8 +228,7 @@ with gr.Blocks(title="πŸ€– Usta Model Chat", theme=gr.themes.Soft()) as demo:
215
  maximum=1.0,
216
  value=0.95,
217
  step=0.05,
218
- label="Top-p (nucleus sampling)",
219
- info="Note: This parameter is not used by UstaModel"
220
  )
221
 
222
  # Chat interface
 
71
 
72
  if os.path.exists(model_path):
73
  try:
74
+ state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
75
+
76
+ # Handle potential key mapping issues
77
+ if "embedding.weight" in state_dict and "embedding.embedding.weight" not in state_dict:
78
+ # Map old key names to new key names
79
+ new_state_dict = {}
80
+ for key, value in state_dict.items():
81
+ if key == "embedding.weight":
82
+ new_state_dict["embedding.embedding.weight"] = value
83
+ elif key == "pos_embedding.weight":
84
+ # Skip positional embedding if not expected
85
+ continue
86
+ else:
87
+ new_state_dict[key] = value
88
+ state_dict = new_state_dict
89
+
90
+ u_model.load_state_dict(state_dict)
91
  u_model.eval()
92
  print("βœ… Model weights loaded successfully!")
93
  return u_model, u_tokenizer, f"βœ… Model loaded from: {model_path}"
 
204
  gr.Markdown("### πŸ“ Model Upload (Optional)")
205
  model_file = gr.File(
206
  label="Upload your own model.pth file",
207
+ file_types=[".pth", ".pt"]
 
208
  )
209
  upload_btn = gr.Button("Load Model", variant="primary")
210
  model_status_display = gr.Textbox(
211
  label="Model Status",
212
  value=model_status,
213
+ interactive=False
 
214
  )
215
 
216
  with gr.Column(scale=1):
 
219
  gr.Markdown("### βš™οΈ Generation Settings")
220
  system_msg = gr.Textbox(
221
  value="You are Usta, a geographical knowledge assistant trained from scratch.",
222
+ label="System message"
 
223
  )
224
  max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens")
225
  temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
 
228
  maximum=1.0,
229
  value=0.95,
230
  step=0.05,
231
+ label="Top-p (nucleus sampling)"
 
232
  )
233
 
234
  # Chat interface