|
from collections import Counter |
|
from typing import Callable, Any, Optional |
|
|
|
import dsp |
|
from dsp.utils import zipstar, normalize_text |
|
from dsp.primitives.inspect import FuncInspector |
|
from dsp.utils.utils import dotdict |
|
from dsp.templates.template_v3 import Template |
|
from dsp.primitives.demonstrate import Example |
|
|
|
|
|
class Completions: |
|
"""A state object that holds the valid LM completions for a given Template.""" |
|
|
|
def __init__(self, completions: list[Example], template: Template): |
|
self.data = completions |
|
self.template = template |
|
|
|
def __iter__(self): |
|
return self.data.__iter__() |
|
|
|
def __getitem__(self, item): |
|
return self.data[item] |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def unpack(self, key=None): |
|
if key: |
|
return [getattr(c, key) for c in self.data] |
|
|
|
return zipstar(self.data) |
|
|
|
def __getattr__(self, name): |
|
assert len(self.data) == 1 |
|
|
|
completion = self.data[0] |
|
|
|
if name in completion.keys(): |
|
return getattr(completion, name) |
|
|
|
if name.endswith("s") and name[:-1] in completion.keys(): |
|
pass |
|
|
|
assert False, name |
|
|
|
|
|
def generate(template: Template, **kwargs) -> Callable: |
|
"""Returns a callable function that generates completions for a given example using the provided template.""" |
|
if hasattr(dsp.settings, "inspect"): |
|
inspector = dsp.settings.inspect |
|
_generate = inspector.inspect_func(dsp.predict._generate) |
|
return _generate(template, **kwargs) |
|
else: |
|
return dsp.predict._generate(template, **kwargs) |
|
|
|
|
|
def _generate(template: Template, **kwargs) -> Callable: |
|
"""Returns a callable function that generates completions for a given example using the provided template.""" |
|
if not dsp.settings.lm: |
|
raise AssertionError("No LM is loaded.") |
|
|
|
generator = dsp.settings.lm |
|
|
|
def do_generate( |
|
example: Example, stage: str, max_depth: int = 2, original_example=None |
|
): |
|
if not dsp.settings.lm: |
|
raise AssertionError("No LM is loaded.") |
|
original_example = original_example or example |
|
assert stage is not None |
|
|
|
|
|
example = example.demos_at(lambda d: d[stage]) |
|
|
|
|
|
prompt = template(example) |
|
completions: list[dict[str, Any]] = generator(prompt, **kwargs) |
|
completions: list[Example] = [template.extract(example, p) for p in completions] |
|
|
|
|
|
field_names: list[str] = [field.input_variable for field in template.fields] |
|
|
|
last_field_idx = 0 |
|
for field_idx, key in enumerate(field_names): |
|
completions_ = [ |
|
c for c in completions if key in c.keys() and c[key] is not None |
|
] |
|
|
|
|
|
if len(completions_): |
|
completions = completions_ |
|
last_field_idx = field_idx + 1 |
|
|
|
|
|
if last_field_idx < len(field_names): |
|
|
|
completion = completions[0] |
|
completion[field_names[last_field_idx]] = "" |
|
|
|
|
|
max_tokens = kwargs.get("max_tokens", dsp.settings.lm.kwargs["max_tokens"]) |
|
max_tokens = min(max(75, max_tokens // 2), max_tokens) |
|
new_kwargs = { |
|
**kwargs, |
|
"max_tokens": max_tokens, |
|
"n": 1, |
|
"temperature": 0.0, |
|
} |
|
|
|
assert max_depth > 0 |
|
return generate(template, **new_kwargs)( |
|
completion, |
|
stage=stage, |
|
max_depth=max_depth - 1, |
|
original_example=original_example, |
|
) |
|
|
|
completions = Completions(completions, template=template) |
|
example = example.copy(completions=completions) |
|
|
|
if len(completions) == 1: |
|
completion = completions[0] |
|
example[stage] = example.copy(**completion) |
|
|
|
if dsp.settings.compiling: |
|
inputs_ = set(original_example.keys()) |
|
inputs = [ |
|
f.input_variable |
|
for f in template.fields |
|
if f.input_variable in inputs_ |
|
] |
|
outputs = [ |
|
f.output_variable |
|
for f in template.fields |
|
if f.input_variable not in inputs_ |
|
] |
|
|
|
example.compiling_stages = example.get("compiling_stages", []) |
|
example.compiling_stages.append( |
|
{ |
|
"name": stage, |
|
"template": template, |
|
"inputs": inputs, |
|
"outputs": outputs, |
|
} |
|
) |
|
else: |
|
|
|
example[stage] = dotdict(completions=completions) |
|
|
|
return example, completions |
|
|
|
return do_generate |
|
|
|
|
|
def generate_sc( |
|
example, prompt, normalize=True, extract=None, prediction_field=None, **kwargs |
|
): |
|
if not dsp.settings.lm: |
|
raise AssertionError("No LM is loaded.") |
|
kwargs = {"temperature": 0.7, "n": 20, "max_tokens": 150, **kwargs} |
|
|
|
completions = dsp.settings.lm(prompt, **kwargs) |
|
completions = extract_final_answer(example, completions, extract=extract) |
|
return majority_vote_( |
|
completions, normalize=normalize, prediction_field=prediction_field |
|
) |
|
|
|
|
|
def extract_final_answer(example, completions, extract=None): |
|
if not dsp.settings.lm: |
|
raise AssertionError("No LM is loaded.") |
|
if extract: |
|
completions = [extract(example, p) for p in completions] |
|
else: |
|
completions = [ |
|
p.strip().split("\n")[-1].split(":", 1)[-1].strip() for p in completions |
|
] |
|
|
|
|
|
dsp.settings.lm.history.append( |
|
{**dsp.settings.lm.history[-1], "completions": completions} |
|
) |
|
|
|
return completions |
|
|
|
|
|
def majority( |
|
completions: Completions, normalize: bool = True, field: Optional[str] = None |
|
): |
|
"""Returns the most common completion for the target field or the last field in the template.""" |
|
field = completions.template.fields[-1].output_variable if field is None else field |
|
|
|
return Completions( |
|
majority_vote_(completions, normalize=normalize, prediction_field=field), |
|
template=completions.template, |
|
) |
|
|
|
|
|
def majority_vote_(completions: Completions, normalize: bool, prediction_field: str): |
|
"""Core logic for majority vote.""" |
|
|
|
if not dsp.settings.lm: |
|
raise AssertionError("No LM is loaded.") |
|
|
|
normalized_to_original = {} |
|
if normalize: |
|
original_completions = completions |
|
completions_ = [] |
|
for pred in completions: |
|
if prediction_field in pred: |
|
completions_.append(normalize_text(pred[prediction_field])) |
|
else: |
|
completions_.append("") |
|
completions = completions_ |
|
|
|
for completion, normalized_completion in zip(original_completions, completions): |
|
if normalized_completion not in normalized_to_original: |
|
normalized_to_original[normalized_completion] = completion |
|
|
|
completions_ = [x for x in completions if x] |
|
|
|
if completions_: |
|
completions = completions_ |
|
|
|
topk = Counter(completions).most_common() |
|
pred, _ = topk[0] |
|
|
|
if normalize: |
|
pred = normalized_to_original[pred] |
|
|
|
dsp.settings.lm.history.append( |
|
{**dsp.settings.lm.history[-1], "topk": topk, "completions": [pred]} |
|
) |
|
|
|
return [pred] |
|
|