File size: 5,787 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import random
from typing import Callable, Any
import numpy as np
import dsp
from dsp.utils import EM, F1, DPR_normalize, dotdict, has_answer, normalize_text
class Example(dotdict):
"""A primitive datatype for representing an example"""
demos: list[Any]
def __init__(self, *args, **kwargs):
assert len(args) <= 1
super().__init__()
if args:
assert len(args) == 1
self.update(args[0])
self.update(**kwargs)
def copy(self, **kwargs):
the_copy = Example(**{**dict(self), **kwargs})
return the_copy
def without(self, *keys):
"""Removes the provided keys from the example and returns a copy"""
keys = set(keys)
return Example({k: v for k, v in self.items() if k not in keys})
def demos_at(self, fn):
"""Returns a copy of the example with the demos stage transformed by the provided function"""
def at(example):
try:
return fn(example).without("augmented")
except Exception:
return {}
demos = [example.copy(**at(example)) for example in self.demos]
return self.copy(demos=demos)
def annotate(*transformations):
"""Returns an Augment function that applies the provided transformations to the Examples"""
def do_augment(train, k=None, return_all=False):
rdemos = []
ademos = []
for example in train: # tqdm.tqdm
raw_example = dsp.Example(example)
if (k is not None) and len(ademos) >= k:
example = None
for f in transformations:
if example is None:
break
example = f(example)
if example is not None:
example.augmented = True
ademos.append(example)
else:
raw_example.augmented = False
rdemos.append(raw_example)
if return_all:
return ademos + rdemos
return ademos
return do_augment
def sample(train: list[Example], k: int):
"""Sample k examples from train."""
rng = random.Random(dsp.settings.branch_idx)
shuffled_train = [dsp.Example(example) for example in train]
rng.shuffle(shuffled_train)
return shuffled_train[:k]
def all_but(train: list[Example], x: Example) -> list[Example]:
"""Removes the example x from the train set by comparing the question and history."""
output = [
y
for y in train
if not set.intersection(
set(x.get("history", []) + [x.question]),
set(y.get("history", []) + [y.question]),
)
]
return output
def passage_match(passages: list[str], answers: list[str]) -> bool:
"""Returns True if any of the passages contains the answer."""
return any(passage_has_answers(psg, answers) for psg in passages)
def answer_match(prediction, answers, frac=1.0):
# pred = example.prediction
# answers = example.answers
if frac >= 1.0:
return EM(prediction, answers)
return F1(prediction, answers) >= frac
def passage_has_answers(passage: str, answers: list[str]) -> bool:
"""Returns True if the passage contains the answer."""
return has_answer(
tokenized_answers=[DPR_normalize(normalize_text(ans)) for ans in answers],
text=normalize_text(passage),
)
def cast_naive_get_only_question_text(inp_example: Example) -> Example:
"""
Extracts question as a field to vectorize with Vectorizer object. `question` field is used.
"""
return inp_example.copy(text_to_vectorize=inp_example.question)
def cast_naive_get_question_and_answer(inp_example: Example) -> Example:
"""
Extracts question and answer as fields to vectorize with Vectorizer object.
`question` and `answer` fields are used. They will be concatenated with the word "Answer"
between.
"""
text_to_vectorize = (
inp_example.question.strip() + " Answer: " + inp_example.answer.strip()
)
return inp_example.copy(text_to_vectorize=text_to_vectorize)
def knn(
train: list[Example],
cast: Callable[[Example], Example] = cast_naive_get_only_question_text,
**knn_args
) -> Callable[[Example, int], list[Example]]:
"""
A function that vectorizes train data using `dsm.settings.vectorizer`, then build an ANN/KNN
index to search similar questions among `train` samples.
Args:
train: a bunch of questions to put in index & search later
cast: function that contructs text before vectorization. By default,
it uses only question. Check `cast_naive_get_question_and_answer` for more details.
n_probe: number of closest IVF-clusters to check for neighbours.
Doesn't affect bruteforce-based search.
knn_args: check `create_faiss_index` function for details on ANN/KNN arguments.
Returns: function to search similar Examples from `train` in FAISS-index.
"""
from dsp.utils.ann_utils import create_faiss_index
train_casted_to_vectorize = [cast(cur_elem) for cur_elem in train]
vectorizer: "BaseSentenceVectorizer" = dsp.settings.vectorizer
all_vectors = vectorizer(train_casted_to_vectorize).astype(np.float32)
index = create_faiss_index(
emb_dim=all_vectors.shape[1], n_objects=len(train), **knn_args
)
index.train(all_vectors)
index.add(all_vectors)
def inner_knn_search(inp_example: Example, k: int) -> list[Example]:
inp_example_vector = vectorizer([cast(inp_example)])
_, nearest_samples_idxs = index.search(inp_example_vector, k)
train_sampled = [train[cur_idx] for cur_idx in nearest_samples_idxs[0]]
return train_sampled
return inner_knn_search
|