Spaces:
Running
Running
| import io | |
| import gc | |
| import os | |
| import json | |
| import struct | |
| import torch | |
| import pandas as pd | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| import plotly.graph_objects as go | |
| import gradio as gr | |
| import PIL.Image | |
| from transformers import AutoModelForCausalLM, AutoConfig | |
| from huggingface_hub import hf_hub_download, hf_hub_url, snapshot_download | |
| from huggingface_hub.utils import build_hf_headers | |
| from safetensors import safe_open | |
| import requests | |
| # Set style for matplotlib | |
| sns.set_theme(style="whitegrid") | |
| # Cache for metadata only | |
| _metadata_cache = {} | |
| def calculate_weight_diff(base_weight, chat_weight): | |
| """Calculates the mean absolute difference between two tensors.""" | |
| b_w = base_weight.float() | |
| c_w = chat_weight.float() | |
| result = torch.abs(b_w - c_w).mean().item() | |
| del b_w, c_w | |
| return result | |
| def get_safetensor_index(repo_id, token=None): | |
| """Download and parse the safetensors index.""" | |
| cache_key = f"{repo_id}_index" | |
| if cache_key in _metadata_cache: | |
| return _metadata_cache[cache_key] | |
| try: | |
| index_path = hf_hub_download(repo_id, "model.safetensors.index.json", token=token) | |
| with open(index_path, 'r') as f: | |
| index_data = json.load(f) | |
| weight_map = index_data.get("weight_map", {}) | |
| _metadata_cache[cache_key] = weight_map | |
| return weight_map | |
| except Exception: | |
| _metadata_cache[cache_key] = None | |
| return None | |
| # ============================================================================= | |
| # STREAMING MODE (Ultra Low Memory - No disk usage) | |
| # ============================================================================= | |
| def get_safetensor_header(repo_id, filename, token=None): | |
| """Fetch only the header of a safetensor file using HTTP range request.""" | |
| cache_key = f"{repo_id}_{filename}_header" | |
| if cache_key in _metadata_cache: | |
| return _metadata_cache[cache_key] | |
| url = hf_hub_url(repo_id, filename) | |
| headers = build_hf_headers(token=token) | |
| # First, get the header size (first 8 bytes) | |
| headers["Range"] = "bytes=0-7" | |
| response = requests.get(url, headers=headers) | |
| response.raise_for_status() | |
| header_size = struct.unpack('<Q', response.content)[0] | |
| # Now get the header JSON | |
| headers["Range"] = f"bytes=8-{8 + header_size - 1}" | |
| response = requests.get(url, headers=headers) | |
| response.raise_for_status() | |
| header = json.loads(response.content.decode('utf-8')) | |
| result = {"header": header, "header_size": header_size, "data_offset": 8 + header_size} | |
| _metadata_cache[cache_key] = result | |
| return result | |
| def stream_tensor_from_safetensor(repo_id, filename, tensor_name, token=None): | |
| """Stream a specific tensor using HTTP range requests - minimal memory usage.""" | |
| header_info = get_safetensor_header(repo_id, filename, token) | |
| header = header_info["header"] | |
| data_offset = header_info["data_offset"] | |
| if tensor_name not in header: | |
| raise KeyError(f"Tensor {tensor_name} not found in {filename}") | |
| tensor_info = header[tensor_name] | |
| dtype_str = tensor_info["dtype"] | |
| shape = tensor_info["shape"] | |
| offsets = tensor_info["data_offsets"] | |
| start_offset = offsets[0] + data_offset | |
| end_offset = offsets[1] + data_offset - 1 | |
| dtype_map = { | |
| "F32": torch.float32, | |
| "F16": torch.float16, | |
| "BF16": torch.bfloat16, | |
| "I64": torch.int64, | |
| "I32": torch.int32, | |
| "I16": torch.int16, | |
| "I8": torch.int8, | |
| "U8": torch.uint8, | |
| "BOOL": torch.bool, | |
| } | |
| torch_dtype = dtype_map.get(dtype_str, torch.float32) | |
| url = hf_hub_url(repo_id, filename) | |
| headers = build_hf_headers(token=token) | |
| headers["Range"] = f"bytes={start_offset}-{end_offset}" | |
| response = requests.get(url, headers=headers) | |
| response.raise_for_status() | |
| tensor = torch.frombuffer(bytearray(response.content), dtype=torch_dtype).reshape(shape).clone() | |
| del response | |
| return tensor | |
| def load_tensor_streaming(repo_id, tensor_name, token=None): | |
| """Load a specific tensor using streaming - minimal memory.""" | |
| weight_map = get_safetensor_index(repo_id, token) | |
| if weight_map is not None: | |
| if tensor_name not in weight_map: | |
| raise KeyError(f"Tensor {tensor_name} not found in weight map for {repo_id}") | |
| filename = weight_map[tensor_name] | |
| else: | |
| filename = "model.safetensors" | |
| return stream_tensor_from_safetensor(repo_id, filename, tensor_name, token) | |
| def calculate_layer_diffs_streaming(base_repo, chat_repo, token=None, progress=None): | |
| """Ultra memory-efficient: streams individual tensors via HTTP range requests.""" | |
| global _metadata_cache | |
| _metadata_cache = {} | |
| print("Fetching model configuration...") | |
| base_config = AutoConfig.from_pretrained(base_repo, token=token, trust_remote_code=True) | |
| num_layers = base_config.num_hidden_layers | |
| components_to_track = [ | |
| ('input_layernorm', 'model.layers.{}.input_layernorm.weight'), | |
| ('self_attn_q_proj', 'model.layers.{}.self_attn.q_proj.weight'), | |
| ('self_attn_k_proj', 'model.layers.{}.self_attn.k_proj.weight'), | |
| ('self_attn_v_proj', 'model.layers.{}.self_attn.v_proj.weight'), | |
| ('self_attn_o_proj', 'model.layers.{}.self_attn.o_proj.weight'), | |
| ('post_attention_layernorm', 'model.layers.{}.post_attention_layernorm.weight'), | |
| ('mlp_gate_proj', 'model.layers.{}.mlp.gate_proj.weight'), | |
| ('mlp_up_proj', 'model.layers.{}.mlp.up_proj.weight'), | |
| ('mlp_down_proj', 'model.layers.{}.mlp.down_proj.weight'), | |
| ] | |
| layer_diffs = [] | |
| total_ops = num_layers * len(components_to_track) | |
| current_op = 0 | |
| print(f"Processing {num_layers} layers in streaming mode...") | |
| get_safetensor_index(base_repo, token) | |
| get_safetensor_index(chat_repo, token) | |
| for layer_idx in range(num_layers): | |
| layer_data = {} | |
| for name, pattern in components_to_track: | |
| tensor_name = pattern.format(layer_idx) | |
| try: | |
| base_tensor = load_tensor_streaming(base_repo, tensor_name, token) | |
| chat_tensor = load_tensor_streaming(chat_repo, tensor_name, token) | |
| diff = calculate_weight_diff(base_tensor, chat_tensor) | |
| layer_data[name] = diff | |
| del base_tensor | |
| del chat_tensor | |
| except (KeyError, Exception) as e: | |
| print(f"Warning: Could not load {tensor_name}: {e}") | |
| layer_data[name] = 0.0 | |
| current_op += 1 | |
| if progress is not None: | |
| progress(current_op / total_ops, desc=f"Layer {layer_idx + 1}/{num_layers}: {name}") | |
| layer_diffs.append(layer_data) | |
| print(f"Completed layer {layer_idx + 1}/{num_layers}") | |
| gc.collect() | |
| _metadata_cache = {} | |
| gc.collect() | |
| return layer_diffs | |
| # ============================================================================= | |
| # DISK CACHE MODE (Low Memory - Uses disk storage) | |
| # ============================================================================= | |
| def download_model_safetensors(repo_id, token=None, progress_callback=None): | |
| """Download all safetensor files for a model to disk cache.""" | |
| print(f"Downloading safetensor files for {repo_id}...") | |
| # Download the entire model's safetensors files | |
| local_dir = snapshot_download( | |
| repo_id, | |
| token=token, | |
| allow_patterns=["*.safetensors", "*.json"], | |
| ignore_patterns=["*.bin", "*.pt", "*.ckpt"], | |
| ) | |
| return local_dir | |
| def get_local_safetensor_files(local_dir): | |
| """Get list of safetensor files in local directory.""" | |
| safetensor_files = [] | |
| for f in os.listdir(local_dir): | |
| if f.endswith('.safetensors'): | |
| safetensor_files.append(os.path.join(local_dir, f)) | |
| return safetensor_files | |
| def get_local_weight_map(local_dir): | |
| """Get weight map from local index file.""" | |
| index_path = os.path.join(local_dir, "model.safetensors.index.json") | |
| if os.path.exists(index_path): | |
| with open(index_path, 'r') as f: | |
| index_data = json.load(f) | |
| return index_data.get("weight_map", {}) | |
| return None | |
| def load_tensor_from_disk(local_dir, tensor_name, weight_map=None): | |
| """Load a specific tensor from disk-cached safetensor files.""" | |
| if weight_map is not None: | |
| if tensor_name not in weight_map: | |
| raise KeyError(f"Tensor {tensor_name} not found in weight map") | |
| filename = weight_map[tensor_name] | |
| file_path = os.path.join(local_dir, filename) | |
| else: | |
| # Single file model | |
| file_path = os.path.join(local_dir, "model.safetensors") | |
| with safe_open(file_path, framework="pt", device="cpu") as f: | |
| if tensor_name not in f.keys(): | |
| raise KeyError(f"Tensor {tensor_name} not found in {file_path}") | |
| tensor = f.get_tensor(tensor_name).clone() | |
| return tensor | |
| def calculate_layer_diffs_disk_cache(base_repo, chat_repo, token=None, progress=None): | |
| """Disk cache mode: downloads full model to disk, loads tensors one at a time.""" | |
| print("=" * 60) | |
| print("DISK CACHE MODE") | |
| print("Step 1: Downloading model files to disk cache...") | |
| print("=" * 60) | |
| if progress: | |
| progress(0.05, desc="Downloading base model to disk...") | |
| base_local_dir = download_model_safetensors(base_repo, token) | |
| print(f"Base model cached at: {base_local_dir}") | |
| if progress: | |
| progress(0.15, desc="Downloading chat model to disk...") | |
| chat_local_dir = download_model_safetensors(chat_repo, token) | |
| print(f"Chat model cached at: {chat_local_dir}") | |
| # Get weight maps | |
| base_weight_map = get_local_weight_map(base_local_dir) | |
| chat_weight_map = get_local_weight_map(chat_local_dir) | |
| # Get config | |
| print("\nFetching model configuration...") | |
| base_config = AutoConfig.from_pretrained(base_repo, token=token, trust_remote_code=True) | |
| num_layers = base_config.num_hidden_layers | |
| components_to_track = [ | |
| ('input_layernorm', 'model.layers.{}.input_layernorm.weight'), | |
| ('self_attn_q_proj', 'model.layers.{}.self_attn.q_proj.weight'), | |
| ('self_attn_k_proj', 'model.layers.{}.self_attn.k_proj.weight'), | |
| ('self_attn_v_proj', 'model.layers.{}.self_attn.v_proj.weight'), | |
| ('self_attn_o_proj', 'model.layers.{}.self_attn.o_proj.weight'), | |
| ('post_attention_layernorm', 'model.layers.{}.post_attention_layernorm.weight'), | |
| ('mlp_gate_proj', 'model.layers.{}.mlp.gate_proj.weight'), | |
| ('mlp_up_proj', 'model.layers.{}.mlp.up_proj.weight'), | |
| ('mlp_down_proj', 'model.layers.{}.mlp.down_proj.weight'), | |
| ] | |
| layer_diffs = [] | |
| total_ops = num_layers * len(components_to_track) | |
| current_op = 0 | |
| print("=" * 60) | |
| print(f"Step 2: Processing {num_layers} layers from disk cache...") | |
| print("=" * 60) | |
| for layer_idx in range(num_layers): | |
| layer_data = {} | |
| for name, pattern in components_to_track: | |
| tensor_name = pattern.format(layer_idx) | |
| try: | |
| # Load from disk cache - only this tensor goes into RAM | |
| base_tensor = load_tensor_from_disk(base_local_dir, tensor_name, base_weight_map) | |
| chat_tensor = load_tensor_from_disk(chat_local_dir, tensor_name, chat_weight_map) | |
| diff = calculate_weight_diff(base_tensor, chat_tensor) | |
| layer_data[name] = diff | |
| # Free RAM immediately | |
| del base_tensor | |
| del chat_tensor | |
| except (KeyError, Exception) as e: | |
| print(f"Warning: Could not load {tensor_name}: {e}") | |
| layer_data[name] = 0.0 | |
| current_op += 1 | |
| if progress is not None: | |
| # Scale progress from 0.2 to 0.9 for the processing phase | |
| scaled_progress = 0.2 + (current_op / total_ops) * 0.7 | |
| progress(scaled_progress, desc=f"Layer {layer_idx + 1}/{num_layers}: {name}") | |
| layer_diffs.append(layer_data) | |
| print(f"Completed layer {layer_idx + 1}/{num_layers}") | |
| # Garbage collect every few layers | |
| if layer_idx % 5 == 0: | |
| gc.collect() | |
| gc.collect() | |
| print("\nProcessing complete!") | |
| return layer_diffs | |
| # ============================================================================= | |
| # STANDARD MODE (Full models in memory) | |
| # ============================================================================= | |
| def calculate_layer_diffs_standard(base_model, chat_model, progress=None): | |
| """Standard mode: loads full models into memory.""" | |
| layer_diffs = [] | |
| layers = list(zip(base_model.model.layers, chat_model.model.layers)) | |
| total_layers = len(layers) | |
| components_to_track = [ | |
| ('input_layernorm', lambda l: l.input_layernorm.weight), | |
| ('self_attn_q_proj', lambda l: l.self_attn.q_proj.weight), | |
| ('self_attn_k_proj', lambda l: l.self_attn.k_proj.weight), | |
| ('self_attn_v_proj', lambda l: l.self_attn.v_proj.weight), | |
| ('self_attn_o_proj', lambda l: l.self_attn.o_proj.weight), | |
| ('post_attention_layernorm', lambda l: l.post_attention_layernorm.weight), | |
| ('mlp_gate_proj', lambda l: l.mlp.gate_proj.weight), | |
| ('mlp_up_proj', lambda l: l.mlp.up_proj.weight), | |
| ('mlp_down_proj', lambda l: l.mlp.down_proj.weight), | |
| ] | |
| print("Calculating differences (standard mode)...") | |
| for idx, (base_layer, chat_layer) in enumerate(layers): | |
| layer_data = {} | |
| for name, getter in components_to_track: | |
| try: | |
| val = calculate_weight_diff(getter(base_layer), getter(chat_layer)) | |
| layer_data[name] = val | |
| except AttributeError: | |
| layer_data[name] = 0.0 | |
| layer_diffs.append(layer_data) | |
| if progress is not None: | |
| progress((idx + 1) / total_layers, desc=f"Processing layer {idx + 1}/{total_layers}") | |
| return layer_diffs | |
| # ============================================================================= | |
| # VISUALIZATION | |
| # ============================================================================= | |
| def visualize_2d_heatmap(layer_diffs, base_model_name, chat_model_name): | |
| """Generates the static 2D Heatmap image.""" | |
| if not layer_diffs: | |
| return None | |
| num_layers = len(layer_diffs) | |
| components = list(layer_diffs[0].keys()) | |
| num_components = len(components) | |
| height = max(8, num_layers / 6) | |
| width = max(20, num_components * 2.5) | |
| if num_components > 6: | |
| nrows = 2 | |
| ncols = (num_components + 1) // 2 | |
| else: | |
| nrows = 1 | |
| ncols = num_components | |
| fig, axs = plt.subplots(nrows, ncols, figsize=(width, height * (1.2 if nrows > 1 else 1))) | |
| axs = axs.flatten() if num_components > 1 else [axs] | |
| fig.suptitle(f"Weight Differences: {base_model_name} vs {chat_model_name}", fontsize=16, y=0.98) | |
| tick_font_size = max(6, min(10, 300 / num_layers)) | |
| for i, component in enumerate(components): | |
| data = [[row[component]] for row in layer_diffs] | |
| sns.heatmap(data, | |
| annot=True, | |
| fmt=".6f", | |
| cmap="viridis", | |
| ax=axs[i], | |
| cbar=False, | |
| annot_kws={'size': tick_font_size * 0.8}) | |
| axs[i].set_title(component, fontsize=12, fontweight='bold') | |
| axs[i].set_yticks(range(num_layers)) | |
| axs[i].set_yticklabels(range(num_layers), fontsize=tick_font_size) | |
| axs[i].set_xticks([]) | |
| axs[i].invert_yaxis() | |
| for j in range(i + 1, len(axs)): | |
| fig.delaxes(axs[j]) | |
| plt.tight_layout(rect=[0, 0, 1, 0.96]) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
| buf.seek(0) | |
| plt.close(fig) | |
| return PIL.Image.open(buf) | |
| def generate_3d_plot(layer_diffs): | |
| """Generates an interactive 3D Surface plot as a Plotly Figure.""" | |
| if not layer_diffs: | |
| return None | |
| df = pd.DataFrame(layer_diffs) | |
| x_labels = df.columns.tolist() | |
| y_labels = df.index.tolist() | |
| z_data = df.values | |
| fig = go.Figure(data=[go.Surface(z=z_data, x=x_labels, y=y_labels, colorscale='Viridis')]) | |
| fig.update_layout( | |
| title='3D Landscape of Weight Differences', | |
| scene=dict( | |
| xaxis_title='Model Components', | |
| yaxis_title='Layer Index', | |
| zaxis_title='Mean Weight Diff', | |
| xaxis=dict(tickangle=45), | |
| ), | |
| autosize=True, | |
| height=700, | |
| margin=dict(l=65, r=50, b=65, t=90) | |
| ) | |
| return fig | |
| # ============================================================================= | |
| # MAIN PROCESSING | |
| # ============================================================================= | |
| def process_models(base_name, chat_name, hf_token, memory_mode, progress=gr.Progress()): | |
| if not base_name or not chat_name: | |
| raise gr.Error("Please provide both model names.") | |
| token = hf_token if hf_token else None | |
| try: | |
| if memory_mode == "streaming": | |
| # Streaming mode - ultra low memory, no disk | |
| progress(0, desc="Starting streaming mode (ultra low memory)...") | |
| diffs = calculate_layer_diffs_streaming( | |
| base_name, | |
| chat_name, | |
| token=token, | |
| progress=progress | |
| ) | |
| elif memory_mode == "disk_cache": | |
| # Disk cache mode - downloads to disk, loads tensors one at a time | |
| progress(0, desc="Starting disk cache mode...") | |
| diffs = calculate_layer_diffs_disk_cache( | |
| base_name, | |
| chat_name, | |
| token=token, | |
| progress=progress | |
| ) | |
| else: | |
| # Standard mode - full models in memory | |
| progress(0, desc=f"Loading {base_name}...") | |
| print(f"Loading {base_name}...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_name, | |
| torch_dtype=torch.bfloat16, | |
| token=token, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| progress(0.3, desc=f"Loading {chat_name}...") | |
| print(f"Loading {chat_name}...") | |
| chat_model = AutoModelForCausalLM.from_pretrained( | |
| chat_name, | |
| torch_dtype=torch.bfloat16, | |
| token=token, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| progress(0.5, desc="Calculating differences...") | |
| diffs = calculate_layer_diffs_standard(base_model, chat_model, progress=None) | |
| del base_model | |
| del chat_model | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| progress(0.9, desc="Generating visualizations...") | |
| img_2d = visualize_2d_heatmap(diffs, base_name, chat_name) | |
| plot_3d = generate_3d_plot(diffs) | |
| progress(1.0, desc="Complete!") | |
| return img_2d, plot_3d | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise gr.Error(f"Error processing models: {str(e)}") | |
| # ============================================================================= | |
| # GRADIO UI | |
| # ============================================================================= | |
| with gr.Blocks(title="Model Diff Visualizer") as demo: | |
| gr.Markdown("# π§ LLM Weight Difference Visualizer") | |
| gr.Markdown("Compare the weights of a Base model vs. its Instruct/Chat tuned version layer by layer.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| base_input = gr.Textbox( | |
| label="Base Model Name", | |
| placeholder="e.g., meta-llama/Llama-3.3-70B-Instruct" | |
| ) | |
| chat_input = gr.Textbox( | |
| label="Chat/Tuned Model Name", | |
| placeholder="e.g., CrucibleLab/L3.3-70B-Loki-V2.0" | |
| ) | |
| token_input = gr.Textbox( | |
| label="Hugging Face Token (Optional)", | |
| type="password", | |
| placeholder="hf_..." | |
| ) | |
| memory_mode = gr.Radio( | |
| label="Memory Mode", | |
| choices=[ | |
| ("π Standard (Fast, High RAM)", "standard"), | |
| ("πΎ Disk Cache (Medium Speed, Low RAM, Uses Disk)", "disk_cache"), | |
| ("π’ Streaming (Slow, Ultra Low RAM, No Disk)", "streaming"), | |
| ], | |
| value="standard", | |
| info="Choose based on your available RAM and disk space" | |
| ) | |
| with gr.Accordion("Memory Mode Details", open=False): | |
| gr.Markdown(""" | |
| ### π Standard Mode | |
| - **RAM Usage:** ~2x model size (e.g., ~280GB for 70B models) | |
| - **Disk Usage:** HuggingFace cache only | |
| - **Speed:** Fastest | |
| - **Best for:** Machines with lots of RAM | |
| ### πΎ Disk Cache Mode | |
| - **RAM Usage:** ~2-4GB (only one tensor at a time) | |
| - **Disk Usage:** ~2x model size (downloads full safetensors) | |
| - **Speed:** Medium (disk I/O bound) | |
| - **Best for:** Machines with limited RAM but plenty of disk space | |
| ### π’ Streaming Mode | |
| - **RAM Usage:** ~1-2GB (streams bytes directly) | |
| - **Disk Usage:** Minimal (only metadata cached) | |
| - **Speed:** Slowest (many HTTP requests) | |
| - **Best for:** Very constrained environments, or when disk space is also limited | |
| """) | |
| submit_btn = gr.Button("π Analyze Differences", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### 2D Layer-wise Heatmap") | |
| output_2d = gr.Image(label="2D Visualization", type="pil") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### 3D Interactive Landscape") | |
| output_3d = gr.Plot(label="3D Visualization") | |
| submit_btn.click( | |
| fn=process_models, | |
| inputs=[base_input, chat_input, token_input, memory_mode], | |
| outputs=[output_2d, output_3d] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False, server_port=7860) |