|
from typing import List |
|
import numpy as np |
|
import dsp |
|
|
|
class KNN: |
|
def __init__(self, k: int, trainset: List[dsp.Example]): |
|
self.k = k |
|
self.trainset = trainset |
|
self.vectorizer = dsp.SentenceTransformersVectorizer() |
|
trainset_casted_to_vectorize = [" | ".join([f"{key}: {value}" for key, value in example.items() if key in example._input_keys]) for example in self.trainset] |
|
self.trainset_vectors = self.vectorizer(trainset_casted_to_vectorize).astype(np.float32) |
|
|
|
def __call__(self, **kwargs) -> List[dsp.Example]: |
|
with dsp.settings.context(vectorizer=self.vectorizer): |
|
input_example_vector = self.vectorizer([" | ".join([f"{key}: {val}" for key, val in kwargs.items()])]) |
|
scores = np.dot(self.trainset_vectors, input_example_vector.T).squeeze() |
|
nearest_samples_idxs = scores.argsort()[-self.k:][::-1] |
|
train_sampled = [self.trainset[cur_idx] for cur_idx in nearest_samples_idxs] |
|
return train_sampled |
|
|