EureCA / dsp /templates /template_v2.py
tonneli's picture
Delete history
f5776d3
from collections import namedtuple
import re
from typing import Union, Any
import dsp
from dsp.primitives.demonstrate import Example
from .utils import passages2text, format_answers
Field = namedtuple("Field", "name separator input_variable output_variable description")
# TODO: de-duplicate with dsp/templates/template.py
class TemplateV2:
def __init__(
self,
template,
format_handlers={
"passages": passages2text,
"contexte": passages2text,
"reponse": format_answers,
"reponses": format_answers,
},
):
self.format_handlers = format_handlers
template = template.strip()
self.instructions = re.search("(.*)\n", template).group(1)
template = template[len(self.instructions) :].strip()
self.fields = []
while len(template) > 0:
match = re.search("(.*)(\s){(.*)}\s(.*\${.*})", template)
if match is not None:
name = match.group(1)
separator = match.group(2)
variable = match.group(3)
description = match.group(4)
else:
match = re.search("(.*)(\s){(.*)}", template)
if match is not None:
name = match.group(1)
separator = match.group(2)
variable = match.group(3)
description = None
else:
raise ValueError(f"Could not parse template")
var_match = re.match("(.*) -> (.*)", variable)
if var_match is not None:
input_variable = var_match.group(1)
output_variable = var_match.group(2)
else:
input_variable = variable
output_variable = variable
self.fields.append(
Field(
name=name,
separator=separator,
input_variable=input_variable,
output_variable=output_variable,
description=description,
)
)
template = template[len(match.group(0)) :].strip()
def query(self, example: Example, is_demo: bool = False) -> str:
"""Retrieves the input variables from the example and formats them into a query string."""
result: list[str] = []
if not is_demo:
has_value = [
field.input_variable in example
and example[field.input_variable] is not None
and example[field.input_variable] != ""
for field in self.fields
]
for i in range(1, len(has_value)):
if has_value[i - 1] and not any(has_value[i:]):
example[self.fields[i].input_variable] = ""
break
for field in self.fields:
if (
field.input_variable in example
and example[field.input_variable] is not None
):
if field.input_variable in self.format_handlers:
format_handler = self.format_handlers[field.input_variable]
else:
def format_handler(x):
return " ".join(x.split())
formatted_value = format_handler(example[field.input_variable])
separator = '\n' if field.separator == ' ' and '\n' in formatted_value else field.separator
result.append(
f"{field.name}{separator}{formatted_value}"
)
if self._has_augmented_guidelines() and ("augmented" in example and example.augmented):
return "\n\n".join([r for r in result if r])
return "\n".join([r for r in result if r])
def guidelines(self, show_guidelines=True) -> str:
"""Returns the task guidelines as described in the lm prompt"""
if (not show_guidelines) or (
hasattr(dsp.settings, "show_guidelines")
and not dsp.settings.show_guidelines
):
return ""
result = "Respecte le format suivant.\n\n"
example = dsp.Example()
for field in self.fields:
example[field.input_variable] = field.description
example.augmented = self._has_augmented_guidelines()
result += self.query(example)
return result
def _has_augmented_guidelines(self):
return len(self.fields) > 3 or any(
("\n" in field.separator) or ('\n' in field.description) for field in self.fields
)
def extract(
self, example: Union[Example, dict[str, Any]], raw_pred: str
) -> Example:
"""Extracts the answer from the LM raw prediction using the template structure
Args:
example (Union[Example, dict[str, Any]]): Contains the input variables that raw_pred was completed on.
raw_pred (str): LM generated string
Returns:
Example: The example with the output variables filled in
"""
example = dsp.Example(example)
raw_pred = raw_pred.strip()
idx = 0
while idx < len(self.fields):
if (
self.fields[idx].input_variable not in example
or example[self.fields[idx].input_variable] is None
):
break
idx += 1
import dspy
idx = min(idx, len(self.fields) - 1)
while raw_pred != "" and idx < len(self.fields):
if idx < len(self.fields) - 1:
next_field_name = "\n" + self.fields[idx + 1].name
offset = raw_pred.find(next_field_name)
if offset >= 0:
if dspy.settings.release >= 20231003:
example[self.fields[idx].output_variable] = raw_pred[:offset].strip().rstrip('---').strip()
raw_pred = raw_pred[offset + len(next_field_name) :].strip().rstrip('---').strip()
else:
example[self.fields[idx].output_variable] = raw_pred[:offset].strip()
raw_pred = raw_pred[offset + len(next_field_name) :].strip()
idx += 1
else:
if dspy.settings.release >= 20231003:
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip('---').strip()
else:
example[self.fields[idx].output_variable] = raw_pred.strip()
raw_pred = ""
idx += 1
break
else:
assert idx == len(self.fields) - 1, (idx, len(self.fields))
if dspy.settings.release >= 20231003:
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip('---').strip()
else:
example[self.fields[idx].output_variable] = raw_pred.strip()
break
return example
def __call__(self, example, show_guidelines=True) -> str:
example = dsp.Example(example)
if hasattr(dsp.settings, 'query_only') and dsp.settings.query_only:
return self.query(example)
# The training data should not contain the output variable
if self.fields[-1].input_variable in example:
del example[self.fields[-1].input_variable]
rdemos = [
self.query(demo, is_demo=True)
for demo in example.demos
if (
("augmented" not in demo or not demo.augmented)
and ( # validate that the training example has the same primitive input var as the template
self.fields[-1].input_variable in demo
and demo[self.fields[-1].input_variable] is not None
)
)
]
ademos = [
self.query(demo, is_demo=True)
for demo in example.demos
if "augmented" in demo and demo.augmented
]
# Move the rdemos to ademos if rdemo has all the fields filled in
rdemos_ = []
new_ademos = []
for rdemo in rdemos:
if all(
(field.name in rdemo)
for field in self.fields
if field.input_variable in example
):
import dspy
if dspy.settings.release >= 20230928:
new_ademos.append(rdemo)
else:
ademos.append(rdemo)
else:
rdemos_.append(rdemo)
ademos = new_ademos + ademos
rdemos = rdemos_
long_query = self._has_augmented_guidelines()
if long_query:
example["augmented"] = True
query = self.query(example)
# if it has more lines than fields
if len(query.split('\n')) > len(self.fields):
long_query = True
if "augmented" not in example or not example.augmented:
example["augmented"] = True
query = self.query(example)
rdemos = "\n\n".join(rdemos)
if len(rdemos) >= 1 and len(ademos) == 0 and not long_query:
rdemos_and_query = "\n\n".join([rdemos, query])
parts = [
self.instructions,
self.guidelines(show_guidelines),
rdemos_and_query,
]
elif len(rdemos) == 0:
parts = [
self.instructions,
self.guidelines(show_guidelines),
*ademos,
query,
]
else:
parts = [
self.instructions,
rdemos,
self.guidelines(show_guidelines),
*ademos,
query,
]
prompt = "\n\n---\n\n".join([p.strip() for p in parts if p])
return prompt.strip()