File size: 5,787 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import random
from typing import Callable, Any

import numpy as np

import dsp
from dsp.utils import EM, F1, DPR_normalize, dotdict, has_answer, normalize_text


class Example(dotdict):
    """A primitive datatype for representing an example"""

    demos: list[Any]

    def __init__(self, *args, **kwargs):
        assert len(args) <= 1
        super().__init__()

        if args:
            assert len(args) == 1
            self.update(args[0])

        self.update(**kwargs)

    def copy(self, **kwargs):
        the_copy = Example(**{**dict(self), **kwargs})

        return the_copy

    def without(self, *keys):
        """Removes the provided keys from the example and returns a copy"""
        keys = set(keys)
        return Example({k: v for k, v in self.items() if k not in keys})

    def demos_at(self, fn):
        """Returns a copy of the example with the demos stage transformed by the provided function"""

        def at(example):
            try:
                return fn(example).without("augmented")
            except Exception:
                return {}

        demos = [example.copy(**at(example)) for example in self.demos]
        return self.copy(demos=demos)


def annotate(*transformations):
    """Returns an Augment function that applies the provided transformations to the Examples"""

    def do_augment(train, k=None, return_all=False):
        rdemos = []
        ademos = []

        for example in train:  # tqdm.tqdm
            raw_example = dsp.Example(example)

            if (k is not None) and len(ademos) >= k:
                example = None

            for f in transformations:
                if example is None:
                    break

                example = f(example)

            if example is not None:
                example.augmented = True
                ademos.append(example)
            else:
                raw_example.augmented = False
                rdemos.append(raw_example)

        if return_all:
            return ademos + rdemos

        return ademos

    return do_augment


def sample(train: list[Example], k: int):
    """Sample k examples from train."""
    rng = random.Random(dsp.settings.branch_idx)
    shuffled_train = [dsp.Example(example) for example in train]
    rng.shuffle(shuffled_train)

    return shuffled_train[:k]


def all_but(train: list[Example], x: Example) -> list[Example]:
    """Removes the example x from the train set by comparing the question and history."""

    output = [
        y
        for y in train
        if not set.intersection(
            set(x.get("history", []) + [x.question]),
            set(y.get("history", []) + [y.question]),
        )
    ]

    return output


def passage_match(passages: list[str], answers: list[str]) -> bool:
    """Returns True if any of the passages contains the answer."""
    return any(passage_has_answers(psg, answers) for psg in passages)


def answer_match(prediction, answers, frac=1.0):
    # pred = example.prediction
    # answers = example.answers

    if frac >= 1.0:
        return EM(prediction, answers)

    return F1(prediction, answers) >= frac


def passage_has_answers(passage: str, answers: list[str]) -> bool:
    """Returns True if the passage contains the answer."""
    return has_answer(
        tokenized_answers=[DPR_normalize(normalize_text(ans)) for ans in answers],
        text=normalize_text(passage),
    )


def cast_naive_get_only_question_text(inp_example: Example) -> Example:
    """
    Extracts question as a field to vectorize with Vectorizer object. `question` field is used.
    """
    return inp_example.copy(text_to_vectorize=inp_example.question)


def cast_naive_get_question_and_answer(inp_example: Example) -> Example:
    """
    Extracts question and answer as fields to vectorize with Vectorizer object.
    `question` and `answer` fields are used. They will be concatenated with the word "Answer"
    between.
    """
    text_to_vectorize = (
        inp_example.question.strip() + " Answer: " + inp_example.answer.strip()
    )
    return inp_example.copy(text_to_vectorize=text_to_vectorize)


def knn(
    train: list[Example],
    cast: Callable[[Example], Example] = cast_naive_get_only_question_text,
    **knn_args
) -> Callable[[Example, int], list[Example]]:
    """
    A function that vectorizes train data using `dsm.settings.vectorizer`, then build an ANN/KNN
    index to search similar questions among `train` samples.

    Args:
        train: a bunch of questions to put in index & search later
        cast: function that contructs text before vectorization. By default,
            it uses only question. Check `cast_naive_get_question_and_answer` for more details.
        n_probe: number of closest IVF-clusters to check for neighbours.
            Doesn't affect bruteforce-based search.
        knn_args: check `create_faiss_index` function for details on ANN/KNN arguments.
    Returns: function to search similar Examples from `train` in FAISS-index.
    """
    from dsp.utils.ann_utils import create_faiss_index

    train_casted_to_vectorize = [cast(cur_elem) for cur_elem in train]

    vectorizer: "BaseSentenceVectorizer" = dsp.settings.vectorizer
    all_vectors = vectorizer(train_casted_to_vectorize).astype(np.float32)

    index = create_faiss_index(
        emb_dim=all_vectors.shape[1], n_objects=len(train), **knn_args
    )
    index.train(all_vectors)
    index.add(all_vectors)

    def inner_knn_search(inp_example: Example, k: int) -> list[Example]:
        inp_example_vector = vectorizer([cast(inp_example)])
        _, nearest_samples_idxs = index.search(inp_example_vector, k)
        train_sampled = [train[cur_idx] for cur_idx in nearest_samples_idxs[0]]
        return train_sampled

    return inner_knn_search