merge-lora / app.py
jukofyork's picture
Update app.py
747bd3e verified
import gc
import gradio as gr
import torch
from huggingface_hub import hf_hub_download, HfApi, login, list_repo_files
from safetensors import safe_open
from safetensors.torch import save_file, load_file
import os
import shutil
import json
api = HfApi()
def info_fn(text):
gr.Info(text)
def warning_fn(text):
gr.Warning(text)
def load_lora_state(lora_model_name):
"""Download and load LoRA adapter weights"""
temp_lora_dir = "/tmp/lora_adapter"
os.makedirs(temp_lora_dir, exist_ok=True)
# Download adapter config
config_path = hf_hub_download(
repo_id=lora_model_name,
filename="adapter_config.json",
local_dir=temp_lora_dir,
local_dir_use_symlinks=False
)
with open(config_path, 'r') as f:
lora_config = json.load(f)
# Download adapter weights
try:
adapter_path = hf_hub_download(
repo_id=lora_model_name,
filename="adapter_model.safetensors",
local_dir=temp_lora_dir,
local_dir_use_symlinks=False
)
lora_state = load_file(adapter_path, device='cpu')
except:
adapter_path = hf_hub_download(
repo_id=lora_model_name,
filename="adapter_model.bin",
local_dir=temp_lora_dir,
local_dir_use_symlinks=False
)
lora_state = torch.load(adapter_path, map_location='cpu')
return lora_state, lora_config, temp_lora_dir
def find_lora_weights(lora_state, key):
"""Find corresponding LoRA A and B weights for a given key"""
lora_A = None
lora_B = None
# Remove .weight suffix for matching
clean_key = key.strip('.weight')
for lora_key, lora_weight in lora_state.items():
if clean_key in lora_key:
if 'lora_A' in lora_key:
lora_A = lora_weight
elif 'lora_B' in lora_key:
lora_B = lora_weight
# Both should be None or both should have values
if (lora_A is None) != (lora_B is None):
return None, None
return lora_A, lora_B
def download_and_upload_non_model_files(base_model_name, output_repo_name):
"""Download and upload non-model files (config, tokenizer, etc.)"""
temp_config_dir = "/tmp/config_files"
os.makedirs(temp_config_dir, exist_ok=True)
try:
# List all files in the repository
files = list_repo_files(repo_id=base_model_name)
# Filter non-model files
non_model_files = [
f for f in files
if not (f.startswith('model') and f.endswith('.safetensors'))
]
# Download and upload each non-model file
for filename in non_model_files:
if filename.endswith(('.gguf', '.bin')) and 'model' in filename:
continue # Skip other model formats
try:
file_path = hf_hub_download(
repo_id=base_model_name,
filename=filename,
local_dir=temp_config_dir,
local_dir_use_symlinks=False
)
# Upload to output repo
api.upload_file(
path_or_fileobj=file_path,
path_in_repo=filename,
repo_id=output_repo_name,
repo_type="model"
)
except Exception as e:
info_fn(f"Skipping {filename}: {e}")
finally:
shutil.rmtree(temp_config_dir, ignore_errors=True)
def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo_name,
scale_factor, multiplicative_lora, inverse_lora, progress=gr.Progress()):
temp_lora_dir = None
try:
# Validate scale factor
if not (0 < scale_factor < 2):
error_msg = "Scale factor must be in the range (0, 2)"
warning_fn(error_msg)
return f"βœ— Error: {error_msg}"
login(hf_token)
progress(0.1, desc="Loading LoRA adapter...")
info_fn("Loading LoRA adapter...")
# Load LoRA state (this downloads the adapter)
lora_state, lora_config, temp_lora_dir = load_lora_state(lora_model_name)
# Calculate scale with user factor
base_scale = lora_config['lora_alpha'] / lora_config['r']
scale = base_scale * scale_factor
info_fn(f"Using LoRA scale: {scale} (base: {base_scale:.3f} Γ— factor: {scale_factor})")
progress(0.2, desc="Creating output repository...")
# Create repository
try:
repo_url = api.create_repo(repo_id=output_repo_name, exist_ok=True)
info_fn(f"Repository created/updated: {repo_url}")
except Exception as e:
warning_fn(f"Repository might already exist: {e}")
progress(0.3, desc="Uploading configuration files...")
info_fn("Uploading configuration files...")
# Download and upload non-model files
download_and_upload_non_model_files(base_model_name, output_repo_name)
progress(0.4, desc="Finding model shards...")
info_fn("Finding model shards...")
# Get list of all safetensors files
all_files = list_repo_files(repo_id=base_model_name)
shard_files = [f for f in all_files if f.startswith('model') and f.endswith('.safetensors')]
if not shard_files:
raise FileNotFoundError("No model safetensors files found in the repository")
info_fn(f"Found {len(shard_files)} model shards to process")
# Determine merge mode
if multiplicative_lora and inverse_lora:
merge_mode = "Multiplicative Inverse"
elif multiplicative_lora:
merge_mode = "Multiplicative"
elif inverse_lora:
merge_mode = "Additive Inverse"
else:
merge_mode = "Additive"
info_fn(f"Merge mode: {merge_mode}")
merged_tensors = 0
total_shards = len(shard_files)
# Process each shard individually
for i, shard_filename in enumerate(shard_files):
progress(0.4 + (i / total_shards) * 0.5,
desc=f"Processing {shard_filename} ({i+1}/{total_shards})")
info_fn(f"Processing shard {i+1}/{total_shards}: {shard_filename}")
# Create temporary directory for this shard only
temp_shard_dir = f"/tmp/shard_{i}"
os.makedirs(temp_shard_dir, exist_ok=True)
try:
# Download the current shard
shard_path = hf_hub_download(
repo_id=base_model_name,
filename=shard_filename,
local_dir=temp_shard_dir,
local_dir_use_symlinks=False
)
# Process the shard
tensors = {}
shard_merged_count = 0
with safe_open(shard_path, framework='pt', device='cpu') as f:
# Get metadata if available
metadata = f.metadata() if hasattr(f, 'metadata') else {}
for key in f.keys():
tensor = f.get_tensor(key)
# Try to find corresponding LoRA weights
lora_A, lora_B = find_lora_weights(lora_state, key)
if lora_A is not None and lora_B is not None:
info_fn(f"Merging {merge_mode} LoRA weights for {key}")
shard_merged_count += 1
merged_tensors += 1
# Convert to float32 for computation
original_dtype = tensor.dtype
tensor = tensor.to(torch.float32)
lora_delta = scale * lora_B.to(torch.float32) @ lora_A.to(torch.float32)
if multiplicative_lora:
# Validate dimensions for multiplicative LoRA
if lora_delta.shape[0] != lora_delta.shape[1]:
raise ValueError(f"Multiplicative LoRA requires square delta matrix for {key}: got shape {lora_delta.shape}")
if lora_delta.shape[-1] != tensor.shape[-2]:
raise ValueError(f"Multiplicative LoRA dimension mismatch for {key}: {lora_delta.shape} vs {tensor.shape}")
if inverse_lora:
# Inverse multiplicative: tensor = (I + lora_delta)^(-1) @ tensor
identity = torch.eye(lora_delta.shape[0], device=lora_delta.device, dtype=torch.float32)
inverse_matrix = torch.linalg.inv(identity + lora_delta)
tensor = inverse_matrix @ tensor
else:
# Forward multiplicative: tensor = (I + lora_delta) @ tensor
tensor += lora_delta @ tensor
else:
# Validate dimensions for additive LoRA
if lora_delta.shape != tensor.shape:
raise ValueError(f"Additive LoRA dimension mismatch for {key}: {lora_delta.shape} vs {tensor.shape}")
if inverse_lora:
# Inverse additive: tensor = tensor - lora_delta
tensor -= lora_delta
else:
# Forward additive: tensor = tensor + lora_delta
tensor += lora_delta
# Convert back to original dtype
tensor = tensor.to(original_dtype)
# Clean up intermediate tensors
del lora_delta
if torch.cuda.is_available():
torch.cuda.empty_cache()
tensors[key] = tensor
# Save processed shard to temporary file
output_shard_path = os.path.join(temp_shard_dir, f"processed_{shard_filename}")
save_file(tensors, output_shard_path, metadata=metadata)
info_fn(f"Shard {shard_filename}: Merged {shard_merged_count} tensors")
# Upload the processed shard
api.upload_file(
path_or_fileobj=output_shard_path,
path_in_repo=shard_filename,
repo_id=output_repo_name,
repo_type="model"
)
# Clean up this shard's data
del tensors
gc.collect()
finally:
# Always clean up the temporary shard directory
shutil.rmtree(temp_shard_dir, ignore_errors=True)
progress(1.0, desc="Upload completed!")
success_msg = f"βœ“ Successfully merged and uploaded model!\nModel URL: https://huggingface.co/{output_repo_name}\nMerge mode: {merge_mode}\nScale factor: {scale_factor}\nProcessed {total_shards} shards\nMerged {merged_tensors} layers with LoRA weights"
info_fn("Merge completed successfully!")
return success_msg
except Exception as e:
error_msg = f"βœ— Error during merge: {str(e)}"
warning_fn(error_msg)
return error_msg
finally:
# Cleanup LoRA directory
if temp_lora_dir and os.path.exists(temp_lora_dir):
shutil.rmtree(temp_lora_dir, ignore_errors=True)
gc.collect()
INTRODUCTION_TEXT = """
## Memory-Efficient LoRA Merge
This tool merges LoRA (Low-Rank Adaptation) adapters with base models using a memory-efficient approach that processes model files individually, significantly reducing memory requirements compared to traditional methods.
### Key Features
- **Minimal Memory Usage**: Processes one model shard at a time instead of loading the entire model
- **Streaming Processing**: Downloads β†’ Processes β†’ Uploads β†’ Deletes each shard sequentially
- **Automatic Cleanup**: Temporary files are automatically removed after processing
- **Progress Tracking**: Real-time status updates throughout the merge process
- **Advanced Options**: Multiplicative LoRA, inverse merging, and custom scale factors
"""
DETAILS_TEXT = """
### How It Works
LoRA enables efficient fine-tuning by adding small adapter weights rather than modifying the entire model. This tool supports four merge modes:
- **Additive LoRA**: `W_new = W + scale Γ— B @ A`
- **Additive Inverse**: `W_new = W - scale Γ— B @ A` (removes LoRA effect)
- **Multiplicative LoRA**: `W_new = W + scale Γ— B @ A @ W`
- **Multiplicative Inverse**: `W_new = (I + scale Γ— B @ A)^(-1) @ W`
### Scale Factor
The scale factor (0 < scale < 2) controls the strength of the LoRA merge:
- **1.0**: Full strength (default)
- **0.5**: Half strength
- **1.5**: 150% strength
### Memory Efficiency
- **Traditional approach**: Loads entire model (~15GB+ for 7B parameter models)
- **This approach**: Peak usage determined by largest shard size, not total model size
- **Result**: Enables merging of much larger models on limited hardware
### Example Usage
- **Base Model:** `microsoft/DialoGPT-medium`
- **LoRA Adapter:** `username/my-trained-lora`
- **Output Name:** `username/dialogpt-merged`
### Attribution
This tool builds upon excellent work from the community:
- **Base implementation:** [Weyaxi/merge-lora](https://huggingface.co/spaces/Weyaxi/merge-lora)
- **Memory-efficient method:** [qlora-pipe](https://github.com/tdrussell/qlora-pipe/blob/main/tools/merge_lora.py) by tdrussell
"""
with gr.Blocks(title="Memory-Efficient LoRA Merge", theme=gr.themes.Soft()) as demo:
gr.Markdown(INTRODUCTION_TEXT)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Configuration")
hf_token = gr.Textbox(
label="Hugging Face Token",
placeholder="hf_...",
type="password",
info="Token with write access to create repositories"
)
base_model_name = gr.Textbox(
label="Base Model Repository",
placeholder="microsoft/DialoGPT-medium",
info="The original model to merge LoRA into"
)
lora_model_name = gr.Textbox(
label="LoRA Adapter Repository",
placeholder="username/my-lora-adapter",
info="Repository containing adapter_model.safetensors"
)
output_repo_name = gr.Textbox(
label="Output Repository Name",
placeholder="username/my-merged-model",
info="Name for the new merged model repository"
)
gr.Markdown("### Advanced Options")
scale_factor = gr.Slider(
minimum=0.01,
maximum=1.99,
value=1.0,
step=0.01,
label="Scale Factor",
info="Strength of LoRA merge (0 < scale < 2)"
)
multiplicative_lora = gr.Checkbox(
label="Multiplicative LoRA",
value=False,
info="Apply multiplicative LoRA instead of additive LoRA"
)
inverse_lora = gr.Checkbox(
label="Inverse Merge",
value=False,
info="Apply inverse operation (subtract/invert the LoRA effect)"
)
with gr.Column(scale=1):
gr.Markdown("### Status")
output_text = gr.Textbox(
label="Merge Progress & Results",
lines=20,
interactive=False,
show_copy_button=True
)
with gr.Row():
submit_btn = gr.Button("Start LoRA Merge", variant="primary", size="lg")
submit_btn.click(
fn=merge_lora_efficient,
inputs=[hf_token, base_model_name, lora_model_name, output_repo_name,
scale_factor, multiplicative_lora, inverse_lora],
outputs=output_text
)
gr.Markdown(DETAILS_TEXT)
demo.queue()
demo.launch(show_error=True)