File size: 1,014 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from typing import List
import numpy as np
import dsp

class KNN:
    def __init__(self, k: int, trainset: List[dsp.Example]):
        self.k = k
        self.trainset = trainset
        self.vectorizer = dsp.SentenceTransformersVectorizer()
        trainset_casted_to_vectorize = [" | ".join([f"{key}: {value}" for key, value in example.items() if key in example._input_keys]) for example in self.trainset]
        self.trainset_vectors = self.vectorizer(trainset_casted_to_vectorize).astype(np.float32)

    def __call__(self, **kwargs) -> List[dsp.Example]:
        with dsp.settings.context(vectorizer=self.vectorizer):
            input_example_vector = self.vectorizer([" | ".join([f"{key}: {val}" for key, val in kwargs.items()])])
            scores = np.dot(self.trainset_vectors, input_example_vector.T).squeeze()
            nearest_samples_idxs = scores.argsort()[-self.k:][::-1]
            train_sampled = [self.trainset[cur_idx] for cur_idx in nearest_samples_idxs]
            return train_sampled