Spaces:
Running
Running
| # File under the MIT license, see https://github.com/adefossez/julius/LICENSE for details. | |
| # Author: adefossez, 2020 | |
| """ | |
| FIR windowed sinc lowpass filters. | |
| """ | |
| import math | |
| from typing import Sequence, Optional | |
| import torch | |
| from torch.nn import functional as F | |
| from .core import sinc | |
| from .fftconv import fft_conv1d | |
| from .utils import simple_repr | |
| class LowPassFilters(torch.nn.Module): | |
| """ | |
| Bank of low pass filters. Note that a high pass or band pass filter can easily | |
| be implemented by substracting a same signal processed with low pass filters with different | |
| frequencies (see `julius.bands.SplitBands` for instance). | |
| This uses a windowed sinc filter, very similar to the one used in | |
| `julius.resample`. However, because we do not change the sample rate here, | |
| this filter can be much more efficiently implemented using the FFT convolution from | |
| `julius.fftconv`. | |
| Args: | |
| cutoffs (list[float]): list of cutoff frequencies, in [0, 0.5] expressed as `f/f_s` where | |
| f_s is the samplerate and `f` is the cutoff frequency. | |
| The upper limit is 0.5, because a signal sampled at `f_s` contains only | |
| frequencies under `f_s / 2`. | |
| stride (int): how much to decimate the output. Keep in mind that decimation | |
| of the output is only acceptable if the cutoff frequency is under `1/ (2 * stride)` | |
| of the original sampling rate. | |
| pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`, | |
| the output will have the same length as the input. | |
| zeros (float): Number of zero crossings to keep. | |
| Controls the receptive field of the Finite Impulse Response filter. | |
| For lowpass filters with low cutoff frequency, e.g. 40Hz at 44.1kHz, | |
| it is a bad idea to set this to a high value. | |
| This is likely appropriate for most use. Lower values | |
| will result in a faster filter, but with a slower attenuation around the | |
| cutoff frequency. | |
| fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions. | |
| If False, uses PyTorch convolutions. If None, either one will be chosen automatically | |
| depending on the effective filter size. | |
| ..warning:: | |
| All the filters will use the same filter size, aligned on the lowest | |
| frequency provided. If you combine a lot of filters with very diverse frequencies, it might | |
| be more efficient to split them over multiple modules with similar frequencies. | |
| ..note:: | |
| A lowpass with a cutoff frequency of 0 is defined as the null function | |
| by convention here. This allows for a highpass with a cutoff of 0 to | |
| be equal to identity, as defined in `julius.filters.HighPassFilters`. | |
| Shape: | |
| - Input: `[*, T]` | |
| - Output: `[F, *, T']`, with `T'=T` if `pad` is True and `stride` is 1, and | |
| `F` is the numer of cutoff frequencies. | |
| >>> lowpass = LowPassFilters([1/4]) | |
| >>> x = torch.randn(4, 12, 21, 1024) | |
| >>> list(lowpass(x).shape) | |
| [1, 4, 12, 21, 1024] | |
| """ | |
| def __init__(self, cutoffs: Sequence[float], stride: int = 1, pad: bool = True, | |
| zeros: float = 8, fft: Optional[bool] = None): | |
| super().__init__() | |
| self.cutoffs = list(cutoffs) | |
| if min(self.cutoffs) < 0: | |
| raise ValueError("Minimum cutoff must be larger than zero.") | |
| if max(self.cutoffs) > 0.5: | |
| raise ValueError("A cutoff above 0.5 does not make sense.") | |
| self.stride = stride | |
| self.pad = pad | |
| self.zeros = zeros | |
| self.half_size = int(zeros / min([c for c in self.cutoffs if c > 0]) / 2) | |
| if fft is None: | |
| fft = self.half_size > 32 | |
| self.fft = fft | |
| window = torch.hann_window(2 * self.half_size + 1, periodic=False) | |
| time = torch.arange(-self.half_size, self.half_size + 1) | |
| filters = [] | |
| for cutoff in cutoffs: | |
| if cutoff == 0: | |
| filter_ = torch.zeros_like(time) | |
| else: | |
| filter_ = 2 * cutoff * window * sinc(2 * cutoff * math.pi * time) | |
| # Normalize filter to have sum = 1, otherwise we will have a small leakage | |
| # of the constant component in the input signal. | |
| filter_ /= filter_.sum() | |
| filters.append(filter_) | |
| self.register_buffer("filters", torch.stack(filters)[:, None]) | |
| def forward(self, input): | |
| shape = list(input.shape) | |
| input = input.view(-1, 1, shape[-1]) | |
| if self.pad: | |
| input = F.pad(input, (self.half_size, self.half_size), mode='replicate') | |
| if self.fft: | |
| out = fft_conv1d(input, self.filters, stride=self.stride) | |
| else: | |
| out = F.conv1d(input, self.filters, stride=self.stride) | |
| shape.insert(0, len(self.cutoffs)) | |
| shape[-1] = out.shape[-1] | |
| return out.permute(1, 0, 2).reshape(shape) | |
| def __repr__(self): | |
| return simple_repr(self) | |
| class LowPassFilter(torch.nn.Module): | |
| """ | |
| Same as `LowPassFilters` but applies a single low pass filter. | |
| Shape: | |
| - Input: `[*, T]` | |
| - Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1. | |
| >>> lowpass = LowPassFilter(1/4, stride=2) | |
| >>> x = torch.randn(4, 124) | |
| >>> list(lowpass(x).shape) | |
| [4, 62] | |
| """ | |
| def __init__(self, cutoff: float, stride: int = 1, pad: bool = True, | |
| zeros: float = 8, fft: Optional[bool] = None): | |
| super().__init__() | |
| self._lowpasses = LowPassFilters([cutoff], stride, pad, zeros, fft) | |
| def cutoff(self): | |
| return self._lowpasses.cutoffs[0] | |
| def stride(self): | |
| return self._lowpasses.stride | |
| def pad(self): | |
| return self._lowpasses.pad | |
| def zeros(self): | |
| return self._lowpasses.zeros | |
| def fft(self): | |
| return self._lowpasses.fft | |
| def forward(self, input): | |
| return self._lowpasses(input)[0] | |
| def __repr__(self): | |
| return simple_repr(self) | |
| def lowpass_filters(input: torch.Tensor, cutoffs: Sequence[float], | |
| stride: int = 1, pad: bool = True, | |
| zeros: float = 8, fft: Optional[bool] = None): | |
| """ | |
| Functional version of `LowPassFilters`, refer to this class for more information. | |
| """ | |
| return LowPassFilters(cutoffs, stride, pad, zeros, fft).to(input)(input) | |
| def lowpass_filter(input: torch.Tensor, cutoff: float, | |
| stride: int = 1, pad: bool = True, | |
| zeros: float = 8, fft: Optional[bool] = None): | |
| """ | |
| Same as `lowpass_filters` but with a single cutoff frequency. | |
| Output will not have a dimension inserted in the front. | |
| """ | |
| return lowpass_filters(input, [cutoff], stride, pad, zeros, fft)[0] | |