Spaces:
Runtime error
Runtime error
| from torch import nn | |
| from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator | |
| class MelganMultiscaleDiscriminator(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels=1, | |
| out_channels=1, | |
| num_scales=3, | |
| kernel_sizes=(5, 3), | |
| base_channels=16, | |
| max_channels=1024, | |
| downsample_factors=(4, 4, 4), | |
| pooling_kernel_size=4, | |
| pooling_stride=2, | |
| pooling_padding=2, | |
| groups_denominator=4, | |
| ): | |
| super().__init__() | |
| self.discriminators = nn.ModuleList( | |
| [ | |
| MelganDiscriminator( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_sizes=kernel_sizes, | |
| base_channels=base_channels, | |
| max_channels=max_channels, | |
| downsample_factors=downsample_factors, | |
| groups_denominator=groups_denominator, | |
| ) | |
| for _ in range(num_scales) | |
| ] | |
| ) | |
| self.pooling = nn.AvgPool1d( | |
| kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False | |
| ) | |
| def forward(self, x): | |
| scores = [] | |
| feats = [] | |
| for disc in self.discriminators: | |
| score, feat = disc(x) | |
| scores.append(score) | |
| feats.append(feat) | |
| x = self.pooling(x) | |
| return scores, feats | |