Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -32,8 +32,6 @@ def load_lora_state(lora_model_name):
|
|
32 |
with open(config_path, 'r') as f:
|
33 |
lora_config = json.load(f)
|
34 |
|
35 |
-
scale = lora_config['lora_alpha'] / lora_config['r']
|
36 |
-
|
37 |
# Download adapter weights
|
38 |
try:
|
39 |
adapter_path = hf_hub_download(
|
@@ -52,18 +50,18 @@ def load_lora_state(lora_model_name):
|
|
52 |
)
|
53 |
lora_state = torch.load(adapter_path, map_location='cpu')
|
54 |
|
55 |
-
return lora_state,
|
56 |
|
57 |
def find_lora_weights(lora_state, key):
|
58 |
"""Find corresponding LoRA A and B weights for a given key"""
|
59 |
lora_A = None
|
60 |
lora_B = None
|
61 |
|
62 |
-
# Remove .weight suffix
|
63 |
-
clean_key = key.
|
64 |
|
65 |
for lora_key, lora_weight in lora_state.items():
|
66 |
-
if clean_key in lora_key
|
67 |
if 'lora_A' in lora_key:
|
68 |
lora_A = lora_weight
|
69 |
elif 'lora_B' in lora_key:
|
@@ -118,17 +116,27 @@ def download_and_upload_non_model_files(base_model_name, output_repo_name):
|
|
118 |
shutil.rmtree(temp_config_dir, ignore_errors=True)
|
119 |
|
120 |
def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo_name,
|
121 |
-
multiplicative_lora, progress=gr.Progress()):
|
122 |
temp_lora_dir = None
|
123 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
login(hf_token)
|
125 |
|
126 |
progress(0.1, desc="Loading LoRA adapter...")
|
127 |
info_fn("Loading LoRA adapter...")
|
128 |
|
129 |
# Load LoRA state (this downloads the adapter)
|
130 |
-
lora_state,
|
131 |
-
|
|
|
|
|
|
|
|
|
132 |
|
133 |
progress(0.2, desc="Creating output repository...")
|
134 |
|
@@ -157,6 +165,18 @@ def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo
|
|
157 |
|
158 |
info_fn(f"Found {len(shard_files)} model shards to process")
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
merged_tensors = 0
|
161 |
total_shards = len(shard_files)
|
162 |
|
@@ -194,29 +214,47 @@ def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo
|
|
194 |
lora_A, lora_B = find_lora_weights(lora_state, key)
|
195 |
|
196 |
if lora_A is not None and lora_B is not None:
|
197 |
-
|
198 |
-
info_fn(f"Merging {lora_type} LoRA weights for {key}")
|
199 |
shard_merged_count += 1
|
200 |
merged_tensors += 1
|
201 |
|
202 |
# Convert to float32 for computation
|
203 |
original_dtype = tensor.dtype
|
204 |
-
|
205 |
-
|
206 |
-
lora_B_f32 = lora_B.to(torch.float32)
|
207 |
|
208 |
if multiplicative_lora:
|
209 |
-
#
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
else:
|
212 |
-
#
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
# Convert back to original dtype
|
216 |
-
tensor =
|
217 |
|
218 |
# Clean up intermediate tensors
|
219 |
-
del
|
220 |
if torch.cuda.is_available():
|
221 |
torch.cuda.empty_cache()
|
222 |
|
@@ -246,7 +284,7 @@ def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo
|
|
246 |
|
247 |
progress(1.0, desc="Upload completed!")
|
248 |
|
249 |
-
success_msg = f"β Successfully merged and uploaded model!\nModel URL: https://huggingface.co/{output_repo_name}\nProcessed {total_shards} shards\nMerged {merged_tensors} layers with LoRA weights"
|
250 |
info_fn("Merge completed successfully!")
|
251 |
|
252 |
return success_msg
|
@@ -272,15 +310,23 @@ This tool merges LoRA (Low-Rank Adaptation) adapters with base models using a me
|
|
272 |
- **Streaming Processing**: Downloads β Processes β Uploads β Deletes each shard sequentially
|
273 |
- **Automatic Cleanup**: Temporary files are automatically removed after processing
|
274 |
- **Progress Tracking**: Real-time status updates throughout the merge process
|
275 |
-
- **Advanced Options**: Multiplicative LoRA
|
276 |
"""
|
277 |
|
278 |
DETAILS_TEXT = """
|
279 |
### How It Works
|
280 |
-
LoRA enables efficient fine-tuning by adding small adapter weights rather than modifying the entire model. This tool
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
-
|
283 |
-
|
|
|
|
|
|
|
284 |
|
285 |
### Memory Efficiency
|
286 |
- **Traditional approach**: Loads entire model (~15GB+ for 7B parameter models)
|
@@ -328,10 +374,23 @@ with gr.Blocks(title="Memory-Efficient LoRA Merge", theme=gr.themes.Soft()) as d
|
|
328 |
)
|
329 |
|
330 |
gr.Markdown("### Advanced Options")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
multiplicative_lora = gr.Checkbox(
|
332 |
label="Multiplicative LoRA",
|
333 |
value=False,
|
334 |
-
info="Apply
|
|
|
|
|
|
|
|
|
|
|
335 |
)
|
336 |
|
337 |
with gr.Column(scale=1):
|
@@ -348,7 +407,8 @@ with gr.Blocks(title="Memory-Efficient LoRA Merge", theme=gr.themes.Soft()) as d
|
|
348 |
|
349 |
submit_btn.click(
|
350 |
fn=merge_lora_efficient,
|
351 |
-
inputs=[hf_token, base_model_name, lora_model_name, output_repo_name,
|
|
|
352 |
outputs=output_text
|
353 |
)
|
354 |
|
|
|
32 |
with open(config_path, 'r') as f:
|
33 |
lora_config = json.load(f)
|
34 |
|
|
|
|
|
35 |
# Download adapter weights
|
36 |
try:
|
37 |
adapter_path = hf_hub_download(
|
|
|
50 |
)
|
51 |
lora_state = torch.load(adapter_path, map_location='cpu')
|
52 |
|
53 |
+
return lora_state, lora_config, temp_lora_dir
|
54 |
|
55 |
def find_lora_weights(lora_state, key):
|
56 |
"""Find corresponding LoRA A and B weights for a given key"""
|
57 |
lora_A = None
|
58 |
lora_B = None
|
59 |
|
60 |
+
# Remove .weight suffix for matching
|
61 |
+
clean_key = key.strip('.weight')
|
62 |
|
63 |
for lora_key, lora_weight in lora_state.items():
|
64 |
+
if clean_key in lora_key:
|
65 |
if 'lora_A' in lora_key:
|
66 |
lora_A = lora_weight
|
67 |
elif 'lora_B' in lora_key:
|
|
|
116 |
shutil.rmtree(temp_config_dir, ignore_errors=True)
|
117 |
|
118 |
def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo_name,
|
119 |
+
scale_factor, multiplicative_lora, inverse_lora, progress=gr.Progress()):
|
120 |
temp_lora_dir = None
|
121 |
try:
|
122 |
+
# Validate scale factor
|
123 |
+
if not (0 < scale_factor < 2):
|
124 |
+
error_msg = "Scale factor must be in the range (0, 2)"
|
125 |
+
warning_fn(error_msg)
|
126 |
+
return f"β Error: {error_msg}"
|
127 |
+
|
128 |
login(hf_token)
|
129 |
|
130 |
progress(0.1, desc="Loading LoRA adapter...")
|
131 |
info_fn("Loading LoRA adapter...")
|
132 |
|
133 |
# Load LoRA state (this downloads the adapter)
|
134 |
+
lora_state, lora_config, temp_lora_dir = load_lora_state(lora_model_name)
|
135 |
+
|
136 |
+
# Calculate scale with user factor
|
137 |
+
base_scale = lora_config['lora_alpha'] / lora_config['r']
|
138 |
+
scale = base_scale * scale_factor
|
139 |
+
info_fn(f"Using LoRA scale: {scale} (base: {base_scale:.3f} Γ factor: {scale_factor})")
|
140 |
|
141 |
progress(0.2, desc="Creating output repository...")
|
142 |
|
|
|
165 |
|
166 |
info_fn(f"Found {len(shard_files)} model shards to process")
|
167 |
|
168 |
+
# Determine merge mode
|
169 |
+
if multiplicative_lora and inverse_lora:
|
170 |
+
merge_mode = "Multiplicative Inverse"
|
171 |
+
elif multiplicative_lora:
|
172 |
+
merge_mode = "Multiplicative"
|
173 |
+
elif inverse_lora:
|
174 |
+
merge_mode = "Additive Inverse"
|
175 |
+
else:
|
176 |
+
merge_mode = "Additive"
|
177 |
+
|
178 |
+
info_fn(f"Merge mode: {merge_mode}")
|
179 |
+
|
180 |
merged_tensors = 0
|
181 |
total_shards = len(shard_files)
|
182 |
|
|
|
214 |
lora_A, lora_B = find_lora_weights(lora_state, key)
|
215 |
|
216 |
if lora_A is not None and lora_B is not None:
|
217 |
+
info_fn(f"Merging {merge_mode} LoRA weights for {key}")
|
|
|
218 |
shard_merged_count += 1
|
219 |
merged_tensors += 1
|
220 |
|
221 |
# Convert to float32 for computation
|
222 |
original_dtype = tensor.dtype
|
223 |
+
tensor = tensor.to(torch.float32)
|
224 |
+
lora_delta = scale * lora_B.to(torch.float32) @ lora_A.to(torch.float32)
|
|
|
225 |
|
226 |
if multiplicative_lora:
|
227 |
+
# Validate dimensions for multiplicative LoRA
|
228 |
+
if lora_delta.shape[0] != lora_delta.shape[1]:
|
229 |
+
raise ValueError(f"Multiplicative LoRA requires square delta matrix for {key}: got shape {lora_delta.shape}")
|
230 |
+
if lora_delta.shape[-1] != tensor.shape[-2]:
|
231 |
+
raise ValueError(f"Multiplicative LoRA dimension mismatch for {key}: {lora_delta.shape} vs {tensor.shape}")
|
232 |
+
|
233 |
+
if inverse_lora:
|
234 |
+
# Inverse multiplicative: tensor = (I + lora_delta)^(-1) @ tensor
|
235 |
+
identity = torch.eye(lora_delta.shape[0], device=lora_delta.device, dtype=torch.float32)
|
236 |
+
inverse_matrix = torch.linalg.inv(identity + lora_delta)
|
237 |
+
tensor = inverse_matrix @ tensor
|
238 |
+
else:
|
239 |
+
# Forward multiplicative: tensor = (I + lora_delta) @ tensor
|
240 |
+
tensor += lora_delta @ tensor
|
241 |
else:
|
242 |
+
# Validate dimensions for additive LoRA
|
243 |
+
if lora_delta.shape != tensor.shape:
|
244 |
+
raise ValueError(f"Additive LoRA dimension mismatch for {key}: {lora_delta.shape} vs {tensor.shape}")
|
245 |
+
|
246 |
+
if inverse_lora:
|
247 |
+
# Inverse additive: tensor = tensor - lora_delta
|
248 |
+
tensor -= lora_delta
|
249 |
+
else:
|
250 |
+
# Forward additive: tensor = tensor + lora_delta
|
251 |
+
tensor += lora_delta
|
252 |
|
253 |
# Convert back to original dtype
|
254 |
+
tensor = tensor.to(original_dtype)
|
255 |
|
256 |
# Clean up intermediate tensors
|
257 |
+
del lora_delta
|
258 |
if torch.cuda.is_available():
|
259 |
torch.cuda.empty_cache()
|
260 |
|
|
|
284 |
|
285 |
progress(1.0, desc="Upload completed!")
|
286 |
|
287 |
+
success_msg = f"β Successfully merged and uploaded model!\nModel URL: https://huggingface.co/{output_repo_name}\nMerge mode: {merge_mode}\nScale factor: {scale_factor}\nProcessed {total_shards} shards\nMerged {merged_tensors} layers with LoRA weights"
|
288 |
info_fn("Merge completed successfully!")
|
289 |
|
290 |
return success_msg
|
|
|
310 |
- **Streaming Processing**: Downloads β Processes β Uploads β Deletes each shard sequentially
|
311 |
- **Automatic Cleanup**: Temporary files are automatically removed after processing
|
312 |
- **Progress Tracking**: Real-time status updates throughout the merge process
|
313 |
+
- **Advanced Options**: Multiplicative LoRA, inverse merging, and custom scale factors
|
314 |
"""
|
315 |
|
316 |
DETAILS_TEXT = """
|
317 |
### How It Works
|
318 |
+
LoRA enables efficient fine-tuning by adding small adapter weights rather than modifying the entire model. This tool supports four merge modes:
|
319 |
+
|
320 |
+
- **Additive LoRA**: `W_new = W + scale Γ B @ A`
|
321 |
+
- **Additive Inverse**: `W_new = W - scale Γ B @ A` (removes LoRA effect)
|
322 |
+
- **Multiplicative LoRA**: `W_new = W + scale Γ B @ A @ W`
|
323 |
+
- **Multiplicative Inverse**: `W_new = (I + scale Γ B @ A)^(-1) @ W`
|
324 |
|
325 |
+
### Scale Factor
|
326 |
+
The scale factor (0 < scale < 2) controls the strength of the LoRA merge:
|
327 |
+
- **1.0**: Full strength (default)
|
328 |
+
- **0.5**: Half strength
|
329 |
+
- **1.5**: 150% strength
|
330 |
|
331 |
### Memory Efficiency
|
332 |
- **Traditional approach**: Loads entire model (~15GB+ for 7B parameter models)
|
|
|
374 |
)
|
375 |
|
376 |
gr.Markdown("### Advanced Options")
|
377 |
+
scale_factor = gr.Slider(
|
378 |
+
minimum=0.01,
|
379 |
+
maximum=1.99,
|
380 |
+
value=1.0,
|
381 |
+
step=0.01,
|
382 |
+
label="Scale Factor",
|
383 |
+
info="Strength of LoRA merge (0 < scale < 2)"
|
384 |
+
)
|
385 |
multiplicative_lora = gr.Checkbox(
|
386 |
label="Multiplicative LoRA",
|
387 |
value=False,
|
388 |
+
info="Apply multiplicative LoRA instead of additive LoRA"
|
389 |
+
)
|
390 |
+
inverse_lora = gr.Checkbox(
|
391 |
+
label="Inverse Merge",
|
392 |
+
value=False,
|
393 |
+
info="Apply inverse operation (subtract/invert the LoRA effect)"
|
394 |
)
|
395 |
|
396 |
with gr.Column(scale=1):
|
|
|
407 |
|
408 |
submit_btn.click(
|
409 |
fn=merge_lora_efficient,
|
410 |
+
inputs=[hf_token, base_model_name, lora_model_name, output_repo_name,
|
411 |
+
scale_factor, multiplicative_lora, inverse_lora],
|
412 |
outputs=output_text
|
413 |
)
|
414 |
|