Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import numpy as np | |
| from fairseq.data import FairseqDataset | |
| from . import data_utils | |
| from .collaters import Seq2SeqCollater | |
| class AsrDataset(FairseqDataset): | |
| """ | |
| A dataset representing speech and corresponding transcription. | |
| Args: | |
| aud_paths: (List[str]): A list of str with paths to audio files. | |
| aud_durations_ms (List[int]): A list of int containing the durations of | |
| audio files. | |
| tgt (List[torch.LongTensor]): A list of LongTensors containing the indices | |
| of target transcriptions. | |
| tgt_dict (~fairseq.data.Dictionary): target vocabulary. | |
| ids (List[str]): A list of utterance IDs. | |
| speakers (List[str]): A list of speakers corresponding to utterances. | |
| num_mel_bins (int): Number of triangular mel-frequency bins (default: 80) | |
| frame_length (float): Frame length in milliseconds (default: 25.0) | |
| frame_shift (float): Frame shift in milliseconds (default: 10.0) | |
| """ | |
| def __init__( | |
| self, | |
| aud_paths, | |
| aud_durations_ms, | |
| tgt, | |
| tgt_dict, | |
| ids, | |
| speakers, | |
| num_mel_bins=80, | |
| frame_length=25.0, | |
| frame_shift=10.0, | |
| ): | |
| assert frame_length > 0 | |
| assert frame_shift > 0 | |
| assert all(x > frame_length for x in aud_durations_ms) | |
| self.frame_sizes = [ | |
| int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms | |
| ] | |
| assert len(aud_paths) > 0 | |
| assert len(aud_paths) == len(aud_durations_ms) | |
| assert len(aud_paths) == len(tgt) | |
| assert len(aud_paths) == len(ids) | |
| assert len(aud_paths) == len(speakers) | |
| self.aud_paths = aud_paths | |
| self.tgt_dict = tgt_dict | |
| self.tgt = tgt | |
| self.ids = ids | |
| self.speakers = speakers | |
| self.num_mel_bins = num_mel_bins | |
| self.frame_length = frame_length | |
| self.frame_shift = frame_shift | |
| self.s2s_collater = Seq2SeqCollater( | |
| 0, | |
| 1, | |
| pad_index=self.tgt_dict.pad(), | |
| eos_index=self.tgt_dict.eos(), | |
| move_eos_to_beginning=True, | |
| ) | |
| def __getitem__(self, index): | |
| import torchaudio | |
| import torchaudio.compliance.kaldi as kaldi | |
| tgt_item = self.tgt[index] if self.tgt is not None else None | |
| path = self.aud_paths[index] | |
| if not os.path.exists(path): | |
| raise FileNotFoundError("Audio file not found: {}".format(path)) | |
| sound, sample_rate = torchaudio.load_wav(path) | |
| output = kaldi.fbank( | |
| sound, | |
| num_mel_bins=self.num_mel_bins, | |
| frame_length=self.frame_length, | |
| frame_shift=self.frame_shift, | |
| ) | |
| output_cmvn = data_utils.apply_mv_norm(output) | |
| return {"id": index, "data": [output_cmvn.detach(), tgt_item]} | |
| def __len__(self): | |
| return len(self.aud_paths) | |
| def collater(self, samples): | |
| """Merge a list of samples to form a mini-batch. | |
| Args: | |
| samples (List[int]): sample indices to collate | |
| Returns: | |
| dict: a mini-batch suitable for forwarding with a Model | |
| """ | |
| return self.s2s_collater.collate(samples) | |
| def num_tokens(self, index): | |
| return self.frame_sizes[index] | |
| def size(self, index): | |
| """Return an example's size as a float or tuple. This value is used when | |
| filtering a dataset with ``--max-positions``.""" | |
| return ( | |
| self.frame_sizes[index], | |
| len(self.tgt[index]) if self.tgt is not None else 0, | |
| ) | |
| def ordered_indices(self): | |
| """Return an ordered list of indices. Batches will be constructed based | |
| on this order.""" | |
| return np.arange(len(self)) | |