import random from typing import Callable, Any import numpy as np import dsp from dsp.utils import EM, F1, DPR_normalize, dotdict, has_answer, normalize_text class Example(dotdict): """A primitive datatype for representing an example""" demos: list[Any] def __init__(self, *args, **kwargs): assert len(args) <= 1 super().__init__() if args: assert len(args) == 1 self.update(args[0]) self.update(**kwargs) def copy(self, **kwargs): the_copy = Example(**{**dict(self), **kwargs}) return the_copy def without(self, *keys): """Removes the provided keys from the example and returns a copy""" keys = set(keys) return Example({k: v for k, v in self.items() if k not in keys}) def demos_at(self, fn): """Returns a copy of the example with the demos stage transformed by the provided function""" def at(example): try: return fn(example).without("augmented") except Exception: return {} demos = [example.copy(**at(example)) for example in self.demos] return self.copy(demos=demos) def annotate(*transformations): """Returns an Augment function that applies the provided transformations to the Examples""" def do_augment(train, k=None, return_all=False): rdemos = [] ademos = [] for example in train: # tqdm.tqdm raw_example = dsp.Example(example) if (k is not None) and len(ademos) >= k: example = None for f in transformations: if example is None: break example = f(example) if example is not None: example.augmented = True ademos.append(example) else: raw_example.augmented = False rdemos.append(raw_example) if return_all: return ademos + rdemos return ademos return do_augment def sample(train: list[Example], k: int): """Sample k examples from train.""" rng = random.Random(dsp.settings.branch_idx) shuffled_train = [dsp.Example(example) for example in train] rng.shuffle(shuffled_train) return shuffled_train[:k] def all_but(train: list[Example], x: Example) -> list[Example]: """Removes the example x from the train set by comparing the question and history.""" output = [ y for y in train if not set.intersection( set(x.get("history", []) + [x.question]), set(y.get("history", []) + [y.question]), ) ] return output def passage_match(passages: list[str], answers: list[str]) -> bool: """Returns True if any of the passages contains the answer.""" return any(passage_has_answers(psg, answers) for psg in passages) def answer_match(prediction, answers, frac=1.0): # pred = example.prediction # answers = example.answers if frac >= 1.0: return EM(prediction, answers) return F1(prediction, answers) >= frac def passage_has_answers(passage: str, answers: list[str]) -> bool: """Returns True if the passage contains the answer.""" return has_answer( tokenized_answers=[DPR_normalize(normalize_text(ans)) for ans in answers], text=normalize_text(passage), ) def cast_naive_get_only_question_text(inp_example: Example) -> Example: """ Extracts question as a field to vectorize with Vectorizer object. `question` field is used. """ return inp_example.copy(text_to_vectorize=inp_example.question) def cast_naive_get_question_and_answer(inp_example: Example) -> Example: """ Extracts question and answer as fields to vectorize with Vectorizer object. `question` and `answer` fields are used. They will be concatenated with the word "Answer" between. """ text_to_vectorize = ( inp_example.question.strip() + " Answer: " + inp_example.answer.strip() ) return inp_example.copy(text_to_vectorize=text_to_vectorize) def knn( train: list[Example], cast: Callable[[Example], Example] = cast_naive_get_only_question_text, **knn_args ) -> Callable[[Example, int], list[Example]]: """ A function that vectorizes train data using `dsm.settings.vectorizer`, then build an ANN/KNN index to search similar questions among `train` samples. Args: train: a bunch of questions to put in index & search later cast: function that contructs text before vectorization. By default, it uses only question. Check `cast_naive_get_question_and_answer` for more details. n_probe: number of closest IVF-clusters to check for neighbours. Doesn't affect bruteforce-based search. knn_args: check `create_faiss_index` function for details on ANN/KNN arguments. Returns: function to search similar Examples from `train` in FAISS-index. """ from dsp.utils.ann_utils import create_faiss_index train_casted_to_vectorize = [cast(cur_elem) for cur_elem in train] vectorizer: "BaseSentenceVectorizer" = dsp.settings.vectorizer all_vectors = vectorizer(train_casted_to_vectorize).astype(np.float32) index = create_faiss_index( emb_dim=all_vectors.shape[1], n_objects=len(train), **knn_args ) index.train(all_vectors) index.add(all_vectors) def inner_knn_search(inp_example: Example, k: int) -> list[Example]: inp_example_vector = vectorizer([cast(inp_example)]) _, nearest_samples_idxs = index.search(inp_example_vector, k) train_sampled = [train[cur_idx] for cur_idx in nearest_samples_idxs[0]] return train_sampled return inner_knn_search