File size: 7,990 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
from collections import Counter
from typing import Callable, Any, Optional
import dsp
from dsp.utils import zipstar, normalize_text
from dsp.primitives.inspect import FuncInspector
from dsp.utils.utils import dotdict
from dsp.templates.template_v3 import Template
from dsp.primitives.demonstrate import Example
class Completions:
"""A state object that holds the valid LM completions for a given Template."""
def __init__(self, completions: list[Example], template: Template):
self.data = completions
self.template = template
def __iter__(self):
return self.data.__iter__()
def __getitem__(self, item):
return self.data[item]
def __len__(self):
return len(self.data)
def unpack(self, key=None):
if key:
return [getattr(c, key) for c in self.data]
return zipstar(self.data)
def __getattr__(self, name):
assert len(self.data) == 1
completion = self.data[0]
if name in completion.keys():
return getattr(completion, name)
if name.endswith("s") and name[:-1] in completion.keys():
pass
assert False, name
def generate(template: Template, **kwargs) -> Callable:
"""Returns a callable function that generates completions for a given example using the provided template."""
if hasattr(dsp.settings, "inspect"):
inspector = dsp.settings.inspect
_generate = inspector.inspect_func(dsp.predict._generate)
return _generate(template, **kwargs)
else:
return dsp.predict._generate(template, **kwargs)
def _generate(template: Template, **kwargs) -> Callable:
"""Returns a callable function that generates completions for a given example using the provided template."""
if not dsp.settings.lm:
raise AssertionError("No LM is loaded.")
generator = dsp.settings.lm
def do_generate(
example: Example, stage: str, max_depth: int = 2, original_example=None
):
if not dsp.settings.lm:
raise AssertionError("No LM is loaded.")
original_example = original_example or example
assert stage is not None
# Look up the appropriate fields in each demonstration.
example = example.demos_at(lambda d: d[stage])
# Generate and extract the fields.
prompt = template(example)
completions: list[dict[str, Any]] = generator(prompt, **kwargs)
completions: list[Example] = [template.extract(example, p) for p in completions]
# Find the completions that are most complete.
field_names: list[str] = [field.input_variable for field in template.fields]
last_field_idx = 0
for field_idx, key in enumerate(field_names):
completions_ = [
c for c in completions if key in c.keys() and c[key] is not None
]
# Filter out completions that are missing fields that are present in at least one completion.
if len(completions_):
completions = completions_
last_field_idx = field_idx + 1
# If none of the completions is completed (i.e., none has the final field set).
if last_field_idx < len(field_names):
# Pick the first completion that has gone farthest.
completion = completions[0]
completion[field_names[last_field_idx]] = ""
# Recurse with greedy decoding and a shorter length.
max_tokens = kwargs.get("max_tokens", dsp.settings.lm.kwargs["max_tokens"])
max_tokens = min(max(75, max_tokens // 2), max_tokens)
new_kwargs = {
**kwargs,
"max_tokens": max_tokens,
"n": 1,
"temperature": 0.0,
}
assert max_depth > 0
return generate(template, **new_kwargs)(
completion,
stage=stage,
max_depth=max_depth - 1,
original_example=original_example,
)
completions = Completions(completions, template=template)
example = example.copy(completions=completions)
if len(completions) == 1:
completion = completions[0]
example[stage] = example.copy(**completion)
if dsp.settings.compiling:
inputs_ = set(original_example.keys())
inputs = [
f.input_variable
for f in template.fields
if f.input_variable in inputs_
]
outputs = [
f.output_variable
for f in template.fields
if f.input_variable not in inputs_
]
example.compiling_stages = example.get("compiling_stages", [])
example.compiling_stages.append(
{
"name": stage,
"template": template,
"inputs": inputs,
"outputs": outputs,
}
)
else:
# assert not dsp.settings.compiling, "TODO: At this point, cannot compile n>1 generations"
example[stage] = dotdict(completions=completions)
return example, completions
return do_generate
def generate_sc(
example, prompt, normalize=True, extract=None, prediction_field=None, **kwargs
):
if not dsp.settings.lm:
raise AssertionError("No LM is loaded.")
kwargs = {"temperature": 0.7, "n": 20, "max_tokens": 150, **kwargs}
completions = dsp.settings.lm(prompt, **kwargs)
completions = extract_final_answer(example, completions, extract=extract)
return majority_vote_(
completions, normalize=normalize, prediction_field=prediction_field
)
def extract_final_answer(example, completions, extract=None):
if not dsp.settings.lm:
raise AssertionError("No LM is loaded.")
if extract:
completions = [extract(example, p) for p in completions]
else:
completions = [
p.strip().split("\n")[-1].split(":", 1)[-1].strip() for p in completions
]
# TODO: make thread-safe?
dsp.settings.lm.history.append(
{**dsp.settings.lm.history[-1], "completions": completions}
)
return completions
def majority(
completions: Completions, normalize: bool = True, field: Optional[str] = None
):
"""Returns the most common completion for the target field or the last field in the template."""
field = completions.template.fields[-1].output_variable if field is None else field
return Completions(
majority_vote_(completions, normalize=normalize, prediction_field=field),
template=completions.template,
)
def majority_vote_(completions: Completions, normalize: bool, prediction_field: str):
"""Core logic for majority vote."""
if not dsp.settings.lm:
raise AssertionError("No LM is loaded.")
normalized_to_original = {}
if normalize:
original_completions = completions
completions_ = []
for pred in completions:
if prediction_field in pred:
completions_.append(normalize_text(pred[prediction_field]))
else:
completions_.append("")
completions = completions_
for completion, normalized_completion in zip(original_completions, completions):
if normalized_completion not in normalized_to_original:
normalized_to_original[normalized_completion] = completion
completions_ = [x for x in completions if x]
if completions_:
completions = completions_
topk = Counter(completions).most_common()
pred, _ = topk[0]
if normalize:
pred = normalized_to_original[pred]
dsp.settings.lm.history.append(
{**dsp.settings.lm.history[-1], "topk": topk, "completions": [pred]}
)
return [pred]
|