|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.utils import weight_norm, spectral_norm |
|
|
|
class DiscriminatorR(torch.nn.Module): |
|
def __init__(self, hp, resolution): |
|
super(DiscriminatorR, self).__init__() |
|
|
|
self.resolution = resolution |
|
self.LRELU_SLOPE = hp.mpd.lReLU_slope |
|
|
|
norm_f = weight_norm if hp.mrd.use_spectral_norm == False else spectral_norm |
|
|
|
self.convs = nn.ModuleList([ |
|
norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), |
|
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), |
|
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), |
|
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), |
|
norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), |
|
]) |
|
self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) |
|
|
|
def forward(self, x): |
|
fmap = [] |
|
|
|
x = self.spectrogram(x) |
|
x = x.unsqueeze(1) |
|
for l in self.convs: |
|
x = l(x) |
|
x = F.leaky_relu(x, self.LRELU_SLOPE) |
|
fmap.append(x) |
|
x = self.conv_post(x) |
|
fmap.append(x) |
|
x = torch.flatten(x, 1, -1) |
|
|
|
return fmap, x |
|
|
|
def spectrogram(self, x): |
|
n_fft, hop_length, win_length = self.resolution |
|
x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect') |
|
x = x.squeeze(1) |
|
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=False) |
|
mag = torch.norm(x, p=2, dim =-1) |
|
|
|
return mag |
|
|
|
|
|
class MultiResolutionDiscriminator(torch.nn.Module): |
|
def __init__(self, hp): |
|
super(MultiResolutionDiscriminator, self).__init__() |
|
self.resolutions = eval(hp.mrd.resolutions) |
|
self.discriminators = nn.ModuleList( |
|
[DiscriminatorR(hp, resolution) for resolution in self.resolutions] |
|
) |
|
|
|
def forward(self, x): |
|
ret = list() |
|
for disc in self.discriminators: |
|
ret.append(disc(x)) |
|
|
|
return ret |
|
|