
Upload from GitHub Actions: Add Todos for using existing machine-translated datasets rather than our own ones
56adaa2
verified
import asyncio | |
import os | |
import random | |
from datasets import Dataset, load_dataset | |
from datasets_.util import _get_dataset_config_names, _load_dataset | |
from langcodes import Language, standardize_tag | |
from models import get_google_supported_languages, translate_google | |
from rich import print | |
from tqdm import tqdm | |
from tqdm.asyncio import tqdm_asyncio | |
slug_mgsm = "juletxara/mgsm" | |
tags_mgsm = { | |
standardize_tag(a, macro=True): a for a in _get_dataset_config_names(slug_mgsm) | |
} | |
slug_afrimgsm = "masakhane/afrimgsm" | |
tags_afrimgsm = { | |
standardize_tag(a, macro=True): a for a in _get_dataset_config_names(slug_afrimgsm) | |
} | |
slug_gsm8kx = "Eurolingua/gsm8kx" | |
tags_gsm8kx = { | |
standardize_tag(a, macro=True): a | |
for a in _get_dataset_config_names(slug_gsm8kx, trust_remote_code=True) | |
} | |
slug_gsm_autotranslated = "fair-forward/gsm-autotranslated" | |
tags_gsm_autotranslated = { | |
standardize_tag(a, macro=True): a | |
for a in _get_dataset_config_names(slug_gsm_autotranslated) | |
} | |
def parse_number(i): | |
if isinstance(i, int): | |
return i | |
try: | |
return int(i.replace(",", "").replace(".", "")) | |
except ValueError: | |
return None | |
def load_mgsm(language_bcp_47, nr): | |
print(f"Loading MGSM data for {language_bcp_47}...") | |
if language_bcp_47 in tags_mgsm.keys(): | |
ds = _load_dataset(slug_mgsm, subset=tags_mgsm[language_bcp_47], split="test") | |
return slug_mgsm, ds[nr], "human" | |
elif language_bcp_47 in tags_afrimgsm.keys(): | |
ds = _load_dataset( | |
slug_afrimgsm, subset=tags_afrimgsm[language_bcp_47], split="test" | |
) | |
return slug_afrimgsm, ds[nr], "human" | |
elif language_bcp_47 in tags_gsm8kx.keys(): | |
row = _load_dataset( | |
slug_gsm8kx, | |
subset=tags_gsm8kx[language_bcp_47], | |
split="test", | |
trust_remote_code=True, | |
)[nr] | |
row["answer_number"] = row["answer"].split("####")[1].strip() | |
return slug_gsm8kx, row, "machine" | |
elif language_bcp_47 in tags_gsm_autotranslated.keys(): | |
ds = _load_dataset( | |
slug_gsm_autotranslated, | |
subset=tags_gsm_autotranslated[language_bcp_47], | |
split="test", | |
) | |
return slug_gsm_autotranslated, ds[nr], "machine" | |
else: | |
return None, None, None | |
def translate_mgsm(languages): | |
human_translated = [*tags_mgsm.keys(), *tags_afrimgsm.keys()] | |
untranslated = [ | |
lang | |
for lang in languages["bcp_47"].values[:100] | |
if lang not in human_translated and lang in get_google_supported_languages() | |
] | |
en = _load_dataset(slug_mgsm, subset=tags_mgsm["en"], split="test") | |
slug = "fair-forward/gsm-autotranslated" | |
for lang in tqdm(untranslated): | |
# check if already exists on hub | |
try: | |
ds_lang = load_dataset(slug, lang, split="test") | |
except ValueError: | |
print(f"Translating {lang}...") | |
questions_tr = [translate_google(q, "en", lang) for q in en["question"]] | |
questions_tr = asyncio.run(tqdm_asyncio.gather(*questions_tr)) | |
ds_lang = Dataset.from_dict( | |
{ | |
"question": questions_tr, | |
"answer": en["answer"], | |
"answer_number": en["answer_number"], | |
"equation_solution": en["equation_solution"], | |
} | |
) | |
ds_lang.push_to_hub( | |
slug, | |
split="test", | |
config_name=lang, | |
token=os.getenv("HUGGINGFACE_ACCESS_TOKEN"), | |
) | |
ds_lang.to_json( | |
f"data/translations/mgsm/{lang}.json", lines=False, force_ascii=False, indent=2 | |
) | |