jukofyork commited on
Commit
747bd3e
Β·
verified Β·
1 Parent(s): e60537d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -27
app.py CHANGED
@@ -32,8 +32,6 @@ def load_lora_state(lora_model_name):
32
  with open(config_path, 'r') as f:
33
  lora_config = json.load(f)
34
 
35
- scale = lora_config['lora_alpha'] / lora_config['r']
36
-
37
  # Download adapter weights
38
  try:
39
  adapter_path = hf_hub_download(
@@ -52,18 +50,18 @@ def load_lora_state(lora_model_name):
52
  )
53
  lora_state = torch.load(adapter_path, map_location='cpu')
54
 
55
- return lora_state, scale, temp_lora_dir
56
 
57
  def find_lora_weights(lora_state, key):
58
  """Find corresponding LoRA A and B weights for a given key"""
59
  lora_A = None
60
  lora_B = None
61
 
62
- # Remove .weight suffix and handle potential prefixes
63
- clean_key = key.replace('.weight', '')
64
 
65
  for lora_key, lora_weight in lora_state.items():
66
- if clean_key in lora_key or clean_key.replace('language_model.', '') in lora_key:
67
  if 'lora_A' in lora_key:
68
  lora_A = lora_weight
69
  elif 'lora_B' in lora_key:
@@ -118,17 +116,27 @@ def download_and_upload_non_model_files(base_model_name, output_repo_name):
118
  shutil.rmtree(temp_config_dir, ignore_errors=True)
119
 
120
  def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo_name,
121
- multiplicative_lora, progress=gr.Progress()):
122
  temp_lora_dir = None
123
  try:
 
 
 
 
 
 
124
  login(hf_token)
125
 
126
  progress(0.1, desc="Loading LoRA adapter...")
127
  info_fn("Loading LoRA adapter...")
128
 
129
  # Load LoRA state (this downloads the adapter)
130
- lora_state, scale, temp_lora_dir = load_lora_state(lora_model_name)
131
- info_fn(f"Using LoRA scale: {scale}")
 
 
 
 
132
 
133
  progress(0.2, desc="Creating output repository...")
134
 
@@ -157,6 +165,18 @@ def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo
157
 
158
  info_fn(f"Found {len(shard_files)} model shards to process")
159
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  merged_tensors = 0
161
  total_shards = len(shard_files)
162
 
@@ -194,29 +214,47 @@ def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo
194
  lora_A, lora_B = find_lora_weights(lora_state, key)
195
 
196
  if lora_A is not None and lora_B is not None:
197
- lora_type = "Multiplicative" if multiplicative_lora else "Additive"
198
- info_fn(f"Merging {lora_type} LoRA weights for {key}")
199
  shard_merged_count += 1
200
  merged_tensors += 1
201
 
202
  # Convert to float32 for computation
203
  original_dtype = tensor.dtype
204
- tensor_f32 = tensor.to(torch.float32)
205
- lora_A_f32 = lora_A.to(torch.float32)
206
- lora_B_f32 = lora_B.to(torch.float32)
207
 
208
  if multiplicative_lora:
209
- # Apply Multiplicative-LoRA: W = W + scale * B @ A @ W
210
- tensor_f32 += scale * lora_B_f32 @ lora_A_f32 @ tensor_f32
 
 
 
 
 
 
 
 
 
 
 
 
211
  else:
212
- # Apply standard LoRA: W = W + scale * B @ A
213
- tensor_f32 += scale * lora_B_f32 @ lora_A_f32
 
 
 
 
 
 
 
 
214
 
215
  # Convert back to original dtype
216
- tensor = tensor_f32.to(original_dtype)
217
 
218
  # Clean up intermediate tensors
219
- del tensor_f32, lora_A_f32, lora_B_f32
220
  if torch.cuda.is_available():
221
  torch.cuda.empty_cache()
222
 
@@ -246,7 +284,7 @@ def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo
246
 
247
  progress(1.0, desc="Upload completed!")
248
 
249
- success_msg = f"βœ“ Successfully merged and uploaded model!\nModel URL: https://huggingface.co/{output_repo_name}\nProcessed {total_shards} shards\nMerged {merged_tensors} layers with LoRA weights"
250
  info_fn("Merge completed successfully!")
251
 
252
  return success_msg
@@ -272,15 +310,23 @@ This tool merges LoRA (Low-Rank Adaptation) adapters with base models using a me
272
  - **Streaming Processing**: Downloads β†’ Processes β†’ Uploads β†’ Deletes each shard sequentially
273
  - **Automatic Cleanup**: Temporary files are automatically removed after processing
274
  - **Progress Tracking**: Real-time status updates throughout the merge process
275
- - **Advanced Options**: Multiplicative LoRA support
276
  """
277
 
278
  DETAILS_TEXT = """
279
  ### How It Works
