|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from torch.autograd import Function |
|
from torch.nn.utils import weight_norm |
|
|
|
|
|
class GradientReversalFunction(Function): |
|
@staticmethod |
|
def forward(ctx, x, lambda_): |
|
ctx.lambda_ = lambda_ |
|
return x.clone() |
|
|
|
@staticmethod |
|
def backward(ctx, grads): |
|
lambda_ = ctx.lambda_ |
|
lambda_ = grads.new_tensor(lambda_) |
|
dx = -lambda_ * grads |
|
return dx, None |
|
|
|
|
|
class GradientReversal(torch.nn.Module): |
|
''' Gradient Reversal Layer |
|
Y. Ganin, V. Lempitsky, |
|
"Unsupervised Domain Adaptation by Backpropagation", |
|
in ICML, 2015. |
|
Forward pass is the identity function |
|
In the backward pass, upstream gradients are multiplied by -lambda (i.e. gradient are reversed) |
|
''' |
|
|
|
def __init__(self, lambda_reversal=1): |
|
super(GradientReversal, self).__init__() |
|
self.lambda_ = lambda_reversal |
|
|
|
def forward(self, x): |
|
return GradientReversalFunction.apply(x, self.lambda_) |
|
|
|
|
|
class SpeakerClassifier(nn.Module): |
|
|
|
def __init__(self, embed_dim, spk_dim): |
|
super(SpeakerClassifier, self).__init__() |
|
self.classifier = nn.Sequential( |
|
GradientReversal(lambda_reversal=1), |
|
weight_norm(nn.Conv1d(embed_dim, embed_dim, kernel_size=5, padding=2)), |
|
nn.ReLU(), |
|
weight_norm(nn.Conv1d(embed_dim, embed_dim, kernel_size=5, padding=2)), |
|
nn.ReLU(), |
|
weight_norm(nn.Conv1d(embed_dim, spk_dim, kernel_size=5, padding=2)) |
|
) |
|
|
|
def forward(self, x): |
|
''' Forward function of Speaker Classifier: |
|
x = (B, embed_dim, len) |
|
''' |
|
|
|
outputs = self.classifier(x) |
|
outputs = torch.mean(outputs, dim=-1) |
|
return outputs |
|
|