|
from .predict import Predict |
|
from ..primitives.program import Module |
|
|
|
import dsp |
|
|
|
class MultiChainComparison(Module): |
|
def __init__(self, signature, M=3, temperature=0.7, **config): |
|
super().__init__() |
|
|
|
self.M = M |
|
signature = Predict(signature).signature |
|
*keys, last_key = signature.kwargs.keys() |
|
|
|
extended_kwargs = {key: signature.kwargs[key] for key in keys} |
|
|
|
for idx in range(M): |
|
candidate_type = dsp.Type(prefix=f"Student Attempt #{idx+1}:", desc="${reasoning attempt}") |
|
extended_kwargs.update({f'reasoning_attempt_{idx+1}': candidate_type}) |
|
|
|
rationale_type = dsp.Type(prefix="Accurate Reasoning: Thank you everyone. Let's now holistically", desc="${corrected reasoning}") |
|
extended_kwargs.update({'rationale': rationale_type, last_key: signature.kwargs[last_key]}) |
|
|
|
signature = dsp.Template(signature.instructions, **extended_kwargs) |
|
self.predict = Predict(signature, temperature=temperature, **config) |
|
self.last_key = last_key |
|
|
|
def forward(self, completions, **kwargs): |
|
attempts = [] |
|
|
|
for c in completions: |
|
rationale = c.rationale.strip().split('\n')[0].strip() |
|
answer = c[self.last_key].strip().split('\n')[0].strip() |
|
attempts.append(f"«I'm trying to {rationale} I'm not sure but my prediction is {answer}»") |
|
|
|
assert len(attempts) == self.M, len(attempts) |
|
|
|
kwargs = {**{f'reasoning_attempt_{idx+1}': attempt for idx, attempt in enumerate(attempts)}, **kwargs} |
|
return self.predict(**kwargs) |
|
|