|
|
|
""" |
|
Load and use compressed models saved by compress_model.py |
|
""" |
|
|
|
import os |
|
import json |
|
import torch |
|
from transformers import AutoTokenizer |
|
from sentence_transformers import SentenceTransformer |
|
import tensorly as tl |
|
from tltorch.factorized_layers import FactorizedLinear, FactorizedEmbedding |
|
|
|
|
|
tl.set_backend("pytorch") |
|
|
|
|
|
def reconstruct_factorized_layer(layer_info, state_dict_prefix): |
|
"""Reconstruct a factorized layer from saved metadata.""" |
|
layer_type = layer_info["type"] |
|
|
|
|
|
factorization = layer_info.get("factorization", "cp") |
|
rank = layer_info.get("rank", 4) |
|
|
|
if layer_type == "FactorizedLinear": |
|
|
|
in_features = layer_info.get("in_features") |
|
out_features = layer_info.get("out_features") |
|
|
|
if in_features is None or out_features is None: |
|
raise ValueError(f"Missing in_features or out_features for FactorizedLinear layer") |
|
|
|
|
|
import torch.nn as nn |
|
linear = nn.Linear(in_features, out_features, bias=layer_info.get("bias", True)) |
|
|
|
|
|
layer = FactorizedLinear.from_linear( |
|
linear, |
|
rank=rank, |
|
factorization=factorization.upper(), |
|
implementation='reconstructed' |
|
) |
|
|
|
elif layer_type == "FactorizedEmbedding": |
|
|
|
num_embeddings = layer_info.get("num_embeddings") |
|
embedding_dim = layer_info.get("embedding_dim") |
|
|
|
if num_embeddings is None or embedding_dim is None: |
|
raise ValueError(f"Missing num_embeddings or embedding_dim for FactorizedEmbedding layer") |
|
|
|
|
|
import torch.nn as nn |
|
embedding = nn.Embedding( |
|
num_embeddings=num_embeddings, |
|
embedding_dim=embedding_dim, |
|
padding_idx=layer_info.get("padding_idx", None), |
|
max_norm=layer_info.get("max_norm", None), |
|
norm_type=layer_info.get("norm_type", 2.0), |
|
scale_grad_by_freq=layer_info.get("scale_grad_by_freq", False), |
|
sparse=layer_info.get("sparse", False) |
|
) |
|
|
|
|
|
layer = FactorizedEmbedding.from_embedding( |
|
embedding, |
|
rank=rank, |
|
factorization=factorization |
|
) |
|
|
|
else: |
|
raise ValueError(f"Unknown factorized layer type: {layer_type}") |
|
|
|
return layer |
|
|
|
|
|
def set_module_by_path(model, path, new_module): |
|
"""Set a module in the model by its dotted path.""" |
|
parts = path.split('.') |
|
parent = model |
|
|
|
|
|
for part in parts[:-1]: |
|
parent = getattr(parent, part) |
|
|
|
|
|
setattr(parent, parts[-1], new_module) |
|
|
|
|
|
def load_compressed_model(load_dir: str, device="cpu"): |
|
"""Load a compressed model from the saved artifacts.""" |
|
|
|
|
|
factorization_info_path = os.path.join(load_dir, "factorization_info.json") |
|
if not os.path.exists(factorization_info_path): |
|
raise FileNotFoundError(f"No factorization_info.json found in {load_dir}") |
|
|
|
with open(factorization_info_path, "r") as f: |
|
factorized_info = json.load(f) |
|
|
|
|
|
checkpoint_path = os.path.join(load_dir, "pytorch_model.bin") |
|
if not os.path.exists(checkpoint_path): |
|
|
|
checkpoint_path = os.path.join(load_dir, "model_state.pt") |
|
if not os.path.exists(checkpoint_path): |
|
raise FileNotFoundError(f"No model checkpoint found in {load_dir}") |
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
if isinstance(checkpoint, dict) and "state_dict" in checkpoint: |
|
state_dict = checkpoint["state_dict"] |
|
is_sentence_encoder = checkpoint.get("is_sentence_encoder", False) |
|
model_name = checkpoint.get("model_name", "unknown") |
|
else: |
|
|
|
state_dict = checkpoint |
|
is_sentence_encoder = False |
|
model_name = "unknown" |
|
|
|
print(f"Loading compressed model (sentence_encoder={is_sentence_encoder})") |
|
|
|
|
|
if is_sentence_encoder: |
|
|
|
|
|
print("Note: Loading sentence encoders requires the original model architecture.") |
|
print("The compressed weights will be loaded, but the model structure needs to be reconstructed manually.") |
|
|
|
|
|
return { |
|
"state_dict": state_dict, |
|
"factorized_info": factorized_info, |
|
"is_sentence_encoder": True, |
|
"model_name": model_name, |
|
} |
|
|
|
else: |
|
|
|
|
|
print("Note: Loading compressed models requires knowing the original model architecture.") |
|
|
|
return { |
|
"state_dict": state_dict, |
|
"factorized_info": factorized_info, |
|
"is_sentence_encoder": False, |
|
"model_name": model_name, |
|
} |
|
|
|
|
|
def load_compressed_sentence_transformer(original_model_name: str, compressed_dir: str, device="cpu"): |
|
""" |
|
Load a compressed SentenceTransformer model. |
|
|
|
Args: |
|
original_model_name: Name of the original model (e.g., "nomic-ai/CodeRankEmbed") |
|
compressed_dir: Directory containing the compressed model |
|
device: Device to load the model on |
|
|
|
Returns: |
|
Compressed SentenceTransformer model |
|
""" |
|
|
|
model = SentenceTransformer(original_model_name, device=device, trust_remote_code=True) |
|
|
|
|
|
artifacts = load_compressed_model(compressed_dir, device) |
|
|
|
if not artifacts.get("is_sentence_encoder"): |
|
raise ValueError("The compressed model is not a sentence encoder") |
|
|
|
|
|
state_dict = artifacts["state_dict"] |
|
factorized_info = artifacts["factorized_info"] |
|
|
|
|
|
for layer_path, layer_info in factorized_info.items(): |
|
|
|
factorized_layer = reconstruct_factorized_layer(layer_info, layer_path) |
|
|
|
|
|
set_module_by_path(model, layer_path, factorized_layer) |
|
|
|
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
return model |
|
|
|
|
|
def example_usage(): |
|
"""Example of how to use the compressed model loader.""" |
|
|
|
compressed_dir = "coderank_compressed" |
|
original_model = "nomic-ai/CodeRankEmbed" |
|
|
|
print(f"Loading compressed model from {compressed_dir}") |
|
|
|
try: |
|
|
|
model = load_compressed_sentence_transformer( |
|
original_model_name=original_model, |
|
compressed_dir=compressed_dir, |
|
device="cpu" |
|
) |
|
|
|
|
|
sentences = ["def hello_world():\n print('Hello, World!')", "System.out.println('Hello, World!');"] |
|
embeddings = model.encode(sentences) |
|
|
|
print(f"✔ Successfully loaded compressed model") |
|
print(f" Embedding shape: {embeddings.shape}") |
|
|
|
except Exception as e: |
|
print(f"⚠ Error loading compressed model: {e}") |
|
print("\nTo manually load the compressed model:") |
|
print("1. Load the factorization_info.json to see the compressed layer structure") |
|
print("2. Reconstruct the model with factorized layers based on the metadata") |
|
print("3. Load the state dict from pytorch_model.bin") |
|
|
|
|
|
if __name__ == "__main__": |
|
example_usage() |