File size: 8,333 Bytes
5c1e941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
#!/usr/bin/env python3
"""
Load and use compressed models saved by compress_model.py
"""

import os
import json
import torch
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import tensorly as tl
from tltorch.factorized_layers import FactorizedLinear, FactorizedEmbedding

# Set TensorLy backend to PyTorch
tl.set_backend("pytorch")


def reconstruct_factorized_layer(layer_info, state_dict_prefix):
    """Reconstruct a factorized layer from saved metadata."""
    layer_type = layer_info["type"]
    
    # Use defaults if factorization/rank not specified
    factorization = layer_info.get("factorization", "cp")  # default to CP factorization
    rank = layer_info.get("rank", 4)  # default rank of 4
    
    if layer_type == "FactorizedLinear":
        # Create a regular linear layer first
        in_features = layer_info.get("in_features")
        out_features = layer_info.get("out_features")
        
        if in_features is None or out_features is None:
            raise ValueError(f"Missing in_features or out_features for FactorizedLinear layer")
        
        # Create a dummy linear layer
        import torch.nn as nn
        linear = nn.Linear(in_features, out_features, bias=layer_info.get("bias", True))
        
        # Convert to factorized using the from_linear method
        layer = FactorizedLinear.from_linear(
            linear,
            rank=rank,
            factorization=factorization.upper(),  # The method expects uppercase
            implementation='reconstructed'
        )
    
    elif layer_type == "FactorizedEmbedding":
        # Create a regular embedding layer first
        num_embeddings = layer_info.get("num_embeddings")
        embedding_dim = layer_info.get("embedding_dim")
        
        if num_embeddings is None or embedding_dim is None:
            raise ValueError(f"Missing num_embeddings or embedding_dim for FactorizedEmbedding layer")
        
        # Create a dummy embedding layer
        import torch.nn as nn
        embedding = nn.Embedding(
            num_embeddings=num_embeddings,
            embedding_dim=embedding_dim,
            padding_idx=layer_info.get("padding_idx", None),
            max_norm=layer_info.get("max_norm", None),
            norm_type=layer_info.get("norm_type", 2.0),
            scale_grad_by_freq=layer_info.get("scale_grad_by_freq", False),
            sparse=layer_info.get("sparse", False)
        )
        
        # Convert to factorized using the from_embedding method
        layer = FactorizedEmbedding.from_embedding(
            embedding,
            rank=rank,
            factorization=factorization
        )
    
    else:
        raise ValueError(f"Unknown factorized layer type: {layer_type}")
    
    return layer


def set_module_by_path(model, path, new_module):
    """Set a module in the model by its dotted path."""
    parts = path.split('.')
    parent = model
    
    # Navigate to the parent module
    for part in parts[:-1]:
        parent = getattr(parent, part)
    
    # Set the new module
    setattr(parent, parts[-1], new_module)


def load_compressed_model(load_dir: str, device="cpu"):
    """Load a compressed model from the saved artifacts."""
    
    # Load factorization info
    factorization_info_path = os.path.join(load_dir, "factorization_info.json")
    if not os.path.exists(factorization_info_path):
        raise FileNotFoundError(f"No factorization_info.json found in {load_dir}")
    
    with open(factorization_info_path, "r") as f:
        factorized_info = json.load(f)
    
    # Load the saved checkpoint
    checkpoint_path = os.path.join(load_dir, "pytorch_model.bin")
    if not os.path.exists(checkpoint_path):
        # Try alternative path
        checkpoint_path = os.path.join(load_dir, "model_state.pt")
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"No model checkpoint found in {load_dir}")
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Extract info from checkpoint
    if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
        state_dict = checkpoint["state_dict"]
        is_sentence_encoder = checkpoint.get("is_sentence_encoder", False)
        model_name = checkpoint.get("model_name", "unknown")
    else:
        # Assume it's just the state dict
        state_dict = checkpoint
        is_sentence_encoder = False
        model_name = "unknown"
    
    print(f"Loading compressed model (sentence_encoder={is_sentence_encoder})")
    
    # For sentence encoders, we need to reconstruct differently
    if is_sentence_encoder:
        # Try to load the base model first
        # This is a simplified approach - in practice, you'd need the original model architecture
        print("Note: Loading sentence encoders requires the original model architecture.")
        print("The compressed weights will be loaded, but the model structure needs to be reconstructed manually.")
        
        # Return the loaded components for manual reconstruction
        return {
            "state_dict": state_dict,
            "factorized_info": factorized_info,
            "is_sentence_encoder": True,
            "model_name": model_name,
        }
    
    else:
        # For standard transformers models, we can try to reconstruct
        # This is also simplified - you'd need to know the original model class
        print("Note: Loading compressed models requires knowing the original model architecture.")
        
        return {
            "state_dict": state_dict,
            "factorized_info": factorized_info,
            "is_sentence_encoder": False,
            "model_name": model_name,
        }


def load_compressed_sentence_transformer(original_model_name: str, compressed_dir: str, device="cpu"):
    """
    Load a compressed SentenceTransformer model.
    
    Args:
        original_model_name: Name of the original model (e.g., "nomic-ai/CodeRankEmbed")
        compressed_dir: Directory containing the compressed model
        device: Device to load the model on
    
    Returns:
        Compressed SentenceTransformer model
    """
    # Load the original model structure
    model = SentenceTransformer(original_model_name, device=device, trust_remote_code=True)
    
    # Load compression artifacts
    artifacts = load_compressed_model(compressed_dir, device)
    
    if not artifacts.get("is_sentence_encoder"):
        raise ValueError("The compressed model is not a sentence encoder")
    
    # Load the compressed state dict
    state_dict = artifacts["state_dict"]
    factorized_info = artifacts["factorized_info"]
    
    # Reconstruct factorized layers
    for layer_path, layer_info in factorized_info.items():
        # Create the factorized layer
        factorized_layer = reconstruct_factorized_layer(layer_info, layer_path)
        
        # Set it in the model
        set_module_by_path(model, layer_path, factorized_layer)
    
    # Load the state dict
    model.load_state_dict(state_dict, strict=False)
    
    return model


def example_usage():
    """Example of how to use the compressed model loader."""
    
    compressed_dir = "coderank_compressed"
    original_model = "nomic-ai/CodeRankEmbed"
    
    print(f"Loading compressed model from {compressed_dir}")
    
    try:
        # For sentence transformers
        model = load_compressed_sentence_transformer(
            original_model_name=original_model,
            compressed_dir=compressed_dir,
            device="cpu"
        )
        
        # Test the model
        sentences = ["def hello_world():\n    print('Hello, World!')", "System.out.println('Hello, World!');"]
        embeddings = model.encode(sentences)
        
        print(f"✔ Successfully loaded compressed model")
        print(f"  Embedding shape: {embeddings.shape}")
        
    except Exception as e:
        print(f"⚠ Error loading compressed model: {e}")
        print("\nTo manually load the compressed model:")
        print("1. Load the factorization_info.json to see the compressed layer structure")
        print("2. Reconstruct the model with factorized layers based on the metadata")
        print("3. Load the state dict from pytorch_model.bin")


if __name__ == "__main__":
    example_usage()