jeevster
huggingface space main commit
64094d4
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.models import wav2vec2_model
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
#basic conv block
def conv_block(n_input, n_output, stride=1, kernel_size=80):
layers = []
if stride ==1:
layers.append(nn.Conv1d(n_input, n_output, kernel_size=kernel_size, stride=stride, padding='same')) #Conv
else:
layers.append(nn.Conv1d(n_input, n_output, kernel_size=kernel_size, stride=stride)) #Conv
layers.append(nn.BatchNorm1d(n_output))
layers.append(nn.ReLU())
return nn.Sequential(*layers)
#basic 2-conv residual block
class ResidualBlock(nn.Module):
def __init__(self, n_channels, kernel_size):
super().__init__()
self.conv_block1 = conv_block(n_channels, n_channels, stride = 1, kernel_size=kernel_size)
self.conv_block2 = conv_block(n_channels, n_channels, stride= 1, kernel_size=3)
def forward(self, x):
identity = x
x = self.conv_block1(x)
x = self.conv_block2(x)
x = x + identity
return x
class ResNetRagaClassifier(nn.Module):
def __init__(self, params):
super().__init__()
n_input = params.n_input
n_channel = params.n_channel
stride = params.stride
self.n_blocks = params.n_blocks
self.conv_first = conv_block(n_input, n_channel, stride=stride, kernel_size = 80)
self.max_pool_every = params.max_pool_every
self.res_blocks = nn.ModuleList() #Residual Blocks
for i in range(self.n_blocks):
self.res_blocks.append(ResidualBlock(n_channel, kernel_size=3))
#linear classification head
self.fc1 = nn.Linear(n_channel, params.num_classes)
def forward(self, x):
#initial conv
x = self.conv_first(x)
#residual blocks
for i, block in enumerate(self.res_blocks):
x = block(x)
if i % self.max_pool_every == 0:
x = F.max_pool1d(x, 2)
#classification head
x = F.avg_pool1d(x, x.shape[-1])
x = x.permute(0, 2, 1)
x = self.fc1(x)
x = F.log_softmax(x, dim=-1)
return x
class BaseRagaClassifier(nn.Module):
def __init__(self, params):
super().__init__()
n_input = params.n_input
n_channel = params.n_channel
stride = params.stride
self.conv_blocks = []
self.conv_block1 = conv_block(n_input, n_channel, stride=stride, kernel_size=80)
self.conv_block2 = conv_block(n_channel, n_channel, stride=1, kernel_size=3)
self.conv_block3 = conv_block(n_channel, 2*n_channel, stride=1, kernel_size=3)
self.conv_block4 = conv_block(2*n_channel, 2*n_channel, stride=1, kernel_size=3)
self.fc1 = nn.Linear(2 * n_channel, params.num_classes)
def forward(self, x):
x = self.conv_block1(x)
x = F.max_pool1d(x, 4)
x = self.conv_block2(x)
x = F.max_pool1d(x, 4)
x = self.conv_block3(x)
x = F.max_pool1d(x, 4)
x = self.conv_block4(x)
x = F.avg_pool1d(x, x.shape[-1])
x = x.permute(0, 2, 1)
x = self.fc1(x)
x = F.log_softmax(x, dim=-1)
return x
class Wav2VecTransformer(nn.Module):
def __init__(self, params):
super().__init__()
self.params = params
self.extractor_mode = params.extractor_mode
self.extractor_conv_layer_config = params.extractor_conv_layer_config
self.extractor_conv_bias = params.extractor_conv_bias
self.encoder_embed_dim = params.encoder_embed_dim
self.encoder_projection_dropout = params.encoder_projection_dropout
self.encoder_pos_conv_kernel = params.encoder_pos_conv_kernel
self.encoder_pos_conv_groups = params.encoder_pos_conv_groups
self.encoder_num_layers = params.encoder_num_layers
self.encoder_num_heads = params.encoder_num_heads
self.encoder_attention_dropout = params.encoder_attention_dropout
self.encoder_ff_interm_features = params.encoder_ff_interm_features
self.encoder_ff_interm_dropout = params.encoder_ff_interm_dropout
self.encoder_dropout = params.encoder_dropout
self.encoder_layer_norm_first = params.encoder_layer_norm_first
self.encoder_layer_drop = params.encoder_layer_drop
self.aux_num_out = params.num_classes
self.extractor_conv_layer_config = [
(32, 80, 16),
(64, 5, 4),
(128, 5, 4),
(256, 5, 4),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
]
self.encoder = wav2vec2_model(self.extractor_mode, \
self.extractor_conv_layer_config, \
self.extractor_conv_bias, \
self.encoder_embed_dim, \
self.encoder_projection_dropout,\
self.encoder_pos_conv_kernel,\
self.encoder_pos_conv_groups,\
self.encoder_num_layers,
self.encoder_num_heads,
self.encoder_attention_dropout,
self.encoder_ff_interm_features,
self.encoder_ff_interm_dropout,
self.encoder_dropout,\
self.encoder_layer_norm_first,\
self.encoder_layer_drop,
aux_num_out = None)
self.audio_length = params.sample_rate*params.clip_length
self.classification_head = nn.Linear(int(self.audio_length/(16*4*4*4*2*2*2))*params.encoder_embed_dim, params.num_classes)
def forward(self, x):
x = self.encoder(x)[0]
x = x.reshape(x.shape[0], -1) # flatten
x = self.classification_head(x)
x = F.log_softmax(x, dim=-1)
return x