|
from collections import namedtuple |
|
import re |
|
from typing import Union, Any |
|
import dsp |
|
from dsp.primitives.demonstrate import Example |
|
from .utils import passages2text, format_answers |
|
|
|
Field = namedtuple("Field", "name separator input_variable output_variable description") |
|
|
|
|
|
|
|
|
|
class TemplateV2: |
|
def __init__( |
|
self, |
|
template, |
|
format_handlers={ |
|
"passages": passages2text, |
|
"contexte": passages2text, |
|
"reponse": format_answers, |
|
"reponses": format_answers, |
|
}, |
|
): |
|
self.format_handlers = format_handlers |
|
|
|
template = template.strip() |
|
|
|
self.instructions = re.search("(.*)\n", template).group(1) |
|
template = template[len(self.instructions) :].strip() |
|
|
|
self.fields = [] |
|
while len(template) > 0: |
|
match = re.search("(.*)(\s){(.*)}\s(.*\${.*})", template) |
|
if match is not None: |
|
name = match.group(1) |
|
separator = match.group(2) |
|
variable = match.group(3) |
|
description = match.group(4) |
|
else: |
|
match = re.search("(.*)(\s){(.*)}", template) |
|
if match is not None: |
|
name = match.group(1) |
|
separator = match.group(2) |
|
variable = match.group(3) |
|
description = None |
|
else: |
|
raise ValueError(f"Could not parse template") |
|
|
|
var_match = re.match("(.*) -> (.*)", variable) |
|
if var_match is not None: |
|
input_variable = var_match.group(1) |
|
output_variable = var_match.group(2) |
|
else: |
|
input_variable = variable |
|
output_variable = variable |
|
|
|
self.fields.append( |
|
Field( |
|
name=name, |
|
separator=separator, |
|
input_variable=input_variable, |
|
output_variable=output_variable, |
|
description=description, |
|
) |
|
) |
|
|
|
template = template[len(match.group(0)) :].strip() |
|
|
|
def query(self, example: Example, is_demo: bool = False) -> str: |
|
"""Retrieves the input variables from the example and formats them into a query string.""" |
|
result: list[str] = [] |
|
|
|
if not is_demo: |
|
has_value = [ |
|
field.input_variable in example |
|
and example[field.input_variable] is not None |
|
and example[field.input_variable] != "" |
|
for field in self.fields |
|
] |
|
|
|
for i in range(1, len(has_value)): |
|
if has_value[i - 1] and not any(has_value[i:]): |
|
example[self.fields[i].input_variable] = "" |
|
break |
|
|
|
for field in self.fields: |
|
if ( |
|
field.input_variable in example |
|
and example[field.input_variable] is not None |
|
): |
|
if field.input_variable in self.format_handlers: |
|
format_handler = self.format_handlers[field.input_variable] |
|
else: |
|
|
|
def format_handler(x): |
|
return " ".join(x.split()) |
|
|
|
formatted_value = format_handler(example[field.input_variable]) |
|
separator = '\n' if field.separator == ' ' and '\n' in formatted_value else field.separator |
|
|
|
result.append( |
|
f"{field.name}{separator}{formatted_value}" |
|
) |
|
|
|
if self._has_augmented_guidelines() and ("augmented" in example and example.augmented): |
|
return "\n\n".join([r for r in result if r]) |
|
return "\n".join([r for r in result if r]) |
|
|
|
def guidelines(self, show_guidelines=True) -> str: |
|
"""Returns the task guidelines as described in the lm prompt""" |
|
if (not show_guidelines) or ( |
|
hasattr(dsp.settings, "show_guidelines") |
|
and not dsp.settings.show_guidelines |
|
): |
|
return "" |
|
|
|
result = "Respecte le format suivant.\n\n" |
|
|
|
example = dsp.Example() |
|
for field in self.fields: |
|
example[field.input_variable] = field.description |
|
example.augmented = self._has_augmented_guidelines() |
|
|
|
result += self.query(example) |
|
return result |
|
|
|
def _has_augmented_guidelines(self): |
|
return len(self.fields) > 3 or any( |
|
("\n" in field.separator) or ('\n' in field.description) for field in self.fields |
|
) |
|
|
|
def extract( |
|
self, example: Union[Example, dict[str, Any]], raw_pred: str |
|
) -> Example: |
|
"""Extracts the answer from the LM raw prediction using the template structure |
|
|
|
Args: |
|
example (Union[Example, dict[str, Any]]): Contains the input variables that raw_pred was completed on. |
|
raw_pred (str): LM generated string |
|
|
|
Returns: |
|
Example: The example with the output variables filled in |
|
""" |
|
example = dsp.Example(example) |
|
|
|
raw_pred = raw_pred.strip() |
|
|
|
idx = 0 |
|
while idx < len(self.fields): |
|
if ( |
|
self.fields[idx].input_variable not in example |
|
or example[self.fields[idx].input_variable] is None |
|
): |
|
break |
|
idx += 1 |
|
|
|
import dspy |
|
|
|
idx = min(idx, len(self.fields) - 1) |
|
while raw_pred != "" and idx < len(self.fields): |
|
if idx < len(self.fields) - 1: |
|
next_field_name = "\n" + self.fields[idx + 1].name |
|
offset = raw_pred.find(next_field_name) |
|
|
|
if offset >= 0: |
|
if dspy.settings.release >= 20231003: |
|
example[self.fields[idx].output_variable] = raw_pred[:offset].strip().rstrip('---').strip() |
|
raw_pred = raw_pred[offset + len(next_field_name) :].strip().rstrip('---').strip() |
|
else: |
|
example[self.fields[idx].output_variable] = raw_pred[:offset].strip() |
|
raw_pred = raw_pred[offset + len(next_field_name) :].strip() |
|
|
|
idx += 1 |
|
else: |
|
if dspy.settings.release >= 20231003: |
|
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip('---').strip() |
|
else: |
|
example[self.fields[idx].output_variable] = raw_pred.strip() |
|
|
|
raw_pred = "" |
|
idx += 1 |
|
break |
|
|
|
else: |
|
assert idx == len(self.fields) - 1, (idx, len(self.fields)) |
|
|
|
if dspy.settings.release >= 20231003: |
|
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip('---').strip() |
|
else: |
|
example[self.fields[idx].output_variable] = raw_pred.strip() |
|
|
|
break |
|
|
|
return example |
|
|
|
def __call__(self, example, show_guidelines=True) -> str: |
|
example = dsp.Example(example) |
|
|
|
if hasattr(dsp.settings, 'query_only') and dsp.settings.query_only: |
|
return self.query(example) |
|
|
|
|
|
if self.fields[-1].input_variable in example: |
|
del example[self.fields[-1].input_variable] |
|
|
|
rdemos = [ |
|
self.query(demo, is_demo=True) |
|
for demo in example.demos |
|
if ( |
|
("augmented" not in demo or not demo.augmented) |
|
and ( |
|
self.fields[-1].input_variable in demo |
|
and demo[self.fields[-1].input_variable] is not None |
|
) |
|
) |
|
] |
|
|
|
ademos = [ |
|
self.query(demo, is_demo=True) |
|
for demo in example.demos |
|
if "augmented" in demo and demo.augmented |
|
] |
|
|
|
|
|
rdemos_ = [] |
|
new_ademos = [] |
|
for rdemo in rdemos: |
|
if all( |
|
(field.name in rdemo) |
|
for field in self.fields |
|
if field.input_variable in example |
|
): |
|
import dspy |
|
|
|
if dspy.settings.release >= 20230928: |
|
new_ademos.append(rdemo) |
|
else: |
|
ademos.append(rdemo) |
|
else: |
|
rdemos_.append(rdemo) |
|
|
|
ademos = new_ademos + ademos |
|
rdemos = rdemos_ |
|
|
|
|
|
long_query = self._has_augmented_guidelines() |
|
|
|
if long_query: |
|
example["augmented"] = True |
|
|
|
query = self.query(example) |
|
|
|
|
|
if len(query.split('\n')) > len(self.fields): |
|
long_query = True |
|
|
|
if "augmented" not in example or not example.augmented: |
|
example["augmented"] = True |
|
query = self.query(example) |
|
|
|
rdemos = "\n\n".join(rdemos) |
|
|
|
if len(rdemos) >= 1 and len(ademos) == 0 and not long_query: |
|
rdemos_and_query = "\n\n".join([rdemos, query]) |
|
parts = [ |
|
self.instructions, |
|
self.guidelines(show_guidelines), |
|
rdemos_and_query, |
|
] |
|
elif len(rdemos) == 0: |
|
parts = [ |
|
self.instructions, |
|
self.guidelines(show_guidelines), |
|
*ademos, |
|
query, |
|
] |
|
else: |
|
parts = [ |
|
self.instructions, |
|
rdemos, |
|
self.guidelines(show_guidelines), |
|
*ademos, |
|
query, |
|
] |
|
|
|
prompt = "\n\n---\n\n".join([p.strip() for p in parts if p]) |
|
|
|
return prompt.strip() |
|
|