Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| This module contains collection of classes which implement | |
| collate functionalities for various tasks. | |
| Collaters should know what data to expect for each sample | |
| and they should pack / collate them into batches | |
| """ | |
| from __future__ import absolute_import, division, print_function, unicode_literals | |
| import numpy as np | |
| import torch | |
| from fairseq.data import data_utils as fairseq_data_utils | |
| class Seq2SeqCollater(object): | |
| """ | |
| Implements collate function mainly for seq2seq tasks | |
| This expects each sample to contain feature (src_tokens) and | |
| targets. | |
| This collator is also used for aligned training task. | |
| """ | |
| def __init__( | |
| self, | |
| feature_index=0, | |
| label_index=1, | |
| pad_index=1, | |
| eos_index=2, | |
| move_eos_to_beginning=True, | |
| ): | |
| self.feature_index = feature_index | |
| self.label_index = label_index | |
| self.pad_index = pad_index | |
| self.eos_index = eos_index | |
| self.move_eos_to_beginning = move_eos_to_beginning | |
| def _collate_frames(self, frames): | |
| """Convert a list of 2d frames into a padded 3d tensor | |
| Args: | |
| frames (list): list of 2d frames of size L[i]*f_dim. Where L[i] is | |
| length of i-th frame and f_dim is static dimension of features | |
| Returns: | |
| 3d tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] | |
| """ | |
| len_max = max(frame.size(0) for frame in frames) | |
| f_dim = frames[0].size(1) | |
| res = frames[0].new(len(frames), len_max, f_dim).fill_(0.0) | |
| for i, v in enumerate(frames): | |
| res[i, : v.size(0)] = v | |
| return res | |
| def collate(self, samples): | |
| """ | |
| utility function to collate samples into batch for speech recognition. | |
| """ | |
| if len(samples) == 0: | |
| return {} | |
| # parse samples into torch tensors | |
| parsed_samples = [] | |
| for s in samples: | |
| # skip invalid samples | |
| if s["data"][self.feature_index] is None: | |
| continue | |
| source = s["data"][self.feature_index] | |
| if isinstance(source, (np.ndarray, np.generic)): | |
| source = torch.from_numpy(source) | |
| target = s["data"][self.label_index] | |
| if isinstance(target, (np.ndarray, np.generic)): | |
| target = torch.from_numpy(target).long() | |
| elif isinstance(target, list): | |
| target = torch.LongTensor(target) | |
| parsed_sample = {"id": s["id"], "source": source, "target": target} | |
| parsed_samples.append(parsed_sample) | |
| samples = parsed_samples | |
| id = torch.LongTensor([s["id"] for s in samples]) | |
| frames = self._collate_frames([s["source"] for s in samples]) | |
| # sort samples by descending number of frames | |
| frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples]) | |
| frames_lengths, sort_order = frames_lengths.sort(descending=True) | |
| id = id.index_select(0, sort_order) | |
| frames = frames.index_select(0, sort_order) | |
| target = None | |
| target_lengths = None | |
| prev_output_tokens = None | |
| if samples[0].get("target", None) is not None: | |
| ntokens = sum(len(s["target"]) for s in samples) | |
| target = fairseq_data_utils.collate_tokens( | |
| [s["target"] for s in samples], | |
| self.pad_index, | |
| self.eos_index, | |
| left_pad=False, | |
| move_eos_to_beginning=False, | |
| ) | |
| target = target.index_select(0, sort_order) | |
| target_lengths = torch.LongTensor( | |
| [s["target"].size(0) for s in samples] | |
| ).index_select(0, sort_order) | |
| prev_output_tokens = fairseq_data_utils.collate_tokens( | |
| [s["target"] for s in samples], | |
| self.pad_index, | |
| self.eos_index, | |
| left_pad=False, | |
| move_eos_to_beginning=self.move_eos_to_beginning, | |
| ) | |
| prev_output_tokens = prev_output_tokens.index_select(0, sort_order) | |
| else: | |
| ntokens = sum(len(s["source"]) for s in samples) | |
| batch = { | |
| "id": id, | |
| "ntokens": ntokens, | |
| "net_input": {"src_tokens": frames, "src_lengths": frames_lengths}, | |
| "target": target, | |
| "target_lengths": target_lengths, | |
| "nsentences": len(samples), | |
| } | |
| if prev_output_tokens is not None: | |
| batch["net_input"]["prev_output_tokens"] = prev_output_tokens | |
| return batch | |