Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.utils import weight_norm, spectral_norm | |
| class DiscriminatorR(torch.nn.Module): | |
| def __init__(self, hp, resolution): | |
| super(DiscriminatorR, self).__init__() | |
| self.resolution = resolution | |
| self.LRELU_SLOPE = hp.mpd.lReLU_slope | |
| norm_f = weight_norm if hp.mrd.use_spectral_norm == False else spectral_norm | |
| self.convs = nn.ModuleList([ | |
| norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), | |
| norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), | |
| norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), | |
| norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), | |
| norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), | |
| ]) | |
| self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) | |
| def forward(self, x): | |
| fmap = [] | |
| x = self.spectrogram(x) | |
| x = x.unsqueeze(1) | |
| for l in self.convs: | |
| x = l(x) | |
| x = F.leaky_relu(x, self.LRELU_SLOPE) | |
| fmap.append(x) | |
| x = self.conv_post(x) | |
| fmap.append(x) | |
| x = torch.flatten(x, 1, -1) | |
| return fmap, x | |
| def spectrogram(self, x): | |
| n_fft, hop_length, win_length = self.resolution | |
| x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect') | |
| x = x.squeeze(1) | |
| x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=False) #[B, F, TT, 2] | |
| mag = torch.norm(x, p=2, dim =-1) #[B, F, TT] | |
| return mag | |
| class MultiResolutionDiscriminator(torch.nn.Module): | |
| def __init__(self, hp): | |
| super(MultiResolutionDiscriminator, self).__init__() | |
| self.resolutions = eval(hp.mrd.resolutions) | |
| self.discriminators = nn.ModuleList( | |
| [DiscriminatorR(hp, resolution) for resolution in self.resolutions] | |
| ) | |
| def forward(self, x): | |
| ret = list() | |
| for disc in self.discriminators: | |
| ret.append(disc(x)) | |
| return ret # [(feat, score), (feat, score), (feat, score)] | |