|
"""Library implementing convolutional neural networks. |
|
|
|
Authors |
|
* Mirco Ravanelli 2020 |
|
* Jianyuan Zhong 2020 |
|
* Cem Subakan 2021 |
|
* Davide Borra 2021 |
|
* Andreas Nautsch 2022 |
|
* Sarthak Yadav 2022 |
|
""" |
|
|
|
import logging |
|
import math |
|
from typing import Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchaudio |
|
|
|
class SincConv(nn.Module): |
|
"""This function implements SincConv (SincNet). |
|
|
|
M. Ravanelli, Y. Bengio, "Speaker Recognition from raw waveform with |
|
SincNet", in Proc. of SLT 2018 (https://arxiv.org/abs/1808.00158) |
|
|
|
Arguments |
|
--------- |
|
out_channels : int |
|
It is the number of output channels. |
|
kernel_size: int |
|
Kernel size of the convolutional filters. |
|
input_shape : tuple |
|
The shape of the input. Alternatively use ``in_channels``. |
|
in_channels : int |
|
The number of input channels. Alternatively use ``input_shape``. |
|
stride : int |
|
Stride factor of the convolutional filters. When the stride factor > 1, |
|
a decimation in time is performed. |
|
dilation : int |
|
Dilation factor of the convolutional filters. |
|
padding : str |
|
(same, valid, causal). If "valid", no padding is performed. |
|
If "same" and stride is 1, output shape is the same as the input shape. |
|
"causal" results in causal (dilated) convolutions. |
|
padding_mode : str |
|
This flag specifies the type of padding. See torch.nn documentation |
|
for more information. |
|
sample_rate : int |
|
Sampling rate of the input signals. It is only used for sinc_conv. |
|
min_low_hz : float |
|
Lowest possible frequency (in Hz) for a filter. It is only used for |
|
sinc_conv. |
|
min_band_hz : float |
|
Lowest possible value (in Hz) for a filter bandwidth. |
|
|
|
Example |
|
------- |
|
>>> inp_tensor = torch.rand([10, 16000]) |
|
>>> conv = SincConv(input_shape=inp_tensor.shape, out_channels=25, kernel_size=11) |
|
>>> out_tensor = conv(inp_tensor) |
|
>>> out_tensor.shape |
|
torch.Size([10, 16000, 25]) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
out_channels, |
|
kernel_size, |
|
input_shape=None, |
|
in_channels=None, |
|
stride=1, |
|
dilation=1, |
|
padding="same", |
|
padding_mode="reflect", |
|
sample_rate=16000, |
|
min_low_hz=50, |
|
min_band_hz=50, |
|
): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.dilation = dilation |
|
self.padding = padding |
|
self.padding_mode = padding_mode |
|
self.sample_rate = sample_rate |
|
self.min_low_hz = min_low_hz |
|
self.min_band_hz = min_band_hz |
|
|
|
|
|
if input_shape is None and self.in_channels is None: |
|
raise ValueError("Must provide one of input_shape or in_channels") |
|
|
|
if self.in_channels is None: |
|
self.in_channels = self._check_input_shape(input_shape) |
|
|
|
if self.out_channels % self.in_channels != 0: |
|
raise ValueError( |
|
"Number of output channels must be divisible by in_channels" |
|
) |
|
|
|
|
|
self._init_sinc_conv() |
|
|
|
def forward(self, x): |
|
"""Returns the output of the convolution. |
|
|
|
Arguments |
|
--------- |
|
x : torch.Tensor (batch, time, channel) |
|
input to convolve. 2d or 4d tensors are expected. |
|
|
|
Returns |
|
------- |
|
wx : torch.Tensor |
|
The convolved outputs. |
|
""" |
|
x = x.transpose(1, -1) |
|
self.device = x.device |
|
|
|
unsqueeze = x.ndim == 2 |
|
if unsqueeze: |
|
x = x.unsqueeze(1) |
|
|
|
if self.padding == "same": |
|
x = self._manage_padding( |
|
x, self.kernel_size, self.dilation, self.stride |
|
) |
|
|
|
elif self.padding == "causal": |
|
num_pad = (self.kernel_size - 1) * self.dilation |
|
x = F.pad(x, (num_pad, 0)) |
|
|
|
elif self.padding == "valid": |
|
pass |
|
|
|
else: |
|
raise ValueError( |
|
"Padding must be 'same', 'valid' or 'causal'. Got %s." |
|
% (self.padding) |
|
) |
|
|
|
sinc_filters = self._get_sinc_filters() |
|
|
|
wx = F.conv1d( |
|
x, |
|
sinc_filters, |
|
stride=self.stride, |
|
padding=0, |
|
dilation=self.dilation, |
|
groups=self.in_channels, |
|
) |
|
|
|
if unsqueeze: |
|
wx = wx.squeeze(1) |
|
|
|
wx = wx.transpose(1, -1) |
|
|
|
return wx |
|
|
|
def _check_input_shape(self, shape): |
|
"""Checks the input shape and returns the number of input channels.""" |
|
|
|
if len(shape) == 2: |
|
in_channels = 1 |
|
elif len(shape) == 3: |
|
in_channels = shape[-1] |
|
else: |
|
raise ValueError( |
|
"sincconv expects 2d or 3d inputs. Got " + str(len(shape)) |
|
) |
|
|
|
|
|
if self.kernel_size % 2 == 0: |
|
raise ValueError( |
|
"The field kernel size must be an odd number. Got %s." |
|
% (self.kernel_size) |
|
) |
|
return in_channels |
|
|
|
def _get_sinc_filters(self): |
|
"""This functions creates the sinc-filters to used for sinc-conv.""" |
|
|
|
low = self.min_low_hz + torch.abs(self.low_hz_) |
|
|
|
|
|
high = torch.clamp( |
|
low + self.min_band_hz + torch.abs(self.band_hz_), |
|
self.min_low_hz, |
|
self.sample_rate / 2, |
|
) |
|
band = (high - low)[:, 0] |
|
|
|
|
|
self.n_ = self.n_.to(self.device) |
|
self.window_ = self.window_.to(self.device) |
|
f_times_t_low = torch.matmul(low, self.n_) |
|
f_times_t_high = torch.matmul(high, self.n_) |
|
|
|
|
|
band_pass_left = ( |
|
(torch.sin(f_times_t_high) - torch.sin(f_times_t_low)) |
|
/ (self.n_ / 2) |
|
) * self.window_ |
|
|
|
|
|
band_pass_center = 2 * band.view(-1, 1) |
|
|
|
|
|
band_pass_right = torch.flip(band_pass_left, dims=[1]) |
|
|
|
|
|
band_pass = torch.cat( |
|
[band_pass_left, band_pass_center, band_pass_right], dim=1 |
|
) |
|
|
|
|
|
band_pass = band_pass / (2 * band[:, None]) |
|
|
|
|
|
filters = band_pass.view(self.out_channels, 1, self.kernel_size) |
|
|
|
return filters |
|
|
|
def _init_sinc_conv(self): |
|
"""Initializes the parameters of the sinc_conv layer.""" |
|
|
|
|
|
high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz) |
|
|
|
mel = torch.linspace( |
|
self._to_mel(self.min_low_hz), |
|
self._to_mel(high_hz), |
|
self.out_channels + 1, |
|
) |
|
|
|
hz = self._to_hz(mel) |
|
|
|
|
|
self.low_hz_ = hz[:-1].unsqueeze(1) |
|
self.band_hz_ = (hz[1:] - hz[:-1]).unsqueeze(1) |
|
|
|
|
|
self.low_hz_ = nn.Parameter(self.low_hz_) |
|
self.band_hz_ = nn.Parameter(self.band_hz_) |
|
|
|
|
|
n_lin = torch.linspace( |
|
0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2)) |
|
) |
|
self.window_ = 0.54 - 0.46 * torch.cos( |
|
2 * math.pi * n_lin / self.kernel_size |
|
) |
|
|
|
|
|
n = (self.kernel_size - 1) / 2.0 |
|
self.n_ = ( |
|
2 * math.pi * torch.arange(-n, 0).view(1, -1) / self.sample_rate |
|
) |
|
|
|
def _to_mel(self, hz): |
|
"""Converts frequency in Hz to the mel scale.""" |
|
return 2595 * np.log10(1 + hz / 700) |
|
|
|
def _to_hz(self, mel): |
|
"""Converts frequency in the mel scale to Hz.""" |
|
return 700 * (10 ** (mel / 2595) - 1) |
|
|
|
def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int): |
|
"""This function performs zero-padding on the time axis |
|
such that their lengths is unchanged after the convolution. |
|
|
|
Arguments |
|
--------- |
|
x : torch.Tensor |
|
Input tensor. |
|
kernel_size : int |
|
Size of kernel. |
|
dilation : int |
|
Dilation used. |
|
stride : int |
|
Stride. |
|
|
|
Returns |
|
------- |
|
x : torch.Tensor |
|
""" |
|
|
|
|
|
L_in = self.in_channels |
|
|
|
|
|
padding = get_padding_elem(L_in, stride, kernel_size, dilation) |
|
|
|
|
|
x = F.pad(x, padding, mode=self.padding_mode) |
|
|
|
return x |
|
|
|
|
|
class Conv1d(nn.Module): |
|
"""This function implements 1d convolution. |
|
|
|
Arguments |
|
--------- |
|
out_channels : int |
|
It is the number of output channels. |
|
kernel_size : int |
|
Kernel size of the convolutional filters. |
|
input_shape : tuple |
|
The shape of the input. Alternatively use ``in_channels``. |
|
in_channels : int |
|
The number of input channels. Alternatively use ``input_shape``. |
|
stride : int |
|
Stride factor of the convolutional filters. When the stride factor > 1, |
|
a decimation in time is performed. |
|
dilation : int |
|
Dilation factor of the convolutional filters. |
|
padding : str |
|
(same, valid, causal). If "valid", no padding is performed. |
|
If "same" and stride is 1, output shape is the same as the input shape. |
|
"causal" results in causal (dilated) convolutions. |
|
groups : int |
|
Number of blocked connections from input channels to output channels. |
|
bias : bool |
|
Whether to add a bias term to convolution operation. |
|
padding_mode : str |
|
This flag specifies the type of padding. See torch.nn documentation |
|
for more information. |
|
skip_transpose : bool |
|
If False, uses batch x time x channel convention of speechbrain. |
|
If True, uses batch x channel x time convention. |
|
weight_norm : bool |
|
If True, use weight normalization, |
|
to be removed with self.remove_weight_norm() at inference |
|
conv_init : str |
|
Weight initialization for the convolution network |
|
default_padding: str or int |
|
This sets the default padding mode that will be used by the pytorch Conv1d backend. |
|
|
|
Example |
|
------- |
|
>>> inp_tensor = torch.rand([10, 40, 16]) |
|
>>> cnn_1d = Conv1d( |
|
... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5 |
|
... ) |
|
>>> out_tensor = cnn_1d(inp_tensor) |
|
>>> out_tensor.shape |
|
torch.Size([10, 40, 8]) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
out_channels, |
|
kernel_size, |
|
input_shape=None, |
|
in_channels=None, |
|
stride=1, |
|
dilation=1, |
|
padding="same", |
|
groups=1, |
|
bias=True, |
|
padding_mode="reflect", |
|
skip_transpose=False, |
|
weight_norm=False, |
|
conv_init=None, |
|
default_padding=0, |
|
): |
|
super().__init__() |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.dilation = dilation |
|
self.padding = padding |
|
self.padding_mode = padding_mode |
|
self.unsqueeze = False |
|
self.skip_transpose = skip_transpose |
|
|
|
if input_shape is None and in_channels is None: |
|
raise ValueError("Must provide one of input_shape or in_channels") |
|
|
|
if in_channels is None: |
|
in_channels = self._check_input_shape(input_shape) |
|
|
|
self.in_channels = in_channels |
|
|
|
self.conv = nn.Conv1d( |
|
in_channels, |
|
out_channels, |
|
self.kernel_size, |
|
stride=self.stride, |
|
dilation=self.dilation, |
|
padding=default_padding, |
|
groups=groups, |
|
bias=bias, |
|
) |
|
|
|
if conv_init == "kaiming": |
|
nn.init.kaiming_normal_(self.conv.weight) |
|
elif conv_init == "zero": |
|
nn.init.zeros_(self.conv.weight) |
|
elif conv_init == "normal": |
|
nn.init.normal_(self.conv.weight, std=1e-6) |
|
|
|
if weight_norm: |
|
self.conv = nn.utils.weight_norm(self.conv) |
|
|
|
def forward(self, x): |
|
"""Returns the output of the convolution. |
|
|
|
Arguments |
|
--------- |
|
x : torch.Tensor (batch, time, channel) |
|
input to convolve. 2d or 4d tensors are expected. |
|
|
|
Returns |
|
------- |
|
wx : torch.Tensor |
|
The convolved outputs. |
|
""" |
|
if not self.skip_transpose: |
|
x = x.transpose(1, -1) |
|
|
|
if self.unsqueeze: |
|
x = x.unsqueeze(1) |
|
|
|
if self.padding == "same": |
|
x = self._manage_padding( |
|
x, self.kernel_size, self.dilation, self.stride |
|
) |
|
|
|
elif self.padding == "causal": |
|
num_pad = (self.kernel_size - 1) * self.dilation |
|
x = F.pad(x, (num_pad, 0)) |
|
|
|
elif self.padding == "valid": |
|
pass |
|
|
|
else: |
|
raise ValueError( |
|
"Padding must be 'same', 'valid' or 'causal'. Got " |
|
+ self.padding |
|
) |
|
|
|
wx = self.conv(x) |
|
|
|
if self.unsqueeze: |
|
wx = wx.squeeze(1) |
|
|
|
if not self.skip_transpose: |
|
wx = wx.transpose(1, -1) |
|
|
|
return wx |
|
|
|
def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int): |
|
"""This function performs zero-padding on the time axis |
|
such that their lengths is unchanged after the convolution. |
|
|
|
Arguments |
|
--------- |
|
x : torch.Tensor |
|
Input tensor. |
|
kernel_size : int |
|
Size of kernel. |
|
dilation : int |
|
Dilation used. |
|
stride : int |
|
Stride. |
|
|
|
Returns |
|
------- |
|
x : torch.Tensor |
|
The padded outputs. |
|
""" |
|
|
|
|
|
L_in = self.in_channels |
|
|
|
|
|
padding = get_padding_elem(L_in, stride, kernel_size, dilation) |
|
|
|
|
|
x = F.pad(x, padding, mode=self.padding_mode) |
|
|
|
return x |
|
|
|
def _check_input_shape(self, shape): |
|
"""Checks the input shape and returns the number of input channels.""" |
|
|
|
if len(shape) == 2: |
|
self.unsqueeze = True |
|
in_channels = 1 |
|
elif self.skip_transpose: |
|
in_channels = shape[1] |
|
elif len(shape) == 3: |
|
in_channels = shape[2] |
|
else: |
|
raise ValueError( |
|
"conv1d expects 2d, 3d inputs. Got " + str(len(shape)) |
|
) |
|
|
|
|
|
if not self.padding == "valid" and self.kernel_size % 2 == 0: |
|
raise ValueError( |
|
"The field kernel size must be an odd number. Got %s." |
|
% (self.kernel_size) |
|
) |
|
|
|
return in_channels |
|
|
|
def remove_weight_norm(self): |
|
"""Removes weight normalization at inference if used during training.""" |
|
self.conv = nn.utils.remove_weight_norm(self.conv) |
|
|
|
|
|
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): |
|
"""This function computes the number of elements to add for zero-padding. |
|
|
|
Arguments |
|
--------- |
|
L_in : int |
|
stride: int |
|
kernel_size : int |
|
dilation : int |
|
|
|
Returns |
|
------- |
|
padding : int |
|
The size of the padding to be added |
|
""" |
|
if stride > 1: |
|
padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)] |
|
|
|
else: |
|
L_out = ( |
|
math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1 |
|
) |
|
padding = [ |
|
math.floor((L_in - L_out) / 2), |
|
math.floor((L_in - L_out) / 2), |
|
] |
|
return padding |
|
|
|
|