#!/usr/bin/env python3 """ Load and use compressed models saved by compress_model.py """ import os import json import torch from transformers import AutoTokenizer from sentence_transformers import SentenceTransformer import tensorly as tl from tltorch.factorized_layers import FactorizedLinear, FactorizedEmbedding # Set TensorLy backend to PyTorch tl.set_backend("pytorch") def reconstruct_factorized_layer(layer_info, state_dict_prefix): """Reconstruct a factorized layer from saved metadata.""" layer_type = layer_info["type"] # Use defaults if factorization/rank not specified factorization = layer_info.get("factorization", "cp") # default to CP factorization rank = layer_info.get("rank", 4) # default rank of 4 if layer_type == "FactorizedLinear": # Create a regular linear layer first in_features = layer_info.get("in_features") out_features = layer_info.get("out_features") if in_features is None or out_features is None: raise ValueError(f"Missing in_features or out_features for FactorizedLinear layer") # Create a dummy linear layer import torch.nn as nn linear = nn.Linear(in_features, out_features, bias=layer_info.get("bias", True)) # Convert to factorized using the from_linear method layer = FactorizedLinear.from_linear( linear, rank=rank, factorization=factorization.upper(), # The method expects uppercase implementation='reconstructed' ) elif layer_type == "FactorizedEmbedding": # Create a regular embedding layer first num_embeddings = layer_info.get("num_embeddings") embedding_dim = layer_info.get("embedding_dim") if num_embeddings is None or embedding_dim is None: raise ValueError(f"Missing num_embeddings or embedding_dim for FactorizedEmbedding layer") # Create a dummy embedding layer import torch.nn as nn embedding = nn.Embedding( num_embeddings=num_embeddings, embedding_dim=embedding_dim, padding_idx=layer_info.get("padding_idx", None), max_norm=layer_info.get("max_norm", None), norm_type=layer_info.get("norm_type", 2.0), scale_grad_by_freq=layer_info.get("scale_grad_by_freq", False), sparse=layer_info.get("sparse", False) ) # Convert to factorized using the from_embedding method layer = FactorizedEmbedding.from_embedding( embedding, rank=rank, factorization=factorization ) else: raise ValueError(f"Unknown factorized layer type: {layer_type}") return layer def set_module_by_path(model, path, new_module): """Set a module in the model by its dotted path.""" parts = path.split('.') parent = model # Navigate to the parent module for part in parts[:-1]: parent = getattr(parent, part) # Set the new module setattr(parent, parts[-1], new_module) def load_compressed_model(load_dir: str, device="cpu"): """Load a compressed model from the saved artifacts.""" # Load factorization info factorization_info_path = os.path.join(load_dir, "factorization_info.json") if not os.path.exists(factorization_info_path): raise FileNotFoundError(f"No factorization_info.json found in {load_dir}") with open(factorization_info_path, "r") as f: factorized_info = json.load(f) # Load the saved checkpoint checkpoint_path = os.path.join(load_dir, "pytorch_model.bin") if not os.path.exists(checkpoint_path): # Try alternative path checkpoint_path = os.path.join(load_dir, "model_state.pt") if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"No model checkpoint found in {load_dir}") checkpoint = torch.load(checkpoint_path, map_location=device) # Extract info from checkpoint if isinstance(checkpoint, dict) and "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] is_sentence_encoder = checkpoint.get("is_sentence_encoder", False) model_name = checkpoint.get("model_name", "unknown") else: # Assume it's just the state dict state_dict = checkpoint is_sentence_encoder = False model_name = "unknown" print(f"Loading compressed model (sentence_encoder={is_sentence_encoder})") # For sentence encoders, we need to reconstruct differently if is_sentence_encoder: # Try to load the base model first # This is a simplified approach - in practice, you'd need the original model architecture print("Note: Loading sentence encoders requires the original model architecture.") print("The compressed weights will be loaded, but the model structure needs to be reconstructed manually.") # Return the loaded components for manual reconstruction return { "state_dict": state_dict, "factorized_info": factorized_info, "is_sentence_encoder": True, "model_name": model_name, } else: # For standard transformers models, we can try to reconstruct # This is also simplified - you'd need to know the original model class print("Note: Loading compressed models requires knowing the original model architecture.") return { "state_dict": state_dict, "factorized_info": factorized_info, "is_sentence_encoder": False, "model_name": model_name, } def load_compressed_sentence_transformer(original_model_name: str, compressed_dir: str, device="cpu"): """ Load a compressed SentenceTransformer model. Args: original_model_name: Name of the original model (e.g., "nomic-ai/CodeRankEmbed") compressed_dir: Directory containing the compressed model device: Device to load the model on Returns: Compressed SentenceTransformer model """ # Load the original model structure model = SentenceTransformer(original_model_name, device=device, trust_remote_code=True) # Load compression artifacts artifacts = load_compressed_model(compressed_dir, device) if not artifacts.get("is_sentence_encoder"): raise ValueError("The compressed model is not a sentence encoder") # Load the compressed state dict state_dict = artifacts["state_dict"] factorized_info = artifacts["factorized_info"] # Reconstruct factorized layers for layer_path, layer_info in factorized_info.items(): # Create the factorized layer factorized_layer = reconstruct_factorized_layer(layer_info, layer_path) # Set it in the model set_module_by_path(model, layer_path, factorized_layer) # Load the state dict model.load_state_dict(state_dict, strict=False) return model def example_usage(): """Example of how to use the compressed model loader.""" compressed_dir = "coderank_compressed" original_model = "nomic-ai/CodeRankEmbed" print(f"Loading compressed model from {compressed_dir}") try: # For sentence transformers model = load_compressed_sentence_transformer( original_model_name=original_model, compressed_dir=compressed_dir, device="cpu" ) # Test the model sentences = ["def hello_world():\n print('Hello, World!')", "System.out.println('Hello, World!');"] embeddings = model.encode(sentences) print(f"✔ Successfully loaded compressed model") print(f" Embedding shape: {embeddings.shape}") except Exception as e: print(f"⚠ Error loading compressed model: {e}") print("\nTo manually load the compressed model:") print("1. Load the factorization_info.json to see the compressed layer structure") print("2. Reconstruct the model with factorized layers based on the metadata") print("3. Load the state dict from pytorch_model.bin") if __name__ == "__main__": example_usage()