Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| from huggingface_hub import model_info | |
| import argparse | |
| from copy import deepcopy | |
| import inspect | |
| from logging import warn | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import json | |
| from tuned_lens.model_surgery import get_final_norm, get_transformer_layers | |
| from tuned_lens.load_artifacts import load_lens_artifacts | |
| from tuned_lens.nn import TunedLens | |
| from transformers.models.bloom.modeling_bloom import BloomBlock | |
| from transformers import PreTrainedModel, AutoModelForCausalLM | |
| from typing import Optional, Generator, Union | |
| import torch as th | |
| from tuned_lens.stats.distance import js_divergence | |
| def instantiate_layer(model_config, layer_idx: int, model_type: str) -> th.nn.Module: | |
| if model_type == "bloom": | |
| from transformers.models.bloom.modeling_bloom import BloomBlock | |
| return _BloomBlockWrapper(BloomBlock(model_config)) # type: ignore[arg-type] | |
| if model_type == "gpt_neo": | |
| from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoBlock | |
| return GPTNeoBlock(model_config, layer_idx) | |
| if model_type == "gpt_neox": | |
| from transformers.models.gpt_neox.modeling_gpt_neox import ( | |
| GPTNeoXLayer, | |
| ) | |
| return GPTNeoXLayer(model_config) # type: ignore[arg-type] | |
| if model_type == "gpt2": | |
| from transformers.models.gpt2.modeling_gpt2 import GPT2Block | |
| return GPT2Block(model_config, layer_idx) # type: ignore[arg-type] | |
| if model_type == "opt": | |
| from transformers.models.opt.modeling_opt import OPTDecoderLayer | |
| return OPTDecoderLayer(model_config) # type: ignore[arg-type] | |
| else: | |
| raise ValueError(f"Unknown model type '{model_type}'") | |
| def maybe_wrap(layer: th.nn.Module) -> th.nn.Module: | |
| return _BloomBlockWrapper(layer) if isinstance(layer, BloomBlock) else layer | |
| # Very annoying that we have to do this. See https://bit.ly/3XSQ7W6 for context on | |
| # what we're doing here. | |
| class _BloomBlockWrapper(th.nn.Module): | |
| def __init__(self, block: BloomBlock): | |
| super().__init__() | |
| self.block = block | |
| def forward(self, x: th.Tensor) -> th.Tensor: | |
| from transformers.models.bloom.modeling_bloom import ( | |
| BloomModel, | |
| build_alibi_tensor, | |
| ) | |
| batch_size, seq_len, _ = x.shape | |
| dummy_mask = x.new_ones([batch_size, seq_len]) | |
| # Causal mask isn't created inside the block itself, so we have to do it here. | |
| # Weirdly _prepare_attn_mask doesn't depend on `self` at all but is still an | |
| # instance method for some reason, so we pass `None` as the first argument. | |
| causal_mask = BloomModel._prepare_attn_mask( | |
| None, dummy_mask, (batch_size, seq_len), 0 # type: ignore[arg-type] | |
| ) | |
| alibi = build_alibi_tensor(dummy_mask, self.block.num_heads, x.dtype) | |
| h, *_ = self.block(x, alibi, causal_mask) | |
| return h | |
| class TunedLensOld(th.nn.Module): | |
| """A tuned lens for decoding hidden states into logits.""" | |
| layer_norm: th.nn.LayerNorm | |
| unembedding: th.nn.Linear | |
| extra_layers: th.nn.Sequential | |
| layer_translators: th.nn.ModuleList | |
| def __init__( | |
| self, | |
| model: Optional[PreTrainedModel] = None, | |
| *, | |
| bias: bool = True, | |
| extra_layers: int = 0, | |
| include_input: bool = True, | |
| reuse_unembedding: bool = True, | |
| # Used when saving and loading the lens | |
| model_config: Optional[dict] = None, | |
| d_model: Optional[int] = None, | |
| num_layers: Optional[int] = None, | |
| vocab_size: Optional[int] = None, | |
| ): | |
| """Create a TunedLensOld. | |
| Args: | |
| model : A pertained model from the transformers library you wish to inspect. | |
| bias : Whether to include a bias term in the translator layers. | |
| extra_layers : The number of extra layers to apply to the hidden states | |
| before decoding into logits. | |
| include_input : Whether to include a lens that decodes the word embeddings. | |
| reuse_unembedding : Weather to reuse the unembedding matrix from the model. | |
| model_config : The config of the model. Used for saving and loading. | |
| d_model : The models hidden size. Used for saving and loading. | |
| num_layers : The number of layers in the model. Used for saving and loading. | |
| vocab_size : The size of the vocabulary. Used for saving and loading. | |
| Raises: | |
| ValueError: if neither a model or d_model, num_layers, and vocab_size, | |
| are provided. | |
| """ | |
| super().__init__() | |
| self.extra_layers = th.nn.Sequential() | |
| if ( | |
| model | |
| is None | |
| == (d_model is None or num_layers is None or vocab_size is None) | |
| ): | |
| raise ValueError( | |
| "Must provide either a model or d_model, num_layers, and vocab_size" | |
| ) | |
| # Initializing from scratch without a model | |
| if not model: | |
| assert d_model and num_layers and vocab_size | |
| self.layer_norm = th.nn.LayerNorm(d_model) | |
| self.unembedding = th.nn.Linear(d_model, vocab_size, bias=False) | |
| # Use HuggingFace methods to get decoder layers | |
| else: | |
| assert not (d_model or num_layers or vocab_size) | |
| d_model = model.config.hidden_size | |
| num_layers = model.config.num_hidden_layers | |
| vocab_size = model.config.vocab_size | |
| assert isinstance(d_model, int) and isinstance(vocab_size, int) | |
| model_config = model.config.to_dict() # type: ignore[F841] | |
| # Currently we convert the decoder to full precision | |
| self.unembedding = deepcopy(model.get_output_embeddings()).float() | |
| if ln := get_final_norm(model): | |
| self.layer_norm = deepcopy(ln).float() | |
| else: | |
| self.layer_norm = th.nn.Identity() | |
| if extra_layers: | |
| _, layers = get_transformer_layers(model) | |
| self.extra_layers.extend( | |
| [maybe_wrap(layer) for layer in layers[-extra_layers:]] | |
| ) | |
| # Save config for later | |
| config_keys = set(inspect.getfullargspec(TunedLensOld).kwonlyargs) | |
| self.config = {k: v for k, v in locals().items() if k in config_keys} | |
| del model_config | |
| # Try to prevent finetuning the decoder | |
| assert d_model and num_layers | |
| self.layer_norm.requires_grad_(False) | |
| self.unembedding.requires_grad_(False) | |
| out_features = d_model if reuse_unembedding else vocab_size | |
| translator = th.nn.Linear(d_model, out_features, bias=bias) | |
| if not reuse_unembedding: | |
| translator.weight.data = self.unembedding.weight.data.clone() | |
| translator.bias.data.zero_() | |
| else: | |
| translator.weight.data.zero_() | |
| translator.bias.data.zero_() | |
| self.add_module("input_translator", translator if include_input else None) | |
| # Don't include the final layer | |
| num_layers -= 1 | |
| self.layer_translators = th.nn.ModuleList( | |
| [deepcopy(translator) for _ in range(num_layers)] | |
| ) | |
| def __getitem__(self, item: int) -> th.nn.Module: | |
| """Get the probe module at the given index.""" | |
| if isinstance(self.input_translator, th.nn.Module): | |
| if item == 0: | |
| return self.input_translator | |
| else: | |
| item -= 1 | |
| return self.layer_translators[item] | |
| def __iter__(self) -> Generator[th.nn.Module, None, None]: | |
| """Get iterator over the translators within the lens.""" | |
| if isinstance(self.input_translator, th.nn.Module): | |
| yield self.input_translator | |
| yield from self.layer_translators | |
| def load(cls, resource_id: str, **kwargs) -> "TunedLensOld": | |
| """Load a tuned lens from a or hugging face hub. | |
| Args: | |
| resource_id : The path to the directory containing the config and checkpoint | |
| or the name of the model on the hugging face hub. | |
| **kwargs : Additional arguments to pass to torch.load. | |
| Returns: | |
| A TunedLensOld instance. | |
| """ | |
| config_path, ckpt_path = load_lens_artifacts(resource_id) | |
| # Load config | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| # Load parameters | |
| state = th.load(ckpt_path, **kwargs) | |
| # Backwards compatibility we really need to stop renaming things | |
| keys = list(state.keys()) | |
| for key in keys: | |
| for old_key in ["probe", "adapter"]: | |
| if old_key in key: | |
| warn( | |
| f"Loading a checkpoint with a '{old_key}' key. " | |
| "This is deprecated and may be removed in a future version. " | |
| ) | |
| new_key = key.replace(old_key, "translator") | |
| state[new_key] = state.pop(key) | |
| # Drop unrecognized config keys | |
| unrecognized = set(config) - set(inspect.getfullargspec(cls).kwonlyargs) | |
| for key in unrecognized: | |
| warn(f"Ignoring config key '{key}'") | |
| del config[key] | |
| lens = cls(**config) | |
| if num_extras := config.get("extra_layers"): | |
| # This is sort of a hack but AutoConfig doesn't appear to have a from_dict | |
| # for some reason. | |
| from transformers.models.auto import CONFIG_MAPPING | |
| model_conf_dict = config.get("model_config") | |
| del model_conf_dict["torch_dtype"] | |
| assert model_conf_dict, "Need a 'model_config' entry to load extra layers" | |
| model_type = model_conf_dict["model_type"] | |
| config_cls = CONFIG_MAPPING[model_type] | |
| model_config = config_cls.from_dict(model_conf_dict) | |
| lens.extra_layers = th.nn.Sequential( | |
| *[ | |
| instantiate_layer( | |
| model_config, model_config.num_hidden_layers - i - 1, model_type | |
| ) | |
| for i in range(num_extras) | |
| ] | |
| ) | |
| lens.load_state_dict(state) | |
| return lens | |
| def save( | |
| self, | |
| path: Union[Path, str], | |
| ckpt: str = "params.pt", | |
| config: str = "config.json", | |
| ) -> None: | |
| """Save the lens to a directory. | |
| Args: | |
| path : The path to the directory to save the lens to. | |
| ckpt : The name of the checkpoint file to save the parameters to. | |
| config : The name of the config file to save the config to. | |
| """ | |
| path = Path(path) | |
| path.mkdir(exist_ok=True, parents=True) | |
| th.save(self.state_dict(), path / ckpt) | |
| with open(path / config, "w") as f: | |
| json.dump(self.config, f) | |
| def normalize_(self): | |
| """Canonicalize the transforms by centering their weights and biases.""" | |
| for linear in self: | |
| assert isinstance(linear, th.nn.Linear) | |
| A, b = linear.weight.data, linear.bias.data | |
| A -= A.mean(dim=0, keepdim=True) | |
| b -= b.mean() | |
| def transform_hidden(self, h: th.Tensor, idx: int) -> th.Tensor: | |
| """Transform hidden state from layer `idx`.""" | |
| if not self.config["reuse_unembedding"]: | |
| raise RuntimeError("TunedLensOld.transform_hidden requires reuse_unembedding") | |
| # Note that we add the translator output residually, in contrast to the formula | |
| # in the paper. By parametrizing it this way we ensure that weight decay | |
| # regularizes the transform toward the identity, not the zero transformation. | |
| return h + self[idx](h) | |
| def to_logits(self, h: th.Tensor) -> th.Tensor: | |
| """Decode a hidden state into logits.""" | |
| h = self.extra_layers(h) | |
| while isinstance(h, tuple): | |
| h, *_ = h | |
| return self.unembedding(self.layer_norm(h)) | |
| def forward(self, h: th.Tensor, idx: int) -> th.Tensor: | |
| """Transform and then decode the hidden states into logits.""" | |
| # Sanity check to make sure we don't finetune the decoder | |
| # if any(p.requires_grad for p in self.parameters(recurse=False)): | |
| # raise RuntimeError("Make sure to freeze the decoder") | |
| # We're learning a separate unembedding for each layer | |
| if not self.config["reuse_unembedding"]: | |
| h_ = self.layer_norm(h) | |
| return self[idx](h_) | |
| h = self.transform_hidden(h, idx) | |
| return self.to_logits(h) | |
| def __len__(self) -> int: | |
| """Return the number of layer translators in the lens.""" | |
| N = len(self.layer_translators) | |
| if self.input_translator: | |
| N += 1 | |
| return N | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model", type=str, default="gpt2") | |
| parser.add_argument("--resource-id", type=str, default="gpt2") | |
| parser.add_argument("--output-dir", type=str, default="lens/gpt2") | |
| args = parser.parse_args() | |
| model = AutoModelForCausalLM.from_pretrained(args.model) | |
| revision = model_info(args.model).sha | |
| model.eval() | |
| model.requires_grad_(False) | |
| device = th.device("cuda:0" if th.cuda.is_available() else "cpu") | |
| print("Loading old lens") | |
| tuned_lens_old = TunedLensOld.load(args.resource_id, map_location=device) | |
| print("Initializing new lens") | |
| tuned_lens = TunedLens.from_model( | |
| model, bias=tuned_lens_old.config['bias'], revision=revision | |
| ) | |
| for i in tqdm(range(len(tuned_lens_old)), desc="Copying parameters"): | |
| tuned_lens[i].load_state_dict(tuned_lens_old[i].state_dict()) | |
| tuned_lens = tuned_lens.to(device) | |
| tuned_lens_old = tuned_lens_old.to(device) | |
| model = model.to(device) | |
| # Fuzz the new lens against the old one's | |
| with th.no_grad(): | |
| for i in tqdm(range(len(tuned_lens)), desc="Fuzzing layers"): | |
| for _ in range(10): | |
| a = th.randn(1, 1, tuned_lens.config.d_model, device=device) | |
| logits_new = tuned_lens(a, i) | |
| logits_old = tuned_lens_old(a, i) | |
| log_ps_new = logits_new.log_softmax(-1) | |
| log_ps_old = logits_old.log_softmax(-1) | |
| print("js div", js_divergence(log_ps_new, log_ps_old)) | |
| assert (th.allclose(log_ps_new, log_ps_old, atol=1e-7)), (log_ps_new - log_ps_old).abs().max() | |
| print("Saving new lens to", args.output_dir) | |
| tuned_lens.to(th.device("cpu")).save(args.output_dir) | |