|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchaudio.models import wav2vec2_model |
|
|
|
def count_parameters(model): |
|
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
|
def conv_block(n_input, n_output, stride=1, kernel_size=80): |
|
layers = [] |
|
if stride ==1: |
|
layers.append(nn.Conv1d(n_input, n_output, kernel_size=kernel_size, stride=stride, padding='same')) |
|
else: |
|
layers.append(nn.Conv1d(n_input, n_output, kernel_size=kernel_size, stride=stride)) |
|
layers.append(nn.BatchNorm1d(n_output)) |
|
layers.append(nn.ReLU()) |
|
return nn.Sequential(*layers) |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
def __init__(self, n_channels, kernel_size): |
|
super().__init__() |
|
|
|
self.conv_block1 = conv_block(n_channels, n_channels, stride = 1, kernel_size=kernel_size) |
|
self.conv_block2 = conv_block(n_channels, n_channels, stride= 1, kernel_size=3) |
|
|
|
def forward(self, x): |
|
|
|
identity = x |
|
x = self.conv_block1(x) |
|
x = self.conv_block2(x) |
|
x = x + identity |
|
return x |
|
|
|
|
|
class ResNetRagaClassifier(nn.Module): |
|
def __init__(self, params): |
|
super().__init__() |
|
n_input = params.n_input |
|
n_channel = params.n_channel |
|
stride = params.stride |
|
self.n_blocks = params.n_blocks |
|
self.conv_first = conv_block(n_input, n_channel, stride=stride, kernel_size = 80) |
|
self.max_pool_every = params.max_pool_every |
|
|
|
self.res_blocks = nn.ModuleList() |
|
for i in range(self.n_blocks): |
|
self.res_blocks.append(ResidualBlock(n_channel, kernel_size=3)) |
|
|
|
|
|
self.fc1 = nn.Linear(n_channel, params.num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
|
x = self.conv_first(x) |
|
|
|
|
|
for i, block in enumerate(self.res_blocks): |
|
x = block(x) |
|
if i % self.max_pool_every == 0: |
|
x = F.max_pool1d(x, 2) |
|
|
|
|
|
x = F.avg_pool1d(x, x.shape[-1]) |
|
x = x.permute(0, 2, 1) |
|
x = self.fc1(x) |
|
x = F.log_softmax(x, dim=-1) |
|
|
|
return x |
|
|
|
|
|
class BaseRagaClassifier(nn.Module): |
|
def __init__(self, params): |
|
super().__init__() |
|
n_input = params.n_input |
|
n_channel = params.n_channel |
|
stride = params.stride |
|
self.conv_blocks = [] |
|
|
|
self.conv_block1 = conv_block(n_input, n_channel, stride=stride, kernel_size=80) |
|
self.conv_block2 = conv_block(n_channel, n_channel, stride=1, kernel_size=3) |
|
self.conv_block3 = conv_block(n_channel, 2*n_channel, stride=1, kernel_size=3) |
|
self.conv_block4 = conv_block(2*n_channel, 2*n_channel, stride=1, kernel_size=3) |
|
self.fc1 = nn.Linear(2 * n_channel, params.num_classes) |
|
|
|
def forward(self, x): |
|
x = self.conv_block1(x) |
|
x = F.max_pool1d(x, 4) |
|
x = self.conv_block2(x) |
|
x = F.max_pool1d(x, 4) |
|
x = self.conv_block3(x) |
|
x = F.max_pool1d(x, 4) |
|
x = self.conv_block4(x) |
|
x = F.avg_pool1d(x, x.shape[-1]) |
|
x = x.permute(0, 2, 1) |
|
x = self.fc1(x) |
|
x = F.log_softmax(x, dim=-1) |
|
return x |
|
|
|
|
|
class Wav2VecTransformer(nn.Module): |
|
def __init__(self, params): |
|
super().__init__() |
|
self.params = params |
|
self.extractor_mode = params.extractor_mode |
|
self.extractor_conv_layer_config = params.extractor_conv_layer_config |
|
self.extractor_conv_bias = params.extractor_conv_bias |
|
self.encoder_embed_dim = params.encoder_embed_dim |
|
self.encoder_projection_dropout = params.encoder_projection_dropout |
|
self.encoder_pos_conv_kernel = params.encoder_pos_conv_kernel |
|
self.encoder_pos_conv_groups = params.encoder_pos_conv_groups |
|
self.encoder_num_layers = params.encoder_num_layers |
|
self.encoder_num_heads = params.encoder_num_heads |
|
self.encoder_attention_dropout = params.encoder_attention_dropout |
|
self.encoder_ff_interm_features = params.encoder_ff_interm_features |
|
self.encoder_ff_interm_dropout = params.encoder_ff_interm_dropout |
|
self.encoder_dropout = params.encoder_dropout |
|
self.encoder_layer_norm_first = params.encoder_layer_norm_first |
|
self.encoder_layer_drop = params.encoder_layer_drop |
|
self.aux_num_out = params.num_classes |
|
|
|
self.extractor_conv_layer_config = [ |
|
(32, 80, 16), |
|
(64, 5, 4), |
|
(128, 5, 4), |
|
(256, 5, 4), |
|
(512, 3, 2), |
|
(512, 2, 2), |
|
(512, 2, 2), |
|
] |
|
self.encoder = wav2vec2_model(self.extractor_mode, \ |
|
self.extractor_conv_layer_config, \ |
|
self.extractor_conv_bias, \ |
|
self.encoder_embed_dim, \ |
|
self.encoder_projection_dropout,\ |
|
self.encoder_pos_conv_kernel,\ |
|
self.encoder_pos_conv_groups,\ |
|
self.encoder_num_layers, |
|
self.encoder_num_heads, |
|
self.encoder_attention_dropout, |
|
self.encoder_ff_interm_features, |
|
self.encoder_ff_interm_dropout, |
|
self.encoder_dropout,\ |
|
self.encoder_layer_norm_first,\ |
|
self.encoder_layer_drop, |
|
aux_num_out = None) |
|
|
|
self.audio_length = params.sample_rate*params.clip_length |
|
self.classification_head = nn.Linear(int(self.audio_length/(16*4*4*4*2*2*2))*params.encoder_embed_dim, params.num_classes) |
|
|
|
def forward(self, x): |
|
x = self.encoder(x)[0] |
|
x = x.reshape(x.shape[0], -1) |
|
x = self.classification_head(x) |
|
x = F.log_softmax(x, dim=-1) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|