Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| # @Author : Xinhao Mei @CVSSP, University of Surrey | |
| # @E-mail : [email protected] | |
| """ | |
| Implemenation of SpecAugment++, | |
| Adapated from Qiuqiang Kong's trochlibrosa: | |
| https://github.com/qiuqiangkong/torchlibrosa/blob/master/torchlibrosa/augmentation.py | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| class DropStripes: | |
| def __init__(self, dim, drop_width, stripes_num): | |
| """ Drop stripes. | |
| args: | |
| dim: int, dimension along which to drop | |
| drop_width: int, maximum width of stripes to drop | |
| stripes_num: int, how many stripes to drop | |
| """ | |
| super(DropStripes, self).__init__() | |
| assert dim in [2, 3] # dim 2: time; dim 3: frequency | |
| self.dim = dim | |
| self.drop_width = drop_width | |
| self.stripes_num = stripes_num | |
| def __call__(self, input): | |
| """input: (batch_size, channels, time_steps, freq_bins)""" | |
| assert input.ndimension() == 4 | |
| batch_size = input.shape[0] | |
| total_width = input.shape[self.dim] | |
| for n in range(batch_size): | |
| self.transform_slice(input[n], total_width) | |
| return input | |
| def transform_slice(self, e, total_width): | |
| """ e: (channels, time_steps, freq_bins)""" | |
| for _ in range(self.stripes_num): | |
| distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0] | |
| bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0] | |
| if self.dim == 2: | |
| e[:, bgn: bgn + distance, :] = 0 | |
| elif self.dim == 3: | |
| e[:, :, bgn: bgn + distance] = 0 | |
| class MixStripes: | |
| def __init__(self, dim, mix_width, stripes_num): | |
| """ Mix stripes | |
| args: | |
| dim: int, dimension along which to mix | |
| mix_width: int, maximum width of stripes to mix | |
| stripes_num: int, how many stripes to mix | |
| """ | |
| super(MixStripes, self).__init__() | |
| assert dim in [2, 3] | |
| self.dim = dim | |
| self.mix_width = mix_width | |
| self.stripes_num = stripes_num | |
| def __call__(self, input): | |
| """input: (batch_size, channel, time_steps, freq_bins)""" | |
| assert input.ndimension() == 4 | |
| batch_size = input.shape[0] | |
| total_width = input.shape[self.dim] | |
| rand_sample = input[torch.randperm(batch_size)] | |
| for i in range(batch_size): | |
| self.transform_slice(input[i], rand_sample[i], total_width) | |
| return input | |
| def transform_slice(self, input, random_sample, total_width): | |
| for _ in range(self.stripes_num): | |
| distance = torch.randint(low=0, high=self.mix_width, size=(1,))[0] | |
| bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0] | |
| if self.dim == 2: | |
| input[:, bgn: bgn + distance, :] = 0.5 * input[:, bgn: bgn + distance, :] + \ | |
| 0.5 * random_sample[:, bgn: bgn + distance, :] | |
| elif self.dim == 3: | |
| input[:, :, bgn: bgn + distance] = 0.5 * input[:, :, bgn: bgn + distance] + \ | |
| 0.5 * random_sample[:, :, bgn: bgn + distance] | |
| class CutStripes: | |
| def __init__(self, dim, cut_width, stripes_num): | |
| """ Cutting stripes with another randomly selected sample in mini-batch. | |
| args: | |
| dim: int, dimension along which to cut | |
| cut_width: int, maximum width of stripes to cut | |
| stripes_num: int, how many stripes to cut | |
| """ | |
| super(CutStripes, self).__init__() | |
| assert dim in [2, 3] | |
| self.dim = dim | |
| self.cut_width = cut_width | |
| self.stripes_num = stripes_num | |
| def __call__(self, input): | |
| """input: (batch_size, channel, time_steps, freq_bins)""" | |
| assert input.ndimension() == 4 | |
| batch_size = input.shape[0] | |
| total_width = input.shape[self.dim] | |
| rand_sample = input[torch.randperm(batch_size)] | |
| for i in range(batch_size): | |
| self.transform_slice(input[i], rand_sample[i], total_width) | |
| return input | |
| def transform_slice(self, input, random_sample, total_width): | |
| for _ in range(self.stripes_num): | |
| distance = torch.randint(low=0, high=self.cut_width, size=(1,))[0] | |
| bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0] | |
| if self.dim == 2: | |
| input[:, bgn: bgn + distance, :] = random_sample[:, bgn: bgn + distance, :] | |
| elif self.dim == 3: | |
| input[:, :, bgn: bgn + distance] = random_sample[:, :, bgn: bgn + distance] | |
| class SpecAugmentation: | |
| def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, freq_stripes_num, | |
| mask_type='mixture'): | |
| """Spec augmetation and SpecAugment++. | |
| [ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D. | |
| and Le, Q.V., 2019. Specaugment: A simple data augmentation method | |
| for automatic speech recognition. arXiv preprint arXiv:1904.08779. | |
| [ref] Wang H, Zou Y, Wang W., 2021. SpecAugment++: A Hidden Space | |
| Data Augmentation Method for Acoustic Scene Classification. arXiv | |
| preprint arXiv:2103.16858. | |
| Args: | |
| time_drop_width: int | |
| time_stripes_num: int | |
| freq_drop_width: int | |
| freq_stripes_num: int | |
| mask_type: str, mask type in SpecAugment++ (zero_value, mixture, cutting) | |
| """ | |
| super(SpecAugmentation, self).__init__() | |
| if mask_type == 'zero_value': | |
| self.time_augmentator = DropStripes(dim=2, drop_width=time_drop_width, | |
| stripes_num=time_stripes_num) | |
| self.freq_augmentator = DropStripes(dim=3, drop_width=freq_drop_width, | |
| stripes_num=freq_stripes_num) | |
| elif mask_type == 'mixture': | |
| self.time_augmentator = MixStripes(dim=2, mix_width=time_drop_width, | |
| stripes_num=time_stripes_num) | |
| self.freq_augmentator = MixStripes(dim=3, mix_width=freq_drop_width, | |
| stripes_num=freq_stripes_num) | |
| elif mask_type == 'cutting': | |
| self.time_augmentator = CutStripes(dim=2, cut_width=time_drop_width, | |
| stripes_num=time_stripes_num) | |
| self.freq_augmentator = CutStripes(dim=3, cut_width=freq_drop_width, | |
| stripes_num=freq_stripes_num) | |
| else: | |
| raise NameError('No such mask type in SpecAugment++') | |
| def __call__(self, inputs): | |
| # x should be in size [batch_size, channel, time_steps, freq_bins] | |
| x = self.time_augmentator(inputs) | |
| x = self.freq_augmentator(x) | |
| return x | |