|
from dspy.primitives.prediction import Prediction, Completions |
|
from dsp.utils import normalize_text |
|
|
|
|
|
default_normalize = lambda s: normalize_text(s) or None |
|
|
|
|
|
def majority(prediction_or_completions, normalize=default_normalize, field=None): |
|
""" |
|
Returns the most common completion for the target field (or the last field) in the signature. |
|
When normalize returns None, that completion is ignored. |
|
In case of a tie, earlier completion are prioritized. |
|
""" |
|
|
|
assert any(isinstance(prediction_or_completions, t) for t in [Prediction, Completions, list]) |
|
input_type = type(prediction_or_completions) |
|
|
|
|
|
if isinstance(prediction_or_completions, Prediction): |
|
completions = prediction_or_completions.completions |
|
else: |
|
completions = prediction_or_completions |
|
|
|
try: |
|
signature = completions.signature |
|
except: |
|
signature = None |
|
|
|
try: |
|
field = field if field else signature.fields[-1].output_variable |
|
except: |
|
field = field if field else list(completions[0].keys())[-1] |
|
|
|
|
|
normalize = normalize if normalize else lambda x: x |
|
normalized_values = [normalize(completion[field]) for completion in completions] |
|
normalized_values_ = [x for x in normalized_values if x is not None] |
|
|
|
|
|
value_counts = {} |
|
for value in (normalized_values_ or normalized_values): |
|
value_counts[value] = value_counts.get(value, 0) + 1 |
|
|
|
majority_value = max(value_counts, key=value_counts.get) |
|
|
|
|
|
for completion in completions: |
|
if normalize(completion[field]) == majority_value: |
|
break |
|
|
|
|
|
return Prediction.from_completions([completion], signature=signature) |
|
|
|
return Completions([completion], signature=signature) |
|
|
|
|