import torch import torch.nn as nn import torch.nn.functional as F from . import normalizations, activations class _Chop1d(nn.Module): """To ensure the output length is the same as the input.""" def __init__(self, chop_size): super().__init__() self.chop_size = chop_size def forward(self, x): return x[..., : -self.chop_size].contiguous() class Conv1DBlock(nn.Module): def __init__( self, in_chan, hid_chan, skip_out_chan, kernel_size, padding, dilation, norm_type="gLN", causal=False, ): super(Conv1DBlock, self).__init__() self.skip_out_chan = skip_out_chan conv_norm = normalizations.get(norm_type) in_conv1d = nn.Conv1d(in_chan, hid_chan, 1) depth_conv1d = nn.Conv1d( hid_chan, hid_chan, kernel_size, padding=padding, dilation=dilation, groups=hid_chan, ) if causal: depth_conv1d = nn.Sequential(depth_conv1d, _Chop1d(padding)) self.shared_block = nn.Sequential( in_conv1d, nn.PReLU(), conv_norm(hid_chan), depth_conv1d, nn.PReLU(), conv_norm(hid_chan), ) self.res_conv = nn.Conv1d(hid_chan, in_chan, 1) if skip_out_chan: self.skip_conv = nn.Conv1d(hid_chan, skip_out_chan, 1) def forward(self, x): r"""Input shape $(batch, feats, seq)$.""" shared_out = self.shared_block(x) res_out = self.res_conv(shared_out) if not self.skip_out_chan: return res_out skip_out = self.skip_conv(shared_out) return res_out, skip_out class ConvNormAct(nn.Module): """ This class defines the convolution layer with normalization and a PReLU activation """ def __init__( self, in_chan, out_chan, kernel_size, stride=1, groups=1, dilation=1, padding=0, norm_type="gLN", act_type="prelu", ): super(ConvNormAct, self).__init__() self.conv = nn.Conv1d( in_chan, out_chan, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=True, groups=groups, ) self.norm = normalizations.get(norm_type)(out_chan) self.act = activations.get(act_type)() def forward(self, x): output = self.conv(x) output = self.norm(output) return self.act(output) class ConvNorm(nn.Module): def __init__( self, in_chan, out_chan, kernel_size, stride=1, groups=1, dilation=1, padding=0, norm_type="gLN", ): super(ConvNorm, self).__init__() self.conv = nn.Conv1d( in_chan, out_chan, kernel_size, stride, padding, dilation, bias=True, groups=groups, ) self.norm = normalizations.get(norm_type)(out_chan) def forward(self, x): output = self.conv(x) return self.norm(output) class NormAct(nn.Module): """ This class defines a normalization and PReLU activation """ def __init__( self, out_chan, norm_type="gLN", act_type="prelu", ): """ :param nOut: number of output channels """ super(NormAct, self).__init__() # self.norm = nn.GroupNorm(1, nOut, eps=1e-08) self.norm = normalizations.get(norm_type)(out_chan) self.act = activations.get(act_type)() def forward(self, input): output = self.norm(input) return self.act(output) class Video1DConv(nn.Module): """ video part 1-D Conv Block in_chan: video Encoder output channels out_chan: dconv channels kernel_size: the depthwise conv kernel size dilation: the depthwise conv dilation residual: Whether to use residual connection skip_con: Whether to use skip connection first_block: first block, not residual """ def __init__( self, in_chan, out_chan, kernel_size, dilation=1, residual=True, skip_con=True, first_block=True, ): super(Video1DConv, self).__init__() self.first_block = first_block # first block, not residual self.residual = residual and not first_block self.bn = nn.BatchNorm1d(in_chan) if not first_block else None self.relu = nn.ReLU() if not first_block else None self.dconv = nn.Conv1d( in_chan, in_chan, kernel_size, groups=in_chan, dilation=dilation, padding=(dilation * (kernel_size - 1)) // 2, bias=True, ) self.bconv = nn.Conv1d(in_chan, out_chan, 1) self.sconv = nn.Conv1d(in_chan, out_chan, 1) self.skip_con = skip_con def forward(self, x): """ x: [B, N, T] out: [B, N, T] """ if not self.first_block: y = self.bn(self.relu(x)) y = self.dconv(y) else: y = self.dconv(x) # skip connection if self.skip_con: skip = self.sconv(y) if self.residual: y = y + x return skip, y else: return skip, y else: y = self.bconv(y) if self.residual: y = y + x return y else: return y class Concat(nn.Module): def __init__(self, ain_chan, vin_chan, out_chan): super(Concat, self).__init__() self.ain_chan = ain_chan self.vin_chan = vin_chan # project self.conv1d = nn.Sequential( nn.Conv1d(ain_chan + vin_chan, out_chan, 1), nn.PReLU() ) def forward(self, a, v): # up-sample video features v = torch.nn.functional.interpolate(v, size=a.size(-1)) # concat: n x (A+V) x Ta y = torch.cat([a, v], dim=1) # conv1d return self.conv1d(y) class FRCNNBlock(nn.Module): def __init__( self, in_chan=128, out_chan=512, upsampling_depth=4, norm_type="gLN", act_type="prelu", ): super().__init__() self.proj_1x1 = ConvNormAct( in_chan, out_chan, kernel_size=1, stride=1, groups=1, dilation=1, padding=0, norm_type=norm_type, act_type=act_type, ) self.depth = upsampling_depth self.spp_dw = nn.ModuleList([]) self.spp_dw.append( ConvNorm( out_chan, out_chan, kernel_size=5, stride=1, groups=out_chan, dilation=1, padding=((5 - 1) // 2) * 1, norm_type=norm_type, ) ) # ----------Down Sample Layer---------- for i in range(1, upsampling_depth): self.spp_dw.append( ConvNorm( out_chan, out_chan, kernel_size=5, stride=2, groups=out_chan, dilation=1, padding=((5 - 1) // 2) * 1, norm_type=norm_type, ) ) # ----------Fusion Layer---------- self.fuse_layers = nn.ModuleList([]) for i in range(upsampling_depth): fuse_layer = nn.ModuleList([]) for j in range(upsampling_depth): if i == j: fuse_layer.append(None) elif j - i == 1: fuse_layer.append(None) elif i - j == 1: fuse_layer.append( ConvNorm( out_chan, out_chan, kernel_size=5, stride=2, groups=out_chan, dilation=1, padding=((5 - 1) // 2) * 1, norm_type=norm_type, ) ) self.fuse_layers.append(fuse_layer) self.concat_layer = nn.ModuleList([]) # ----------Concat Layer---------- for i in range(upsampling_depth): if i == 0 or i == upsampling_depth - 1: self.concat_layer.append( ConvNormAct( out_chan * 2, out_chan, 1, 1, norm_type=norm_type, act_type=act_type, ) ) else: self.concat_layer.append( ConvNormAct( out_chan * 3, out_chan, 1, 1, norm_type=norm_type, act_type=act_type, ) ) self.last_layer = nn.Sequential( ConvNormAct( out_chan * upsampling_depth, out_chan, 1, 1, norm_type=norm_type, act_type=act_type, ) ) self.res_conv = nn.Conv1d(out_chan, in_chan, 1) # ----------parameters------------- self.depth = upsampling_depth def forward(self, x): """ :param x: input feature map :return: transformed feature map """ residual = x.clone() # Reduce --> project high-dimensional feature maps to low-dimensional space output1 = self.proj_1x1(x) output = [self.spp_dw[0](output1)] for k in range(1, self.depth): out_k = self.spp_dw[k](output[-1]) output.append(out_k) x_fuse = [] for i in range(len(self.fuse_layers)): wav_length = output[i].shape[-1] y = torch.cat( ( self.fuse_layers[i][0](output[i - 1]) if i - 1 >= 0 else torch.Tensor().to(output1.device), output[i], F.interpolate(output[i + 1], size=wav_length, mode="nearest") if i + 1 < self.depth else torch.Tensor().to(output1.device), ), dim=1, ) x_fuse.append(self.concat_layer[i](y)) wav_length = output[0].shape[-1] for i in range(1, len(x_fuse)): x_fuse[i] = F.interpolate(x_fuse[i], size=wav_length, mode="nearest") concat = self.last_layer(torch.cat(x_fuse, dim=1)) expanded = self.res_conv(concat) return expanded + residual class Bottomup(nn.Module): def __init__( self, in_chan=128, out_chan=512, upsampling_depth=4, norm_type="gLN", act_type="prelu", ): super().__init__() self.proj_1x1 = ConvNormAct( in_chan, out_chan, kernel_size=1, stride=1, groups=1, dilation=1, padding=0, norm_type=norm_type, act_type=act_type, ) self.depth = upsampling_depth self.spp_dw = nn.ModuleList([]) self.spp_dw.append( ConvNorm( out_chan, out_chan, kernel_size=5, stride=1, groups=out_chan, dilation=1, padding=((5 - 1) // 2) * 1, norm_type=norm_type, ) ) # ----------Down Sample Layer---------- for i in range(1, upsampling_depth): self.spp_dw.append( ConvNorm( out_chan, out_chan, kernel_size=5, stride=2, groups=out_chan, dilation=1, padding=((5 - 1) // 2) * 1, norm_type=norm_type, ) ) def forward(self, x): residual = x.clone() # Reduce --> project high-dimensional feature maps to low-dimensional space output1 = self.proj_1x1(x) output = [self.spp_dw[0](output1)] for k in range(1, self.depth): out_k = self.spp_dw[k](output[-1]) output.append(out_k) return residual, output[-1], output class BottomupTCN(nn.Module): def __init__( self, in_chan=128, out_chan=512, upsampling_depth=4, norm_type="gLN", act_type="prelu", ): super().__init__() self.proj_1x1 = ConvNormAct( in_chan, out_chan, kernel_size=1, stride=1, groups=1, dilation=1, padding=0, norm_type=norm_type, act_type=act_type, ) self.depth = upsampling_depth self.spp_dw = nn.ModuleList([]) self.spp_dw.append( Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=True) ) # ----------Down Sample Layer---------- for i in range(1, upsampling_depth): self.spp_dw.append( Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=False) ) def forward(self, x): residual = x.clone() # Reduce --> project high-dimensional feature maps to low-dimensional space output1 = self.proj_1x1(x) output = [self.spp_dw[0](output1)] for k in range(1, self.depth): out_k = self.spp_dw[k](output[-1]) output.append(out_k) return residual, output[-1], output class Bottomup_Concat_Topdown(nn.Module): def __init__( self, in_chan=128, out_chan=512, upsampling_depth=4, norm_type="gLN", act_type="prelu", ): super().__init__() # ----------Fusion Layer---------- self.fuse_layers = nn.ModuleList([]) for i in range(upsampling_depth): fuse_layer = nn.ModuleList([]) for j in range(upsampling_depth): if i == j: fuse_layer.append(None) elif j - i == 1: fuse_layer.append(None) elif i - j == 1: fuse_layer.append( ConvNorm( out_chan, out_chan, kernel_size=5, stride=2, groups=out_chan, dilation=1, padding=((5 - 1) // 2) * 1, norm_type=norm_type, ) ) self.fuse_layers.append(fuse_layer) self.concat_layer = nn.ModuleList([]) # ----------Concat Layer---------- for i in range(upsampling_depth): if i == 0 or i == upsampling_depth - 1: self.concat_layer.append( ConvNormAct( out_chan * 3, out_chan, 1, 1, norm_type=norm_type, act_type=act_type, ) ) else: self.concat_layer.append( ConvNormAct( out_chan * 4, out_chan, 1, 1, norm_type=norm_type, act_type=act_type, ) ) self.last_layer = nn.Sequential( ConvNormAct( out_chan * upsampling_depth, out_chan, 1, 1, norm_type=norm_type, act_type=act_type, ) ) self.res_conv = nn.Conv1d(out_chan, in_chan, 1) # ----------parameters------------- self.depth = upsampling_depth def forward(self, residual, bottomup, topdown): x_fuse = [] for i in range(len(self.fuse_layers)): wav_length = bottomup[i].shape[-1] y = torch.cat( ( self.fuse_layers[i][0](bottomup[i - 1]) if i - 1 >= 0 else torch.Tensor().to(bottomup[i].device), bottomup[i], F.interpolate(bottomup[i + 1], size=wav_length, mode="nearest") if i + 1 < self.depth else torch.Tensor().to(bottomup[i].device), F.interpolate(topdown, size=wav_length, mode="nearest"), ), dim=1, ) x_fuse.append(self.concat_layer[i](y)) wav_length = bottomup[0].shape[-1] for i in range(1, len(x_fuse)): x_fuse[i] = F.interpolate(x_fuse[i], size=wav_length, mode="nearest") concat = self.last_layer(torch.cat(x_fuse, dim=1)) expanded = self.res_conv(concat) return expanded + residual class Bottomup_Concat_Topdown_TCN(nn.Module): def __init__( self, in_chan=128, out_chan=512, upsampling_depth=4, norm_type="gLN", act_type="prelu", ): super().__init__() # ----------Fusion Layer---------- self.fuse_layers = nn.ModuleList([]) for i in range(upsampling_depth): fuse_layer = nn.ModuleList([]) for j in range(upsampling_depth): if i == j: fuse_layer.append(None) elif j - i == 1: fuse_layer.append(None) elif i - j == 1: fuse_layer.append(None) self.fuse_layers.append(fuse_layer) self.concat_layer = nn.ModuleList([]) # ----------Concat Layer---------- for i in range(upsampling_depth): if i == 0 or i == upsampling_depth - 1: self.concat_layer.append( ConvNormAct( out_chan * 3, out_chan, 1, 1, norm_type=norm_type, act_type=act_type, ) ) else: self.concat_layer.append( ConvNormAct( out_chan * 4, out_chan, 1, 1, norm_type=norm_type, act_type=act_type, ) ) self.last_layer = nn.Sequential( ConvNormAct( out_chan * upsampling_depth, out_chan, 1, 1, norm_type=norm_type, act_type=act_type, ) ) self.res_conv = nn.Conv1d(out_chan, in_chan, 1) # ----------parameters------------- self.depth = upsampling_depth def forward(self, residual, bottomup, topdown): x_fuse = [] for i in range(len(self.fuse_layers)): wav_length = bottomup[i].shape[-1] y = torch.cat( ( bottomup[i - 1] if i - 1 >= 0 else torch.Tensor().to(bottomup[i].device), bottomup[i], bottomup[i + 1] if i + 1 < self.depth else torch.Tensor().to(bottomup[i].device), F.interpolate(topdown, size=wav_length, mode="nearest"), ), dim=1, ) x_fuse.append(self.concat_layer[i](y)) concat = self.last_layer(torch.cat(x_fuse, dim=1)) expanded = self.res_conv(concat) return expanded + residual class FRCNNBlockTCN(nn.Module): def __init__( self, in_chan=128, out_chan=512, upsampling_depth=4, norm_type="gLN", act_type="prelu", ): super().__init__() self.proj_1x1 = ConvNormAct( in_chan, out_chan, kernel_size=1, stride=1, groups=1, dilation=1, padding=0, norm_type=norm_type, act_type=act_type, ) self.depth = upsampling_depth self.spp_dw = nn.ModuleList([]) self.spp_dw.append( Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=True) ) # ----------Down Sample Layer---------- for i in range(1, upsampling_depth): self.spp_dw.append( Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=False) ) # ----------Fusion Layer---------- self.fuse_layers = nn.ModuleList([]) for i in range(upsampling_depth): fuse_layer = nn.ModuleList([]) for j in range(upsampling_depth): if i == j: fuse_layer.append(None) elif j - i == 1: fuse_layer.append(None) elif i - j == 1: fuse_layer.append(None) self.fuse_layers.append(fuse_layer) self.concat_layer = nn.ModuleList([]) # ----------Concat Layer---------- for i in range(upsampling_depth): if i == 0 or i == upsampling_depth - 1: self.concat_layer.append( ConvNormAct( out_chan * 2, out_chan, 1, 1, norm_type=norm_type, act_type=act_type, ) ) else: self.concat_layer.append( ConvNormAct( out_chan * 3, out_chan, 1, 1, norm_type=norm_type, act_type=act_type, ) ) self.last_layer = nn.Sequential( ConvNormAct( out_chan * upsampling_depth, out_chan, 1, 1, norm_type=norm_type, act_type=act_type, ) ) self.res_conv = nn.Conv1d(out_chan, in_chan, 1) # ----------parameters------------- self.depth = upsampling_depth def forward(self, x): """ :param x: input feature map :return: transformed feature map """ residual = x.clone() # Reduce --> project high-dimensional feature maps to low-dimensional space output1 = self.proj_1x1(x) output = [self.spp_dw[0](output1)] for k in range(1, self.depth): out_k = self.spp_dw[k](output[-1]) output.append(out_k) x_fuse = [] for i in range(len(self.fuse_layers)): wav_length = output[i].shape[-1] y = torch.cat( ( output[i - 1] if i - 1 >= 0 else torch.Tensor().to(output1.device), output[i], output[i + 1] if i + 1 < self.depth else torch.Tensor().to(output1.device), ), dim=1, ) x_fuse.append(self.concat_layer[i](y)) concat = self.last_layer(torch.cat(x_fuse, dim=1)) expanded = self.res_conv(concat) return expanded + residual class TAC(nn.Module): """Transform-Average-Concatenate inter-microphone-channel permutation invariant communication block [1]. Args: input_dim (int): Number of features of input representation. hidden_dim (int, optional): size of hidden layers in TAC operations. activation (str, optional): type of activation used. See asteroid.masknn.activations. norm_type (str, optional): type of normalization layer used. See asteroid.masknn.norms. .. note:: Supports inputs of shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)` as in FasNet-TAC. The operations are applied for each element in ``chunk_size`` and ``n_chunks``. Output is of same shape as input. References [1] : Luo, Yi, et al. "End-to-end microphone permutation and number invariant multi-channel speech separation." ICASSP 2020. """ def __init__(self, input_dim, hidden_dim=384, activation="prelu", norm_type="gLN"): super().__init__() self.hidden_dim = hidden_dim self.input_tf = nn.Sequential( nn.Linear(input_dim, hidden_dim), activations.get(activation)() ) self.avg_tf = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), activations.get(activation)() ) self.concat_tf = nn.Sequential( nn.Linear(2 * hidden_dim, input_dim), activations.get(activation)() ) self.norm = normalizations.get(norm_type)(input_dim) def forward(self, x, valid_mics=None): """ Args: x: (:class:`torch.Tensor`): Input multi-channel DPRNN features. Shape: :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`. valid_mics: (:class:`torch.LongTensor`): tensor containing effective number of microphones on each batch. Batches can be composed of examples coming from arrays with a different number of microphones and thus the ``mic_channels`` dimension is padded. E.g. torch.tensor([4, 3]) means first example has 4 channels and the second 3. Shape: :math`(batch)`. Returns: output (:class:`torch.Tensor`): features for each mic_channel after TAC inter-channel processing. Shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`. """ # Input is 5D because it is multi-channel DPRNN. DPRNN single channel is 4D. batch_size, nmics, channels, chunk_size, n_chunks = x.size() if valid_mics is None: valid_mics = torch.LongTensor([nmics] * batch_size) # First operation: transform the input for each frame and independently on each mic channel. output = self.input_tf( x.permute(0, 3, 4, 1, 2).reshape( batch_size * nmics * chunk_size * n_chunks, channels ) ).reshape(batch_size, chunk_size, n_chunks, nmics, self.hidden_dim) # Mean pooling across channels if valid_mics.max() == 0: # Fixed geometry array mics_mean = output.mean(1) else: # Only consider valid channels in each batch element: each example can have different number of microphones. mics_mean = [ output[b, :, :, : valid_mics[b]].mean(2).unsqueeze(0) for b in range(batch_size) ] # 1, dim1*dim2, H mics_mean = torch.cat(mics_mean, 0) # B*dim1*dim2, H # The average is processed by a non-linear transform mics_mean = self.avg_tf( mics_mean.reshape(batch_size * chunk_size * n_chunks, self.hidden_dim) ) mics_mean = ( mics_mean.reshape(batch_size, chunk_size, n_chunks, self.hidden_dim) .unsqueeze(3) .expand_as(output) ) # Concatenate the transformed average in each channel with the original feats and # project back to same number of features output = torch.cat([output, mics_mean], -1) output = self.concat_tf( output.reshape(batch_size * chunk_size * n_chunks * nmics, -1) ).reshape(batch_size, chunk_size, n_chunks, nmics, -1) output = self.norm( output.permute(0, 3, 4, 1, 2).reshape( batch_size * nmics, -1, chunk_size, n_chunks ) ).reshape(batch_size, nmics, -1, chunk_size, n_chunks) output += x return output