File size: 1,604 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from .predict import Predict
from ..primitives.program import Module
import dsp
class MultiChainComparison(Module):
def __init__(self, signature, M=3, temperature=0.7, **config):
super().__init__()
self.M = M
signature = Predict(signature).signature
*keys, last_key = signature.kwargs.keys()
extended_kwargs = {key: signature.kwargs[key] for key in keys}
for idx in range(M):
candidate_type = dsp.Type(prefix=f"Student Attempt #{idx+1}:", desc="${reasoning attempt}")
extended_kwargs.update({f'reasoning_attempt_{idx+1}': candidate_type})
rationale_type = dsp.Type(prefix="Accurate Reasoning: Thank you everyone. Let's now holistically", desc="${corrected reasoning}")
extended_kwargs.update({'rationale': rationale_type, last_key: signature.kwargs[last_key]})
signature = dsp.Template(signature.instructions, **extended_kwargs)
self.predict = Predict(signature, temperature=temperature, **config)
self.last_key = last_key
def forward(self, completions, **kwargs):
attempts = []
for c in completions:
rationale = c.rationale.strip().split('\n')[0].strip()
answer = c[self.last_key].strip().split('\n')[0].strip()
attempts.append(f"«I'm trying to {rationale} I'm not sure but my prediction is {answer}»")
assert len(attempts) == self.M, len(attempts)
kwargs = {**{f'reasoning_attempt_{idx+1}': attempt for idx, attempt in enumerate(attempts)}, **kwargs}
return self.predict(**kwargs)
|