Spaces:
Sleeping
Sleeping
import torch | |
from scipy.signal import get_window | |
# from asteroid_test.losses import PITLossWrapper | |
from torch import nn | |
''' | |
class LambdaOverlapAdd(torch.nn.Module): | |
"""Overlap-add with lambda transform on segments. | |
Segment input signal, apply lambda function (a neural network for example) | |
and combine with OLA. | |
Args: | |
nnet (callable): Function to apply to each segment. | |
n_src (int): Number of sources in the output of nnet. | |
window_size (int): Size of segmenting window. | |
hop_size (int): Segmentation hop size. | |
window (str): Name of the window (see scipy.signal.get_window) used | |
for the synthesis. | |
reorder_chunks (bool): Whether to reorder each consecutive segment. | |
This might be useful when `nnet` is permutation invariant, as | |
source assignements might change output channel from one segment | |
to the next (in classic speech separation for example). | |
Reordering is performed based on the correlation between | |
the overlapped part of consecutive segment. | |
Examples: | |
>>> from asteroid_test import ConvTasNet | |
>>> nnet = ConvTasNet(n_src=2) | |
>>> continuous_nnet = LambdaOverlapAdd( | |
>>> nnet=nnet, | |
>>> n_src=2, | |
>>> window_size=64000, | |
>>> hop_size=None, | |
>>> window="hanning", | |
>>> reorder_chunks=True, | |
>>> enable_grad=False, | |
>>> ) | |
>>> wav = torch.randn(1, 1, 500000) | |
>>> out_wavs = continuous_nnet.forward(wav) | |
""" | |
def __init__( | |
self, | |
nnet, | |
n_src, | |
window_size, | |
hop_size=None, | |
window="hanning", | |
reorder_chunks=True, | |
enable_grad=False, | |
): | |
super().__init__() | |
assert window_size % 2 == 0, "Window size must be even" | |
self.nnet = nnet | |
self.window_size = window_size | |
self.hop_size = hop_size if hop_size is not None else window_size // 2 | |
self.n_src = n_src | |
if window: | |
window = get_window(window, self.window_size).astype("float32") | |
window = torch.from_numpy(window) | |
self.use_window = True | |
else: | |
self.use_window = False | |
self.register_buffer("window", window) | |
self.reorder_chunks = reorder_chunks | |
self.enable_grad = enable_grad | |
def ola_forward(self, x): | |
"""Heart of the class: segment signal, apply func, combine with OLA.""" | |
assert x.ndim == 3 | |
batch, channels, n_frames = x.size() | |
# Overlap and add: | |
# [batch, chans, n_frames] -> [batch, chans, win_size, n_chunks] | |
unfolded = torch.nn.functional.unfold( | |
x.unsqueeze(-1), | |
kernel_size=(self.window_size, 1), | |
padding=(self.window_size, 0), | |
stride=(self.hop_size, 1), | |
) | |
out = [] | |
n_chunks = unfolded.shape[-1] | |
for frame_idx in range(n_chunks): # for loop to spare memory | |
frame = self.nnet(unfolded[..., frame_idx]) | |
# user must handle multichannel by reshaping to batch | |
if frame_idx == 0: | |
assert frame.ndim == 3, "nnet should return (batch, n_src, time)" | |
assert frame.shape[1] == self.n_src, "nnet should return (batch, n_src, time)" | |
frame = frame.reshape(batch * self.n_src, -1) | |
if frame_idx != 0 and self.reorder_chunks: | |
# we determine best perm based on xcorr with previous sources | |
frame = _reorder_sources( | |
frame, out[-1], self.n_src, self.window_size, self.hop_size | |
) | |
if self.use_window: | |
frame = frame * self.window | |
else: | |
frame = frame / (self.window_size / self.hop_size) | |
out.append(frame) | |
out = torch.stack(out).reshape(n_chunks, batch * self.n_src, self.window_size) | |
out = out.permute(1, 2, 0) | |
out = torch.nn.functional.fold( | |
out, | |
(n_frames, 1), | |
kernel_size=(self.window_size, 1), | |
padding=(self.window_size, 0), | |
stride=(self.hop_size, 1), | |
) | |
return out.squeeze(-1).reshape(batch, self.n_src, -1) | |
def forward(self, x): | |
"""Forward module: segment signal, apply func, combine with OLA. | |
Args: | |
x (:class:`torch.Tensor`): waveform signal of shape (batch, 1, time). | |
Returns: | |
:class:`torch.Tensor`: The output of the lambda OLA. | |
""" | |
# Here we can do the reshaping | |
with torch.autograd.set_grad_enabled(self.enable_grad): | |
olad = self.ola_forward(x) | |
return olad | |
def _reorder_sources( | |
current: torch.FloatTensor, | |
previous: torch.FloatTensor, | |
n_src: int, | |
window_size: int, | |
hop_size: int, | |
): | |
""" | |
Reorder sources in current chunk to maximize correlation with previous chunk. | |
Used for Continuous Source Separation. Standard dsp correlation is used | |
for reordering. | |
Args: | |
current (:class:`torch.Tensor`): current chunk, tensor | |
of shape (batch, n_src, window_size) | |
previous (:class:`torch.Tensor`): previous chunk, tensor | |
of shape (batch, n_src, window_size) | |
n_src (:class:`int`): number of sources. | |
window_size (:class:`int`): window_size, equal to last dimension of | |
both current and previous. | |
hop_size (:class:`int`): hop_size between current and previous tensors. | |
Returns: | |
current: | |
""" | |
batch, frames = current.size() | |
current = current.reshape(-1, n_src, frames) | |
previous = previous.reshape(-1, n_src, frames) | |
overlap_f = window_size - hop_size | |
def reorder_func(x, y): | |
x = x[..., :overlap_f] | |
y = y[..., -overlap_f:] | |
# Mean normalization | |
x = x - x.mean(-1, keepdim=True) | |
y = y - y.mean(-1, keepdim=True) | |
# Negative mean Correlation | |
return -torch.sum(x.unsqueeze(1) * y.unsqueeze(2), dim=-1) | |
# We maximize correlation-like between previous and current. | |
pit = PITLossWrapper(reorder_func) | |
current = pit(current, previous, return_est=True)[1] | |
return current.reshape(batch, frames) | |
''' | |
class DualPathProcessing(nn.Module): | |
"""Perform Dual-Path processing via overlap-add as in DPRNN [1]. | |
Args: | |
chunk_size (int): Size of segmenting window. | |
hop_size (int): segmentation hop size. | |
References: | |
[1] "Dual-path RNN: efficient long sequence modeling for | |
time-domain single-channel speech separation", Yi Luo, Zhuo Chen | |
and Takuya Yoshioka. https://arxiv.org/abs/1910.06379 | |
""" | |
def __init__(self, chunk_size, hop_size): | |
super(DualPathProcessing, self).__init__() | |
self.chunk_size = chunk_size | |
self.hop_size = hop_size | |
self.n_orig_frames = None | |
def unfold(self, x): | |
"""Unfold the feature tensor from | |
(batch, channels, time) to (batch, channels, chunk_size, n_chunks). | |
Args: | |
x: (:class:`torch.Tensor`): feature tensor of shape (batch, channels, time). | |
Returns: | |
x: (:class:`torch.Tensor`): spliced feature tensor of shape | |
(batch, channels, chunk_size, n_chunks). | |
""" | |
# x is (batch, chan, frames) | |
batch, chan, frames = x.size() | |
assert x.ndim == 3 | |
self.n_orig_frames = x.shape[-1] | |
unfolded = torch.nn.functional.unfold( | |
x.unsqueeze(-1), | |
kernel_size=(self.chunk_size, 1), | |
padding=(self.chunk_size, 0), | |
stride=(self.hop_size, 1), | |
) | |
return unfolded.reshape( | |
batch, chan, self.chunk_size, -1 | |
) # (batch, chan, chunk_size, n_chunks) | |
def fold(self, x, output_size=None): | |
"""Folds back the spliced feature tensor. | |
Input shape (batch, channels, chunk_size, n_chunks) to original shape | |
(batch, channels, time) using overlap-add. | |
Args: | |
x: (:class:`torch.Tensor`): spliced feature tensor of shape | |
(batch, channels, chunk_size, n_chunks). | |
output_size: (int, optional): sequence length of original feature tensor. | |
If None, the original length cached by the previous call of `unfold` | |
will be used. | |
Returns: | |
x: (:class:`torch.Tensor`): feature tensor of shape (batch, channels, time). | |
.. note:: `fold` caches the original length of the pr | |
""" | |
output_size = output_size if output_size is not None else self.n_orig_frames | |
# x is (batch, chan, chunk_size, n_chunks) | |
batch, chan, chunk_size, n_chunks = x.size() | |
to_unfold = x.reshape(batch, chan * self.chunk_size, n_chunks) | |
x = torch.nn.functional.fold( | |
to_unfold, | |
(output_size, 1), | |
kernel_size=(self.chunk_size, 1), | |
padding=(self.chunk_size, 0), | |
stride=(self.hop_size, 1), | |
) | |
x /= self.chunk_size / self.hop_size | |
return x.reshape(batch, chan, self.n_orig_frames) | |
def intra_process(x, module): | |
"""Performs intra-chunk processing. | |
Args: | |
x (:class:`torch.Tensor`): spliced feature tensor of shape | |
(batch, channels, chunk_size, n_chunks). | |
module (:class:`torch.nn.Module`): module one wish to apply to each chunk | |
of the spliced feature tensor. | |
Returns: | |
x (:class:`torch.Tensor`): processed spliced feature tensor of shape | |
(batch, channels, chunk_size, n_chunks). | |
.. note:: the module should have the channel first convention and accept | |
a 3D tensor of shape (batch, channels, time). | |
""" | |
# x is (batch, channels, chunk_size, n_chunks) | |
batch, channels, chunk_size, n_chunks = x.size() | |
# we reshape to batch*chunk_size, channels, n_chunks | |
x = x.transpose(1, -1).reshape(batch * n_chunks, chunk_size, channels).transpose(1, -1) | |
x = module(x) | |
x = x.reshape(batch, n_chunks, channels, chunk_size).transpose(1, -1).transpose(1, 2) | |
return x | |
def inter_process(x, module): | |
"""Performs inter-chunk processing. | |
Args: | |
x (:class:`torch.Tensor`): spliced feature tensor of shape | |
(batch, channels, chunk_size, n_chunks). | |
module (:class:`torch.nn.Module`): module one wish to apply between | |
each chunk of the spliced feature tensor. | |
Returns: | |
x (:class:`torch.Tensor`): processed spliced feature tensor of shape | |
(batch, channels, chunk_size, n_chunks). | |
.. note:: the module should have the channel first convention and accept | |
a 3D tensor of shape (batch, channels, time). | |
""" | |
batch, channels, chunk_size, n_chunks = x.size() | |
x = x.transpose(1, 2).reshape(batch * chunk_size, channels, n_chunks) | |
x = module(x) | |
x = x.reshape(batch, chunk_size, channels, n_chunks).transpose(1, 2) | |
return x | |