Spaces:
Runtime error
Runtime error
| # Copyright (c) 2023 Wenet Community. (authors: Dinghao Zhou) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import collections | |
| from collections.abc import Callable | |
| import copy | |
| import sys | |
| import tarfile | |
| import logging | |
| from typing import List, Optional | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import IterDataPipe, functional_datapipe | |
| from torch.utils.data import datapipes | |
| from torch.utils.data.datapipes.iter import Mapper | |
| from torch.utils.data.datapipes.iter.sharding import ( | |
| SHARDING_PRIORITIES, ShardingFilterIterDataPipe) | |
| from torch.utils.data.datapipes.utils.common import _check_unpickable_fn | |
| from wenet.dataset.processor import parse_url | |
| class MapperIgnoreErrorDataPipe(Mapper): | |
| def __init__(self, | |
| dataset: IterDataPipe, | |
| fn: Callable, | |
| input_col=None, | |
| output_col=None, | |
| log_error: bool = True) -> None: | |
| super().__init__(dataset, fn, input_col, output_col) | |
| self._iter = None | |
| self.log_error = log_error | |
| def __iter__(self): | |
| if self._iter is None: | |
| self._iter = iter(self.datapipe) | |
| while True: | |
| try: | |
| elem = next(self._iter) | |
| yield self._apply_fn(elem) | |
| except StopIteration: | |
| self._iter = None | |
| return | |
| except Exception as ex: | |
| if self.log_error: | |
| logging.warning(str(ex)) | |
| class BucketBySequenceLengthDataPipe(IterDataPipe): | |
| def __init__( | |
| self, | |
| dataset: IterDataPipe, | |
| elem_length_func, | |
| bucket_boundaries: List[int], | |
| bucket_batch_sizes: List[int], | |
| wrapper_class=None, | |
| ) -> None: | |
| super().__init__() | |
| _check_unpickable_fn(elem_length_func) | |
| assert len(bucket_batch_sizes) == len(bucket_boundaries) + 1 | |
| self.bucket_batch_sizes = bucket_batch_sizes | |
| self.bucket_boundaries = bucket_boundaries + [sys.maxsize] | |
| self.elem_length_func = elem_length_func | |
| self._group_dp = GroupByWindowDataPipe(dataset, | |
| self._element_to_bucket_id, | |
| self._window_size_func, | |
| wrapper_class=wrapper_class) | |
| def __iter__(self): | |
| yield from self._group_dp | |
| def _element_to_bucket_id(self, elem): | |
| seq_len = self.elem_length_func(elem) | |
| bucket_id = 0 | |
| for (i, b) in enumerate(self.bucket_boundaries): | |
| if seq_len < b: | |
| bucket_id = i | |
| break | |
| return bucket_id | |
| def _window_size_func(self, bucket_id): | |
| return self.bucket_batch_sizes[bucket_id] | |
| class GroupByWindowDataPipe(datapipes.iter.Grouper): | |
| def __init__( | |
| self, | |
| dataset: IterDataPipe, | |
| key_func, | |
| window_size_func, | |
| wrapper_class=None, | |
| ): | |
| super().__init__(dataset, | |
| key_func, | |
| keep_key=False, | |
| group_size=None, | |
| drop_remaining=False) | |
| _check_unpickable_fn(window_size_func) | |
| self.dp = dataset | |
| self.window_size_func = window_size_func | |
| if wrapper_class is not None: | |
| _check_unpickable_fn(wrapper_class) | |
| del self.wrapper_class | |
| self.wrapper_class = wrapper_class | |
| def __iter__(self): | |
| for x in self.datapipe: | |
| key = self.group_key_fn(x) | |
| self.buffer_elements[key].append(x) | |
| self.curr_buffer_size += 1 | |
| group_size = self.window_size_func(key) | |
| if group_size == len(self.buffer_elements[key]): | |
| result = self.wrapper_class(self.buffer_elements[key]) | |
| yield result | |
| self.curr_buffer_size -= len(self.buffer_elements[key]) | |
| del self.buffer_elements[key] | |
| if self.curr_buffer_size == self.max_buffer_size: | |
| result_to_yield = self._remove_biggest_key() | |
| if result_to_yield is not None: | |
| result = self.wrapper_class(result_to_yield) | |
| yield result | |
| for key in tuple(self.buffer_elements.keys()): | |
| result = self.wrapper_class(self.buffer_elements.pop(key)) | |
| self.curr_buffer_size -= len(result) | |
| yield result | |
| class SortDataPipe(IterDataPipe): | |
| def __init__(self, | |
| dataset: IterDataPipe, | |
| buffer_size: int = 500, | |
| key_func=None, | |
| reverse=False) -> None: | |
| if key_func is not None: | |
| _check_unpickable_fn(key_func) | |
| self.buffer_size = buffer_size | |
| super().__init__() | |
| self.dp = dataset | |
| self._buffer = [] | |
| self.key_func = key_func | |
| self.reverse = reverse | |
| def __iter__(self): | |
| for elem in self.dp: | |
| self._buffer.append(elem) | |
| if len(self._buffer) >= self.buffer_size: | |
| self._buffer.sort(key=self.key_func, reverse=self.reverse) | |
| for x in self._buffer: | |
| yield x | |
| del self._buffer | |
| self._buffer = [] | |
| # The sample left over | |
| self._buffer.sort(key=self.key_func, reverse=self.reverse) | |
| for x in self._buffer: | |
| yield x | |
| del self._buffer | |
| self._buffer = [] | |
| class DynamicBatchDataPipe(IterDataPipe): | |
| def __init__(self, dataset: IterDataPipe, window_class, | |
| wrapper_class) -> None: | |
| _check_unpickable_fn(window_class) | |
| _check_unpickable_fn(wrapper_class) | |
| super().__init__() | |
| self.dp = dataset | |
| assert window_class is not None | |
| assert wrapper_class is not None | |
| self.window_class = window_class | |
| self._buffer = [] | |
| self._wrappr_class = wrapper_class | |
| def __iter__(self): | |
| for elem in self.dp: | |
| if not self.window_class(elem, len(self._buffer)): | |
| self._buffer.append(elem) | |
| else: | |
| if len(self._buffer) > 0: | |
| yield self._wrappr_class(self._buffer) | |
| del self._buffer | |
| self._buffer = [elem] | |
| if len(self._buffer) > 0: | |
| yield self._wrappr_class(self._buffer) | |
| del self._buffer | |
| self._buffer = [] | |
| class PrefetchDataPipe(IterDataPipe): | |
| """Performs prefetching""" | |
| def __init__( | |
| self, | |
| dataset: IterDataPipe, | |
| buffer_size: int = 500, | |
| ): | |
| # TODO(Mddct): support multiprocessing pool with shared-memory to | |
| # prefetch | |
| super().__init__() | |
| self.dp = dataset | |
| self._iter = None | |
| self._prefetch_buffer_size = buffer_size | |
| self._buffer = None | |
| if self._prefetch_buffer_size > 0: | |
| self._buffer = collections.deque(maxlen=self._prefetch_buffer_size) | |
| def __iter__(self): | |
| if self._prefetch_buffer_size > 0: | |
| if self._iter is None: | |
| self._iter = iter(self.dp) | |
| assert self._buffer is not None | |
| while True: | |
| if len(self._buffer) <= self._prefetch_buffer_size // 2: | |
| while len(self._buffer) < self._prefetch_buffer_size: | |
| try: | |
| self._buffer.append(next(self._iter)) | |
| except StopIteration: | |
| if len(self._buffer) != 0: | |
| while len(self._buffer) > 0: | |
| yield self._buffer.popleft() | |
| self._iter = None | |
| return | |
| while len(self._buffer) > self._prefetch_buffer_size // 2: | |
| elem = self._buffer.popleft() | |
| yield elem | |
| else: | |
| yield from self.dp | |
| class RepeatDatapipe(IterDataPipe): | |
| def __init__(self, dataset: IterDataPipe, count: int = -1): | |
| super().__init__() | |
| self.dp = dataset | |
| self.count = count | |
| def __iter__(self): | |
| if self.count == 1: | |
| yield from self.dp | |
| return | |
| i = 0 | |
| while self.count < 0 or i < self.count: | |
| for elem in self.dp: | |
| new_elem = copy.copy(elem) | |
| yield new_elem | |
| i += 1 | |
| class ShardDataPipe(ShardingFilterIterDataPipe): | |
| def __init__(self, dataset: IterDataPipe, partition: bool = False): | |
| super().__init__(dataset, None) | |
| self.partition = partition | |
| self.dp = dataset | |
| def apply_sharding(self, num_of_instances: int, instance_id: int, | |
| sharding_group: SHARDING_PRIORITIES): | |
| if self.partition: | |
| return super().apply_sharding(num_of_instances, instance_id, | |
| sharding_group) | |
| else: | |
| # We can not handle uneven data for CV on DDP, so we don't | |
| # sample data by rank, that means every GPU gets the same | |
| # and all the CV data | |
| info = torch.utils.data.get_worker_info() | |
| if info is None: | |
| self.num_of_instances = 1 | |
| self.instance_id = 0 | |
| else: | |
| n_workers_per_device = info.num_workers | |
| self.num_of_instances = n_workers_per_device | |
| self.instance_id = info.id | |
| class InterlaveDataPipe(IterDataPipe): | |
| def __init__( | |
| self, | |
| source_datapipes: List[IterDataPipe], | |
| weights: Optional[List[float]] = None, | |
| seed=2027, | |
| ): | |
| super().__init__() | |
| self.rng = np.random.default_rng(seed) | |
| self.source_datapipes = source_datapipes | |
| self.weights = weights | |
| if weights is None: | |
| self.weights = [1 / len(self.source_datapipes)] * len( | |
| self.source_datapipes) | |
| else: | |
| self.weights = [weight / sum(weights) for weight in weights] | |
| self.iters = None | |
| def __iter__(self): | |
| weights = copy.deepcopy(self.weights) | |
| exhausted = len(self.source_datapipes) * [False] | |
| if self.iters is None: | |
| self.iters = [(i, iter(d)) | |
| for i, d in enumerate(self.source_datapipes)] | |
| while True: | |
| # TODO(Mddct): rng | |
| index_iter = self.rng.choice(self.iters, p=weights) | |
| i, ite = index_iter | |
| try: | |
| elem = next(ite) | |
| yield elem | |
| except StopIteration: | |
| weights[i] = 0. | |
| exhausted[i] = True | |
| if all(exhausted): | |
| return | |
| weights = [weight / sum(weights) for weight in weights] | |
| class TextLineDataPipe(IterDataPipe): | |
| """ Streamming Text line | |
| """ | |
| def __init__(self, filenames, mode='r'): | |
| super().__init__() | |
| _dp = datapipes.iter.FileLister(filenames) | |
| _dp = datapipes.iter.FileOpener(_dp, mode=mode) | |
| self.dp = _dp | |
| def __iter__(self): | |
| for fname, stream in self.dp: | |
| for line in stream: | |
| line = line.strip('\n') | |
| yield {"file_name": fname, "line": line} | |
| stream.close() | |
| class TarsDataPipe(IterDataPipe): | |
| """ Decode wenet's tar , yield {'txt': "...", "raw": "..."} | |
| """ | |
| def __init__(self, dataset: IterDataPipe) -> None: | |
| super().__init__() | |
| self.dp = dataset | |
| def __iter__(self): | |
| from wenet.dataset.processor import AUDIO_FORMAT_SETS | |
| for sample in self.dp: | |
| assert 'file_name' in sample | |
| assert 'line' in sample | |
| assert 'stream' in sample | |
| try: | |
| with tarfile.open(fileobj=sample['stream'], | |
| mode="r:*") as stream: | |
| prev_prefix = None | |
| example = { | |
| 'file_name': sample['file_name'], | |
| 'tar_file_name': sample['line'] | |
| } | |
| valid = True | |
| for tarinfo in stream: | |
| name = tarinfo.name | |
| pos = name.rfind('.') | |
| assert pos > 0 | |
| prefix, postfix = name[:pos], name[pos + 1:] | |
| if prev_prefix is not None and prefix != prev_prefix: | |
| example['key'] = prev_prefix | |
| if valid: | |
| yield example | |
| example = { | |
| 'file_name': sample['file_name'], | |
| 'tar_file_name': sample['line'] | |
| } | |
| valid = True | |
| with stream.extractfile(tarinfo) as file_obj: | |
| try: | |
| if postfix == 'txt': | |
| example['txt'] = file_obj.read().decode( | |
| 'utf8').strip() | |
| elif postfix in AUDIO_FORMAT_SETS: | |
| example['wav'] = file_obj.read() | |
| else: | |
| example[postfix] = file_obj.read() | |
| except Exception as ex: | |
| valid = False | |
| logging.warning( | |
| 'error to parse {}'.format(name)) | |
| prev_prefix = prefix | |
| if prev_prefix is not None: | |
| example['key'] = prev_prefix | |
| yield example | |
| except Exception as ex: | |
| msg = 'In tar_file_and_group: {} when processing {}'.format( | |
| ex, sample['line']) | |
| logging.warning(msg) | |
| finally: | |
| if 'process' in sample: | |
| sample['process'].communicate() | |
| sample['stream'].close() | |
| class WenetRawDatasetSource(IterDataPipe): | |
| def __init__(self, | |
| filenames: str, | |
| prefetch: int = 500, | |
| partition: bool = True, | |
| shuffle: bool = False, | |
| shuffle_size: int = 10000, | |
| cycle: int = 1) -> None: | |
| super().__init__() | |
| self.dp = TextLineDataPipe(filenames) | |
| if shuffle: | |
| self.dp = self.dp.shuffle(buffer_size=shuffle_size) | |
| self.dp = self.dp.repeat(cycle).prefetch(prefetch) | |
| self.dp = self.dp.shard(partition) | |
| def __iter__(self): | |
| for d in self.dp: | |
| yield d | |
| class WenetTarShardDatasetSource(IterDataPipe): | |
| def __init__(self, | |
| filenames: str, | |
| prefetch: int = 500, | |
| partition: bool = True, | |
| shuffle: bool = False, | |
| shuffle_size: int = 10000, | |
| cycle: int = 1) -> None: | |
| super().__init__() | |
| self.dp = TextLineDataPipe(filenames) | |
| if shuffle: | |
| self.dp = self.dp.shuffle(buffer_size=shuffle_size) | |
| self.dp = self.dp.repeat(cycle) | |
| self.dp = self.dp.shard(partition).map_ignore_error( | |
| parse_url).tar_file_and_group().prefetch(prefetch) | |
| def __iter__(self): | |
| for d in self.dp: | |
| yield d | |