Vis_Diff / app.py
Steelskull's picture
Update app.py
1b0840b verified
import io
import gc
import os
import json
import struct
import torch
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import gradio as gr
import PIL.Image
from transformers import AutoModelForCausalLM, AutoConfig
from huggingface_hub import hf_hub_download, hf_hub_url, snapshot_download
from huggingface_hub.utils import build_hf_headers
from safetensors import safe_open
import requests
# Set style for matplotlib
sns.set_theme(style="whitegrid")
# Cache for metadata only
_metadata_cache = {}
def calculate_weight_diff(base_weight, chat_weight):
"""Calculates the mean absolute difference between two tensors."""
b_w = base_weight.float()
c_w = chat_weight.float()
result = torch.abs(b_w - c_w).mean().item()
del b_w, c_w
return result
def get_safetensor_index(repo_id, token=None):
"""Download and parse the safetensors index."""
cache_key = f"{repo_id}_index"
if cache_key in _metadata_cache:
return _metadata_cache[cache_key]
try:
index_path = hf_hub_download(repo_id, "model.safetensors.index.json", token=token)
with open(index_path, 'r') as f:
index_data = json.load(f)
weight_map = index_data.get("weight_map", {})
_metadata_cache[cache_key] = weight_map
return weight_map
except Exception:
_metadata_cache[cache_key] = None
return None
# =============================================================================
# STREAMING MODE (Ultra Low Memory - No disk usage)
# =============================================================================
def get_safetensor_header(repo_id, filename, token=None):
"""Fetch only the header of a safetensor file using HTTP range request."""
cache_key = f"{repo_id}_{filename}_header"
if cache_key in _metadata_cache:
return _metadata_cache[cache_key]
url = hf_hub_url(repo_id, filename)
headers = build_hf_headers(token=token)
# First, get the header size (first 8 bytes)
headers["Range"] = "bytes=0-7"
response = requests.get(url, headers=headers)
response.raise_for_status()
header_size = struct.unpack('<Q', response.content)[0]
# Now get the header JSON
headers["Range"] = f"bytes=8-{8 + header_size - 1}"
response = requests.get(url, headers=headers)
response.raise_for_status()
header = json.loads(response.content.decode('utf-8'))
result = {"header": header, "header_size": header_size, "data_offset": 8 + header_size}
_metadata_cache[cache_key] = result
return result
def stream_tensor_from_safetensor(repo_id, filename, tensor_name, token=None):
"""Stream a specific tensor using HTTP range requests - minimal memory usage."""
header_info = get_safetensor_header(repo_id, filename, token)
header = header_info["header"]
data_offset = header_info["data_offset"]
if tensor_name not in header:
raise KeyError(f"Tensor {tensor_name} not found in {filename}")
tensor_info = header[tensor_name]
dtype_str = tensor_info["dtype"]
shape = tensor_info["shape"]
offsets = tensor_info["data_offsets"]
start_offset = offsets[0] + data_offset
end_offset = offsets[1] + data_offset - 1
dtype_map = {
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}
torch_dtype = dtype_map.get(dtype_str, torch.float32)
url = hf_hub_url(repo_id, filename)
headers = build_hf_headers(token=token)
headers["Range"] = f"bytes={start_offset}-{end_offset}"
response = requests.get(url, headers=headers)
response.raise_for_status()
tensor = torch.frombuffer(bytearray(response.content), dtype=torch_dtype).reshape(shape).clone()
del response
return tensor
def load_tensor_streaming(repo_id, tensor_name, token=None):
"""Load a specific tensor using streaming - minimal memory."""
weight_map = get_safetensor_index(repo_id, token)
if weight_map is not None:
if tensor_name not in weight_map:
raise KeyError(f"Tensor {tensor_name} not found in weight map for {repo_id}")
filename = weight_map[tensor_name]
else:
filename = "model.safetensors"
return stream_tensor_from_safetensor(repo_id, filename, tensor_name, token)
def calculate_layer_diffs_streaming(base_repo, chat_repo, token=None, progress=None):
"""Ultra memory-efficient: streams individual tensors via HTTP range requests."""
global _metadata_cache
_metadata_cache = {}
print("Fetching model configuration...")
base_config = AutoConfig.from_pretrained(base_repo, token=token, trust_remote_code=True)
num_layers = base_config.num_hidden_layers
components_to_track = [
('input_layernorm', 'model.layers.{}.input_layernorm.weight'),
('self_attn_q_proj', 'model.layers.{}.self_attn.q_proj.weight'),
('self_attn_k_proj', 'model.layers.{}.self_attn.k_proj.weight'),
('self_attn_v_proj', 'model.layers.{}.self_attn.v_proj.weight'),
('self_attn_o_proj', 'model.layers.{}.self_attn.o_proj.weight'),
('post_attention_layernorm', 'model.layers.{}.post_attention_layernorm.weight'),
('mlp_gate_proj', 'model.layers.{}.mlp.gate_proj.weight'),
('mlp_up_proj', 'model.layers.{}.mlp.up_proj.weight'),
('mlp_down_proj', 'model.layers.{}.mlp.down_proj.weight'),
]
layer_diffs = []
total_ops = num_layers * len(components_to_track)
current_op = 0
print(f"Processing {num_layers} layers in streaming mode...")
get_safetensor_index(base_repo, token)
get_safetensor_index(chat_repo, token)
for layer_idx in range(num_layers):
layer_data = {}
for name, pattern in components_to_track:
tensor_name = pattern.format(layer_idx)
try:
base_tensor = load_tensor_streaming(base_repo, tensor_name, token)
chat_tensor = load_tensor_streaming(chat_repo, tensor_name, token)
diff = calculate_weight_diff(base_tensor, chat_tensor)
layer_data[name] = diff
del base_tensor
del chat_tensor
except (KeyError, Exception) as e:
print(f"Warning: Could not load {tensor_name}: {e}")
layer_data[name] = 0.0
current_op += 1
if progress is not None:
progress(current_op / total_ops, desc=f"Layer {layer_idx + 1}/{num_layers}: {name}")
layer_diffs.append(layer_data)
print(f"Completed layer {layer_idx + 1}/{num_layers}")
gc.collect()
_metadata_cache = {}
gc.collect()
return layer_diffs
# =============================================================================
# DISK CACHE MODE (Low Memory - Uses disk storage)
# =============================================================================
def download_model_safetensors(repo_id, token=None, progress_callback=None):
"""Download all safetensor files for a model to disk cache."""
print(f"Downloading safetensor files for {repo_id}...")
# Download the entire model's safetensors files
local_dir = snapshot_download(
repo_id,
token=token,
allow_patterns=["*.safetensors", "*.json"],
ignore_patterns=["*.bin", "*.pt", "*.ckpt"],
)
return local_dir
def get_local_safetensor_files(local_dir):
"""Get list of safetensor files in local directory."""
safetensor_files = []
for f in os.listdir(local_dir):
if f.endswith('.safetensors'):
safetensor_files.append(os.path.join(local_dir, f))
return safetensor_files
def get_local_weight_map(local_dir):
"""Get weight map from local index file."""
index_path = os.path.join(local_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
with open(index_path, 'r') as f:
index_data = json.load(f)
return index_data.get("weight_map", {})
return None
def load_tensor_from_disk(local_dir, tensor_name, weight_map=None):
"""Load a specific tensor from disk-cached safetensor files."""
if weight_map is not None:
if tensor_name not in weight_map:
raise KeyError(f"Tensor {tensor_name} not found in weight map")
filename = weight_map[tensor_name]
file_path = os.path.join(local_dir, filename)
else:
# Single file model
file_path = os.path.join(local_dir, "model.safetensors")
with safe_open(file_path, framework="pt", device="cpu") as f:
if tensor_name not in f.keys():
raise KeyError(f"Tensor {tensor_name} not found in {file_path}")
tensor = f.get_tensor(tensor_name).clone()
return tensor
def calculate_layer_diffs_disk_cache(base_repo, chat_repo, token=None, progress=None):
"""Disk cache mode: downloads full model to disk, loads tensors one at a time."""
print("=" * 60)
print("DISK CACHE MODE")
print("Step 1: Downloading model files to disk cache...")
print("=" * 60)
if progress:
progress(0.05, desc="Downloading base model to disk...")
base_local_dir = download_model_safetensors(base_repo, token)
print(f"Base model cached at: {base_local_dir}")
if progress:
progress(0.15, desc="Downloading chat model to disk...")
chat_local_dir = download_model_safetensors(chat_repo, token)
print(f"Chat model cached at: {chat_local_dir}")
# Get weight maps
base_weight_map = get_local_weight_map(base_local_dir)
chat_weight_map = get_local_weight_map(chat_local_dir)
# Get config
print("\nFetching model configuration...")
base_config = AutoConfig.from_pretrained(base_repo, token=token, trust_remote_code=True)
num_layers = base_config.num_hidden_layers
components_to_track = [
('input_layernorm', 'model.layers.{}.input_layernorm.weight'),
('self_attn_q_proj', 'model.layers.{}.self_attn.q_proj.weight'),
('self_attn_k_proj', 'model.layers.{}.self_attn.k_proj.weight'),
('self_attn_v_proj', 'model.layers.{}.self_attn.v_proj.weight'),
('self_attn_o_proj', 'model.layers.{}.self_attn.o_proj.weight'),
('post_attention_layernorm', 'model.layers.{}.post_attention_layernorm.weight'),
('mlp_gate_proj', 'model.layers.{}.mlp.gate_proj.weight'),
('mlp_up_proj', 'model.layers.{}.mlp.up_proj.weight'),
('mlp_down_proj', 'model.layers.{}.mlp.down_proj.weight'),
]
layer_diffs = []
total_ops = num_layers * len(components_to_track)
current_op = 0
print("=" * 60)
print(f"Step 2: Processing {num_layers} layers from disk cache...")
print("=" * 60)
for layer_idx in range(num_layers):
layer_data = {}
for name, pattern in components_to_track:
tensor_name = pattern.format(layer_idx)
try:
# Load from disk cache - only this tensor goes into RAM
base_tensor = load_tensor_from_disk(base_local_dir, tensor_name, base_weight_map)
chat_tensor = load_tensor_from_disk(chat_local_dir, tensor_name, chat_weight_map)
diff = calculate_weight_diff(base_tensor, chat_tensor)
layer_data[name] = diff
# Free RAM immediately
del base_tensor
del chat_tensor
except (KeyError, Exception) as e:
print(f"Warning: Could not load {tensor_name}: {e}")
layer_data[name] = 0.0
current_op += 1
if progress is not None:
# Scale progress from 0.2 to 0.9 for the processing phase
scaled_progress = 0.2 + (current_op / total_ops) * 0.7
progress(scaled_progress, desc=f"Layer {layer_idx + 1}/{num_layers}: {name}")
layer_diffs.append(layer_data)
print(f"Completed layer {layer_idx + 1}/{num_layers}")
# Garbage collect every few layers
if layer_idx % 5 == 0:
gc.collect()
gc.collect()
print("\nProcessing complete!")
return layer_diffs
# =============================================================================
# STANDARD MODE (Full models in memory)
# =============================================================================
def calculate_layer_diffs_standard(base_model, chat_model, progress=None):
"""Standard mode: loads full models into memory."""
layer_diffs = []
layers = list(zip(base_model.model.layers, chat_model.model.layers))
total_layers = len(layers)
components_to_track = [
('input_layernorm', lambda l: l.input_layernorm.weight),
('self_attn_q_proj', lambda l: l.self_attn.q_proj.weight),
('self_attn_k_proj', lambda l: l.self_attn.k_proj.weight),
('self_attn_v_proj', lambda l: l.self_attn.v_proj.weight),
('self_attn_o_proj', lambda l: l.self_attn.o_proj.weight),
('post_attention_layernorm', lambda l: l.post_attention_layernorm.weight),
('mlp_gate_proj', lambda l: l.mlp.gate_proj.weight),
('mlp_up_proj', lambda l: l.mlp.up_proj.weight),
('mlp_down_proj', lambda l: l.mlp.down_proj.weight),
]
print("Calculating differences (standard mode)...")
for idx, (base_layer, chat_layer) in enumerate(layers):
layer_data = {}
for name, getter in components_to_track:
try:
val = calculate_weight_diff(getter(base_layer), getter(chat_layer))
layer_data[name] = val
except AttributeError:
layer_data[name] = 0.0
layer_diffs.append(layer_data)
if progress is not None:
progress((idx + 1) / total_layers, desc=f"Processing layer {idx + 1}/{total_layers}")
return layer_diffs
# =============================================================================
# VISUALIZATION
# =============================================================================
def visualize_2d_heatmap(layer_diffs, base_model_name, chat_model_name):
"""Generates the static 2D Heatmap image."""
if not layer_diffs:
return None
num_layers = len(layer_diffs)
components = list(layer_diffs[0].keys())
num_components = len(components)
height = max(8, num_layers / 6)
width = max(20, num_components * 2.5)
if num_components > 6:
nrows = 2
ncols = (num_components + 1) // 2
else:
nrows = 1
ncols = num_components
fig, axs = plt.subplots(nrows, ncols, figsize=(width, height * (1.2 if nrows > 1 else 1)))
axs = axs.flatten() if num_components > 1 else [axs]
fig.suptitle(f"Weight Differences: {base_model_name} vs {chat_model_name}", fontsize=16, y=0.98)
tick_font_size = max(6, min(10, 300 / num_layers))
for i, component in enumerate(components):
data = [[row[component]] for row in layer_diffs]
sns.heatmap(data,
annot=True,
fmt=".6f",
cmap="viridis",
ax=axs[i],
cbar=False,
annot_kws={'size': tick_font_size * 0.8})
axs[i].set_title(component, fontsize=12, fontweight='bold')
axs[i].set_yticks(range(num_layers))
axs[i].set_yticklabels(range(num_layers), fontsize=tick_font_size)
axs[i].set_xticks([])
axs[i].invert_yaxis()
for j in range(i + 1, len(axs)):
fig.delaxes(axs[j])
plt.tight_layout(rect=[0, 0, 1, 0.96])
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
plt.close(fig)
return PIL.Image.open(buf)
def generate_3d_plot(layer_diffs):
"""Generates an interactive 3D Surface plot as a Plotly Figure."""
if not layer_diffs:
return None
df = pd.DataFrame(layer_diffs)
x_labels = df.columns.tolist()
y_labels = df.index.tolist()
z_data = df.values
fig = go.Figure(data=[go.Surface(z=z_data, x=x_labels, y=y_labels, colorscale='Viridis')])
fig.update_layout(
title='3D Landscape of Weight Differences',
scene=dict(
xaxis_title='Model Components',
yaxis_title='Layer Index',
zaxis_title='Mean Weight Diff',
xaxis=dict(tickangle=45),
),
autosize=True,
height=700,
margin=dict(l=65, r=50, b=65, t=90)
)
return fig
# =============================================================================
# MAIN PROCESSING
# =============================================================================
def process_models(base_name, chat_name, hf_token, memory_mode, progress=gr.Progress()):
if not base_name or not chat_name:
raise gr.Error("Please provide both model names.")
token = hf_token if hf_token else None
try:
if memory_mode == "streaming":
# Streaming mode - ultra low memory, no disk
progress(0, desc="Starting streaming mode (ultra low memory)...")
diffs = calculate_layer_diffs_streaming(
base_name,
chat_name,
token=token,
progress=progress
)
elif memory_mode == "disk_cache":
# Disk cache mode - downloads to disk, loads tensors one at a time
progress(0, desc="Starting disk cache mode...")
diffs = calculate_layer_diffs_disk_cache(
base_name,
chat_name,
token=token,
progress=progress
)
else:
# Standard mode - full models in memory
progress(0, desc=f"Loading {base_name}...")
print(f"Loading {base_name}...")
base_model = AutoModelForCausalLM.from_pretrained(
base_name,
torch_dtype=torch.bfloat16,
token=token,
trust_remote_code=True,
low_cpu_mem_usage=True
)
progress(0.3, desc=f"Loading {chat_name}...")
print(f"Loading {chat_name}...")
chat_model = AutoModelForCausalLM.from_pretrained(
chat_name,
torch_dtype=torch.bfloat16,
token=token,
trust_remote_code=True,
low_cpu_mem_usage=True
)
progress(0.5, desc="Calculating differences...")
diffs = calculate_layer_diffs_standard(base_model, chat_model, progress=None)
del base_model
del chat_model
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
progress(0.9, desc="Generating visualizations...")
img_2d = visualize_2d_heatmap(diffs, base_name, chat_name)
plot_3d = generate_3d_plot(diffs)
progress(1.0, desc="Complete!")
return img_2d, plot_3d
except Exception as e:
import traceback
traceback.print_exc()
raise gr.Error(f"Error processing models: {str(e)}")
# =============================================================================
# GRADIO UI
# =============================================================================
with gr.Blocks(title="Model Diff Visualizer") as demo:
gr.Markdown("# 🧠 LLM Weight Difference Visualizer")
gr.Markdown("Compare the weights of a Base model vs. its Instruct/Chat tuned version layer by layer.")
with gr.Row():
with gr.Column(scale=1):
base_input = gr.Textbox(
label="Base Model Name",
placeholder="e.g., meta-llama/Llama-3.3-70B-Instruct"
)
chat_input = gr.Textbox(
label="Chat/Tuned Model Name",
placeholder="e.g., CrucibleLab/L3.3-70B-Loki-V2.0"
)
token_input = gr.Textbox(
label="Hugging Face Token (Optional)",
type="password",
placeholder="hf_..."
)
memory_mode = gr.Radio(
label="Memory Mode",
choices=[
("πŸš€ Standard (Fast, High RAM)", "standard"),
("πŸ’Ύ Disk Cache (Medium Speed, Low RAM, Uses Disk)", "disk_cache"),
("🐒 Streaming (Slow, Ultra Low RAM, No Disk)", "streaming"),
],
value="standard",
info="Choose based on your available RAM and disk space"
)
with gr.Accordion("Memory Mode Details", open=False):
gr.Markdown("""
### πŸš€ Standard Mode
- **RAM Usage:** ~2x model size (e.g., ~280GB for 70B models)
- **Disk Usage:** HuggingFace cache only
- **Speed:** Fastest
- **Best for:** Machines with lots of RAM
### πŸ’Ύ Disk Cache Mode
- **RAM Usage:** ~2-4GB (only one tensor at a time)
- **Disk Usage:** ~2x model size (downloads full safetensors)
- **Speed:** Medium (disk I/O bound)
- **Best for:** Machines with limited RAM but plenty of disk space
### 🐒 Streaming Mode
- **RAM Usage:** ~1-2GB (streams bytes directly)
- **Disk Usage:** Minimal (only metadata cached)
- **Speed:** Slowest (many HTTP requests)
- **Best for:** Very constrained environments, or when disk space is also limited
""")
submit_btn = gr.Button("πŸš€ Analyze Differences", variant="primary")
with gr.Row():
with gr.Column():
gr.Markdown("### 2D Layer-wise Heatmap")
output_2d = gr.Image(label="2D Visualization", type="pil")
with gr.Row():
with gr.Column():
gr.Markdown("### 3D Interactive Landscape")
output_3d = gr.Plot(label="3D Visualization")
submit_btn.click(
fn=process_models,
inputs=[base_input, chat_input, token_input, memory_mode],
outputs=[output_2d, output_3d]
)
if __name__ == "__main__":
demo.launch(share=False, server_port=7860)