Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import itertools | |
| import json | |
| import logging | |
| import os | |
| import torch | |
| from argparse import Namespace | |
| import numpy as np | |
| from fairseq import metrics, options, utils | |
| from fairseq.data import ( | |
| AppendTokenDataset, | |
| ConcatDataset, | |
| LanguagePairDataset, | |
| PrependTokenDataset, | |
| StripTokenDataset, | |
| TruncateDataset, | |
| data_utils, | |
| encoders, | |
| indexed_dataset, | |
| ) | |
| from fairseq.tasks.translation import TranslationTask | |
| from fairseq.tasks import register_task, LegacyFairseqTask | |
| EVAL_BLEU_ORDER = 4 | |
| logger = logging.getLogger(__name__) | |
| def load_langpair_dataset( | |
| data_path, | |
| split, | |
| src, | |
| src_dict, | |
| tgt, | |
| tgt_dict, | |
| combine, | |
| dataset_impl, | |
| upsample_primary, | |
| left_pad_source, | |
| left_pad_target, | |
| max_source_positions, | |
| max_target_positions, | |
| prepend_bos=False, | |
| load_alignments=False, | |
| truncate_source=False, | |
| append_source_id=False, | |
| num_buckets=0, | |
| shuffle=True, | |
| pad_to_multiple=1, | |
| ): | |
| def split_exists(split, src, tgt, lang, data_path): | |
| filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) | |
| return indexed_dataset.dataset_exists(filename, impl=dataset_impl) | |
| src_datasets = [] | |
| tgt_datasets = [] | |
| for k in itertools.count(): | |
| split_k = split + (str(k) if k > 0 else "") | |
| # infer langcode | |
| if split_exists(split_k, src, tgt, src, data_path): | |
| prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) | |
| elif split_exists(split_k, tgt, src, src, data_path): | |
| prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) | |
| else: | |
| if k > 0: | |
| break | |
| else: | |
| raise FileNotFoundError( | |
| "Dataset not found: {} ({})".format(split, data_path) | |
| ) | |
| src_dataset = data_utils.load_indexed_dataset( | |
| prefix + src, src_dict, dataset_impl | |
| ) | |
| if truncate_source: | |
| src_dataset = AppendTokenDataset( | |
| TruncateDataset( | |
| StripTokenDataset(src_dataset, src_dict.eos()), | |
| max_source_positions - 1, | |
| ), | |
| src_dict.eos(), | |
| ) | |
| src_datasets.append(src_dataset) | |
| tgt_dataset = data_utils.load_indexed_dataset( | |
| prefix + tgt, tgt_dict, dataset_impl | |
| ) | |
| if tgt_dataset is not None: | |
| tgt_datasets.append(tgt_dataset) | |
| logger.info( | |
| "{} {} {}-{} {} examples".format( | |
| data_path, split_k, src, tgt, len(src_datasets[-1]) | |
| ) | |
| ) | |
| if not combine: | |
| break | |
| assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 | |
| if len(src_datasets) == 1: | |
| src_dataset = src_datasets[0] | |
| tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None | |
| else: | |
| sample_ratios = [1] * len(src_datasets) | |
| sample_ratios[0] = upsample_primary | |
| src_dataset = ConcatDataset(src_datasets, sample_ratios) | |
| if len(tgt_datasets) > 0: | |
| tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) | |
| else: | |
| tgt_dataset = None | |
| if prepend_bos: | |
| assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") | |
| src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) | |
| if tgt_dataset is not None: | |
| tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) | |
| eos = None | |
| if append_source_id: | |
| src_dataset = AppendTokenDataset( | |
| src_dataset, src_dict.index("[{}]".format(src)) | |
| ) | |
| if tgt_dataset is not None: | |
| tgt_dataset = AppendTokenDataset( | |
| tgt_dataset, tgt_dict.index("[{}]".format(tgt)) | |
| ) | |
| eos = tgt_dict.index("[{}]".format(tgt)) | |
| align_dataset = None | |
| if load_alignments: | |
| align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) | |
| if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): | |
| align_dataset = data_utils.load_indexed_dataset( | |
| align_path, None, dataset_impl | |
| ) | |
| tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None | |
| return LanguagePairDataset( | |
| src_dataset, | |
| src_dataset.sizes, | |
| src_dict, | |
| tgt_dataset, | |
| tgt_dataset_sizes, | |
| tgt_dict, | |
| left_pad_source=left_pad_source, | |
| left_pad_target=left_pad_target, | |
| align_dataset=align_dataset, | |
| eos=eos, | |
| num_buckets=num_buckets, | |
| shuffle=shuffle, | |
| pad_to_multiple=pad_to_multiple, | |
| ) | |
| class TranslationWithLangtokTask(LegacyFairseqTask): | |
| """ | |
| Translate from one (source) language to another (target) language. | |
| Args: | |
| src_dict (~fairseq.data.Dictionary): dictionary for the source language | |
| tgt_dict (~fairseq.data.Dictionary): dictionary for the target language | |
| .. note:: | |
| The translation task is compatible with :mod:`fairseq-train`, | |
| :mod:`fairseq-generate` and :mod:`fairseq-interactive`. | |
| The translation task provides the following additional command-line | |
| arguments: | |
| .. argparse:: | |
| :ref: fairseq.tasks.translation_parser | |
| :prog: | |
| """ | |
| def add_args(parser): | |
| """Add task-specific arguments to the parser.""" | |
| # fmt: off | |
| parser.add_argument('data', help='colon separated path to data directories list, \ | |
| will be iterated upon during epochs in round-robin manner; \ | |
| however, valid and test data are always in the first directory to \ | |
| avoid the need for repeating them in all directories') | |
| parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', | |
| help='source language') | |
| parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', | |
| help='target language') | |
| parser.add_argument('--load-alignments', action='store_true', | |
| help='load the binarized alignments') | |
| parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', | |
| help='pad the source on the left') | |
| parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', | |
| help='pad the target on the left') | |
| parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', | |
| help='max number of tokens in the source sequence') | |
| parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', | |
| help='max number of tokens in the target sequence') | |
| parser.add_argument('--upsample-primary', default=1, type=int, | |
| help='amount to upsample primary dataset') | |
| parser.add_argument('--truncate-source', action='store_true', default=False, | |
| help='truncate source to max-source-positions') | |
| parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N', | |
| help='if >0, then bucket source and target lengths into N ' | |
| 'buckets and pad accordingly; this is useful on TPUs ' | |
| 'to minimize the number of compilations') | |
| parser.add_argument('--lang-prefix-tok', default=None, type=str, help="starting token in decoder") | |
| # options for reporting BLEU during validation | |
| parser.add_argument('--eval-bleu', action='store_true', | |
| help='evaluation with BLEU scores') | |
| parser.add_argument('--eval-bleu-detok', type=str, default="space", | |
| help='detokenize before computing BLEU (e.g., "moses"); ' | |
| 'required if using --eval-bleu; use "space" to ' | |
| 'disable detokenization; see fairseq.data.encoders ' | |
| 'for other options') | |
| parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON', | |
| help='args for building the tokenizer, if needed') | |
| parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False, | |
| help='compute tokenized BLEU instead of sacrebleu') | |
| parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None, | |
| help='remove BPE before computing BLEU') | |
| parser.add_argument('--eval-bleu-args', type=str, metavar='JSON', | |
| help='generation args for BLUE scoring, ' | |
| 'e.g., \'{"beam": 4, "lenpen": 0.6}\'') | |
| parser.add_argument('--eval-bleu-print-samples', action='store_true', | |
| help='print sample generations during validation') | |
| # fmt: on | |
| def __init__(self, args, src_dict, tgt_dict): | |
| super().__init__(args) | |
| self.src_dict = src_dict | |
| self.tgt_dict = tgt_dict | |
| def setup_task(cls, args, **kwargs): | |
| """Setup the task (e.g., load dictionaries). | |
| Args: | |
| args (argparse.Namespace): parsed command-line arguments | |
| """ | |
| args.left_pad_source = utils.eval_bool(args.left_pad_source) | |
| args.left_pad_target = utils.eval_bool(args.left_pad_target) | |
| paths = utils.split_paths(args.data) | |
| assert len(paths) > 0 | |
| # find language pair automatically | |
| if args.source_lang is None or args.target_lang is None: | |
| args.source_lang, args.target_lang = data_utils.infer_language_pair( | |
| paths[0] | |
| ) | |
| if args.source_lang is None or args.target_lang is None: | |
| raise Exception( | |
| "Could not infer language pair, please provide it explicitly" | |
| ) | |
| # load dictionaries | |
| src_dict = cls.load_dictionary( | |
| os.path.join(paths[0], "dict.{}.txt".format(args.source_lang)) | |
| ) | |
| tgt_dict = cls.load_dictionary( | |
| os.path.join(paths[0], "dict.{}.txt".format(args.target_lang)) | |
| ) | |
| assert src_dict.pad() == tgt_dict.pad() | |
| assert src_dict.eos() == tgt_dict.eos() | |
| assert src_dict.unk() == tgt_dict.unk() | |
| logger.info("[{}] dictionary: {} types".format(args.source_lang, len(src_dict))) | |
| logger.info("[{}] dictionary: {} types".format(args.target_lang, len(tgt_dict))) | |
| return cls(args, src_dict, tgt_dict) | |
| def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
| """Load a given dataset split. | |
| Args: | |
| split (str): name of the split (e.g., train, valid, test) | |
| """ | |
| paths = utils.split_paths(self.args.data) | |
| assert len(paths) > 0 | |
| if split != getattr(self.args, "train_subset", None): | |
| # if not training data set, use the first shard for valid and test | |
| paths = paths[:1] | |
| data_path = paths[(epoch - 1) % len(paths)] | |
| # infer langcode | |
| src, tgt = self.args.source_lang, self.args.target_lang | |
| self.datasets[split] = load_langpair_dataset( | |
| data_path, | |
| split, | |
| src, | |
| self.src_dict, | |
| tgt, | |
| self.tgt_dict, | |
| combine=combine, | |
| dataset_impl=self.args.dataset_impl, | |
| upsample_primary=self.args.upsample_primary, | |
| left_pad_source=self.args.left_pad_source, | |
| left_pad_target=self.args.left_pad_target, | |
| max_source_positions=self.args.max_source_positions, | |
| max_target_positions=self.args.max_target_positions, | |
| load_alignments=self.args.load_alignments, | |
| truncate_source=self.args.truncate_source, | |
| num_buckets=self.args.num_batch_buckets, | |
| shuffle=(split != "test"), | |
| pad_to_multiple=self.args.required_seq_len_multiple, | |
| ) | |
| def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): | |
| return LanguagePairDataset( | |
| src_tokens, | |
| src_lengths, | |
| self.source_dictionary, | |
| tgt_dict=self.target_dictionary, | |
| constraints=constraints, | |
| ) | |
| def build_model(self, args): | |
| model = super().build_model(args) | |
| if getattr(args, "eval_bleu", False): | |
| assert getattr(args, "eval_bleu_detok", None) is not None, ( | |
| "--eval-bleu-detok is required if using --eval-bleu; " | |
| "try --eval-bleu-detok=moses (or --eval-bleu-detok=space " | |
| "to disable detokenization, e.g., when using sentencepiece)" | |
| ) | |
| detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}") | |
| self.tokenizer = encoders.build_tokenizer( | |
| Namespace( | |
| tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args | |
| ) | |
| ) | |
| gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}") | |
| self.sequence_generator = self.build_generator( | |
| [model], Namespace(**gen_args) | |
| ) | |
| return model | |
| def valid_step(self, sample, model, criterion): | |
| loss, sample_size, logging_output = super().valid_step(sample, model, criterion) | |
| if self.args.eval_bleu: | |
| bleu = self._inference_with_bleu(self.sequence_generator, sample, model) | |
| logging_output["_bleu_sys_len"] = bleu.sys_len | |
| logging_output["_bleu_ref_len"] = bleu.ref_len | |
| # we split counts into separate entries so that they can be | |
| # summed efficiently across workers using fast-stat-sync | |
| assert len(bleu.counts) == EVAL_BLEU_ORDER | |
| for i in range(EVAL_BLEU_ORDER): | |
| logging_output["_bleu_counts_" + str(i)] = bleu.counts[i] | |
| logging_output["_bleu_totals_" + str(i)] = bleu.totals[i] | |
| return loss, sample_size, logging_output | |
| def inference_step( | |
| self, generator, models, sample, prefix_tokens=None, constraints=None | |
| ): | |
| if self.args.lang_prefix_tok is None: | |
| prefix_tokens = None | |
| else: | |
| prefix_tokens = self.target_dictionary.index(self.args.lang_prefix_tok) | |
| assert prefix_tokens != self.target_dictionary.unk_index | |
| with torch.no_grad(): | |
| net_input = sample["net_input"] | |
| if "src_tokens" in net_input: | |
| src_tokens = net_input["src_tokens"] | |
| elif "source" in net_input: | |
| src_tokens = net_input["source"] | |
| else: | |
| raise Exception("expected src_tokens or source in net input") | |
| # bsz: total number of sentences in beam | |
| # Note that src_tokens may have more than 2 dimenions (i.e. audio features) | |
| bsz, _ = src_tokens.size()[:2] | |
| if prefix_tokens is not None: | |
| if isinstance(prefix_tokens, int): | |
| prefix_tokens = torch.LongTensor([prefix_tokens]).unsqueeze(1) # 1,1 | |
| prefix_tokens = prefix_tokens.expand(bsz, -1) | |
| prefix_tokens = prefix_tokens.to(src_tokens.device) | |
| return generator.generate(models, sample, prefix_tokens=prefix_tokens) | |
| def reduce_metrics(self, logging_outputs, criterion): | |
| super().reduce_metrics(logging_outputs, criterion) | |
| if self.args.eval_bleu: | |
| def sum_logs(key): | |
| return sum(log.get(key, 0) for log in logging_outputs) | |
| counts, totals = [], [] | |
| for i in range(EVAL_BLEU_ORDER): | |
| counts.append(sum_logs("_bleu_counts_" + str(i))) | |
| totals.append(sum_logs("_bleu_totals_" + str(i))) | |
| if max(totals) > 0: | |
| # log counts as numpy arrays -- log_scalar will sum them correctly | |
| metrics.log_scalar("_bleu_counts", np.array(counts)) | |
| metrics.log_scalar("_bleu_totals", np.array(totals)) | |
| metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len")) | |
| metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len")) | |
| def compute_bleu(meters): | |
| import inspect | |
| import sacrebleu | |
| fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0] | |
| if "smooth_method" in fn_sig: | |
| smooth = {"smooth_method": "exp"} | |
| else: | |
| smooth = {"smooth": "exp"} | |
| bleu = sacrebleu.compute_bleu( | |
| correct=meters["_bleu_counts"].sum, | |
| total=meters["_bleu_totals"].sum, | |
| sys_len=meters["_bleu_sys_len"].sum, | |
| ref_len=meters["_bleu_ref_len"].sum, | |
| **smooth | |
| ) | |
| return round(bleu.score, 2) | |
| metrics.log_derived("bleu", compute_bleu) | |
| def max_positions(self): | |
| """Return the max sentence length allowed by the task.""" | |
| return (self.args.max_source_positions, self.args.max_target_positions) | |
| def source_dictionary(self): | |
| """Return the source :class:`~fairseq.data.Dictionary`.""" | |
| return self.src_dict | |
| def target_dictionary(self): | |
| """Return the target :class:`~fairseq.data.Dictionary`.""" | |
| return self.tgt_dict | |
| def _inference_with_bleu(self, generator, sample, model): | |
| import sacrebleu | |
| def decode(toks, escape_unk=False): | |
| s = self.tgt_dict.string( | |
| toks.int().cpu(), | |
| self.args.eval_bleu_remove_bpe, | |
| # The default unknown string in fairseq is `<unk>`, but | |
| # this is tokenized by sacrebleu as `< unk >`, inflating | |
| # BLEU scores. Instead, we use a somewhat more verbose | |
| # alternative that is unlikely to appear in the real | |
| # reference, but doesn't get split into multiple tokens. | |
| unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), | |
| ) | |
| if self.tokenizer: | |
| s = self.tokenizer.decode(s) | |
| return s | |
| gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None) | |
| hyps, refs = [], [] | |
| for i in range(len(gen_out)): | |
| hyps.append(decode(gen_out[i][0]["tokens"])) | |
| refs.append( | |
| decode( | |
| utils.strip_pad(sample["target"][i], self.tgt_dict.pad()), | |
| escape_unk=True, # don't count <unk> as matches to the hypo | |
| ) | |
| ) | |
| if self.args.eval_bleu_print_samples: | |
| logger.info("example hypothesis: " + hyps[0]) | |
| logger.info("example reference: " + refs[0]) | |
| if self.args.eval_tokenized_bleu: | |
| return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none") | |
| else: | |
| return sacrebleu.corpus_bleu(hyps, [refs]) | |