Spaces:
Runtime error
Runtime error
| # Copyright 2022 The OFA-Sys Team. | |
| # All rights reserved. | |
| # This source code is licensed under the Apache 2.0 license | |
| # found in the LICENSE file in the root directory. | |
| from dataclasses import dataclass, field | |
| import logging | |
| import os | |
| import math | |
| import torch | |
| from typing import Dict, Optional | |
| from fairseq import search | |
| from fairseq.data import FairseqDataset, iterators | |
| from fairseq.optim.amp_optimizer import AMPOptimizer | |
| from fairseq.dataclass import FairseqDataclass | |
| from fairseq.tasks import FairseqTask, register_task | |
| from omegaconf import DictConfig | |
| logger = logging.getLogger(__name__) | |
| class OFAConfig(FairseqDataclass): | |
| data: Optional[str] = field( | |
| default=None, | |
| metadata={ | |
| "help": "comma separated path to data list, will be iterated upon during epochs " | |
| "in round-robin manner; valid data are always in the last" | |
| }, | |
| ) | |
| selected_cols: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "selected cols"}, | |
| ) | |
| bpe: Optional[str] = field( | |
| default='gpt2', | |
| metadata={"help": "which bpe to use"}, | |
| ) | |
| bpe_dir: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "bpe dir"}, | |
| ) | |
| max_source_positions: int = field( | |
| default=1024, metadata={"help": "max number of tokens in the source sequence"} | |
| ) | |
| max_target_positions: int = field( | |
| default=1024, metadata={"help": "max number of tokens in the target sequence"} | |
| ) | |
| max_src_length: int = field( | |
| default=128, metadata={"help": "the maximum src sequence length"} | |
| ) | |
| max_tgt_length: int = field( | |
| default=30, metadata={"help": "the maximum target sequence length"} | |
| ) | |
| code_dict_size: int = field( | |
| default=8192, metadata={"help": "code dict size"} | |
| ) | |
| patch_image_size: int = field( | |
| default=480, metadata={"help": "patch image size"} | |
| ) | |
| orig_patch_image_size: int = field( | |
| default=256, metadata={"help": "patch image size"} | |
| ) | |
| num_bins: int = field( | |
| default=1000, metadata={"help": "number of quantization bins"} | |
| ) | |
| imagenet_default_mean_and_std: bool = field( | |
| default=False, | |
| metadata={"help": "imagenet normalize"}, | |
| ) | |
| constraint_range: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "constraint range"} | |
| ) | |
| class OFATask(FairseqTask): | |
| def __init__(self, cfg: OFAConfig, src_dict, tgt_dict): | |
| super().__init__(cfg) | |
| self.src_dict = src_dict | |
| self.tgt_dict = tgt_dict | |
| def setup_task(cls, cfg: DictConfig, **kwargs): | |
| """Setup the task.""" | |
| # load dictionaries | |
| src_dict = cls.load_dictionary( | |
| os.path.join(cfg.bpe_dir, "dict.txt") | |
| ) | |
| tgt_dict = cls.load_dictionary( | |
| os.path.join(cfg.bpe_dir, "dict.txt") | |
| ) | |
| src_dict.add_symbol("<mask>") | |
| tgt_dict.add_symbol("<mask>") | |
| for i in range(cfg.code_dict_size): | |
| src_dict.add_symbol("<code_{}>".format(i)) | |
| tgt_dict.add_symbol("<code_{}>".format(i)) | |
| # quantization | |
| for i in range(cfg.num_bins): | |
| src_dict.add_symbol("<bin_{}>".format(i)) | |
| tgt_dict.add_symbol("<bin_{}>".format(i)) | |
| logger.info("source dictionary: {} types".format(len(src_dict))) | |
| logger.info("target dictionary: {} types".format(len(tgt_dict))) | |
| return cls(cfg, src_dict, tgt_dict) | |
| def get_batch_iterator( | |
| self, | |
| dataset, | |
| max_tokens=None, | |
| max_sentences=None, | |
| max_positions=None, | |
| ignore_invalid_inputs=False, | |
| required_batch_size_multiple=1, | |
| seed=1, | |
| num_shards=1, | |
| shard_id=0, | |
| num_workers=0, | |
| epoch=1, | |
| data_buffer_size=0, | |
| disable_iterator_cache=False, | |
| ): | |
| assert isinstance(dataset, FairseqDataset) | |
| # initialize the dataset with the correct starting epoch | |
| dataset.set_epoch(epoch) | |
| # create mini-batches with given size constraints | |
| batch_sampler = [ | |
| [j for j in range(i, min(i + max_sentences, len(dataset)))] | |
| for i in range(0, len(dataset), max_sentences) | |
| ] | |
| total_row_count = dataset.dataset.get_total_row_count() | |
| num_batches = math.ceil(math.ceil(total_row_count / num_shards) / max_sentences) | |
| if len(batch_sampler) < num_batches: | |
| batch_sampler.append([]) | |
| # return a reusable, sharded iterator | |
| epoch_iter = iterators.EpochBatchIterator( | |
| dataset=dataset, | |
| collate_fn=dataset.collater, | |
| batch_sampler=batch_sampler, | |
| seed=seed, | |
| num_shards=1, | |
| shard_id=0, | |
| num_workers=num_workers, | |
| epoch=epoch, | |
| buffer_size=data_buffer_size | |
| ) | |
| return epoch_iter | |
| def build_model(self, cfg: FairseqDataclass): | |
| model = super().build_model(cfg) | |
| if self.cfg.bpe == 'bert': | |
| bpe_dict = { | |
| "_name": "bert", | |
| "bpe_vocab_file": os.path.join(self.cfg.bpe_dir, "vocab.txt"), | |
| "bpe_cased": False | |
| } | |
| bpe_dict = DictConfig(bpe_dict) | |
| self.bpe = self.build_bpe(bpe_dict) | |
| else: | |
| bpe_dict = { | |
| "_name": "gpt2", | |
| "gpt2_encoder_json": os.path.join(self.cfg.bpe_dir, "encoder.json"), | |
| "gpt2_vocab_bpe": os.path.join(self.cfg.bpe_dir, "vocab.bpe") | |
| } | |
| bpe_dict = DictConfig(bpe_dict) | |
| self.bpe = self.build_bpe(bpe_dict) | |
| return model | |
| def build_generator( | |
| self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None, | |
| ): | |
| """ | |
| Build a :class:`~fairseq.SequenceGenerator` instance for this | |
| task. | |
| Args: | |
| models (List[~fairseq.models.FairseqModel]): ensemble of models | |
| args (fairseq.dataclass.configs.GenerationConfig): | |
| configuration object (dataclass) for generation | |
| extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass | |
| through to SequenceGenerator | |
| prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): | |
| If provided, this function constrains the beam search to | |
| allowed tokens only at each step. The provided function | |
| should take 2 arguments: the batch ID (`batch_id: int`) | |
| and a unidimensional tensor of token ids (`inputs_ids: | |
| torch.Tensor`). It has to return a `List[int]` with the | |
| allowed tokens for the next generation step conditioned | |
| on the previously generated tokens (`inputs_ids`) and | |
| the batch ID (`batch_id`). This argument is useful for | |
| constrained generation conditioned on the prefix, as | |
| described in "Autoregressive Entity Retrieval" | |
| (https://arxiv.org/abs/2010.00904) and | |
| https://github.com/facebookresearch/GENRE. | |
| """ | |
| if getattr(args, "score_reference", False): | |
| from fairseq.sequence_scorer import SequenceScorer | |
| return SequenceScorer( | |
| self.target_dictionary, | |
| compute_alignment=getattr(args, "print_alignment", False), | |
| ) | |
| from fairseq.sequence_generator import ( | |
| # SequenceGenerator, | |
| SequenceGeneratorWithAlignment, | |
| ) | |
| from models.sequence_generator import SequenceGenerator | |
| # Choose search strategy. Defaults to Beam Search. | |
| sampling = getattr(args, "sampling", False) | |
| sampling_topk = getattr(args, "sampling_topk", -1) | |
| sampling_topp = getattr(args, "sampling_topp", -1.0) | |
| diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) | |
| diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) | |
| match_source_len = getattr(args, "match_source_len", False) | |
| diversity_rate = getattr(args, "diversity_rate", -1) | |
| constrained = getattr(args, "constraints", False) | |
| if prefix_allowed_tokens_fn is None: | |
| prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) | |
| if ( | |
| sum( | |
| int(cond) | |
| for cond in [ | |
| sampling, | |
| diverse_beam_groups > 0, | |
| match_source_len, | |
| diversity_rate > 0, | |
| ] | |
| ) | |
| > 1 | |
| ): | |
| raise ValueError("Provided Search parameters are mutually exclusive.") | |
| assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" | |
| assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" | |
| if sampling: | |
| search_strategy = search.Sampling( | |
| self.target_dictionary, sampling_topk, sampling_topp | |
| ) | |
| elif diverse_beam_groups > 0: | |
| search_strategy = search.DiverseBeamSearch( | |
| self.target_dictionary, diverse_beam_groups, diverse_beam_strength | |
| ) | |
| elif match_source_len: | |
| # this is useful for tagging applications where the output | |
| # length should match the input length, so we hardcode the | |
| # length constraints for simplicity | |
| search_strategy = search.LengthConstrainedBeamSearch( | |
| self.target_dictionary, | |
| min_len_a=1, | |
| min_len_b=0, | |
| max_len_a=1, | |
| max_len_b=0, | |
| ) | |
| elif diversity_rate > -1: | |
| search_strategy = search.DiverseSiblingsSearch( | |
| self.target_dictionary, diversity_rate | |
| ) | |
| elif constrained: | |
| search_strategy = search.LexicallyConstrainedBeamSearch( | |
| self.target_dictionary, args.constraints | |
| ) | |
| elif prefix_allowed_tokens_fn: | |
| search_strategy = search.PrefixConstrainedBeamSearch( | |
| self.target_dictionary, prefix_allowed_tokens_fn | |
| ) | |
| else: | |
| search_strategy = search.BeamSearch(self.target_dictionary) | |
| extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} | |
| if seq_gen_cls is None: | |
| if getattr(args, "print_alignment", False): | |
| seq_gen_cls = SequenceGeneratorWithAlignment | |
| extra_gen_cls_kwargs["print_alignment"] = args.print_alignment | |
| else: | |
| seq_gen_cls = SequenceGenerator | |
| return seq_gen_cls( | |
| models, | |
| self.target_dictionary, | |
| beam_size=getattr(args, "beam", 5), | |
| max_len_a=getattr(args, "max_len_a", 0), | |
| max_len_b=getattr(args, "max_len_b", 200), | |
| min_len=getattr(args, "min_len", 1), | |
| normalize_scores=(not getattr(args, "unnormalized", False)), | |
| len_penalty=getattr(args, "lenpen", 1), | |
| unk_penalty=getattr(args, "unkpen", 0), | |
| temperature=getattr(args, "temperature", 1.0), | |
| match_source_len=getattr(args, "match_source_len", False), | |
| no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), | |
| search_strategy=search_strategy, | |
| constraint_range=self.cfg.constraint_range, | |
| **extra_gen_cls_kwargs, | |
| ) | |
| def train_step( | |
| self, sample, model, criterion, optimizer, update_num, ignore_grad=False, **extra_kwargs | |
| ): | |
| """ | |
| Do forward and backward, and return the loss as computed by *criterion* | |
| for the given *model* and *sample*. | |
| Args: | |
| sample (dict): the mini-batch. The format is defined by the | |
| :class:`~fairseq.data.FairseqDataset`. | |
| model (~fairseq.models.BaseFairseqModel): the model | |
| criterion (~fairseq.criterions.FairseqCriterion): the criterion | |
| optimizer (~fairseq.optim.FairseqOptimizer): the optimizer | |
| update_num (int): the current update | |
| ignore_grad (bool): multiply loss by 0 if this is set to True | |
| Returns: | |
| tuple: | |
| - the loss | |
| - the sample size, which is used as the denominator for the | |
| gradient | |
| - logging outputs to display while training | |
| """ | |
| model.train() | |
| model.set_num_updates(update_num) | |
| with torch.autograd.profiler.record_function("forward"): | |
| with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): | |
| loss, sample_size, logging_output = criterion(model, sample, update_num=update_num) | |
| if ignore_grad: | |
| loss *= 0 | |
| with torch.autograd.profiler.record_function("backward"): | |
| optimizer.backward(loss) | |
| return loss, sample_size, logging_output | |
| def max_positions(self): | |
| """Return the max sentence length allowed by the task.""" | |
| return (self.cfg.max_source_positions, self.cfg.max_target_positions) | |
| def source_dictionary(self): | |
| """Return the source :class:`~fairseq.data.Dictionary`.""" | |
| return self.src_dict | |
| def target_dictionary(self): | |
| """Return the target :class:`~fairseq.data.Dictionary`.""" | |
| return self.tgt_dict | |