Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| import warnings | |
| from functools import reduce | |
| import torch | |
| from torch.nn import Module | |
| class InterpretableEmbeddingBase(Module): | |
| r""" | |
| Since some embedding vectors, e.g. word are created and assigned in | |
| the embedding layers of Pytorch models we need a way to access | |
| those layers, generate the embeddings and subtract the baseline. | |
| To do so, we separate embedding layers from the model, compute the | |
| embeddings separately and do all operations needed outside of the model. | |
| The original embedding layer is being replaced by | |
| `InterpretableEmbeddingBase` layer which passes already | |
| precomputed embedding vectors to the layers below. | |
| """ | |
| def __init__(self, embedding, full_name) -> None: | |
| Module.__init__(self) | |
| self.num_embeddings = getattr(embedding, "num_embeddings", None) | |
| self.embedding_dim = getattr(embedding, "embedding_dim", None) | |
| self.embedding = embedding | |
| self.full_name = full_name | |
| def forward(self, *inputs, **kwargs): | |
| r""" | |
| The forward function of a wrapper embedding layer that takes and returns | |
| embedding layer. It allows embeddings to be created outside of the model | |
| and passes them seamlessly to the preceding layers of the model. | |
| Args: | |
| *inputs (Any, optional): A sequence of inputs arguments that the | |
| forward function takes. Since forward functions can take any | |
| type and number of arguments, this will ensure that we can | |
| execute the forward pass using interpretable embedding layer. | |
| Note that if inputs are specified, it is assumed that the first | |
| argument is the embedding tensor generated using the | |
| `self.embedding` layer using all input arguments provided in | |
| `inputs` and `kwargs`. | |
| **kwargs (Any, optional): Similar to `inputs` we want to make sure | |
| that our forward pass supports arbitrary number and type of | |
| key-value arguments. If `inputs` is not provided, `kwargs` must | |
| be provided and the first argument corresponds to the embedding | |
| tensor generated using the `self.embedding`. Note that we make | |
| here an assumption here that `kwargs` is an ordered dict which | |
| is new in python 3.6 and is not guaranteed that it will | |
| consistently remain that way in the newer versions. In case | |
| current implementation doesn't work for special use cases, | |
| it is encouraged to override `InterpretableEmbeddingBase` and | |
| address those specifics in descendant classes. | |
| Returns: | |
| embedding_tensor (Tensor): | |
| Returns a tensor which is the same as first argument passed | |
| to the forward function. | |
| It passes pre-computed embedding tensors to lower layers | |
| without any modifications. | |
| """ | |
| assert len(inputs) > 0 or len(kwargs) > 0, ( | |
| "No input arguments are provided to `InterpretableEmbeddingBase`." | |
| "Input embedding tensor has to be provided as first argument to forward " | |
| "function either through inputs argument or kwargs." | |
| ) | |
| return inputs[0] if len(inputs) > 0 else list(kwargs.values())[0] | |
| def indices_to_embeddings(self, *input, **kwargs): | |
| r""" | |
| Maps indices to corresponding embedding vectors. E.g. word embeddings | |
| Args: | |
| *input (Any, Optional): This can be a tensor(s) of input indices or any | |
| other variable necessary to comput the embeddings. A typical | |
| example of input indices are word or token indices. | |
| **kwargs (Any, optional): Similar to `input` this can be any sequence | |
| of key-value arguments necessary to compute final embedding | |
| tensor. | |
| Returns: | |
| tensor: | |
| A tensor of word embeddings corresponding to the | |
| indices specified in the input | |
| """ | |
| return self.embedding(*input, **kwargs) | |
| class TokenReferenceBase: | |
| r""" | |
| A base class for creating reference (aka baseline) tensor for a sequence of | |
| tokens. A typical example of such token is `PAD`. Users need to provide the | |
| index of the reference token in the vocabulary as an argument to | |
| `TokenReferenceBase` class. | |
| """ | |
| def __init__(self, reference_token_idx=0) -> None: | |
| self.reference_token_idx = reference_token_idx | |
| def generate_reference(self, sequence_length, device): | |
| r""" | |
| Generated reference tensor of given `sequence_length` using | |
| `reference_token_idx`. | |
| Args: | |
| sequence_length (int): The length of the reference sequence | |
| device (torch.device): The device on which the reference tensor will | |
| be created. | |
| Returns: | |
| tensor: | |
| A sequence of reference token with shape: | |
| [sequence_length] | |
| """ | |
| return torch.tensor([self.reference_token_idx] * sequence_length, device=device) | |
| def _get_deep_layer_name(obj, layer_names): | |
| r""" | |
| Traverses through the layer names that are separated by | |
| dot in order to access the embedding layer. | |
| """ | |
| return reduce(getattr, layer_names.split("."), obj) | |
| def _set_deep_layer_value(obj, layer_names, value): | |
| r""" | |
| Traverses through the layer names that are separated by | |
| dot in order to access the embedding layer and update its value. | |
| """ | |
| layer_names = layer_names.split(".") | |
| setattr(reduce(getattr, layer_names[:-1], obj), layer_names[-1], value) | |
| def configure_interpretable_embedding_layer(model, embedding_layer_name="embedding"): | |
| r""" | |
| This method wraps model's embedding layer with an interpretable embedding | |
| layer that allows us to access the embeddings through their indices. | |
| Args: | |
| model (torch.nn.Model): An instance of PyTorch model that contains embeddings. | |
| embedding_layer_name (str, optional): The name of the embedding layer | |
| in the `model` that we would like to make interpretable. | |
| Returns: | |
| interpretable_emb (tensor): An instance of `InterpretableEmbeddingBase` | |
| embedding layer that wraps model's embedding layer that is being | |
| accessed through `embedding_layer_name`. | |
| Examples:: | |
| >>> # Let's assume that we have a DocumentClassifier model that | |
| >>> # has a word embedding layer named 'embedding'. | |
| >>> # To make that layer interpretable we need to execute the | |
| >>> # following command: | |
| >>> net = DocumentClassifier() | |
| >>> interpretable_emb = configure_interpretable_embedding_layer(net, | |
| >>> 'embedding') | |
| >>> # then we can use interpretable embedding to convert our | |
| >>> # word indices into embeddings. | |
| >>> # Let's assume that we have the following word indices | |
| >>> input_indices = torch.tensor([1, 0, 2]) | |
| >>> # we can access word embeddings for those indices with the command | |
| >>> # line stated below. | |
| >>> input_emb = interpretable_emb.indices_to_embeddings(input_indices) | |
| >>> # Let's assume that we want to apply integrated gradients to | |
| >>> # our model and that target attribution class is 3 | |
| >>> ig = IntegratedGradients(net) | |
| >>> attribution = ig.attribute(input_emb, target=3) | |
| >>> # after we finish the interpretation we need to remove | |
| >>> # interpretable embedding layer with the following command: | |
| >>> remove_interpretable_embedding_layer(net, interpretable_emb) | |
| """ | |
| embedding_layer = _get_deep_layer_name(model, embedding_layer_name) | |
| assert ( | |
| embedding_layer.__class__ is not InterpretableEmbeddingBase | |
| ), "InterpretableEmbeddingBase has already been configured for layer {}".format( | |
| embedding_layer_name | |
| ) | |
| warnings.warn( | |
| "In order to make embedding layers more interpretable they will " | |
| "be replaced with an interpretable embedding layer which wraps the " | |
| "original embedding layer and takes word embedding vectors as inputs of " | |
| "the forward function. This allows us to generate baselines for word " | |
| "embeddings and compute attributions for each embedding dimension. " | |
| "The original embedding layer must be set " | |
| "back by calling `remove_interpretable_embedding_layer` function " | |
| "after model interpretation is finished. " | |
| ) | |
| interpretable_emb = InterpretableEmbeddingBase( | |
| embedding_layer, embedding_layer_name | |
| ) | |
| _set_deep_layer_value(model, embedding_layer_name, interpretable_emb) | |
| return interpretable_emb | |
| def remove_interpretable_embedding_layer(model, interpretable_emb): | |
| r""" | |
| Removes interpretable embedding layer and sets back original | |
| embedding layer in the model. | |
| Args: | |
| model (torch.nn.Module): An instance of PyTorch model that contains embeddings | |
| interpretable_emb (tensor): An instance of `InterpretableEmbeddingBase` | |
| that was originally created in | |
| `configure_interpretable_embedding_layer` function and has | |
| to be removed after interpretation is finished. | |
| Examples:: | |
| >>> # Let's assume that we have a DocumentClassifier model that | |
| >>> # has a word embedding layer named 'embedding'. | |
| >>> # To make that layer interpretable we need to execute the | |
| >>> # following command: | |
| >>> net = DocumentClassifier() | |
| >>> interpretable_emb = configure_interpretable_embedding_layer(net, | |
| >>> 'embedding') | |
| >>> # then we can use interpretable embedding to convert our | |
| >>> # word indices into embeddings. | |
| >>> # Let's assume that we have the following word indices | |
| >>> input_indices = torch.tensor([1, 0, 2]) | |
| >>> # we can access word embeddings for those indices with the command | |
| >>> # line stated below. | |
| >>> input_emb = interpretable_emb.indices_to_embeddings(input_indices) | |
| >>> # Let's assume that we want to apply integrated gradients to | |
| >>> # our model and that target attribution class is 3 | |
| >>> ig = IntegratedGradients(net) | |
| >>> attribution = ig.attribute(input_emb, target=3) | |
| >>> # after we finish the interpretation we need to remove | |
| >>> # interpretable embedding layer with the following command: | |
| >>> remove_interpretable_embedding_layer(net, interpretable_emb) | |
| """ | |
| _set_deep_layer_value( | |
| model, interpretable_emb.full_name, interpretable_emb.embedding | |
| ) | |