Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| from typing import Callable, Union | |
| import torch | |
| from torch.nn import Module | |
| class Concept: | |
| r""" | |
| Concepts are human-friendly abstract representations that can be | |
| numerically encoded into torch tensors. They can be illustrated as | |
| images, text or any other form of representation. In case of images, | |
| for example, "stripes" concept can be represented through a number | |
| of example images resembling "stripes" in various different | |
| contexts. In case of Natural Language Processing, the concept of | |
| "happy", for instance, can be illustrated through a number of | |
| adjectives and words that convey happiness. | |
| """ | |
| def __init__( | |
| self, id: int, name: str, data_iter: Union[None, torch.utils.data.DataLoader] | |
| ) -> None: | |
| r""" | |
| Args: | |
| id (int): The unique identifier of the concept. | |
| name (str): A unique name of the concept. | |
| data_iter (DataLoader): A pytorch DataLoader object that combines a dataset | |
| and a sampler, and provides an iterable over a given | |
| dataset. Only the input batches are provided by `data_iter`. | |
| Concept ids can be used as labels if necessary. | |
| For more information, please check: | |
| https://pytorch.org/docs/stable/data.html | |
| Example:: | |
| >>> # Creates a Concept object named "striped", with a data_iter | |
| >>> # object to iterate over all files in "./concepts/striped" | |
| >>> concept_name = "striped" | |
| >>> concept_path = os.path.join("./concepts", concept_name) + "/" | |
| >>> concept_iter = dataset_to_dataloader( | |
| >>> get_tensor_from_filename, concepts_path=concept_path) | |
| >>> concept_object = Concept( | |
| id=0, name=concept_name, data_iter=concept_iter) | |
| """ | |
| self.id = id | |
| self.name = name | |
| self.data_iter = data_iter | |
| def identifier(self) -> str: | |
| return "%s-%s" % (self.name, self.id) | |
| def __repr__(self) -> str: | |
| return "Concept(%r, %r)" % (self.id, self.name) | |
| class ConceptInterpreter: | |
| r""" | |
| An abstract class that exposes an abstract interpret method | |
| that has to be implemented by a specific algorithm for | |
| concept-based model interpretability. | |
| """ | |
| def __init__(self, model: Module) -> None: | |
| r""" | |
| Args: | |
| model (torch.nn.Module): An instance of pytorch model. | |
| """ | |
| self.model = model | |
| interpret: Callable | |
| r""" | |
| An abstract interpret method that performs concept-based model interpretability | |
| and returns the interpretation results in form of tensors, dictionaries or other | |
| data structures. | |
| Args: | |
| inputs (tensor or tuple of tensors): Inputs for which concept-based | |
| interpretation scores are computed. It can be provided as | |
| a single tensor or a tuple of multiple tensors. If multiple | |
| input tensors are provided, the batch size (the first | |
| dimension of the tensors) must be aligned across all tensors. | |
| """ | |