Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import json | |
| import os | |
| import random | |
| import sys | |
| from io import StringIO | |
| import torch | |
| import torch.nn.functional as F | |
| from fairseq import options, utils | |
| from fairseq.data import Dictionary | |
| from fairseq.data.language_pair_dataset import collate | |
| from fairseq.models import ( | |
| FairseqEncoder, | |
| FairseqEncoderDecoderModel, | |
| FairseqIncrementalDecoder, | |
| ) | |
| from fairseq.models.fairseq_encoder import EncoderOut | |
| from fairseq.tasks import LegacyFairseqTask | |
| from fairseq_cli import generate, interactive, preprocess, train, validate | |
| import fairseq.distributed.utils as distributed_utils | |
| from fairseq.dataclass.utils import convert_namespace_to_omegaconf | |
| def dummy_dictionary(vocab_size, prefix="token_"): | |
| d = Dictionary() | |
| for i in range(vocab_size): | |
| token = prefix + str(i) | |
| d.add_symbol(token) | |
| d.finalize(padding_factor=1) # don't add extra padding symbols | |
| return d | |
| def dummy_dataloader( | |
| samples, padding_idx=1, eos_idx=2, batch_size=None, | |
| ): | |
| if batch_size is None: | |
| batch_size = len(samples) | |
| # add any missing data to samples | |
| for i, sample in enumerate(samples): | |
| if "id" not in sample: | |
| sample["id"] = i | |
| # create dataloader | |
| dataset = TestDataset(samples) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)), | |
| ) | |
| return iter(dataloader) | |
| def sequence_generator_setup(): | |
| # construct dummy dictionary | |
| d = dummy_dictionary(vocab_size=2) | |
| eos = d.eos() | |
| w1 = 4 | |
| w2 = 5 | |
| # construct source data | |
| src_tokens = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]]) | |
| src_lengths = torch.LongTensor([2, 2]) | |
| args = argparse.Namespace() | |
| unk = 0.0 | |
| args.beam_probs = [ | |
| # step 0: | |
| torch.FloatTensor( | |
| [ | |
| # eos w1 w2 | |
| # sentence 1: | |
| [0.0, unk, 0.9, 0.1], # beam 1 | |
| [0.0, unk, 0.9, 0.1], # beam 2 | |
| # sentence 2: | |
| [0.0, unk, 0.7, 0.3], | |
| [0.0, unk, 0.7, 0.3], | |
| ] | |
| ), | |
| # step 1: | |
| torch.FloatTensor( | |
| [ | |
| # eos w1 w2 prefix | |
| # sentence 1: | |
| [1.0, unk, 0.0, 0.0], # w1: 0.9 (emit: w1 <eos>: 0.9*1.0) | |
| [0.0, unk, 0.9, 0.1], # w2: 0.1 | |
| # sentence 2: | |
| [0.25, unk, 0.35, 0.4], # w1: 0.7 (don't emit: w1 <eos>: 0.7*0.25) | |
| [0.00, unk, 0.10, 0.9], # w2: 0.3 | |
| ] | |
| ), | |
| # step 2: | |
| torch.FloatTensor( | |
| [ | |
| # eos w1 w2 prefix | |
| # sentence 1: | |
| [0.0, unk, 0.1, 0.9], # w2 w1: 0.1*0.9 | |
| [ | |
| 0.6, | |
| unk, | |
| 0.2, | |
| 0.2, | |
| ], # w2 w2: 0.1*0.1 (emit: w2 w2 <eos>: 0.1*0.1*0.6) | |
| # sentence 2: | |
| [ | |
| 0.60, | |
| unk, | |
| 0.4, | |
| 0.00, | |
| ], # w1 w2: 0.7*0.4 (emit: w1 w2 <eos>: 0.7*0.4*0.6) | |
| [0.01, unk, 0.0, 0.99], # w2 w2: 0.3*0.9 | |
| ] | |
| ), | |
| # step 3: | |
| torch.FloatTensor( | |
| [ | |
| # eos w1 w2 prefix | |
| # sentence 1: | |
| [ | |
| 1.0, | |
| unk, | |
| 0.0, | |
| 0.0, | |
| ], # w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 <eos>: 0.1*0.9*0.9*1.0) | |
| [ | |
| 1.0, | |
| unk, | |
| 0.0, | |
| 0.0, | |
| ], # w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 <eos>: 0.1*0.9*0.1*1.0) | |
| # sentence 2: | |
| [ | |
| 0.1, | |
| unk, | |
| 0.5, | |
| 0.4, | |
| ], # w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 <eos>: 0.3*0.9*0.99*0.1) | |
| [ | |
| 1.0, | |
| unk, | |
| 0.0, | |
| 0.0, | |
| ], # w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 <eos>: 0.7*0.4*0.4*1.0) | |
| ] | |
| ), | |
| ] | |
| task = TestTranslationTask.setup_task(args, d, d) | |
| model = task.build_model(args) | |
| tgt_dict = task.target_dictionary | |
| return tgt_dict, w1, w2, src_tokens, src_lengths, model | |
| def create_dummy_data(data_dir, num_examples=100, maxlen=20, alignment=False): | |
| def _create_dummy_data(filename): | |
| data = torch.rand(num_examples * maxlen) | |
| data = 97 + torch.floor(26 * data).int() | |
| with open(os.path.join(data_dir, filename), "w") as h: | |
| offset = 0 | |
| for _ in range(num_examples): | |
| ex_len = random.randint(1, maxlen) | |
| ex_str = " ".join(map(chr, data[offset : offset + ex_len])) | |
| print(ex_str, file=h) | |
| offset += ex_len | |
| def _create_dummy_alignment_data(filename_src, filename_tgt, filename): | |
| with open(os.path.join(data_dir, filename_src), "r") as src_f, open( | |
| os.path.join(data_dir, filename_tgt), "r" | |
| ) as tgt_f, open(os.path.join(data_dir, filename), "w") as h: | |
| for src, tgt in zip(src_f, tgt_f): | |
| src_len = len(src.split()) | |
| tgt_len = len(tgt.split()) | |
| avg_len = (src_len + tgt_len) // 2 | |
| num_alignments = random.randint(avg_len // 2, 2 * avg_len) | |
| src_indices = torch.floor(torch.rand(num_alignments) * src_len).int() | |
| tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int() | |
| ex_str = " ".join( | |
| [ | |
| "{}-{}".format(src, tgt) | |
| for src, tgt in zip(src_indices, tgt_indices) | |
| ] | |
| ) | |
| print(ex_str, file=h) | |
| _create_dummy_data("train.in") | |
| _create_dummy_data("train.out") | |
| _create_dummy_data("valid.in") | |
| _create_dummy_data("valid.out") | |
| _create_dummy_data("test.in") | |
| _create_dummy_data("test.out") | |
| if alignment: | |
| _create_dummy_alignment_data("train.in", "train.out", "train.align") | |
| _create_dummy_alignment_data("valid.in", "valid.out", "valid.align") | |
| _create_dummy_alignment_data("test.in", "test.out", "test.align") | |
| def preprocess_lm_data(data_dir): | |
| preprocess_parser = options.get_preprocessing_parser() | |
| preprocess_args = preprocess_parser.parse_args( | |
| [ | |
| "--only-source", | |
| "--trainpref", | |
| os.path.join(data_dir, "train.out"), | |
| "--validpref", | |
| os.path.join(data_dir, "valid.out"), | |
| "--testpref", | |
| os.path.join(data_dir, "test.out"), | |
| "--destdir", | |
| data_dir, | |
| ] | |
| ) | |
| preprocess.main(preprocess_args) | |
| def preprocess_translation_data(data_dir, extra_flags=None): | |
| preprocess_parser = options.get_preprocessing_parser() | |
| preprocess_args = preprocess_parser.parse_args( | |
| [ | |
| "--source-lang", | |
| "in", | |
| "--target-lang", | |
| "out", | |
| "--trainpref", | |
| os.path.join(data_dir, "train"), | |
| "--validpref", | |
| os.path.join(data_dir, "valid"), | |
| "--testpref", | |
| os.path.join(data_dir, "test"), | |
| "--thresholdtgt", | |
| "0", | |
| "--thresholdsrc", | |
| "0", | |
| "--destdir", | |
| data_dir, | |
| ] | |
| + (extra_flags or []), | |
| ) | |
| preprocess.main(preprocess_args) | |
| def preprocess_summarization_data(data_dir, extra_flags=None): | |
| preprocess_parser = options.get_preprocessing_parser() | |
| preprocess_args = preprocess_parser.parse_args( | |
| [ | |
| "--source-lang", | |
| "in", | |
| "--target-lang", | |
| "out", | |
| "--trainpref", | |
| os.path.join(data_dir, "train"), | |
| "--validpref", | |
| os.path.join(data_dir, "valid"), | |
| "--testpref", | |
| os.path.join(data_dir, "test"), | |
| "--thresholdtgt", | |
| "0", | |
| "--thresholdsrc", | |
| "0", | |
| "--joined-dictionary", | |
| "--destdir", | |
| data_dir, | |
| ] | |
| + (extra_flags or []), | |
| ) | |
| preprocess.main(preprocess_args) | |
| def create_laser_data_and_config_json(data_dir): | |
| src_langs = ["de", "fr", "ru", "tr", "zh"] | |
| tgt_langs = ["en", "es"] | |
| config_json = {} | |
| config_train_json = [] | |
| src_vocab = None | |
| tgt_vocab = None | |
| for src_lang in src_langs: | |
| for tgt_lang in tgt_langs: | |
| langpair_folder = f"{src_lang}-{tgt_lang}" | |
| langpair_path = os.path.join(data_dir, langpair_folder) | |
| os.mkdir(langpair_path) | |
| create_dummy_data(langpair_path) | |
| preprocess_translation_data(langpair_path, ["--dataset-impl", "cached"]) | |
| src_vocab = os.path.join(langpair_path, "dict.in.txt") | |
| tgt_vocab = os.path.join(langpair_path, "dict.out.txt") | |
| config_train_json.append( | |
| { | |
| "id": 0 if tgt_lang == "en" else 1, | |
| "src": os.path.join(langpair_path, "train.in-out.in"), | |
| "tgt": os.path.join(langpair_path, "train.in-out.out"), | |
| } | |
| ) | |
| config_json["src_vocab"] = src_vocab | |
| config_json["tgt_vocab"] = tgt_vocab | |
| config_json["train"] = config_train_json | |
| with open(os.path.join(data_dir, "laserconfig.json"), "w") as config_file: | |
| json.dump(config_json, config_file) | |
| return config_file | |
| def train_translation_model( | |
| data_dir, | |
| arch, | |
| extra_flags=None, | |
| task="translation", | |
| run_validation=False, | |
| lang_flags=None, | |
| extra_valid_flags=None, | |
| world_size=1, | |
| ): | |
| if lang_flags is None: | |
| lang_flags = [ | |
| "--source-lang", | |
| "in", | |
| "--target-lang", | |
| "out", | |
| ] | |
| train_parser = options.get_training_parser() | |
| train_args = options.parse_args_and_arch( | |
| train_parser, | |
| [ | |
| "--task", | |
| task, | |
| data_dir, | |
| "--save-dir", | |
| data_dir, | |
| "--arch", | |
| arch, | |
| "--optimizer", | |
| "nag", | |
| "--lr", | |
| "0.05", | |
| "--max-tokens", | |
| "500", | |
| "--max-epoch", | |
| "1", | |
| "--no-progress-bar", | |
| "--distributed-world-size", | |
| str(world_size), | |
| "--num-workers", | |
| "0", | |
| ] | |
| + lang_flags | |
| + (extra_flags or []), | |
| ) | |
| cfg = convert_namespace_to_omegaconf(train_args) | |
| distributed_utils.call_main(cfg, train.main) | |
| if run_validation: | |
| # test validation | |
| validate_parser = options.get_validation_parser() | |
| validate_args = options.parse_args_and_arch( | |
| validate_parser, | |
| [ | |
| "--task", | |
| task, | |
| data_dir, | |
| "--path", | |
| os.path.join(data_dir, "checkpoint_last.pt"), | |
| "--valid-subset", | |
| "valid", | |
| "--max-tokens", | |
| "500", | |
| "--no-progress-bar", | |
| "--num-workers", | |
| "0", | |
| ] | |
| + lang_flags | |
| + (extra_valid_flags or []), | |
| ) | |
| validate.main(validate_args) | |
| def generate_main(data_dir, extra_flags=None, path=None): | |
| if extra_flags is None: | |
| extra_flags = [ | |
| "--print-alignment", | |
| ] | |
| if path is None: | |
| path = os.path.join(data_dir, "checkpoint_last.pt") | |
| generate_parser = options.get_generation_parser() | |
| generate_args = options.parse_args_and_arch( | |
| generate_parser, | |
| [ | |
| data_dir, | |
| "--path", | |
| path, | |
| "--beam", | |
| "3", | |
| "--batch-size", | |
| "64", | |
| "--max-len-b", | |
| "5", | |
| "--gen-subset", | |
| "valid", | |
| "--no-progress-bar", | |
| "--num-workers", | |
| "0", | |
| ] | |
| + (extra_flags or []), | |
| ) | |
| # evaluate model in batch mode | |
| generate.main(generate_args) | |
| # evaluate model interactively | |
| generate_args.buffer_size = 0 | |
| generate_args.input = "-" | |
| generate_args.batch_size = None | |
| orig_stdin = sys.stdin | |
| sys.stdin = StringIO("h e l l o\n") | |
| interactive.main(generate_args) | |
| sys.stdin = orig_stdin | |
| class TestDataset(torch.utils.data.Dataset): | |
| def __init__(self, data): | |
| super().__init__() | |
| self.data = data | |
| self.sizes = None | |
| def __getitem__(self, index): | |
| return self.data[index] | |
| def __len__(self): | |
| return len(self.data) | |
| class TestTranslationTask(LegacyFairseqTask): | |
| def __init__(self, args, src_dict, tgt_dict, model): | |
| super().__init__(args) | |
| self.src_dict = src_dict | |
| self.tgt_dict = tgt_dict | |
| self.model = model | |
| def setup_task(cls, args, src_dict=None, tgt_dict=None, model=None): | |
| return cls(args, src_dict, tgt_dict, model) | |
| def build_model(self, args): | |
| return TestModel.build_model(args, self) | |
| def source_dictionary(self): | |
| return self.src_dict | |
| def target_dictionary(self): | |
| return self.tgt_dict | |
| class TestModel(FairseqEncoderDecoderModel): | |
| def __init__(self, encoder, decoder): | |
| super().__init__(encoder, decoder) | |
| def build_model(cls, args, task): | |
| encoder = TestEncoder(args, task.source_dictionary) | |
| decoder = TestIncrementalDecoder(args, task.target_dictionary) | |
| return cls(encoder, decoder) | |
| class TestEncoder(FairseqEncoder): | |
| def __init__(self, args, dictionary): | |
| super().__init__(dictionary) | |
| self.args = args | |
| def forward(self, src_tokens, src_lengths=None, **kwargs): | |
| return EncoderOut( | |
| encoder_out=src_tokens, | |
| encoder_padding_mask=None, | |
| encoder_embedding=None, | |
| encoder_states=None, | |
| src_tokens=None, | |
| src_lengths=None, | |
| ) | |
| def reorder_encoder_out(self, encoder_out, new_order): | |
| return EncoderOut( | |
| encoder_out=encoder_out.encoder_out.index_select(0, new_order), | |
| encoder_padding_mask=None, | |
| encoder_embedding=None, | |
| encoder_states=None, | |
| src_tokens=None, | |
| src_lengths=None, | |
| ) | |
| class TestIncrementalDecoder(FairseqIncrementalDecoder): | |
| def __init__(self, args, dictionary): | |
| super().__init__(dictionary) | |
| assert hasattr(args, "beam_probs") or hasattr(args, "probs") | |
| args.max_decoder_positions = getattr(args, "max_decoder_positions", 100) | |
| self.args = args | |
| def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): | |
| if incremental_state is not None: | |
| prev_output_tokens = prev_output_tokens[:, -1:] | |
| bbsz = prev_output_tokens.size(0) | |
| vocab = len(self.dictionary) | |
| src_len = encoder_out.encoder_out.size(1) | |
| tgt_len = prev_output_tokens.size(1) | |
| # determine number of steps | |
| if incremental_state is not None: | |
| # cache step number | |
| step = utils.get_incremental_state(self, incremental_state, "step") | |
| if step is None: | |
| step = 0 | |
| utils.set_incremental_state(self, incremental_state, "step", step + 1) | |
| steps = [step] | |
| else: | |
| steps = list(range(tgt_len)) | |
| # define output in terms of raw probs | |
| if hasattr(self.args, "probs"): | |
| assert ( | |
| self.args.probs.dim() == 3 | |
| ), "expected probs to have size bsz*steps*vocab" | |
| probs = self.args.probs.index_select(1, torch.LongTensor(steps)) | |
| else: | |
| probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_() | |
| for i, step in enumerate(steps): | |
| # args.beam_probs gives the probability for every vocab element, | |
| # starting with eos, then unknown, and then the rest of the vocab | |
| if step < len(self.args.beam_probs): | |
| probs[:, i, self.dictionary.eos() :] = self.args.beam_probs[step] | |
| else: | |
| probs[:, i, self.dictionary.eos()] = 1.0 | |
| # random attention | |
| attn = torch.rand(bbsz, tgt_len, src_len) | |
| dev = prev_output_tokens.device | |
| return probs.to(dev), {"attn": [attn.to(dev)]} | |
| def get_normalized_probs(self, net_output, log_probs, _): | |
| # the decoder returns probabilities directly | |
| probs = net_output[0] | |
| if log_probs: | |
| return probs.log() | |
| else: | |
| return probs | |
| def max_positions(self): | |
| return self.args.max_decoder_positions | |
| class TestReshapingEncoder(FairseqEncoder): | |
| def __init__(self, args, dictionary): | |
| super().__init__(dictionary) | |
| self.args = args | |
| def forward(self, src_tokens, src_lengths=None, **kwargs): | |
| b_sz, t_sz = src_tokens.shape | |
| padding_needed = t_sz % 2 | |
| x = src_tokens | |
| if padding_needed > 0: | |
| padding_needed = 2 - padding_needed | |
| x = F.pad(x, (0, padding_needed)) | |
| return EncoderOut( | |
| encoder_out=x.view(b_sz, -1, 2), | |
| encoder_padding_mask=None, | |
| encoder_embedding=None, | |
| encoder_states=None, | |
| src_tokens=None, | |
| src_lengths=None, | |
| ) | |
| def reorder_encoder_out(self, encoder_out, new_order): | |
| return EncoderOut( | |
| encoder_out=encoder_out.encoder_out.index_select(0, new_order), | |
| encoder_padding_mask=None, | |
| encoder_embedding=None, | |
| encoder_states=None, | |
| src_tokens=None, | |
| src_lengths=None, | |
| ) | |
| class TestReshapingModel(FairseqEncoderDecoderModel): | |
| def __init__(self, encoder, decoder): | |
| super().__init__(encoder, decoder) | |
| def build_model(cls, args, task): | |
| encoder = TestReshapingEncoder(args, task.source_dictionary) | |
| decoder = TestIncrementalDecoder(args, task.target_dictionary) | |
| return cls(encoder, decoder) | |
| class TestAdditionalInputEncoder(FairseqEncoder): | |
| def __init__(self, args, dictionary): | |
| super().__init__(dictionary) | |
| self.args = args | |
| def forward(self, src_tokens, src_lengths=None, **kwargs): | |
| assert "fancy_other_input" in kwargs | |
| assert kwargs["fancy_other_input"] is not None | |
| return EncoderOut( | |
| encoder_out=src_tokens, | |
| encoder_padding_mask=None, | |
| encoder_embedding=None, | |
| encoder_states=None, | |
| src_tokens=None, | |
| src_lengths=None, | |
| ) | |
| def reorder_encoder_out(self, encoder_out, new_order): | |
| return EncoderOut( | |
| encoder_out=encoder_out.encoder_out.index_select(0, new_order), | |
| encoder_padding_mask=None, | |
| encoder_embedding=None, | |
| encoder_states=None, | |
| src_tokens=None, | |
| src_lengths=None, | |
| ) | |
| class TestAdditionalInputModel(FairseqEncoderDecoderModel): | |
| def __init__(self, encoder, decoder): | |
| super().__init__(encoder, decoder) | |
| def build_model(cls, args, task): | |
| encoder = TestAdditionalInputEncoder(args, task.source_dictionary) | |
| decoder = TestIncrementalDecoder(args, task.target_dictionary) | |
| return cls(encoder, decoder) | |
| def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): | |
| encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) | |
| decoder_out = self.decoder( | |
| prev_output_tokens, encoder_out=encoder_out, **kwargs | |
| ) | |
| return decoder_out | |
| def train_language_model( | |
| data_dir, | |
| arch, | |
| extra_flags=None, | |
| run_validation=False, | |
| extra_valid_flags=None, | |
| task="language_modeling", | |
| world_size=1, | |
| ): | |
| train_parser = options.get_training_parser() | |
| train_args = options.parse_args_and_arch( | |
| train_parser, | |
| [ | |
| "--task", | |
| task, | |
| data_dir, | |
| "--arch", | |
| arch, | |
| "--optimizer", | |
| "adam", | |
| "--lr", | |
| "0.0001", | |
| "--max-tokens", | |
| "500", | |
| "--tokens-per-sample", | |
| "500", | |
| "--save-dir", | |
| data_dir, | |
| "--max-epoch", | |
| "1", | |
| "--no-progress-bar", | |
| "--distributed-world-size", | |
| str(world_size), | |
| "--ddp-backend", | |
| "no_c10d", | |
| "--num-workers", | |
| "0", | |
| ] | |
| + (extra_flags or []), | |
| ) | |
| cfg = convert_namespace_to_omegaconf(train_args) | |
| distributed_utils.call_main(cfg, train.main) | |
| if run_validation: | |
| # test validation | |
| validate_parser = options.get_validation_parser() | |
| validate_args = options.parse_args_and_arch( | |
| validate_parser, | |
| [ | |
| "--task", | |
| task, | |
| data_dir, | |
| "--path", | |
| os.path.join(data_dir, "checkpoint_last.pt"), | |
| "--valid-subset", | |
| "valid", | |
| "--max-tokens", | |
| "500", | |
| "--no-progress-bar", | |
| "--num-workers", | |
| "0", | |
| ] | |
| + (extra_valid_flags or []), | |
| ) | |
| validate.main(validate_args) | |