AbstractPhil commited on
Commit
3248cf5
·
1 Parent(s): 5d33a3c

peft loading fixed

Browse files
Files changed (1) hide show
  1. app.py +80 -22
app.py CHANGED
@@ -41,8 +41,8 @@ except ImportError:
41
  # -----------------------
42
  # MX format uses special dtypes - we need to handle this properly
43
  MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
44
- ADAPTER_ID = os.getenv("ADAPTER_ID", "AbstractPhil/mirel-gpt-oss-20b") or None
45
- ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "checkpoints/checkpoint-516") or None
46
  ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
47
  SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.")
48
  MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
@@ -163,7 +163,7 @@ def convert_fp32_lora_to_mx_compatible(lora_state_dict: Dict[str, torch.Tensor])
163
 
164
  return converted
165
 
166
- def prepare_model_for_mx_lora(model, adapter_path: str):
167
  """
168
  Prepare and attach LoRA adapter to MX format model.
169
  Handles the special requirements of GPT-OSS MX models.
@@ -171,24 +171,80 @@ def prepare_model_for_mx_lora(model, adapter_path: str):
171
  if not _HAS_PEFT:
172
  raise RuntimeError("PEFT is required for LoRA adapters. Install with: pip install peft")
173
 
174
- print(f"[LoRA] Loading adapter from {adapter_path}")
 
 
 
 
 
 
 
 
 
 
175
 
176
- # Load the LoRA config
177
- peft_config = PeftConfig.from_pretrained(adapter_path, token=HF_TOKEN)
178
 
179
- # Load the LoRA weights
180
  from safetensors.torch import load_file
181
  import os.path as osp
 
182
 
183
- adapter_weights_path = osp.join(adapter_path, "adapter_model.safetensors")
184
- if not osp.exists(adapter_weights_path):
185
- adapter_weights_path = osp.join(adapter_path, "adapter_model.bin")
186
- if osp.exists(adapter_weights_path):
187
- adapter_weights = torch.load(adapter_weights_path, map_location="cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  else:
189
- raise FileNotFoundError(f"No adapter weights found at {adapter_path}")
190
- else:
191
- adapter_weights = load_file(adapter_weights_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  # Convert weights for MX compatibility
194
  print("[LoRA] Converting fp32 weights for MX format compatibility...")
@@ -203,8 +259,7 @@ def prepare_model_for_mx_lora(model, adapter_path: str):
203
  model,
204
  adapter_path,
205
  is_trainable=False,
206
- token=HF_TOKEN,
207
- # Don't specify torch_dtype here - let it match the base model
208
  )
209
 
210
  # Manually update the adapter weights with our converted versions
@@ -285,18 +340,21 @@ def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
285
  if ADAPTER_ID:
286
  try:
287
  if is_mx_model:
288
- # Use special MX-compatible LoRA loading
289
- model = prepare_model_for_mx_lora(model, ADAPTER_ID)
290
  else:
291
  # Standard PEFT loading for non-MX models
292
  if not _HAS_PEFT:
293
  raise RuntimeError("PEFT is required when ADAPTER_ID is set.")
294
  print(f"[Model] Loading adapter from {ADAPTER_ID} (standard mode)...")
 
 
 
 
295
  model = PeftModel.from_pretrained(
296
  model,
297
- ADAPTER_ID,
298
- is_trainable=False,
299
- token=HF_TOKEN
300
  )
301
 
302
  print("[Model] Successfully loaded with LoRA adapter")
 
41
  # -----------------------
42
  # MX format uses special dtypes - we need to handle this properly
43
  MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
44
+ ADAPTER_ID = os.getenv("ADAPTER_ID", "AbstractPhil/mirel-gpt-oss-20b") # Default to your adapter
45
+ ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "checkpoints/checkpoint-516") # Default to the subfolder
46
  ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
47
  SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.")
48
  MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
 
163
 
164
  return converted
165
 
166
+ def prepare_model_for_mx_lora(model, adapter_path: str, subfolder: Optional[str] = None):
167
  """
168
  Prepare and attach LoRA adapter to MX format model.
169
  Handles the special requirements of GPT-OSS MX models.
 
171
  if not _HAS_PEFT:
172
  raise RuntimeError("PEFT is required for LoRA adapters. Install with: pip install peft")
173
 