280
- LoRA enables efficient fine-tuning by adding small adapter weights rather than modifying the entire model. This tool applies the LoRA transformation:
 
 
 
 
 
281
 
282
- - **Standard Additive-LoRA**: `W_new = W + scale Γ— B^T @ A`
283
- - **Multiplicative LoRA**: `W_new = W + scale Γ— B^T @ A @ W`
 
 
 
284
 
285
  ### Memory Efficiency
286
  - **Traditional approach**: Loads entire model (~15GB+ for 7B parameter models)
@@ -328,10 +374,23 @@ with gr.Blocks(title="Memory-Efficient LoRA Merge", theme=gr.themes.Soft()) as d
328
  )
329
 
330
  gr.Markdown("### Advanced Options")
 
 
 
 
 
 
 
 
331
  multiplicative_lora = gr.Checkbox(
332
  label="Multiplicative LoRA",
333
  value=False,
334
- info="Apply a \"multiplicative-LoRA\" instead of a standard \"additive-LoRA\""
 
 
 
 
 
335
  )
336
 
337
  with gr.Column(scale=1):
@@ -348,7 +407,8 @@ with gr.Blocks(title="Memory-Efficient LoRA Merge", theme=gr.themes.Soft()) as d
348
 
349
  submit_btn.click(
350
  fn=merge_lora_efficient,
351
- inputs=[hf_token, base_model_name, lora_model_name, output_repo_name, multiplicative_lora],
 
352
  outputs=output_text
353
  )
354
 
 
32
  with open(config_path, 'r') as f:
33
  lora_config = json.load(f)
34
 
 
 
35
  # Download adapter weights
36
  try:
37
  adapter_path = hf_hub_download(
 
50
  )
51
  lora_state = torch.load(adapter_path, map_location='cpu')
52
 
53
+ return lora_state, lora_config, temp_lora_dir
54
 
55
  def find_lora_weights(lora_state, key):
56
  """Find corresponding LoRA A and B weights for a given key"""
57
  lora_A = None
58
  lora_B = None
59
 
60
+ # Remove .weight suffix for matching
61
+ clean_key = key.strip('.weight')
62
 
63
  for lora_key, lora_weight in lora_state.items():
64
+ if clean_key in lora_key:
65
  if 'lora_A' in lora_key:
66
  lora_A = lora_weight
67
  elif 'lora_B' in lora_key:
 
116
  shutil.rmtree(temp_config_dir, ignore_errors=True)
117
 
118
  def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo_name,
119
+ scale_factor, multiplicative_lora, inverse_lora, progress=gr.Progress()):
120
  temp_lora_dir = None
121
  try:
122
+ # Validate scale factor
123
+ if not (0 < scale_factor < 2):
124
+ error_msg = "Scale factor must be in the range (0, 2)"
125
+ warning_fn(error_msg)
126
+ return f"βœ— Error: {error_msg}"
127
+
128
  login(hf_token)
129
 
130
  progress(0.1, desc="Loading LoRA adapter...")
131
  info_fn("Loading LoRA adapter...")
132
 
133
  # Load LoRA state (this downloads the adapter)
134
+ lora_state, lora_config, temp_lora_dir = load_lora_state(lora_model_name)
135
+
136
+ # Calculate scale with user factor
137
+ base_scale = lora_config['lora_alpha'] / lora_config['r']
138
+ scale = base_scale * scale_factor
139
+ info_fn(f"Using LoRA scale: {scale} (base: {base_scale:.3f} Γ— factor: {scale_factor})")
140
 
141
  progress(0.2, desc="Creating output repository...")
142
 
 
165
 
166
  info_fn(f"Found {len(shard_files)} model shards to process")
167
 
168
+ # Determine merge mode
169
+ if multiplicative_lora and inverse_lora:
170
+ merge_mode = "Multiplicative Inverse"
171
+ elif multiplicative_lora:
172
+ merge_mode = "Multiplicative"
173
+ elif inverse_lora:
174
+ merge_mode = "Additive Inverse"
175
+ else:
176
+ merge_mode = "Additive"
177
+
178
+ info_fn(f"Merge mode: {merge_mode}")
179
+
180
  merged_tensors = 0
181
  total_shards = len(shard_files)
182
 
 
214
  lora_A, lora_B = find_lora_weights(lora_state, key)
215
 
216
  if lora_A is not None and lora_B is not None:
217
+ info_fn(f"Merging {merge_mode} LoRA weights for {key}")
 
218
  shard_merged_count += 1
219
  merged_tensors += 1
220
 
221
  # Convert to float32 for computation
222
  original_dtype = tensor.dtype
223
+ tensor = tensor.to(torch.float32)
224
+ lora_delta = scale * lora_B.to(torch.float32) @ lora_A.to(torch.float32)
 
225
 
