Spaces:
Running
Running
space update
Browse files
app.py
CHANGED
@@ -71,7 +71,23 @@ def load_model(custom_model_path=None):
|
|
71 |
|
72 |
if os.path.exists(model_path):
|
73 |
try:
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
u_model.eval()
|
76 |
print("β
Model weights loaded successfully!")
|
77 |
return u_model, u_tokenizer, f"β
Model loaded from: {model_path}"
|
@@ -188,15 +204,13 @@ with gr.Blocks(title="π€ Usta Model Chat", theme=gr.themes.Soft()) as demo:
|
|
188 |
gr.Markdown("### π Model Upload (Optional)")
|
189 |
model_file = gr.File(
|
190 |
label="Upload your own model.pth file",
|
191 |
-
file_types=[".pth", ".pt"]
|
192 |
-
info="Upload a custom UstaModel checkpoint to use instead of the default model"
|
193 |
)
|
194 |
upload_btn = gr.Button("Load Model", variant="primary")
|
195 |
model_status_display = gr.Textbox(
|
196 |
label="Model Status",
|
197 |
value=model_status,
|
198 |
-
interactive=False
|
199 |
-
info="Shows the current model loading status"
|
200 |
)
|
201 |
|
202 |
with gr.Column(scale=1):
|
@@ -205,8 +219,7 @@ with gr.Blocks(title="π€ Usta Model Chat", theme=gr.themes.Soft()) as demo:
|
|
205 |
gr.Markdown("### βοΈ Generation Settings")
|
206 |
system_msg = gr.Textbox(
|
207 |
value="You are Usta, a geographical knowledge assistant trained from scratch.",
|
208 |
-
label="System message"
|
209 |
-
info="Note: This model focuses on geographical knowledge"
|
210 |
)
|
211 |
max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens")
|
212 |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
|
@@ -215,8 +228,7 @@ with gr.Blocks(title="π€ Usta Model Chat", theme=gr.themes.Soft()) as demo:
|
|
215 |
maximum=1.0,
|
216 |
value=0.95,
|
217 |
step=0.05,
|
218 |
-
label="Top-p (nucleus sampling)"
|
219 |
-
info="Note: This parameter is not used by UstaModel"
|
220 |
)
|
221 |
|
222 |
# Chat interface
|
|
|
71 |
|
72 |
if os.path.exists(model_path):
|
73 |
try:
|
74 |
+
state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
|
75 |
+
|
76 |
+
# Handle potential key mapping issues
|
77 |
+
if "embedding.weight" in state_dict and "embedding.embedding.weight" not in state_dict:
|
78 |
+
# Map old key names to new key names
|
79 |
+
new_state_dict = {}
|
80 |
+
for key, value in state_dict.items():
|
81 |
+
if key == "embedding.weight":
|
82 |
+
new_state_dict["embedding.embedding.weight"] = value
|
83 |
+
elif key == "pos_embedding.weight":
|
84 |
+
# Skip positional embedding if not expected
|
85 |
+
continue
|
86 |
+
else:
|
87 |
+
new_state_dict[key] = value
|
88 |
+
state_dict = new_state_dict
|
89 |
+
|
90 |
+
u_model.load_state_dict(state_dict)
|
91 |
u_model.eval()
|
92 |
print("β
Model weights loaded successfully!")
|
93 |
return u_model, u_tokenizer, f"β
Model loaded from: {model_path}"
|
|
|
204 |
gr.Markdown("### π Model Upload (Optional)")
|
205 |
model_file = gr.File(
|
206 |
label="Upload your own model.pth file",
|
207 |
+
file_types=[".pth", ".pt"]
|
|
|
208 |
)
|
209 |
upload_btn = gr.Button("Load Model", variant="primary")
|
210 |
model_status_display = gr.Textbox(
|
211 |
label="Model Status",
|
212 |
value=model_status,
|
213 |
+
interactive=False
|
|
|
214 |
)
|
215 |
|
216 |
with gr.Column(scale=1):
|
|
|
219 |
gr.Markdown("### βοΈ Generation Settings")
|
220 |
system_msg = gr.Textbox(
|
221 |
value="You are Usta, a geographical knowledge assistant trained from scratch.",
|
222 |
+
label="System message"
|
|
|
223 |
)
|
224 |
max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens")
|
225 |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
|
|
|
228 |
maximum=1.0,
|
229 |
value=0.95,
|
230 |
step=0.05,
|
231 |
+
label="Top-p (nucleus sampling)"
|
|
|
232 |
)
|
233 |
|
234 |
# Chat interface
|