Spaces:
Running
on
Zero
Running
on
Zero
AbstractPhil
commited on
Commit
·
3248cf5
1
Parent(s):
5d33a3c
peft loading fixed
Browse files
app.py
CHANGED
@@ -41,8 +41,8 @@ except ImportError:
|
|
41 |
# -----------------------
|
42 |
# MX format uses special dtypes - we need to handle this properly
|
43 |
MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
|
44 |
-
ADAPTER_ID = os.getenv("ADAPTER_ID", "AbstractPhil/mirel-gpt-oss-20b")
|
45 |
-
ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "checkpoints/checkpoint-516")
|
46 |
ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
|
47 |
SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.")
|
48 |
MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
|
@@ -163,7 +163,7 @@ def convert_fp32_lora_to_mx_compatible(lora_state_dict: Dict[str, torch.Tensor])
|
|
163 |
|
164 |
return converted
|
165 |
|
166 |
-
def prepare_model_for_mx_lora(model, adapter_path: str):
|
167 |
"""
|
168 |
Prepare and attach LoRA adapter to MX format model.
|
169 |
Handles the special requirements of GPT-OSS MX models.
|
@@ -171,24 +171,80 @@ def prepare_model_for_mx_lora(model, adapter_path: str):
|
|
171 |
if not _HAS_PEFT:
|
172 |
raise RuntimeError("PEFT is required for LoRA adapters. Install with: pip install peft")
|
173 |
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
-
|
177 |
-
peft_config = PeftConfig.from_pretrained(adapter_path, token=HF_TOKEN)
|
178 |
|
179 |
-
# Load the LoRA weights
|
180 |
from safetensors.torch import load_file
|
181 |
import os.path as osp
|
|
|
182 |
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
else:
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
# Convert weights for MX compatibility
|
194 |
print("[LoRA] Converting fp32 weights for MX format compatibility...")
|
@@ -203,8 +259,7 @@ def prepare_model_for_mx_lora(model, adapter_path: str):
|
|
203 |
model,
|
204 |
adapter_path,
|
205 |
is_trainable=False,
|
206 |
-
token
|
207 |
-
# Don't specify torch_dtype here - let it match the base model
|
208 |
)
|
209 |
|
210 |
# Manually update the adapter weights with our converted versions
|
@@ -285,18 +340,21 @@ def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
|
|
285 |
if ADAPTER_ID:
|
286 |
try:
|
287 |
if is_mx_model:
|
288 |
-
# Use special MX-compatible LoRA loading
|
289 |
-
model = prepare_model_for_mx_lora(model, ADAPTER_ID)
|
290 |
else:
|
291 |
# Standard PEFT loading for non-MX models
|
292 |
if not _HAS_PEFT:
|
293 |
raise RuntimeError("PEFT is required when ADAPTER_ID is set.")
|
294 |
print(f"[Model] Loading adapter from {ADAPTER_ID} (standard mode)...")
|
|
|
|
|
|
|
|
|
295 |
model = PeftModel.from_pretrained(
|
296 |
model,
|
297 |
-
ADAPTER_ID,
|
298 |
-
|
299 |
-
token=HF_TOKEN
|
300 |
)
|
301 |
|
302 |
print("[Model] Successfully loaded with LoRA adapter")
|
|
|
41 |
# -----------------------
|
42 |
# MX format uses special dtypes - we need to handle this properly
|
43 |
MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
|
44 |
+
ADAPTER_ID = os.getenv("ADAPTER_ID", "AbstractPhil/mirel-gpt-oss-20b") # Default to your adapter
|
45 |
+
ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "checkpoints/checkpoint-516") # Default to the subfolder
|
46 |
ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
|
47 |
SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.")
|
48 |
MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
|
|
|
163 |
|
164 |
return converted
|
165 |
|
166 |
+
def prepare_model_for_mx_lora(model, adapter_path: str, subfolder: Optional[str] = None):
|
167 |
"""
|
168 |
Prepare and attach LoRA adapter to MX format model.
|
169 |
Handles the special requirements of GPT-OSS MX models.
|
|
|
171 |
if not _HAS_PEFT:
|
172 |
raise RuntimeError("PEFT is required for LoRA adapters. Install with: pip install peft")
|
173 |
|
174 |
+
# Build the full path including subfolder
|
175 |
+
full_adapter_path = adapter_path
|
176 |
+
if subfolder:
|
177 |
+
print(f"[LoRA] Loading adapter from {adapter_path} (subfolder: {subfolder})")
|
178 |
+
else:
|
179 |
+
print(f"[LoRA] Loading adapter from {adapter_path}")
|
180 |
+
|
181 |
+
# Load the LoRA config with subfolder support
|
182 |
+
peft_kwargs = {"token": HF_TOKEN}
|
183 |
+
if subfolder:
|
184 |
+
peft_kwargs["subfolder"] = subfolder
|
185 |
|
186 |
+
peft_config = PeftConfig.from_pretrained(adapter_path, **peft_kwargs)
|
|
|
187 |
|
188 |
+
# Load the LoRA weights - need to check in the right location
|
189 |
from safetensors.torch import load_file
|
190 |
import os.path as osp
|
191 |
+
from huggingface_hub import hf_hub_download
|
192 |
|
193 |
+
try:
|
194 |
+
# Try to download from HF Hub with subfolder
|
195 |
+
if subfolder:
|
196 |
+
# Download the adapter weights file
|
197 |
+
try:
|
198 |
+
adapter_weights_path = hf_hub_download(
|
199 |
+
repo_id=adapter_path,
|
200 |
+
filename="adapter_model.safetensors",
|
201 |
+
subfolder=subfolder,
|
202 |
+
token=HF_TOKEN
|
203 |
+
)
|
204 |
+
adapter_weights = load_file(adapter_weights_path)
|
205 |
+
print(f"[LoRA] Loaded safetensors weights from {subfolder}")
|
206 |
+
except Exception:
|
207 |
+
# Try .bin format
|
208 |
+
adapter_weights_path = hf_hub_download(
|
209 |
+
repo_id=adapter_path,
|
210 |
+
filename="adapter_model.bin",
|
211 |
+
subfolder=subfolder,
|
212 |
+
token=HF_TOKEN
|
213 |
+
)
|
214 |
+
adapter_weights = torch.load(adapter_weights_path, map_location="cpu")
|
215 |
+
print(f"[LoRA] Loaded bin weights from {subfolder}")
|
216 |
else:
|
217 |
+
# No subfolder - try local path first, then HF Hub
|
218 |
+
local_safetensors = osp.join(adapter_path, "adapter_model.safetensors")
|
219 |
+
local_bin = osp.join(adapter_path, "adapter_model.bin")
|
220 |
+
|
221 |
+
if osp.exists(local_safetensors):
|
222 |
+
adapter_weights = load_file(local_safetensors)
|
223 |
+
print("[LoRA] Loaded local safetensors weights")
|
224 |
+
elif osp.exists(local_bin):
|
225 |
+
adapter_weights = torch.load(local_bin, map_location="cpu")
|
226 |
+
print("[LoRA] Loaded local bin weights")
|
227 |
+
else:
|
228 |
+
# Try downloading from HF Hub
|
229 |
+
try:
|
230 |
+
adapter_weights_path = hf_hub_download(
|
231 |
+
repo_id=adapter_path,
|
232 |
+
filename="adapter_model.safetensors",
|
233 |
+
token=HF_TOKEN
|
234 |
+
)
|
235 |
+
adapter_weights = load_file(adapter_weights_path)
|
236 |
+
print("[LoRA] Downloaded safetensors weights from Hub")
|
237 |
+
except Exception:
|
238 |
+
adapter_weights_path = hf_hub_download(
|
239 |
+
repo_id=adapter_path,
|
240 |
+
filename="adapter_model.bin",
|
241 |
+
token=HF_TOKEN
|
242 |
+
)
|
243 |
+
adapter_weights = torch.load(adapter_weights_path, map_location="cpu")
|
244 |
+
print("[LoRA] Downloaded bin weights from Hub")
|
245 |
+
|
246 |
+
except Exception as e:
|
247 |
+
raise FileNotFoundError(f"Could not load adapter weights: {e}")
|
248 |
|
249 |
# Convert weights for MX compatibility
|
250 |
print("[LoRA] Converting fp32 weights for MX format compatibility...")
|
|
|
259 |
model,
|
260 |
adapter_path,
|
261 |
is_trainable=False,
|
262 |
+
**peft_kwargs # This includes token and subfolder
|
|
|
263 |
)
|
264 |
|
265 |
# Manually update the adapter weights with our converted versions
|
|
|
340 |
if ADAPTER_ID:
|
341 |
try:
|
342 |
if is_mx_model:
|
343 |
+
# Use special MX-compatible LoRA loading with subfolder support
|
344 |
+
model = prepare_model_for_mx_lora(model, ADAPTER_ID, ADAPTER_SUBFOLDER)
|
345 |
else:
|
346 |
# Standard PEFT loading for non-MX models
|
347 |
if not _HAS_PEFT:
|
348 |
raise RuntimeError("PEFT is required when ADAPTER_ID is set.")
|
349 |
print(f"[Model] Loading adapter from {ADAPTER_ID} (standard mode)...")
|
350 |
+
peft_kwargs = {"token": HF_TOKEN, "is_trainable": False}
|
351 |
+
if ADAPTER_SUBFOLDER:
|
352 |
+
peft_kwargs["subfolder"] = ADAPTER_SUBFOLDER
|
353 |
+
print(f"[Model] Using subfolder: {ADAPTER_SUBFOLDER}")
|
354 |
model = PeftModel.from_pretrained(
|
355 |
model,
|
356 |
+
ADAPTER_ID,
|
357 |
+
**peft_kwargs
|
|
|
358 |
)
|
359 |
|
360 |
print("[Model] Successfully loaded with LoRA adapter")
|