Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| from fairseq import options, utils | |
| from fairseq.data import ( | |
| ConcatDataset, | |
| data_utils, | |
| LanguagePairDataset) | |
| from ..data import SubsampleLanguagePairDataset | |
| import logging | |
| from fairseq.tasks import register_task | |
| from fairseq.tasks.translation import TranslationTask, load_langpair_dataset | |
| logger = logging.getLogger(__name__) | |
| def concat_language_pair_dataset(*language_pair_datasets, up_sample_ratio=None, | |
| all_dataset_upsample_ratio=None): | |
| logger.info("To cancat the language pairs") | |
| dataset_number = len(language_pair_datasets) | |
| if dataset_number == 1: | |
| return language_pair_datasets[0] | |
| elif dataset_number < 1: | |
| raise ValueError("concat_language_pair_dataset needs at least on dataset") | |
| # for dataset in language_pair_datasets: | |
| # assert isinstance(dataset, LanguagePairDataset), "concat_language_pair_dataset can only concat language pair" \ | |
| # "dataset" | |
| src_list = [language_pair_datasets[0].src] | |
| tgt_list = [language_pair_datasets[0].tgt] | |
| src_dict = language_pair_datasets[0].src_dict | |
| tgt_dict = language_pair_datasets[0].tgt_dict | |
| left_pad_source = language_pair_datasets[0].left_pad_source | |
| left_pad_target = language_pair_datasets[0].left_pad_target | |
| logger.info("To construct the source dataset list and the target dataset list") | |
| for dataset in language_pair_datasets[1:]: | |
| assert dataset.src_dict == src_dict | |
| assert dataset.tgt_dict == tgt_dict | |
| assert dataset.left_pad_source == left_pad_source | |
| assert dataset.left_pad_target == left_pad_target | |
| src_list.append(dataset.src) | |
| tgt_list.append(dataset.tgt) | |
| logger.info("Have constructed the source dataset list and the target dataset list") | |
| if all_dataset_upsample_ratio is None: | |
| sample_ratio = [1] * len(src_list) | |
| sample_ratio[0] = up_sample_ratio | |
| else: | |
| sample_ratio = [int(t) for t in all_dataset_upsample_ratio.strip().split(",")] | |
| assert len(sample_ratio) == len(src_list) | |
| src_dataset = ConcatDataset(src_list, sample_ratios=sample_ratio) | |
| tgt_dataset = ConcatDataset(tgt_list, sample_ratios=sample_ratio) | |
| res = LanguagePairDataset( | |
| src_dataset, src_dataset.sizes, src_dict, | |
| tgt_dataset, tgt_dataset.sizes, tgt_dict, | |
| left_pad_source=left_pad_source, | |
| left_pad_target=left_pad_target, | |
| ) | |
| logger.info("Have created the concat language pair dataset") | |
| return res | |
| class TranslationWithMonoTask(TranslationTask): | |
| """ | |
| Translate from one (source) language to another (target) language. | |
| Args: | |
| src_dict (~fairseq.data.Dictionary): dictionary for the source language | |
| tgt_dict (~fairseq.data.Dictionary): dictionary for the target language | |
| .. note:: | |
| The translation task is compatible with :mod:`fairseq-train`, | |
| :mod:`fairseq-generate` and :mod:`fairseq-interactive`. | |
| The translation task provides the following additional command-line | |
| arguments: | |
| .. argparse:: | |
| :ref: fairseq.tasks.translation_parser | |
| :prog: | |
| """ | |
| def add_args(parser): | |
| """Add task-specific arguments to the parser.""" | |
| # fmt: off | |
| TranslationTask.add_args(parser) | |
| parser.add_argument('--mono-data', default=None, help='monolingual data, split by :') | |
| parser.add_argument('--mono-one-split-each-epoch', action='store_true', default=False, help='use on split of monolingual data at each epoch') | |
| parser.add_argument('--parallel-ratio', default=1.0, type=float, help='subsample ratio of parallel data') | |
| parser.add_argument('--mono-ratio', default=1.0, type=float, help='subsample ratio of mono data') | |
| def __init__(self, args, src_dict, tgt_dict): | |
| super().__init__(args, src_dict, tgt_dict) | |
| self.src_dict = src_dict | |
| self.tgt_dict = tgt_dict | |
| self.update_number = 0 | |
| def setup_task(cls, args, **kwargs): | |
| """Setup the task (e.g., load dictionaries). | |
| Args: | |
| args (argparse.Namespace): parsed command-line arguments | |
| """ | |
| args.left_pad_source = options.eval_bool(args.left_pad_source) | |
| args.left_pad_target = options.eval_bool(args.left_pad_target) | |
| if getattr(args, 'raw_text', False): | |
| utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw') | |
| args.dataset_impl = 'raw' | |
| elif getattr(args, 'lazy_load', False): | |
| utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy') | |
| args.dataset_impl = 'lazy' | |
| paths = utils.split_paths(args.data) | |
| assert len(paths) > 0 | |
| # find language pair automatically | |
| if args.source_lang is None or args.target_lang is None: | |
| args.source_lang, args.target_lang = data_utils.infer_language_pair(paths[0]) | |
| if args.source_lang is None or args.target_lang is None: | |
| raise Exception('Could not infer language pair, please provide it explicitly') | |
| # load dictionaries | |
| src_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang))) | |
| tgt_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang))) | |
| assert src_dict.pad() == tgt_dict.pad() | |
| assert src_dict.eos() == tgt_dict.eos() | |
| assert src_dict.unk() == tgt_dict.unk() | |
| logger.info('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict))) | |
| logger.info('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict))) | |
| return cls(args, src_dict, tgt_dict) | |
| def load_dataset(self, split, epoch=0, combine=False, **kwargs): | |
| """Load a given dataset split. | |
| Args: | |
| split (str): name of the split (e.g., train, valid, test) | |
| """ | |
| logger.info("To load the dataset {}".format(split)) | |
| paths = utils.split_paths(self.args.data) | |
| assert len(paths) > 0 | |
| if split != getattr(self.args, "train_subset", None): | |
| # if not training data set, use the first shard for valid and test | |
| paths = paths[:1] | |
| data_path = paths[(epoch - 1) % len(paths)] | |
| mono_paths = utils.split_paths(self.args.mono_data) | |
| # infer langcode | |
| src, tgt = self.args.source_lang, self.args.target_lang | |
| parallel_data = load_langpair_dataset( | |
| data_path, split, src, self.src_dict, tgt, self.tgt_dict, | |
| combine=combine, dataset_impl=self.args.dataset_impl, | |
| upsample_primary=self.args.upsample_primary, | |
| left_pad_source=self.args.left_pad_source, | |
| left_pad_target=self.args.left_pad_target, | |
| max_source_positions=self.args.max_source_positions, | |
| max_target_positions=self.args.max_target_positions, | |
| load_alignments=self.args.load_alignments, | |
| num_buckets=self.args.num_batch_buckets, | |
| shuffle=(split != "test"), | |
| pad_to_multiple=self.args.required_seq_len_multiple, | |
| ) | |
| if split == "train": | |
| parallel_data = SubsampleLanguagePairDataset(parallel_data, size_ratio=self.args.parallel_ratio, | |
| seed=self.args.seed, | |
| epoch=epoch) | |
| if self.args.mono_one_split_each_epoch: | |
| mono_path = mono_paths[(epoch - 1) % len(mono_paths)] # each at one epoch | |
| mono_data = load_langpair_dataset( | |
| mono_path, split, src, self.src_dict, tgt, self.tgt_dict, | |
| combine=combine, dataset_impl=self.args.dataset_impl, | |
| upsample_primary=self.args.upsample_primary, | |
| left_pad_source=self.args.left_pad_source, | |
| left_pad_target=self.args.left_pad_target, | |
| max_source_positions=self.args.max_source_positions, | |
| shuffle=(split != "test"), | |
| max_target_positions=self.args.max_target_positions, | |
| ) | |
| mono_data = SubsampleLanguagePairDataset(mono_data, size_ratio=self.args.mono_ratio, | |
| seed=self.args.seed, | |
| epoch=epoch) | |
| all_dataset = [parallel_data, mono_data] | |
| else: | |
| mono_datas = [] | |
| for mono_path in mono_paths: | |
| mono_data = load_langpair_dataset( | |
| mono_path, split, src, self.src_dict, tgt, self.tgt_dict, | |
| combine=combine, dataset_impl=self.args.dataset_impl, | |
| upsample_primary=self.args.upsample_primary, | |
| left_pad_source=self.args.left_pad_source, | |
| left_pad_target=self.args.left_pad_target, | |
| max_source_positions=self.args.max_source_positions, | |
| shuffle=(split != "test"), | |
| max_target_positions=self.args.max_target_positions, | |
| ) | |
| mono_data = SubsampleLanguagePairDataset(mono_data, size_ratio=self.args.mono_ratio, | |
| seed=self.args.seed, | |
| epoch=epoch) | |
| mono_datas.append(mono_data) | |
| all_dataset = [parallel_data] + mono_datas | |
| self.datasets[split] = ConcatDataset(all_dataset) | |
| else: | |
| self.datasets[split] = parallel_data | |