File size: 2,867 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from dspy.primitives.example import Example
class Prediction(Example):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
del self._demos
del self._input_keys
self._completions = None
@classmethod
def from_completions(cls, list_or_dict, signature=None):
obj = cls()
obj._completions = Completions(list_or_dict, signature=signature)
obj._store = {k: v[0] for k, v in obj._completions.items()}
return obj
def __repr__(self):
store_repr = ',\n '.join(f"{k}={repr(v)}" for k, v in self._store.items())
if self._completions is None or len(self._completions) == 1:
return f"Prediction(\n {store_repr}\n)"
num_completions = len(self._completions)
return f"Prediction(\n {store_repr},\n completions=Completions(...)\n) ({num_completions-1} completions omitted)"
def __str__(self):
return self.__repr__()
@property
def completions(self):
return self._completions
class Completions:
def __init__(self, list_or_dict, signature=None):
self.signature = signature
if isinstance(list_or_dict, list):
kwargs = {}
for arg in list_or_dict:
for k, v in arg.items():
kwargs.setdefault(k, []).append(v)
else:
kwargs = list_or_dict
assert all(isinstance(v, list) for v in kwargs.values()), "All values must be lists"
if kwargs:
length = len(next(iter(kwargs.values())))
assert all(len(v) == length for v in kwargs.values()), "All lists must have the same length"
self._completions = kwargs
def items(self):
return self._completions.items()
def __getitem__(self, key):
if isinstance(key, int):
if key < 0 or key >= len(self):
raise IndexError("Index out of range")
return Prediction(**{k: v[key] for k, v in self._completions.items()})
return self._completions[key]
def __getattr__(self, name):
if name in self._completions:
return self._completions[name]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
def __len__(self):
# Return the length of the list for one of the keys
# It assumes all lists have the same length
return len(next(iter(self._completions.values())))
def __contains__(self, key):
return key in self._completions
def __repr__(self):
items_repr = ',\n '.join(f"{k}={repr(v)}" for k, v in self._completions.items())
return f"Completions(\n {items_repr}\n)"
def __str__(self):
# return str(self._completions)
return self.__repr__()
|