EureCA / dspy /teleprompt /vanilla.py
tonneli's picture
Delete history
f5776d3
raw
history blame contribute delete
961 Bytes
import dsp
import random
from .teleprompt import Teleprompter
class LabeledFewShot(Teleprompter):
def __init__(self, k=16):
self.k = k
def compile(self, student, *, trainset, sample=True):
self.student = student.reset_copy()
self.trainset = trainset
if len(self.trainset) == 0:
return self.student
rng = random.Random(0)
for predictor in self.student.predictors():
if sample:
predictor.demos = rng.sample(self.trainset, min(self.k, len(self.trainset)))
else:
predictor.demos = self.trainset[:min(self.k, len(self.trainset))]
return self.student
# NOTE: I believe templatev2 keeps rdemos as long as they have the last field.
# This may change later, especially with the introduction of required vs optional fields.
# NOTE: Since we're relying on downstream code to handle the demos, this sampling may be sub-sampled.