174
+ # Build the full path including subfolder
175
+ full_adapter_path = adapter_path
176
+ if subfolder:
177
+ print(f"[LoRA] Loading adapter from {adapter_path} (subfolder: {subfolder})")
178
+ else:
179
+ print(f"[LoRA] Loading adapter from {adapter_path}")
180
+
181
+ # Load the LoRA config with subfolder support
182
+ peft_kwargs = {"token": HF_TOKEN}
183
+ if subfolder:
184
+ peft_kwargs["subfolder"] = subfolder
185
 
186
+ peft_config = PeftConfig.from_pretrained(adapter_path, **peft_kwargs)
 
187
 
188
+ # Load the LoRA weights - need to check in the right location
189
  from safetensors.torch import load_file
190
  import os.path as osp
191
+ from huggingface_hub import hf_hub_download
192
 
193
+ try:
194
+ # Try to download from HF Hub with subfolder
195
+ if subfolder:
196
+ # Download the adapter weights file
197
+ try:
198
+ adapter_weights_path = hf_hub_download(
199
+ repo_id=adapter_path,
200
+ filename="adapter_model.safetensors",
201
+ subfolder=subfolder,
202
+ token=HF_TOKEN
203
+ )
204
+ adapter_weights = load_file(adapter_weights_path)
205
+ print(f"[LoRA] Loaded safetensors weights from {subfolder}")
206
+ except Exception:
207
+ # Try .bin format
208
+ adapter_weights_path = hf_hub_download(
209
+ repo_id=adapter_path,
210
+ filename="adapter_model.bin",
211
+ subfolder=subfolder,
212
+ token=HF_TOKEN
213
+ )
214
+ adapter_weights = torch.load(adapter_weights_path, map_location="cpu")
215
+ print(f"[LoRA] Loaded bin weights from {subfolder}")
216
  else:
217
+ # No subfolder - try local path first, then HF Hub
218
+ local_safetensors = osp.join(adapter_path, "adapter_model.safetensors")
219
+ local_bin = osp.join(adapter_path, "adapter_model.bin")
220
+
221
+ if osp.exists(local_safetensors):
222
+ adapter_weights = load_file(local_safetensors)
223
+ print("[LoRA] Loaded local safetensors weights")
224
+ elif osp.exists(local_bin):
225
+ adapter_weights = torch.load(local_bin, map_location="cpu")
226
+ print("[LoRA] Loaded local bin weights")
227
+ else:
228
+ # Try downloading from HF Hub
229
+ try:
230
+ adapter_weights_path = hf_hub_download(
231
+ repo_id=adapter_path,
232
+ filename="adapter_model.safetensors",
233
+ token=HF_TOKEN
234
+ )
235
+ adapter_weights = load_file(adapter_weights_path)
236
+ print("[LoRA] Downloaded safetensors weights from Hub")
237
+ except Exception:
238
+ adapter_weights_path = hf_hub_download(
239
+ repo_id=adapter_path,
240
+ filename="adapter_model.bin",
241
+ token=HF_TOKEN
242
+ )
243
+ adapter_weights = torch.load(adapter_weights_path, map_location="cpu")
244
+ print("[LoRA] Downloaded bin weights from Hub")
245
+
246
+ except Exception as e:
247
+ raise FileNotFoundError(f"Could not load adapter weights: {e}")
248
 
249
  # Convert weights for MX compatibility
250
  print("[LoRA] Converting fp32 weights for MX format compatibility...")
 
259
  model,
260
  adapter_path,
261
  is_trainable=False,
262
+ **peft_kwargs # This includes token and subfolder
 
263
  )
264
 
265
  # Manually update the adapter weights with our converted versions
 
340
  if ADAPTER_ID:
341
  try:
342
  if is_mx_model:
343
+ # Use special MX-compatible LoRA loading with subfolder support
344
+ model = prepare_model_for_mx_lora(model, ADAPTER_ID, ADAPTER_SUBFOLDER)
345
  else:
346
  # Standard PEFT loading for non-MX models
347
  if not _HAS_PEFT:
348
  raise RuntimeError("PEFT is required when ADAPTER_ID is set.")
349
  print(f"[Model] Loading adapter from {ADAPTER_ID} (standard mode)...")
350
+ peft_kwargs = {"token": HF_TOKEN, "is_trainable": False}
351
+ if ADAPTER_SUBFOLDER:
352
+ peft_kwargs["subfolder"] = ADAPTER_SUBFOLDER
353
+ print(f"[Model] Using subfolder: {ADAPTER_SUBFOLDER}")
354
  model = PeftModel.from_pretrained(
355
  model,
356
+ ADAPTER_ID,
357
+ **peft_kwargs
 
358
  )
359
 
360
  print("[Model] Successfully loaded with LoRA adapter")