Spaces:
Runtime error
Runtime error
| from ..filterbanks import make_enc_dec | |
| from ..masknn import DPTransformer | |
| from .base_models import BaseEncoderMaskerDecoder | |
| class DPTNet(BaseEncoderMaskerDecoder): | |
| """DPTNet separation model, as described in [1]. | |
| Args: | |
| n_src (int): Number of masks to estimate. | |
| out_chan (int or None): Number of bins in the estimated masks. | |
| Defaults to `in_chan`. | |
| bn_chan (int): Number of channels after the bottleneck. | |
| Defaults to 128. | |
| hid_size (int): Number of neurons in the RNNs cell state. | |
| Defaults to 128. | |
| chunk_size (int): window size of overlap and add processing. | |
| Defaults to 100. | |
| hop_size (int or None): hop size (stride) of overlap and add processing. | |
| Default to `chunk_size // 2` (50% overlap). | |
| n_repeats (int): Number of repeats. Defaults to 6. | |
| norm_type (str, optional): Type of normalization to use. To choose from | |
| - ``'gLN'``: global Layernorm | |
| - ``'cLN'``: channelwise Layernorm | |
| mask_act (str, optional): Which non-linear function to generate mask. | |
| bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN | |
| (Intra-Chunk is always bidirectional). | |
| rnn_type (str, optional): Type of RNN used. Choose between ``'RNN'``, | |
| ``'LSTM'`` and ``'GRU'``. | |
| num_layers (int, optional): Number of layers in each RNN. | |
| dropout (float, optional): Dropout ratio, must be in [0,1]. | |
| in_chan (int, optional): Number of input channels, should be equal to | |
| n_filters. | |
| fb_name (str, className): Filterbank family from which to make encoder | |
| and decoder. To choose among [``'free'``, ``'analytic_free'``, | |
| ``'param_sinc'``, ``'stft'``]. | |
| n_filters (int): Number of filters / Input dimension of the masker net. | |
| kernel_size (int): Length of the filters. | |
| stride (int, optional): Stride of the convolution. | |
| If None (default), set to ``kernel_size // 2``. | |
| **fb_kwargs (dict): Additional kwards to pass to the filterbank | |
| creation. | |
| References: | |
| [1]: Jingjing Chen et al. "Dual-Path Transformer Network: Direct | |
| Context-Aware Modeling for End-to-End Monaural Speech Separation" | |
| Interspeech 2020. | |
| """ | |
| def __init__( | |
| self, | |
| n_src, | |
| ff_hid=256, | |
| chunk_size=100, | |
| hop_size=None, | |
| n_repeats=6, | |
| norm_type="gLN", | |
| ff_activation="relu", | |
| encoder_activation="relu", | |
| mask_act="relu", | |
| bidirectional=True, | |
| dropout=0, | |
| in_chan=None, | |
| fb_name="free", | |
| kernel_size=16, | |
| n_filters=64, | |
| stride=8, | |
| **fb_kwargs, | |
| ): | |
| encoder, decoder = make_enc_dec( | |
| fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, **fb_kwargs | |
| ) | |
| n_feats = encoder.n_feats_out | |
| if in_chan is not None: | |
| assert in_chan == n_feats, ( | |
| "Number of filterbank output channels" | |
| " and number of input channels should " | |
| "be the same. Received " | |
| f"{n_feats} and {in_chan}" | |
| ) | |
| # Update in_chan | |
| masker = DPTransformer( | |
| n_feats, | |
| n_src, | |
| ff_hid=ff_hid, | |
| ff_activation=ff_activation, | |
| chunk_size=chunk_size, | |
| hop_size=hop_size, | |
| n_repeats=n_repeats, | |
| norm_type=norm_type, | |
| mask_act=mask_act, | |
| bidirectional=bidirectional, | |
| dropout=dropout, | |
| ) | |
| super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation) | |