File size: 805 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
from typing import List
import types
import dsp
from .teleprompt import Teleprompter
from dspy.teleprompt import BootstrapFewShot
class KNNFewShot(Teleprompter):
def __init__(self, KNN, k: int, trainset: List[dsp.Example]):
self.KNN = KNN(k, trainset)
def compile(self, student, *, teacher=None, trainset, valset=None):
student_copy = student.reset_copy()
def forward_pass(*args, **kwargs):
knn_trainset = self.KNN(**kwargs)
few_shot_bootstrap = BootstrapFewShot()
compiled_program = few_shot_bootstrap.compile(student, teacher=teacher, trainset=knn_trainset, valset=valset)
return compiled_program(**kwargs)
student_copy.forward = types.MethodType(forward_pass, student_copy)
return student_copy |