Spaces:
Running
Running
import gc | |
import gradio as gr | |
import torch | |
from huggingface_hub import hf_hub_download, HfApi, login, list_repo_files | |
from safetensors import safe_open | |
from safetensors.torch import save_file, load_file | |
import os | |
import shutil | |
import json | |
api = HfApi() | |
def info_fn(text): | |
gr.Info(text) | |
def warning_fn(text): | |
gr.Warning(text) | |
def load_lora_state(lora_model_name): | |
"""Download and load LoRA adapter weights""" | |
temp_lora_dir = "/tmp/lora_adapter" | |
os.makedirs(temp_lora_dir, exist_ok=True) | |
# Download adapter config | |
config_path = hf_hub_download( | |
repo_id=lora_model_name, | |
filename="adapter_config.json", | |
local_dir=temp_lora_dir, | |
local_dir_use_symlinks=False | |
) | |
with open(config_path, 'r') as f: | |
lora_config = json.load(f) | |
# Download adapter weights | |
try: | |
adapter_path = hf_hub_download( | |
repo_id=lora_model_name, | |
filename="adapter_model.safetensors", | |
local_dir=temp_lora_dir, | |
local_dir_use_symlinks=False | |
) | |
lora_state = load_file(adapter_path, device='cpu') | |
except: | |
adapter_path = hf_hub_download( | |
repo_id=lora_model_name, | |
filename="adapter_model.bin", | |
local_dir=temp_lora_dir, | |
local_dir_use_symlinks=False | |
) | |
lora_state = torch.load(adapter_path, map_location='cpu') | |
return lora_state, lora_config, temp_lora_dir | |
def find_lora_weights(lora_state, key): | |
"""Find corresponding LoRA A and B weights for a given key""" | |
lora_A = None | |
lora_B = None | |
# Remove .weight suffix for matching | |
clean_key = key.strip('.weight') | |
for lora_key, lora_weight in lora_state.items(): | |
if clean_key in lora_key: | |
if 'lora_A' in lora_key: | |
lora_A = lora_weight | |
elif 'lora_B' in lora_key: | |
lora_B = lora_weight | |
# Both should be None or both should have values | |
if (lora_A is None) != (lora_B is None): | |
return None, None | |
return lora_A, lora_B | |
def download_and_upload_non_model_files(base_model_name, output_repo_name): | |
"""Download and upload non-model files (config, tokenizer, etc.)""" | |
temp_config_dir = "/tmp/config_files" | |
os.makedirs(temp_config_dir, exist_ok=True) | |
try: | |
# List all files in the repository | |
files = list_repo_files(repo_id=base_model_name) | |
# Filter non-model files | |
non_model_files = [ | |
f for f in files | |
if not (f.startswith('model') and f.endswith('.safetensors')) | |
] | |
# Download and upload each non-model file | |
for filename in non_model_files: | |
if filename.endswith(('.gguf', '.bin')) and 'model' in filename: | |
continue # Skip other model formats | |
try: | |
file_path = hf_hub_download( | |
repo_id=base_model_name, | |
filename=filename, | |
local_dir=temp_config_dir, | |
local_dir_use_symlinks=False | |
) | |
# Upload to output repo | |
api.upload_file( | |
path_or_fileobj=file_path, | |
path_in_repo=filename, | |
repo_id=output_repo_name, | |
repo_type="model" | |
) | |
except Exception as e: | |
info_fn(f"Skipping {filename}: {e}") | |
finally: | |
shutil.rmtree(temp_config_dir, ignore_errors=True) | |
def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo_name, | |
scale_factor, multiplicative_lora, inverse_lora, progress=gr.Progress()): | |
temp_lora_dir = None | |
try: | |
# Validate scale factor | |
if not (0 < scale_factor < 2): | |
error_msg = "Scale factor must be in the range (0, 2)" | |
warning_fn(error_msg) | |
return f"β Error: {error_msg}" | |
login(hf_token) | |
progress(0.1, desc="Loading LoRA adapter...") | |
info_fn("Loading LoRA adapter...") | |
# Load LoRA state (this downloads the adapter) | |
lora_state, lora_config, temp_lora_dir = load_lora_state(lora_model_name) | |
# Calculate scale with user factor | |
base_scale = lora_config['lora_alpha'] / lora_config['r'] | |
scale = base_scale * scale_factor | |
info_fn(f"Using LoRA scale: {scale} (base: {base_scale:.3f} Γ factor: {scale_factor})") | |
progress(0.2, desc="Creating output repository...") | |
# Create repository | |
try: | |
repo_url = api.create_repo(repo_id=output_repo_name, exist_ok=True) | |
info_fn(f"Repository created/updated: {repo_url}") | |
except Exception as e: | |
warning_fn(f"Repository might already exist: {e}") | |
progress(0.3, desc="Uploading configuration files...") | |
info_fn("Uploading configuration files...") | |
# Download and upload non-model files | |
download_and_upload_non_model_files(base_model_name, output_repo_name) | |
progress(0.4, desc="Finding model shards...") | |
info_fn("Finding model shards...") | |
# Get list of all safetensors files | |
all_files = list_repo_files(repo_id=base_model_name) | |
shard_files = [f for f in all_files if f.startswith('model') and f.endswith('.safetensors')] | |
if not shard_files: | |
raise FileNotFoundError("No model safetensors files found in the repository") | |
info_fn(f"Found {len(shard_files)} model shards to process") | |
# Determine merge mode | |
if multiplicative_lora and inverse_lora: | |
merge_mode = "Multiplicative Inverse" | |
elif multiplicative_lora: | |
merge_mode = "Multiplicative" | |
elif inverse_lora: | |
merge_mode = "Additive Inverse" | |
else: | |
merge_mode = "Additive" | |
info_fn(f"Merge mode: {merge_mode}") | |
merged_tensors = 0 | |
total_shards = len(shard_files) | |
# Process each shard individually | |
for i, shard_filename in enumerate(shard_files): | |
progress(0.4 + (i / total_shards) * 0.5, | |
desc=f"Processing {shard_filename} ({i+1}/{total_shards})") | |
info_fn(f"Processing shard {i+1}/{total_shards}: {shard_filename}") | |
# Create temporary directory for this shard only | |
temp_shard_dir = f"/tmp/shard_{i}" | |
os.makedirs(temp_shard_dir, exist_ok=True) | |
try: | |
# Download the current shard | |
shard_path = hf_hub_download( | |
repo_id=base_model_name, | |
filename=shard_filename, | |
local_dir=temp_shard_dir, | |
local_dir_use_symlinks=False | |
) | |
# Process the shard | |
tensors = {} | |
shard_merged_count = 0 | |
with safe_open(shard_path, framework='pt', device='cpu') as f: | |
# Get metadata if available | |
metadata = f.metadata() if hasattr(f, 'metadata') else {} | |
for key in f.keys(): | |
tensor = f.get_tensor(key) | |
# Try to find corresponding LoRA weights | |
lora_A, lora_B = find_lora_weights(lora_state, key) | |
if lora_A is not None and lora_B is not None: | |
info_fn(f"Merging {merge_mode} LoRA weights for {key}") | |
shard_merged_count += 1 | |
merged_tensors += 1 | |
# Convert to float32 for computation | |
original_dtype = tensor.dtype | |
tensor = tensor.to(torch.float32) | |
lora_delta = scale * lora_B.to(torch.float32) @ lora_A.to(torch.float32) | |
if multiplicative_lora: | |
# Validate dimensions for multiplicative LoRA | |
if lora_delta.shape[0] != lora_delta.shape[1]: | |
raise ValueError(f"Multiplicative LoRA requires square delta matrix for {key}: got shape {lora_delta.shape}") | |
if lora_delta.shape[-1] != tensor.shape[-2]: | |
raise ValueError(f"Multiplicative LoRA dimension mismatch for {key}: {lora_delta.shape} vs {tensor.shape}") | |
if inverse_lora: | |
# Inverse multiplicative: tensor = (I + lora_delta)^(-1) @ tensor | |
identity = torch.eye(lora_delta.shape[0], device=lora_delta.device, dtype=torch.float32) | |
inverse_matrix = torch.linalg.inv(identity + lora_delta) | |
tensor = inverse_matrix @ tensor | |
else: | |
# Forward multiplicative: tensor = (I + lora_delta) @ tensor | |
tensor += lora_delta @ tensor | |
else: | |
# Validate dimensions for additive LoRA | |
if lora_delta.shape != tensor.shape: | |
raise ValueError(f"Additive LoRA dimension mismatch for {key}: {lora_delta.shape} vs {tensor.shape}") | |
if inverse_lora: | |
# Inverse additive: tensor = tensor - lora_delta | |
tensor -= lora_delta | |
else: | |
# Forward additive: tensor = tensor + lora_delta | |
tensor += lora_delta | |
# Convert back to original dtype | |
tensor = tensor.to(original_dtype) | |
# Clean up intermediate tensors | |
del lora_delta | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
tensors[key] = tensor | |
# Save processed shard to temporary file | |
output_shard_path = os.path.join(temp_shard_dir, f"processed_{shard_filename}") | |
save_file(tensors, output_shard_path, metadata=metadata) | |
info_fn(f"Shard {shard_filename}: Merged {shard_merged_count} tensors") | |
# Upload the processed shard | |
api.upload_file( | |
path_or_fileobj=output_shard_path, | |
path_in_repo=shard_filename, | |
repo_id=output_repo_name, | |
repo_type="model" | |
) | |
# Clean up this shard's data | |
del tensors | |
gc.collect() | |
finally: | |
# Always clean up the temporary shard directory | |
shutil.rmtree(temp_shard_dir, ignore_errors=True) | |
progress(1.0, desc="Upload completed!") | |
success_msg = f"β Successfully merged and uploaded model!\nModel URL: https://huggingface.co/{output_repo_name}\nMerge mode: {merge_mode}\nScale factor: {scale_factor}\nProcessed {total_shards} shards\nMerged {merged_tensors} layers with LoRA weights" | |
info_fn("Merge completed successfully!") | |
return success_msg | |
except Exception as e: | |
error_msg = f"β Error during merge: {str(e)}" | |
warning_fn(error_msg) | |
return error_msg | |
finally: | |
# Cleanup LoRA directory | |
if temp_lora_dir and os.path.exists(temp_lora_dir): | |
shutil.rmtree(temp_lora_dir, ignore_errors=True) | |
gc.collect() | |
INTRODUCTION_TEXT = """ | |
## Memory-Efficient LoRA Merge | |
This tool merges LoRA (Low-Rank Adaptation) adapters with base models using a memory-efficient approach that processes model files individually, significantly reducing memory requirements compared to traditional methods. | |
### Key Features | |
- **Minimal Memory Usage**: Processes one model shard at a time instead of loading the entire model | |
- **Streaming Processing**: Downloads β Processes β Uploads β Deletes each shard sequentially | |
- **Automatic Cleanup**: Temporary files are automatically removed after processing | |
- **Progress Tracking**: Real-time status updates throughout the merge process | |
- **Advanced Options**: Multiplicative LoRA, inverse merging, and custom scale factors | |
""" | |
DETAILS_TEXT = """ | |
### How It Works | |
LoRA enables efficient fine-tuning by adding small adapter weights rather than modifying the entire model. This tool supports four merge modes: | |
- **Additive LoRA**: `W_new = W + scale Γ B @ A` | |
- **Additive Inverse**: `W_new = W - scale Γ B @ A` (removes LoRA effect) | |
- **Multiplicative LoRA**: `W_new = W + scale Γ B @ A @ W` | |
- **Multiplicative Inverse**: `W_new = (I + scale Γ B @ A)^(-1) @ W` | |
### Scale Factor | |
The scale factor (0 < scale < 2) controls the strength of the LoRA merge: | |
- **1.0**: Full strength (default) | |
- **0.5**: Half strength | |
- **1.5**: 150% strength | |
### Memory Efficiency | |
- **Traditional approach**: Loads entire model (~15GB+ for 7B parameter models) | |
- **This approach**: Peak usage determined by largest shard size, not total model size | |
- **Result**: Enables merging of much larger models on limited hardware | |
### Example Usage | |
- **Base Model:** `microsoft/DialoGPT-medium` | |
- **LoRA Adapter:** `username/my-trained-lora` | |
- **Output Name:** `username/dialogpt-merged` | |
### Attribution | |
This tool builds upon excellent work from the community: | |
- **Base implementation:** [Weyaxi/merge-lora](https://huggingface.co/spaces/Weyaxi/merge-lora) | |
- **Memory-efficient method:** [qlora-pipe](https://github.com/tdrussell/qlora-pipe/blob/main/tools/merge_lora.py) by tdrussell | |
""" | |
with gr.Blocks(title="Memory-Efficient LoRA Merge", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(INTRODUCTION_TEXT) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Configuration") | |
hf_token = gr.Textbox( | |
label="Hugging Face Token", | |
placeholder="hf_...", | |
type="password", | |
info="Token with write access to create repositories" | |
) | |
base_model_name = gr.Textbox( | |
label="Base Model Repository", | |
placeholder="microsoft/DialoGPT-medium", | |
info="The original model to merge LoRA into" | |
) | |
lora_model_name = gr.Textbox( | |
label="LoRA Adapter Repository", | |
placeholder="username/my-lora-adapter", | |
info="Repository containing adapter_model.safetensors" | |
) | |
output_repo_name = gr.Textbox( | |
label="Output Repository Name", | |
placeholder="username/my-merged-model", | |
info="Name for the new merged model repository" | |
) | |
gr.Markdown("### Advanced Options") | |
scale_factor = gr.Slider( | |
minimum=0.01, | |
maximum=1.99, | |
value=1.0, | |
step=0.01, | |
label="Scale Factor", | |
info="Strength of LoRA merge (0 < scale < 2)" | |
) | |
multiplicative_lora = gr.Checkbox( | |
label="Multiplicative LoRA", | |
value=False, | |
info="Apply multiplicative LoRA instead of additive LoRA" | |
) | |
inverse_lora = gr.Checkbox( | |
label="Inverse Merge", | |
value=False, | |
info="Apply inverse operation (subtract/invert the LoRA effect)" | |
) | |
with gr.Column(scale=1): | |
gr.Markdown("### Status") | |
output_text = gr.Textbox( | |
label="Merge Progress & Results", | |
lines=20, | |
interactive=False, | |
show_copy_button=True | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Start LoRA Merge", variant="primary", size="lg") | |
submit_btn.click( | |
fn=merge_lora_efficient, | |
inputs=[hf_token, base_model_name, lora_model_name, output_repo_name, | |
scale_factor, multiplicative_lora, inverse_lora], | |
outputs=output_text | |
) | |
gr.Markdown(DETAILS_TEXT) | |
demo.queue() | |
demo.launch(show_error=True) |