226
  if multiplicative_lora:
227
+ # Validate dimensions for multiplicative LoRA
228
+ if lora_delta.shape[0] != lora_delta.shape[1]:
229
+ raise ValueError(f"Multiplicative LoRA requires square delta matrix for {key}: got shape {lora_delta.shape}")
230
+ if lora_delta.shape[-1] != tensor.shape[-2]:
231
+ raise ValueError(f"Multiplicative LoRA dimension mismatch for {key}: {lora_delta.shape} vs {tensor.shape}")
232
+
233
+ if inverse_lora:
234
+ # Inverse multiplicative: tensor = (I + lora_delta)^(-1) @ tensor
235
+ identity = torch.eye(lora_delta.shape[0], device=lora_delta.device, dtype=torch.float32)
236
+ inverse_matrix = torch.linalg.inv(identity + lora_delta)
237
+ tensor = inverse_matrix @ tensor
238
+ else:
239
+ # Forward multiplicative: tensor = (I + lora_delta) @ tensor
240
+ tensor += lora_delta @ tensor
241
  else:
242
+ # Validate dimensions for additive LoRA
243
+ if lora_delta.shape != tensor.shape:
244
+ raise ValueError(f"Additive LoRA dimension mismatch for {key}: {lora_delta.shape} vs {tensor.shape}")
245
+
246
+ if inverse_lora:
247
+ # Inverse additive: tensor = tensor - lora_delta
248
+ tensor -= lora_delta
249
+ else:
250
+ # Forward additive: tensor = tensor + lora_delta
251
+ tensor += lora_delta
252
 
253
  # Convert back to original dtype
254
+ tensor = tensor.to(original_dtype)
255
 
256
  # Clean up intermediate tensors
257
+ del lora_delta
258
  if torch.cuda.is_available():
259
  torch.cuda.empty_cache()
260
 
 
284
 
285
  progress(1.0, desc="Upload completed!")
286
 
287
+ success_msg = f"βœ“ Successfully merged and uploaded model!\nModel URL: https://huggingface.co/{output_repo_name}\nMerge mode: {merge_mode}\nScale factor: {scale_factor}\nProcessed {total_shards} shards\nMerged {merged_tensors} layers with LoRA weights"
288
  info_fn("Merge completed successfully!")
289
 
290
  return success_msg
 
310
  - **Streaming Processing**: Downloads β†’ Processes β†’ Uploads β†’ Deletes each shard sequentially
311
  - **Automatic Cleanup**: Temporary files are automatically removed after processing
312
  - **Progress Tracking**: Real-time status updates throughout the merge process
313
+ - **Advanced Options**: Multiplicative LoRA, inverse merging, and custom scale factors
314
  """
315
 
316
  DETAILS_TEXT = """
317
  ### How It Works
318
+ LoRA enables efficient fine-tuning by adding small adapter weights rather than modifying the entire model. This tool supports four merge modes:
319
+
320
+ - **Additive LoRA**: `W_new = W + scale Γ— B @ A`
321
+ - **Additive Inverse**: `W_new = W - scale Γ— B @ A` (removes LoRA effect)
322
+ - **Multiplicative LoRA**: `W_new = W + scale Γ— B @ A @ W`
323
+ - **Multiplicative Inverse**: `W_new = (I + scale Γ— B @ A)^(-1) @ W`
324
 
325
+ ### Scale Factor
326
+ The scale factor (0 < scale < 2) controls the strength of the LoRA merge:
327
+ - **1.0**: Full strength (default)
328
+ - **0.5**: Half strength
329
+ - **1.5**: 150% strength
330
 
331
  ### Memory Efficiency
332
  - **Traditional approach**: Loads entire model (~15GB+ for 7B parameter models)
 
374
  )
375
 
376
  gr.Markdown("### Advanced Options")
377
+ scale_factor = gr.Slider(
378
+ minimum=0.01,
379
+ maximum=1.99,
380
+ value=1.0,
381
+ step=0.01,
382
+ label="Scale Factor",
383
+ info="Strength of LoRA merge (0 < scale < 2)"
384
+ )
385
  multiplicative_lora = gr.Checkbox(
386
  label="Multiplicative LoRA",
387
  value=False,
388
+ info="Apply multiplicative LoRA instead of additive LoRA"
389
+ )
390
+ inverse_lora = gr.Checkbox(
391
+ label="Inverse Merge",
392
+ value=False,
393
+ info="Apply inverse operation (subtract/invert the LoRA effect)"
394
  )
395
 
396
  with gr.Column(scale=1):
 
407
 
408
  submit_btn.click(
409
  fn=merge_lora_efficient,
410
+ inputs=[hf_token, base_model_name, lora_model_name, output_repo_name,
411
+ scale_factor, multiplicative_lora, inverse_lora],
412
  outputs=output_text
413
  )
414