import gc import gradio as gr import torch from huggingface_hub import hf_hub_download, HfApi, login, list_repo_files from safetensors import safe_open from safetensors.torch import save_file, load_file import os import shutil import json api = HfApi() def info_fn(text): gr.Info(text) def warning_fn(text): gr.Warning(text) def load_lora_state(lora_model_name): """Download and load LoRA adapter weights""" temp_lora_dir = "/tmp/lora_adapter" os.makedirs(temp_lora_dir, exist_ok=True) # Download adapter config config_path = hf_hub_download( repo_id=lora_model_name, filename="adapter_config.json", local_dir=temp_lora_dir, local_dir_use_symlinks=False ) with open(config_path, 'r') as f: lora_config = json.load(f) # Download adapter weights try: adapter_path = hf_hub_download( repo_id=lora_model_name, filename="adapter_model.safetensors", local_dir=temp_lora_dir, local_dir_use_symlinks=False ) lora_state = load_file(adapter_path, device='cpu') except: adapter_path = hf_hub_download( repo_id=lora_model_name, filename="adapter_model.bin", local_dir=temp_lora_dir, local_dir_use_symlinks=False ) lora_state = torch.load(adapter_path, map_location='cpu') return lora_state, lora_config, temp_lora_dir def find_lora_weights(lora_state, key): """Find corresponding LoRA A and B weights for a given key""" lora_A = None lora_B = None # Remove .weight suffix for matching clean_key = key.strip('.weight') for lora_key, lora_weight in lora_state.items(): if clean_key in lora_key: if 'lora_A' in lora_key: lora_A = lora_weight elif 'lora_B' in lora_key: lora_B = lora_weight # Both should be None or both should have values if (lora_A is None) != (lora_B is None): return None, None return lora_A, lora_B def download_and_upload_non_model_files(base_model_name, output_repo_name): """Download and upload non-model files (config, tokenizer, etc.)""" temp_config_dir = "/tmp/config_files" os.makedirs(temp_config_dir, exist_ok=True) try: # List all files in the repository files = list_repo_files(repo_id=base_model_name) # Filter non-model files non_model_files = [ f for f in files if not (f.startswith('model') and f.endswith('.safetensors')) ] # Download and upload each non-model file for filename in non_model_files: if filename.endswith(('.gguf', '.bin')) and 'model' in filename: continue # Skip other model formats try: file_path = hf_hub_download( repo_id=base_model_name, filename=filename, local_dir=temp_config_dir, local_dir_use_symlinks=False ) # Upload to output repo api.upload_file( path_or_fileobj=file_path, path_in_repo=filename, repo_id=output_repo_name, repo_type="model" ) except Exception as e: info_fn(f"Skipping {filename}: {e}") finally: shutil.rmtree(temp_config_dir, ignore_errors=True) def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo_name, scale_factor, multiplicative_lora, inverse_lora, progress=gr.Progress()): temp_lora_dir = None try: # Validate scale factor if not (0 < scale_factor < 2): error_msg = "Scale factor must be in the range (0, 2)" warning_fn(error_msg) return f"✗ Error: {error_msg}" login(hf_token) progress(0.1, desc="Loading LoRA adapter...") info_fn("Loading LoRA adapter...") # Load LoRA state (this downloads the adapter) lora_state, lora_config, temp_lora_dir = load_lora_state(lora_model_name) # Calculate scale with user factor base_scale = lora_config['lora_alpha'] / lora_config['r'] scale = base_scale * scale_factor info_fn(f"Using LoRA scale: {scale} (base: {base_scale:.3f} × factor: {scale_factor})") progress(0.2, desc="Creating output repository...") # Create repository try: repo_url = api.create_repo(repo_id=output_repo_name, exist_ok=True) info_fn(f"Repository created/updated: {repo_url}") except Exception as e: warning_fn(f"Repository might already exist: {e}") progress(0.3, desc="Uploading configuration files...") info_fn("Uploading configuration files...") # Download and upload non-model files download_and_upload_non_model_files(base_model_name, output_repo_name) progress(0.4, desc="Finding model shards...") info_fn("Finding model shards...") # Get list of all safetensors files all_files = list_repo_files(repo_id=base_model_name) shard_files = [f for f in all_files if f.startswith('model') and f.endswith('.safetensors')] if not shard_files: raise FileNotFoundError("No model safetensors files found in the repository") info_fn(f"Found {len(shard_files)} model shards to process") # Determine merge mode if multiplicative_lora and inverse_lora: merge_mode = "Multiplicative Inverse" elif multiplicative_lora: merge_mode = "Multiplicative" elif inverse_lora: merge_mode = "Additive Inverse" else: merge_mode = "Additive" info_fn(f"Merge mode: {merge_mode}") merged_tensors = 0 total_shards = len(shard_files) # Process each shard individually for i, shard_filename in enumerate(shard_files): progress(0.4 + (i / total_shards) * 0.5, desc=f"Processing {shard_filename} ({i+1}/{total_shards})") info_fn(f"Processing shard {i+1}/{total_shards}: {shard_filename}") # Create temporary directory for this shard only temp_shard_dir = f"/tmp/shard_{i}" os.makedirs(temp_shard_dir, exist_ok=True) try: # Download the current shard shard_path = hf_hub_download( repo_id=base_model_name, filename=shard_filename, local_dir=temp_shard_dir, local_dir_use_symlinks=False ) # Process the shard tensors = {} shard_merged_count = 0 with safe_open(shard_path, framework='pt', device='cpu') as f: # Get metadata if available metadata = f.metadata() if hasattr(f, 'metadata') else {} for key in f.keys(): tensor = f.get_tensor(key) # Try to find corresponding LoRA weights lora_A, lora_B = find_lora_weights(lora_state, key) if lora_A is not None and lora_B is not None: info_fn(f"Merging {merge_mode} LoRA weights for {key}") shard_merged_count += 1 merged_tensors += 1 # Convert to float32 for computation original_dtype = tensor.dtype tensor = tensor.to(torch.float32) lora_delta = scale * lora_B.to(torch.float32) @ lora_A.to(torch.float32) if multiplicative_lora: # Validate dimensions for multiplicative LoRA if lora_delta.shape[0] != lora_delta.shape[1]: raise ValueError(f"Multiplicative LoRA requires square delta matrix for {key}: got shape {lora_delta.shape}") if lora_delta.shape[-1] != tensor.shape[-2]: raise ValueError(f"Multiplicative LoRA dimension mismatch for {key}: {lora_delta.shape} vs {tensor.shape}") if inverse_lora: # Inverse multiplicative: tensor = (I + lora_delta)^(-1) @ tensor identity = torch.eye(lora_delta.shape[0], device=lora_delta.device, dtype=torch.float32) inverse_matrix = torch.linalg.inv(identity + lora_delta) tensor = inverse_matrix @ tensor else: # Forward multiplicative: tensor = (I + lora_delta) @ tensor tensor += lora_delta @ tensor else: # Validate dimensions for additive LoRA if lora_delta.shape != tensor.shape: raise ValueError(f"Additive LoRA dimension mismatch for {key}: {lora_delta.shape} vs {tensor.shape}") if inverse_lora: # Inverse additive: tensor = tensor - lora_delta tensor -= lora_delta else: # Forward additive: tensor = tensor + lora_delta tensor += lora_delta # Convert back to original dtype tensor = tensor.to(original_dtype) # Clean up intermediate tensors del lora_delta if torch.cuda.is_available(): torch.cuda.empty_cache() tensors[key] = tensor # Save processed shard to temporary file output_shard_path = os.path.join(temp_shard_dir, f"processed_{shard_filename}") save_file(tensors, output_shard_path, metadata=metadata) info_fn(f"Shard {shard_filename}: Merged {shard_merged_count} tensors") # Upload the processed shard api.upload_file( path_or_fileobj=output_shard_path, path_in_repo=shard_filename, repo_id=output_repo_name, repo_type="model" ) # Clean up this shard's data del tensors gc.collect() finally: # Always clean up the temporary shard directory shutil.rmtree(temp_shard_dir, ignore_errors=True) progress(1.0, desc="Upload completed!") success_msg = f"✓ Successfully merged and uploaded model!\nModel URL: https://huggingface.co/{output_repo_name}\nMerge mode: {merge_mode}\nScale factor: {scale_factor}\nProcessed {total_shards} shards\nMerged {merged_tensors} layers with LoRA weights" info_fn("Merge completed successfully!") return success_msg except Exception as e: error_msg = f"✗ Error during merge: {str(e)}" warning_fn(error_msg) return error_msg finally: # Cleanup LoRA directory if temp_lora_dir and os.path.exists(temp_lora_dir): shutil.rmtree(temp_lora_dir, ignore_errors=True) gc.collect() INTRODUCTION_TEXT = """ ## Memory-Efficient LoRA Merge This tool merges LoRA (Low-Rank Adaptation) adapters with base models using a memory-efficient approach that processes model files individually, significantly reducing memory requirements compared to traditional methods. ### Key Features - **Minimal Memory Usage**: Processes one model shard at a time instead of loading the entire model - **Streaming Processing**: Downloads → Processes → Uploads → Deletes each shard sequentially - **Automatic Cleanup**: Temporary files are automatically removed after processing - **Progress Tracking**: Real-time status updates throughout the merge process - **Advanced Options**: Multiplicative LoRA, inverse merging, and custom scale factors """ DETAILS_TEXT = """ ### How It Works LoRA enables efficient fine-tuning by adding small adapter weights rather than modifying the entire model. This tool supports four merge modes: - **Additive LoRA**: `W_new = W + scale × B @ A` - **Additive Inverse**: `W_new = W - scale × B @ A` (removes LoRA effect) - **Multiplicative LoRA**: `W_new = W + scale × B @ A @ W` - **Multiplicative Inverse**: `W_new = (I + scale × B @ A)^(-1) @ W` ### Scale Factor The scale factor (0 < scale < 2) controls the strength of the LoRA merge: - **1.0**: Full strength (default) - **0.5**: Half strength - **1.5**: 150% strength ### Memory Efficiency - **Traditional approach**: Loads entire model (~15GB+ for 7B parameter models) - **This approach**: Peak usage determined by largest shard size, not total model size - **Result**: Enables merging of much larger models on limited hardware ### Example Usage - **Base Model:** `microsoft/DialoGPT-medium` - **LoRA Adapter:** `username/my-trained-lora` - **Output Name:** `username/dialogpt-merged` ### Attribution This tool builds upon excellent work from the community: - **Base implementation:** [Weyaxi/merge-lora](https://huggingface.co/spaces/Weyaxi/merge-lora) - **Memory-efficient method:** [qlora-pipe](https://github.com/tdrussell/qlora-pipe/blob/main/tools/merge_lora.py) by tdrussell """ with gr.Blocks(title="Memory-Efficient LoRA Merge", theme=gr.themes.Soft()) as demo: gr.Markdown(INTRODUCTION_TEXT) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Configuration") hf_token = gr.Textbox( label="Hugging Face Token", placeholder="hf_...", type="password", info="Token with write access to create repositories" ) base_model_name = gr.Textbox( label="Base Model Repository", placeholder="microsoft/DialoGPT-medium", info="The original model to merge LoRA into" ) lora_model_name = gr.Textbox( label="LoRA Adapter Repository", placeholder="username/my-lora-adapter", info="Repository containing adapter_model.safetensors" ) output_repo_name = gr.Textbox( label="Output Repository Name", placeholder="username/my-merged-model", info="Name for the new merged model repository" ) gr.Markdown("### Advanced Options") scale_factor = gr.Slider( minimum=0.01, maximum=1.99, value=1.0, step=0.01, label="Scale Factor", info="Strength of LoRA merge (0 < scale < 2)" ) multiplicative_lora = gr.Checkbox( label="Multiplicative LoRA", value=False, info="Apply multiplicative LoRA instead of additive LoRA" ) inverse_lora = gr.Checkbox( label="Inverse Merge", value=False, info="Apply inverse operation (subtract/invert the LoRA effect)" ) with gr.Column(scale=1): gr.Markdown("### Status") output_text = gr.Textbox( label="Merge Progress & Results", lines=20, interactive=False, show_copy_button=True ) with gr.Row(): submit_btn = gr.Button("Start LoRA Merge", variant="primary", size="lg") submit_btn.click( fn=merge_lora_efficient, inputs=[hf_token, base_model_name, lora_model_name, output_repo_name, scale_factor, multiplicative_lora, inverse_lora], outputs=output_text ) gr.Markdown(DETAILS_TEXT) demo.queue() demo.launch(show_error=True)