Spaces:
Running
on
Zero
Running
on
Zero
| from models.mossformer2_se.mossformer2 import MossFormer_MaskNet | |
| import torch.nn as nn | |
| class MossFormer2_SE_48K(nn.Module): | |
| """ | |
| The MossFormer2_SE_48K model for speech enhancement. | |
| This class encapsulates the functionality of the MossFormer MaskNet | |
| within a higher-level model. It processes input audio data to produce | |
| enhanced outputs and corresponding masks. | |
| Arguments | |
| --------- | |
| args : Namespace | |
| Configuration arguments that may include hyperparameters | |
| and model settings (not utilized in this implementation but | |
| can be extended for flexibility). | |
| Example | |
| --------- | |
| >>> model = MossFormer2_SE_48K(args).model | |
| >>> x = torch.randn(10, 180, 2000) # Example input | |
| >>> outputs, mask = model(x) # Forward pass | |
| >>> outputs.shape, mask.shape # Check output shapes | |
| """ | |
| def __init__(self, args): | |
| super(MossFormer2_SE_48K, self).__init__() | |
| # Initialize the TestNet model, which contains the MossFormer MaskNet | |
| self.model = TestNet() # Instance of TestNet | |
| def forward(self, x): | |
| """ | |
| Forward pass through the model. | |
| Arguments | |
| --------- | |
| x : torch.Tensor | |
| Input tensor of dimension [B, N, S], where B is the batch size, | |
| N is the number of channels (180 in this case), and S is the | |
| sequence length (e.g., time frames). | |
| Returns | |
| ------- | |
| outputs : torch.Tensor | |
| Enhanced audio output tensor from the model. | |
| mask : torch.Tensor | |
| Mask tensor predicted by the model for speech separation. | |
| """ | |
| outputs, mask = self.model(x) # Get outputs and mask from TestNet | |
| return outputs, mask # Return the outputs and mask | |
| class TestNet(nn.Module): | |
| """ | |
| The TestNet class for testing the MossFormer MaskNet implementation. | |
| This class builds a model that integrates the MossFormer_MaskNet | |
| for processing input audio and generating masks for source separation. | |
| Arguments | |
| --------- | |
| n_layers : int | |
| The number of layers in the model. It determines the depth | |
| of the model architecture, we leave this para unused at this moment. | |
| """ | |
| def __init__(self, n_layers=18): | |
| super(TestNet, self).__init__() | |
| self.n_layers = n_layers # Set the number of layers | |
| # Initialize the MossFormer MaskNet with specified input and output channels | |
| self.mossformer = MossFormer_MaskNet(in_channels=180, out_channels=512, out_channels_final=961) | |
| def forward(self, input): | |
| """ | |
| Forward pass through the TestNet model. | |
| Arguments | |
| --------- | |
| input : torch.Tensor | |
| Input tensor of dimension [B, N, S], where B is the batch size, | |
| N is the number of input channels (180), and S is the sequence length. | |
| Returns | |
| ------- | |
| out_list : list | |
| List containing the mask tensor predicted by the MossFormer_MaskNet. | |
| """ | |
| out_list = [] # Initialize output list to store outputs | |
| # Transpose input to match expected shape for MaskNet | |
| x = input.transpose(1, 2) # Change shape from [B, N, S] to [B, S, N] | |
| # Get the mask from the MossFormer MaskNet | |
| mask = self.mossformer(x) # Forward pass through the MossFormer_MaskNet | |
| out_list.append(mask) # Append the mask to the output list | |
| return out_list # Return the list containing the mask | |