EureCA / dspy /datasets /gsm8k.py
tonneli's picture
Delete history
f5776d3
import tqdm
import random
from datasets import load_dataset
from dspy.datasets.dataset import Dataset
class GSM8K:
def __init__(self) -> None:
super().__init__()
self.do_shuffle = False
dataset = load_dataset("gsm8k", 'main')
hf_official_train = dataset['train']
hf_official_test = dataset['test']
official_train = []
official_test = []
for example in tqdm.tqdm(hf_official_train):
question = example['question']
answer = example['answer'].strip().split()
assert answer[-2] == '####'
gold_reasoning = ' '.join(answer[:-2])
answer = str(int(answer[-1].replace(',', '')))
official_train.append(dict(question=question, gold_reasoning=gold_reasoning, answer=answer))
for example in tqdm.tqdm(hf_official_test):
question = example['question']
answer = example['answer'].strip().split()
assert answer[-2] == '####'
gold_reasoning = ' '.join(answer[:-2])
answer = str(int(answer[-1].replace(',', '')))
official_test.append(dict(question=question, gold_reasoning=gold_reasoning, answer=answer))
rng = random.Random(0)
rng.shuffle(official_train)
rng = random.Random(0)
rng.shuffle(official_test)
trainset = official_train[:200]
devset = official_train[200:500]
testset = official_test[:]
import dspy
trainset = [dspy.Example(**x).with_inputs('question') for x in trainset]
devset = [dspy.Example(**x).with_inputs('question') for x in devset]
testset = [dspy.Example(**x).with_inputs('question') for x in testset]
# print(f"Trainset size: {len(trainset)}")
# print(f"Devset size: {len(devset)}")
# print(f"Testset size: {len(testset)}")
self.train = trainset
self.dev = devset
self.test = testset
def parse_integer_answer(answer, only_first_line=True):
try:
if only_first_line:
answer = answer.strip().split('\n')[0]
# find the last token that has a number in it
answer = [token for token in answer.split() if any(c.isdigit() for c in token)][-1]
answer = answer.split('.')[0]
answer = ''.join([c for c in answer if c.isdigit()])
answer = int(answer)
except (ValueError, IndexError):
# print(answer)
answer = 0
return answer
def gsm8k_metric(gold, pred, trace=None):
return int(parse_integer_answer(str(gold.answer))) == int(parse_integer_answer(str(pred.answer)))