File size: 2,308 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from typing import Callable
from dsp.templates import TemplateV2, passages2text, format_answers, Field
class Type:
"""A primitive datatype that defines and represents a prompt label."""
def __init__(self, prefix: str, desc: str, format=None) -> None:
self.prefix = prefix
self.desc = desc
self.format = format
def __call__(self, **kwargs):
kwargs = {**self.__dict__, **kwargs}
return Type(**kwargs)
def __eq__(self, __value: object) -> bool:
return isinstance(__value, Type) and self.__dict__ == __value.__dict__
class Template(TemplateV2):
"""A template datatype that represents the structure of communicate with the LM."""
def __init__(self, instructions: str, **kwargs):
self.instructions = instructions
self.kwargs = kwargs
self.fields: list[Field] = []
self.format_handlers: dict[str, Callable] = {
"contexte": passages2text,
"passages": passages2text,
"reponses": format_answers,
}
for key, value in kwargs.items():
prefix: str = value.prefix
separator: str = (
" " if prefix.rstrip() == prefix and len(prefix) > 0 else prefix[len(prefix.rstrip()) :]
)
field = Field(
name=prefix.strip(),
description=value.desc,
input_variable=key,
output_variable=key,
separator=separator,
)
self.fields.append(field)
if value.format:
self.format_handlers[key] = value.format
# equality
def __eq__(self, other):
if set(self.kwargs.keys()) != set(other.kwargs.keys()):
return False
for k in self.kwargs.keys():
v1, v2 = self.kwargs[k], other.kwargs[k]
if not v1 == v2:
print(k, v1, v2)
# print("here?", self.instructions == other.instructions, self.kwargs == other.kwargs)
return self.instructions == other.instructions and self.kwargs == other.kwargs
def __str__(self) -> str:
# field names
field_names = [field.name for field in self.fields]
return f"Template({self.instructions}, {field_names})"
|