jeevster
commited on
Commit
·
64094d4
1
Parent(s):
08771d5
huggingface space main commit
Browse files- .gitignore +1 -0
- about.md +12 -0
- app.py +57 -0
- ckpts/resnet_0.7/150classes_alldata_cliplength30/training_checkpoints/best_ckpt.tar +3 -0
- config.yaml +130 -0
- data/dataloader.py +143 -0
- inference.py +117 -0
- labeled_0.7_wav_metadata.json +0 -0
- metadata_0.7.json +530 -0
- models/RagaNet.py +164 -0
- requirements.txt +3 -0
- site/tsne.jpeg +0 -0
- utils/YParams.py +47 -0
- utils/logging_utils.py +32 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
about.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### About the Classifier
|
| 2 |
+
The classifier is a [convolutional neural network](https://en.wikipedia.org/wiki/Convolutional_neural_network) trained on over 10,000 hours of Carnatic audio sourced from this incredible [YouTube collection](https://ramanarunachalam.github.io/Music/Carnatic/carnatic.html).
|
| 3 |
+
### Key Features:
|
| 4 |
+
- Can identify **150 ragas**
|
| 5 |
+
- Does not require any information about the **shruthi (tonic pitch)** of the recording.
|
| 6 |
+
- **Compatible** with male/female vocal or instrumental recordings.
|
| 7 |
+
|
| 8 |
+
### Interpreting the Classifier:
|
| 9 |
+
We can gain an intuitive sense for what the classifier has learned. Here is a [t-SNE](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding) projection of the hidden activations averaged per ragam. Each point is a ragam, and relative distances between the points indicate the degree to which the classifier thinks the ragas are similar. Each ragam is color coded by the [melakartha chakra](https://en.wikipedia.org/wiki/Melakarta#Chakras) it belongs to. We observe that the classifier has learned to a representation that roughly corresponds to these chakras!
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
app.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from inference import Evaluator
|
| 3 |
+
import argparse
|
| 4 |
+
from utils.YParams import YParams
|
| 5 |
+
import torch
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
def read_markdown_file(path):
|
| 9 |
+
with open(path, 'r', encoding='utf-8') as file:
|
| 10 |
+
return file.read()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
if __name__ == '__main__':
|
| 14 |
+
|
| 15 |
+
#parse args
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
parser.add_argument("--yaml_config", default='config.yaml', type=str)
|
| 18 |
+
parser.add_argument("--config", default='resnet_0.7', type=str)
|
| 19 |
+
|
| 20 |
+
args = parser.parse_args()
|
| 21 |
+
params = YParams(os.path.abspath(args.yaml_config), args.config)
|
| 22 |
+
|
| 23 |
+
#GPU stuff
|
| 24 |
+
try:
|
| 25 |
+
params.device = torch.device(torch.cuda.current_device())
|
| 26 |
+
except:
|
| 27 |
+
params.device = "cpu"
|
| 28 |
+
|
| 29 |
+
#checkpoint stuff
|
| 30 |
+
expDir = "ckpts/resnet_0.7/150classes_alldata_cliplength30"
|
| 31 |
+
params['checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/ckpt.tar')
|
| 32 |
+
params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
|
| 33 |
+
|
| 34 |
+
evaluator = Evaluator(params)
|
| 35 |
+
|
| 36 |
+
with gr.Blocks() as demo:
|
| 37 |
+
with gr.Tab("Classifier"):
|
| 38 |
+
gr.Interface(
|
| 39 |
+
title="Carnatic Raga Classifier",
|
| 40 |
+
description="**Welcome!** This is a deep-learning based raga classifier. Simply upload or record an audio clip to test it out. \n",
|
| 41 |
+
article = "**Get in Touch:** Feel free to reach out to [me](https://sanjeevraja.com/) via email (sanjeevr AT berkeley DOT edu) with any questions or feedback! ",
|
| 42 |
+
fn=evaluator.inference,
|
| 43 |
+
inputs=[
|
| 44 |
+
gr.Slider(minimum = 1, maximum = 150, value = 5, label = "Number of displayed ragas", info = "Choose number of top predictions to display"),
|
| 45 |
+
gr.Audio()
|
| 46 |
+
],
|
| 47 |
+
outputs="label",
|
| 48 |
+
allow_flagging = False
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
with gr.Tab("About"):
|
| 52 |
+
gr.Markdown(read_markdown_file('about.md'))
|
| 53 |
+
gr.Image('site/tsne.jpeg', height = 800, width=800)
|
| 54 |
+
|
| 55 |
+
demo.launch()
|
| 56 |
+
|
| 57 |
+
|
ckpts/resnet_0.7/150classes_alldata_cliplength30/training_checkpoints/best_ckpt.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:09844cbcfc6c98af632671ca14dd4878fe6656811cdb77097a847c8b324362ca
|
| 3 |
+
size 66309523
|
config.yaml
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
default: &DEFAULT
|
| 2 |
+
|
| 3 |
+
#data
|
| 4 |
+
exp_dir: carnatic/ckpts
|
| 5 |
+
metadata_labeled_path: labeled_small_wav_metadata.json
|
| 6 |
+
metadata_unlabeled_path: unlabeled_data/unlabeled_mp3_metadata.json
|
| 7 |
+
num_files_per_raga_path: metadata_small.json
|
| 8 |
+
num_classes: 150
|
| 9 |
+
use_frac: 1
|
| 10 |
+
use_unlabeled_data: !!bool False
|
| 11 |
+
labeled_data_dir: labeled_data_small
|
| 12 |
+
clip_length: 30
|
| 13 |
+
sample_rate: 8000
|
| 14 |
+
normalize: !!bool True
|
| 15 |
+
|
| 16 |
+
#training
|
| 17 |
+
batch_size: 16
|
| 18 |
+
num_data_workers: 16
|
| 19 |
+
n_epochs: 100
|
| 20 |
+
lr: 0.001
|
| 21 |
+
class_imbalance_weights: !!bool False
|
| 22 |
+
patience: 10
|
| 23 |
+
train_frac: 0.8
|
| 24 |
+
|
| 25 |
+
#model
|
| 26 |
+
model: 'base'
|
| 27 |
+
n_input: 2 #stereo
|
| 28 |
+
stride: 16
|
| 29 |
+
n_channel: 32
|
| 30 |
+
max_pool_every: 1
|
| 31 |
+
|
| 32 |
+
#logging
|
| 33 |
+
save_checkpoint: !!bool False
|
| 34 |
+
wandb_api_key: f7892f37dd96b5f1da5c85a410300bb661f3c4de
|
| 35 |
+
log_to_wandb: !!bool False
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
default_0.7: &DEFAULT_0.7
|
| 39 |
+
|
| 40 |
+
<<: *DEFAULT
|
| 41 |
+
metadata_labeled_path: labeled_0.7_wav_metadata.json
|
| 42 |
+
num_files_per_raga_path: metadata_0.7.json
|
| 43 |
+
labeled_data_dir: labeled_data_0.7
|
| 44 |
+
|
| 45 |
+
default_0.9: &DEFAULT_0.9
|
| 46 |
+
|
| 47 |
+
<<: *DEFAULT
|
| 48 |
+
metadata_labeled_path: labeled_0.9_wav_metadata.json
|
| 49 |
+
num_files_per_raga_path: metadata_0.9.json
|
| 50 |
+
labeled_data_dir: labeled_data_0.9
|
| 51 |
+
train_frac: 0.85
|
| 52 |
+
num_classes: 200
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
resnet: &RESNET
|
| 56 |
+
|
| 57 |
+
<<: *DEFAULT
|
| 58 |
+
model: 'resnet'
|
| 59 |
+
n_blocks: 5 #for resnet
|
| 60 |
+
n_channel: 128
|
| 61 |
+
|
| 62 |
+
resnet_0.7: &RESNET_0.7
|
| 63 |
+
|
| 64 |
+
<<: *DEFAULT_0.7
|
| 65 |
+
model: 'resnet'
|
| 66 |
+
n_blocks: 10 #for resnet
|
| 67 |
+
n_channel: 300
|
| 68 |
+
num_classes: 150
|
| 69 |
+
|
| 70 |
+
resnet_0.9: &RESNET_0.9
|
| 71 |
+
|
| 72 |
+
<<: *DEFAULT_0.9
|
| 73 |
+
model: 'resnet'
|
| 74 |
+
n_blocks: 10 #for resnet
|
| 75 |
+
n_channel: 350
|
| 76 |
+
max_pool_every: 1 #downsample every other res block
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
wav2vec_0.7: &WAV2VEC_0.7
|
| 80 |
+
<<: *DEFAULT_0.7
|
| 81 |
+
|
| 82 |
+
model: 'wav2vec'
|
| 83 |
+
n_input: 1 #mono
|
| 84 |
+
|
| 85 |
+
#transformer parameters (this config leads to around 29M params)
|
| 86 |
+
extractor_mode: "layer_norm"
|
| 87 |
+
extractor_conv_layer_config: None #harcoded for now, fix this at some point
|
| 88 |
+
extractor_conv_bias: !!bool True
|
| 89 |
+
encoder_embed_dim: 512
|
| 90 |
+
encoder_projection_dropout: 0
|
| 91 |
+
encoder_pos_conv_kernel: 3
|
| 92 |
+
encoder_pos_conv_groups: 32
|
| 93 |
+
encoder_num_layers: 12
|
| 94 |
+
encoder_num_heads: 16
|
| 95 |
+
encoder_attention_dropout: 0
|
| 96 |
+
encoder_ff_interm_features: 1024
|
| 97 |
+
encoder_ff_interm_dropout: 0
|
| 98 |
+
encoder_dropout: 0
|
| 99 |
+
encoder_layer_norm_first: !!bool True
|
| 100 |
+
encoder_layer_drop: 0
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
wav2vec_0.9: &WAV2VEC_0.9
|
| 104 |
+
<<: *DEFAULT_0.9
|
| 105 |
+
|
| 106 |
+
model: 'wav2vec'
|
| 107 |
+
n_input: 1 #mono
|
| 108 |
+
|
| 109 |
+
#transformer parameters (this config leads to around 29M params)
|
| 110 |
+
extractor_mode: "layer_norm"
|
| 111 |
+
extractor_conv_layer_config: None #harcoded for now, fix this at some point
|
| 112 |
+
extractor_conv_bias: !!bool True
|
| 113 |
+
encoder_embed_dim: 512
|
| 114 |
+
encoder_projection_dropout: 0
|
| 115 |
+
encoder_pos_conv_kernel: 3
|
| 116 |
+
encoder_pos_conv_groups: 32
|
| 117 |
+
encoder_num_layers: 12
|
| 118 |
+
encoder_num_heads: 16
|
| 119 |
+
encoder_attention_dropout: 0
|
| 120 |
+
encoder_ff_interm_features: 1024
|
| 121 |
+
encoder_ff_interm_dropout: 0
|
| 122 |
+
encoder_dropout: 0
|
| 123 |
+
encoder_layer_norm_first: !!bool True
|
| 124 |
+
encoder_layer_drop: 0
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
|
data/dataloader.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import glob
|
| 3 |
+
from multiprocessing.sharedctypes import Value
|
| 4 |
+
import torch
|
| 5 |
+
import random
|
| 6 |
+
import numpy as np
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
import torchaudio
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
from math import floor
|
| 12 |
+
import sys
|
| 13 |
+
import time
|
| 14 |
+
import copy
|
| 15 |
+
import math
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
np.random.seed(123)
|
| 20 |
+
random.seed(123)
|
| 21 |
+
#load used subset of metadata
|
| 22 |
+
def extract_data(params):
|
| 23 |
+
with open(params.metadata_labeled_path) as f:
|
| 24 |
+
metadata_labeled = json.load(f)
|
| 25 |
+
|
| 26 |
+
if params.use_unlabeled_data:
|
| 27 |
+
with open(params.metadata_unlabeled_path) as f:
|
| 28 |
+
metadata_unlabeled = json.load(f)
|
| 29 |
+
n_unlabeled_samples = len(metadata_unlabeled)
|
| 30 |
+
|
| 31 |
+
#shuffle data and removed unused files
|
| 32 |
+
random.shuffle(metadata_labeled)
|
| 33 |
+
keep = math.ceil(params.use_frac*len(metadata_labeled))
|
| 34 |
+
metadata_labeled = metadata_labeled[0:keep]
|
| 35 |
+
|
| 36 |
+
#construct raga label lookup table.
|
| 37 |
+
raga2label = get_raga2label(params)
|
| 38 |
+
|
| 39 |
+
#remove ragas unused ragas from metadata dictionary
|
| 40 |
+
metadata_labeled_final = remove_unused_ragas(metadata_labeled, raga2label)
|
| 41 |
+
|
| 42 |
+
return metadata_labeled_final, raga2label
|
| 43 |
+
|
| 44 |
+
def get_raga2label(params):
|
| 45 |
+
with open(params.num_files_per_raga_path) as f:
|
| 46 |
+
num_files_per_raga = json.load(f)
|
| 47 |
+
raga2label = {}
|
| 48 |
+
for i, raga in enumerate(num_files_per_raga.keys()):
|
| 49 |
+
raga2label[raga] = i #assign every raga to a unique number from 0 to self.num_classes
|
| 50 |
+
if i == params.num_classes-1:
|
| 51 |
+
break
|
| 52 |
+
return raga2label
|
| 53 |
+
|
| 54 |
+
def remove_unused_ragas(metadata_labeled, raga2label):
|
| 55 |
+
temp = copy.deepcopy(metadata_labeled)
|
| 56 |
+
for i, entry in enumerate(metadata_labeled):
|
| 57 |
+
raga = entry['filename'].split("/")[0]
|
| 58 |
+
if raga not in raga2label.keys(): #this raga is not in the top self.params.num_classes ragas
|
| 59 |
+
temp.remove(entry)
|
| 60 |
+
|
| 61 |
+
return temp
|
| 62 |
+
|
| 63 |
+
class RagaDataset(Dataset):
|
| 64 |
+
def __init__(self, params, metadata_labeled, raga2label):
|
| 65 |
+
self.params = params
|
| 66 |
+
self.metadata_labeled = metadata_labeled
|
| 67 |
+
self.raga2label = raga2label
|
| 68 |
+
self.n_labeled_samples = len(self.metadata_labeled)
|
| 69 |
+
self.transform_dict = {}
|
| 70 |
+
self.count=0
|
| 71 |
+
|
| 72 |
+
if params.local_rank ==0:
|
| 73 |
+
print("Begin training using ", self.__len__(), " audio samples of ", self.params.clip_length, " seconds each.")
|
| 74 |
+
print("Total number of ragas specified: ", self.params.num_classes)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def construct_label(self, raga, label_smoothing=False):
|
| 78 |
+
#construct one hot encoding vector for raga
|
| 79 |
+
raga_index = self.raga2label[raga]
|
| 80 |
+
label = torch.zeros((self.params.num_classes,), dtype = torch.float32)
|
| 81 |
+
label[raga_index] = 1
|
| 82 |
+
return label
|
| 83 |
+
|
| 84 |
+
def normalize(self, audio):
|
| 85 |
+
|
| 86 |
+
return (audio - torch.mean(audio, dim=1, keepdim=True))/(torch.std(audio, dim=1, keepdim=True) + 1e-5)
|
| 87 |
+
|
| 88 |
+
def pad_audio(self, audio):
|
| 89 |
+
pad = (0, self.params.sample_rate*self.params.clip_length - audio.shape[1])
|
| 90 |
+
return torch.nn.functional.pad(audio, pad = pad, value=0)
|
| 91 |
+
|
| 92 |
+
def __len__(self):
|
| 93 |
+
return len(self.metadata_labeled)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def __getitem__(self, idx):
|
| 97 |
+
#get metadata
|
| 98 |
+
file_info = self.metadata_labeled[idx]
|
| 99 |
+
|
| 100 |
+
#sample offset uniformly
|
| 101 |
+
rng = max(0,file_info['duration'] - self.params.clip_length)
|
| 102 |
+
if rng == 0:
|
| 103 |
+
rng = file_info['duration']
|
| 104 |
+
|
| 105 |
+
seconds_offset = np.random.randint(floor(rng))
|
| 106 |
+
|
| 107 |
+
#open audio file
|
| 108 |
+
audio_clip, sample_rate = torchaudio.load(filepath = os.path.join(self.params.labeled_data_dir, file_info['filename']), \
|
| 109 |
+
frame_offset = seconds_offset * file_info['sample_rate'], \
|
| 110 |
+
num_frames=self.params.clip_length*file_info['sample_rate'], normalize=True)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if audio_clip.shape[0] !=2:
|
| 114 |
+
audio_clip = audio_clip.repeat(2, 1)
|
| 115 |
+
|
| 116 |
+
#keep stereo
|
| 117 |
+
#audio_clip = audio_clip.mean(dim=0, keepdim=True)
|
| 118 |
+
#add transform to dictionary
|
| 119 |
+
if sample_rate not in self.transform_dict.keys():
|
| 120 |
+
self.transform_dict[sample_rate] = torchaudio.transforms.Resample(orig_freq = sample_rate, new_freq = self.params.sample_rate)
|
| 121 |
+
|
| 122 |
+
#load cached transform
|
| 123 |
+
resample = self.transform_dict[sample_rate]
|
| 124 |
+
audio_clip = resample(audio_clip)
|
| 125 |
+
|
| 126 |
+
if self.params.normalize:
|
| 127 |
+
audio_clip = self.normalize(audio_clip)
|
| 128 |
+
|
| 129 |
+
if audio_clip.size()[1] < self.params.sample_rate*self.params.clip_length:
|
| 130 |
+
#pad audio with zeros if it's not long enough
|
| 131 |
+
audio_clip = self.pad_audio(audio_clip)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
raga = file_info['filename'].split("/")[0]
|
| 135 |
+
|
| 136 |
+
#construct label
|
| 137 |
+
label = self.construct_label(raga)
|
| 138 |
+
|
| 139 |
+
assert not torch.any(torch.isnan(audio_clip))
|
| 140 |
+
assert not torch.any(torch.isnan(label))
|
| 141 |
+
assert audio_clip.shape[1] == self.params.sample_rate*self.params.clip_length
|
| 142 |
+
return audio_clip, label
|
| 143 |
+
|
inference.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from utils import logging_utils
|
| 3 |
+
logging_utils.config_logger()
|
| 4 |
+
import torch
|
| 5 |
+
import random
|
| 6 |
+
import numpy as np
|
| 7 |
+
from data.dataloader import extract_data
|
| 8 |
+
import torchaudio
|
| 9 |
+
from models.RagaNet import BaseRagaClassifier, ResNetRagaClassifier, Wav2VecTransformer, count_parameters
|
| 10 |
+
from collections import OrderedDict
|
| 11 |
+
|
| 12 |
+
np.random.seed(123)
|
| 13 |
+
random.seed(123)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Evaluator():
|
| 17 |
+
|
| 18 |
+
def __init__(self, params):
|
| 19 |
+
self.params = params
|
| 20 |
+
self.device = self.params.device
|
| 21 |
+
#get raga to label mapping
|
| 22 |
+
_, self.raga2label = extract_data(self.params)
|
| 23 |
+
self.raga_list = list(self.raga2label.keys())
|
| 24 |
+
self.label_list = list(self.raga2label.values())
|
| 25 |
+
|
| 26 |
+
#initialize model
|
| 27 |
+
if params.model == 'base':
|
| 28 |
+
self.model = BaseRagaClassifier(params).to(self.device)
|
| 29 |
+
elif params.model == 'resnet':
|
| 30 |
+
self.model = ResNetRagaClassifier(params).to(self.device)
|
| 31 |
+
elif params.model == 'wav2vec':
|
| 32 |
+
self.model = Wav2VecTransformer(params).to(self.device)
|
| 33 |
+
else:
|
| 34 |
+
logging.error("Model must be either 'base', 'resnet', or 'wav2vec'")
|
| 35 |
+
|
| 36 |
+
#load best model
|
| 37 |
+
logging.info("Loading checkpoint %s"%params.best_checkpoint_path)
|
| 38 |
+
self.restore_checkpoint('ckpts/resnet_0.7/150classes_alldata_cliplength30/training_checkpoints/best_ckpt.tar')#params.best_checkpoint_path)
|
| 39 |
+
self.model.eval()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def normalize(self, audio):
|
| 43 |
+
return (audio - torch.mean(audio, dim=1, keepdim=True))/(torch.std(audio, dim=1, keepdim=True) + 1e-5)
|
| 44 |
+
|
| 45 |
+
def pad_audio(self, audio):
|
| 46 |
+
pad = (0, self.params.sample_rate*self.params.clip_length - audio.shape[1])
|
| 47 |
+
return torch.nn.functional.pad(audio, pad = pad, value=0)
|
| 48 |
+
|
| 49 |
+
def inference(self, k, audio):
|
| 50 |
+
#open audio file
|
| 51 |
+
sample_rate, audio_clip = audio
|
| 52 |
+
|
| 53 |
+
#repeat mono channel to get stereo if necessary
|
| 54 |
+
if len(audio_clip.shape) == 1:
|
| 55 |
+
audio_clip = torch.tensor(audio_clip).unsqueeze(0).repeat(2,1).to(torch.float32)
|
| 56 |
+
else:
|
| 57 |
+
audio_clip = torch.tensor(audio_clip).T.to(torch.float32)
|
| 58 |
+
|
| 59 |
+
#resample audio clip
|
| 60 |
+
resample = torchaudio.transforms.Resample(orig_freq = sample_rate, new_freq = self.params.sample_rate)
|
| 61 |
+
audio_clip = resample(audio_clip)
|
| 62 |
+
|
| 63 |
+
#normalize the audio clip
|
| 64 |
+
if self.params.normalize:
|
| 65 |
+
audio_clip = self.normalize(audio_clip)
|
| 66 |
+
|
| 67 |
+
#pad audio with zeros if it's not long enough
|
| 68 |
+
if audio_clip.size()[1] < self.params.sample_rate*self.params.clip_length:
|
| 69 |
+
audio_clip = self.pad_audio(audio_clip)
|
| 70 |
+
|
| 71 |
+
assert not torch.any(torch.isnan(audio_clip))
|
| 72 |
+
audio_clip = audio_clip.to(self.device)
|
| 73 |
+
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
length = audio_clip.shape[1]
|
| 76 |
+
train_length = self.params.sample_rate*self.params.clip_length
|
| 77 |
+
|
| 78 |
+
pred_probs = torch.zeros((self.params.num_classes,)).to(self.device)
|
| 79 |
+
|
| 80 |
+
#loop over clip_length segments and perform inference
|
| 81 |
+
num_clips = int(np.floor(length/train_length))
|
| 82 |
+
for i in range(num_clips):
|
| 83 |
+
|
| 84 |
+
clip = audio_clip[:, i*train_length:(i+1)*train_length].unsqueeze(0)
|
| 85 |
+
|
| 86 |
+
#perform forward pass through model
|
| 87 |
+
pred_distribution = self.model(clip).reshape(-1, self.params.num_classes)
|
| 88 |
+
pred_probs += 1 / num_clips * (torch.exp(pred_distribution)/torch.exp(pred_distribution).sum(axis = 1, keepdim=True))[0]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
pred_probs, labels = pred_probs.sort(descending=True)
|
| 92 |
+
pred_probs_topk = pred_probs[:k]
|
| 93 |
+
pred_ragas_topk = [self.raga_list[self.label_list.index(label)] for label in labels[:k]]
|
| 94 |
+
d = dict(zip(pred_ragas_topk, pred_probs_topk))
|
| 95 |
+
return {k: v.item() for k, v in d.items()}
|
| 96 |
+
|
| 97 |
+
def restore_checkpoint(self, checkpoint_path):
|
| 98 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 99 |
+
try:
|
| 100 |
+
self.model.load_state_dict(checkpoint['model_state'])
|
| 101 |
+
except:
|
| 102 |
+
#loading DDP checkpoint into non-DDP model
|
| 103 |
+
new_state_dict = OrderedDict()
|
| 104 |
+
for k, v in checkpoint['model_state'].items():
|
| 105 |
+
name = k[7:] # remove `module.`
|
| 106 |
+
new_state_dict[name] = v
|
| 107 |
+
# load params
|
| 108 |
+
self.model.load_state_dict(new_state_dict)
|
| 109 |
+
|
| 110 |
+
self.iters = checkpoint['iters']
|
| 111 |
+
self.startEpoch = checkpoint['epoch']
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
labeled_0.7_wav_metadata.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
metadata_0.7.json
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"Mohanam": 1085,
|
| 3 |
+
"Sankarabharanam": 1084,
|
| 4 |
+
"Panthuvarali": 1036,
|
| 5 |
+
"Kalyani": 1032,
|
| 6 |
+
"Bhairavi": 1032,
|
| 7 |
+
"Hamsadhwani": 1005,
|
| 8 |
+
"Sindhubhairavi": 985,
|
| 9 |
+
"Kambhoji": 958,
|
| 10 |
+
"Thodi": 936,
|
| 11 |
+
"Khamas": 926,
|
| 12 |
+
"Poorvikalyani": 920,
|
| 13 |
+
"Madhyamavathi": 914,
|
| 14 |
+
"Kapi": 907,
|
| 15 |
+
"Karaharapriya": 881,
|
| 16 |
+
"Hindolam": 868,
|
| 17 |
+
"Anandabhairavi": 847,
|
| 18 |
+
"Saveri": 801,
|
| 19 |
+
"Nata": 788,
|
| 20 |
+
"Shanmukapriya": 779,
|
| 21 |
+
"Reethigowla": 743,
|
| 22 |
+
"Begada": 732,
|
| 23 |
+
"Behag": 714,
|
| 24 |
+
"Abheri": 692,
|
| 25 |
+
"Arabhi": 691,
|
| 26 |
+
"Bilahari": 651,
|
| 27 |
+
"Atana": 648,
|
| 28 |
+
"Kaanada": 613,
|
| 29 |
+
"Sri": 601,
|
| 30 |
+
"Varali": 577,
|
| 31 |
+
"Sahana": 572,
|
| 32 |
+
"Yamunakalyani": 560,
|
| 33 |
+
"Suruti": 544,
|
| 34 |
+
"Mayamalavagowla": 539,
|
| 35 |
+
"Yadukulakambhoji": 536,
|
| 36 |
+
"Mukhari": 493,
|
| 37 |
+
"Sriranjani": 477,
|
| 38 |
+
"Abhogi": 464,
|
| 39 |
+
"Natakurinji": 463,
|
| 40 |
+
"Harikambhoji": 458,
|
| 41 |
+
"Vasantha": 447,
|
| 42 |
+
"Dhanyasi": 442,
|
| 43 |
+
"Neelambari": 434,
|
| 44 |
+
"Keeravani": 421,
|
| 45 |
+
"Kedaragowla": 413,
|
| 46 |
+
"Chenchurutti": 392,
|
| 47 |
+
"Sowrashtram": 387,
|
| 48 |
+
"Hamsanandi": 378,
|
| 49 |
+
"Shuddhadhanyasi": 366,
|
| 50 |
+
"Sama": 351,
|
| 51 |
+
"Brindavanasaranga": 340,
|
| 52 |
+
"Shuddhasaveri": 331,
|
| 53 |
+
"Hamirkalyani": 323,
|
| 54 |
+
"Devagandhari": 322,
|
| 55 |
+
"Gowla": 313,
|
| 56 |
+
"Charukesi": 310,
|
| 57 |
+
"Darbar": 306,
|
| 58 |
+
"Revathi": 302,
|
| 59 |
+
"Saranga": 299,
|
| 60 |
+
"Simhendramadhyamam": 290,
|
| 61 |
+
"Dwijavanthi": 283,
|
| 62 |
+
"Kunthalavarali": 269,
|
| 63 |
+
"Jonpuri": 269,
|
| 64 |
+
"Bahudhari": 269,
|
| 65 |
+
"Huseni": 267,
|
| 66 |
+
"Ranjani": 263,
|
| 67 |
+
"Maand": 259,
|
| 68 |
+
"Nalinakanthi": 257,
|
| 69 |
+
"Asaveri": 257,
|
| 70 |
+
"Kedaram": 253,
|
| 71 |
+
"Hamsanadam": 252,
|
| 72 |
+
"Kurinji": 250,
|
| 73 |
+
"Kadanakuthuhalam": 236,
|
| 74 |
+
"Desh": 234,
|
| 75 |
+
"Bowli": 232,
|
| 76 |
+
"Amrithavarshini": 232,
|
| 77 |
+
"Saramathi": 232,
|
| 78 |
+
"Punnagavarali": 231,
|
| 79 |
+
"Mohanakalyani": 231,
|
| 80 |
+
"Subhapanthuvarali": 224,
|
| 81 |
+
"Nadanamakriya": 223,
|
| 82 |
+
"Lathangi": 222,
|
| 83 |
+
"Gambheeranata": 219,
|
| 84 |
+
"Darbarikaanada": 218,
|
| 85 |
+
"Malayamarutham": 208,
|
| 86 |
+
"Chakravakam": 204,
|
| 87 |
+
"Bageshree": 201,
|
| 88 |
+
"Ahiri": 200,
|
| 89 |
+
"Saraswathi": 200,
|
| 90 |
+
"Kannada": 197,
|
| 91 |
+
"Manirangu": 196,
|
| 92 |
+
"Kalyanavasantham": 195,
|
| 93 |
+
"Jaganmohini": 192,
|
| 94 |
+
"Andolika": 190,
|
| 95 |
+
"Poornachandrika": 184,
|
| 96 |
+
"Devamanohari": 183,
|
| 97 |
+
"Lalitha": 179,
|
| 98 |
+
"Gowrimanohari": 178,
|
| 99 |
+
"Valaji": 171,
|
| 100 |
+
"Vachaspathi": 171,
|
| 101 |
+
"Dharmavathi": 168,
|
| 102 |
+
"Paras": 161,
|
| 103 |
+
"Varamu": 159,
|
| 104 |
+
"Janaranjani": 159,
|
| 105 |
+
"Hemavathi": 154,
|
| 106 |
+
"Ravichandrika": 150,
|
| 107 |
+
"Nayaki": 143,
|
| 108 |
+
"Navarasakannada": 143,
|
| 109 |
+
"Gowlipanthu": 126,
|
| 110 |
+
"Madhuvanthi": 124,
|
| 111 |
+
"Manji": 124,
|
| 112 |
+
"Thilang": 124,
|
| 113 |
+
"Kapinarayani": 122,
|
| 114 |
+
"Kamalamanohari": 121,
|
| 115 |
+
"Ahirbhairavi": 119,
|
| 116 |
+
"Revagupthi": 118,
|
| 117 |
+
"Jayanthashree": 117,
|
| 118 |
+
"Chandrajyothi": 117,
|
| 119 |
+
"Nagaswaravali": 116,
|
| 120 |
+
"Kannadagowla": 116,
|
| 121 |
+
"Bindumalini": 112,
|
| 122 |
+
"Durga": 111,
|
| 123 |
+
"Vagadheeswari": 110,
|
| 124 |
+
"Karnaranjani": 106,
|
| 125 |
+
"Nasikabhushani": 106,
|
| 126 |
+
"Dhenuka": 103,
|
| 127 |
+
"Malavi": 102,
|
| 128 |
+
"Natabhairavi": 101,
|
| 129 |
+
"Ganamoorthi": 99,
|
| 130 |
+
"Vegavahini": 97,
|
| 131 |
+
"Sarasangi": 94,
|
| 132 |
+
"Saraswathimanohari": 93,
|
| 133 |
+
"Ramapriya": 87,
|
| 134 |
+
"Rudrapriya": 87,
|
| 135 |
+
"Chittaranjani": 86,
|
| 136 |
+
"Mandari": 86,
|
| 137 |
+
"Malahari": 81,
|
| 138 |
+
"Shuddhabangala": 80,
|
| 139 |
+
"Sunadhavinodhini": 79,
|
| 140 |
+
"Salagabhairavi": 75,
|
| 141 |
+
"Chinthamani": 74,
|
| 142 |
+
"Devamruthavarshini": 73,
|
| 143 |
+
"Brindavani": 72,
|
| 144 |
+
"Balahamsa": 70,
|
| 145 |
+
"Mishrayaman": 70,
|
| 146 |
+
"Vakulabharanam": 70,
|
| 147 |
+
"Bhavapriya": 70,
|
| 148 |
+
"Vijayashree": 69,
|
| 149 |
+
"Bhoopalam": 69,
|
| 150 |
+
"Dhanashree": 68,
|
| 151 |
+
"Jayamanohari": 67,
|
| 152 |
+
"Peeloo": 66,
|
| 153 |
+
"Shivaranjani": 66,
|
| 154 |
+
"Rasikapriya": 63,
|
| 155 |
+
"Vasanthabhairavi": 63,
|
| 156 |
+
"Rishabhapriya": 62,
|
| 157 |
+
"Garudadhwani": 61,
|
| 158 |
+
"Natakapriya": 61,
|
| 159 |
+
"Chalanata": 60,
|
| 160 |
+
"Ratipatipriya": 60,
|
| 161 |
+
"Gowdamalhar": 58,
|
| 162 |
+
"Navaroj": 58,
|
| 163 |
+
"Karnatakakapi": 58,
|
| 164 |
+
"Kumudakriya": 58,
|
| 165 |
+
"Bangala": 54,
|
| 166 |
+
"Chenchukambhoji": 54,
|
| 167 |
+
"Shuddhasarang": 53,
|
| 168 |
+
"Kalavathi": 53,
|
| 169 |
+
"Bhairavam": 53,
|
| 170 |
+
"Narayanagowla": 52,
|
| 171 |
+
"Shuddhaseemanthini": 52,
|
| 172 |
+
"Hamsavinodhini": 49,
|
| 173 |
+
"Manoranjani": 49,
|
| 174 |
+
"Chayatharangini": 49,
|
| 175 |
+
"Kiranavali": 49,
|
| 176 |
+
"Amrithavahini": 49,
|
| 177 |
+
"Devagandharam": 48,
|
| 178 |
+
"Jayanthasena": 47,
|
| 179 |
+
"Ramamanohari": 47,
|
| 180 |
+
"Mishramaand": 47,
|
| 181 |
+
"Nagagandhari": 47,
|
| 182 |
+
"Kanthamani": 46,
|
| 183 |
+
"Neethimathi": 46,
|
| 184 |
+
"Lalithapanchamam": 45,
|
| 185 |
+
"Vijayanagari": 45,
|
| 186 |
+
"Devakriya": 45,
|
| 187 |
+
"Rageshree": 45,
|
| 188 |
+
"Suryakantam": 45,
|
| 189 |
+
"Chandrakowns": 45,
|
| 190 |
+
"Vasanthi": 45,
|
| 191 |
+
"Pushpalatika": 44,
|
| 192 |
+
"Umabharanam": 44,
|
| 193 |
+
"Kanakangi": 43,
|
| 194 |
+
"Sindhuramakriya": 42,
|
| 195 |
+
"Gangeyabhushani": 41,
|
| 196 |
+
"Poornashadjam": 41,
|
| 197 |
+
"Tharangini": 39,
|
| 198 |
+
"Lavangi": 39,
|
| 199 |
+
"Karnatakabehag": 38,
|
| 200 |
+
"Sucharithra": 38,
|
| 201 |
+
"Neelamani": 38,
|
| 202 |
+
"Hindolavasantham": 37,
|
| 203 |
+
"Margahindolam": 37,
|
| 204 |
+
"Mishrashivaranjani": 36,
|
| 205 |
+
"Patdeep": 36,
|
| 206 |
+
"Jyothiswaroopini": 35,
|
| 207 |
+
"Phalamanjari": 35,
|
| 208 |
+
"Urmika": 35,
|
| 209 |
+
"Yaman": 35,
|
| 210 |
+
"Vanaspathi": 35,
|
| 211 |
+
"Veeravasantham": 35,
|
| 212 |
+
"Gamanashrama": 34,
|
| 213 |
+
"Kokilavarali": 33,
|
| 214 |
+
"Pahadi": 33,
|
| 215 |
+
"Kalanidhi": 33,
|
| 216 |
+
"Sindhukannada": 32,
|
| 217 |
+
"Narayani": 32,
|
| 218 |
+
"Manjari": 31,
|
| 219 |
+
"Hindusthanigandhari": 29,
|
| 220 |
+
"Bhavani": 29,
|
| 221 |
+
"Purvi": 29,
|
| 222 |
+
"Prathapavarali": 29,
|
| 223 |
+
"Vallabhi": 28,
|
| 224 |
+
"Karnatakashuddhasaveri": 27,
|
| 225 |
+
"Megharanjani": 27,
|
| 226 |
+
"Isamanohari": 27,
|
| 227 |
+
"Soorya": 27,
|
| 228 |
+
"Pasupathipriya": 26,
|
| 229 |
+
"Gambheeravani": 26,
|
| 230 |
+
"Simhavahini": 26,
|
| 231 |
+
"Madhavamanohari": 26,
|
| 232 |
+
"Niroshta": 26,
|
| 233 |
+
"Varunapriya": 26,
|
| 234 |
+
"Manavathi": 26,
|
| 235 |
+
"Naganandhini": 25,
|
| 236 |
+
"Ghanta": 25,
|
| 237 |
+
"Shivashakti": 25,
|
| 238 |
+
"Kaikavasi": 24,
|
| 239 |
+
"Kokilapriya": 24,
|
| 240 |
+
"Mishrapahadi": 24,
|
| 241 |
+
"Saindhavi": 24,
|
| 242 |
+
"Jingala": 24,
|
| 243 |
+
"Kosalam": 24,
|
| 244 |
+
"Kalgada": 23,
|
| 245 |
+
"Sumanesaranjani": 23,
|
| 246 |
+
"Bhushavali": 23,
|
| 247 |
+
"Narireethigowla": 22,
|
| 248 |
+
"Rasali": 22,
|
| 249 |
+
"Vivardhini": 22,
|
| 250 |
+
"Jankaradhvani": 22,
|
| 251 |
+
"Sarangatharangini": 21,
|
| 252 |
+
"Paadi": 21,
|
| 253 |
+
"Gowri": 21,
|
| 254 |
+
"Bhujangini": 21,
|
| 255 |
+
"Gurjari": 21,
|
| 256 |
+
"Suposhini": 21,
|
| 257 |
+
"Udhayaravichandrika": 20,
|
| 258 |
+
"Rathnangi": 20,
|
| 259 |
+
"Deepali": 20,
|
| 260 |
+
"Ragavardhani": 20,
|
| 261 |
+
"Supradeepam": 19,
|
| 262 |
+
"Mangalakaishiki": 19,
|
| 263 |
+
"Kannadabangala": 19,
|
| 264 |
+
"Shadhvidhamargini": 19,
|
| 265 |
+
"Kokiladhwani": 18,
|
| 266 |
+
"Mahathi": 18,
|
| 267 |
+
"Janasammodhini": 17,
|
| 268 |
+
"Takka": 17,
|
| 269 |
+
"Vasanthavarali": 17,
|
| 270 |
+
"Thanarupi": 17,
|
| 271 |
+
"Amrithabehag": 17,
|
| 272 |
+
"Rupavathi": 16,
|
| 273 |
+
"Ganavaridhi": 16,
|
| 274 |
+
"Natanarayani": 15,
|
| 275 |
+
"Jog": 15,
|
| 276 |
+
"Manorama": 15,
|
| 277 |
+
"Swarabhushani": 15,
|
| 278 |
+
"Gopikavasantham": 15,
|
| 279 |
+
"Shuddhadesi": 14,
|
| 280 |
+
"Mararanjani": 14,
|
| 281 |
+
"Deshakshi": 14,
|
| 282 |
+
"Gundakriya": 14,
|
| 283 |
+
"Gayakapriya": 14,
|
| 284 |
+
"Chithrambari": 14,
|
| 285 |
+
"Maruvabehag": 14,
|
| 286 |
+
"Ramkali": 13,
|
| 287 |
+
"Shulini": 13,
|
| 288 |
+
"Desiyathodi": 13,
|
| 289 |
+
"Vandanadharini": 13,
|
| 290 |
+
"Malavashree": 13,
|
| 291 |
+
"Deepakam": 12,
|
| 292 |
+
"Kokilaravam": 12,
|
| 293 |
+
"Vijayasaraswathi": 12,
|
| 294 |
+
"Mishraharikambhoji": 12,
|
| 295 |
+
"Maruvadhanyasi": 12,
|
| 296 |
+
"Mohanangi": 12,
|
| 297 |
+
"Hatakambari": 12,
|
| 298 |
+
"Yagapriya": 12,
|
| 299 |
+
"Suvarnangi": 12,
|
| 300 |
+
"Samanta": 12,
|
| 301 |
+
"Basantbahar": 12,
|
| 302 |
+
"Mishrapeeloo": 11,
|
| 303 |
+
"Raghupriya": 11,
|
| 304 |
+
"Pavani": 11,
|
| 305 |
+
"Navanitham": 11,
|
| 306 |
+
"Sindhumandari": 11,
|
| 307 |
+
"Buddhamanohari": 11,
|
| 308 |
+
"Shuddhavasantha": 10,
|
| 309 |
+
"Bhanumathi": 10,
|
| 310 |
+
"Andhali": 10,
|
| 311 |
+
"Jalarnavam": 10,
|
| 312 |
+
"Puriyadhanashree": 10,
|
| 313 |
+
"Gavambodhi": 10,
|
| 314 |
+
"Vamsavathi": 10,
|
| 315 |
+
"Sarangamalhar": 10,
|
| 316 |
+
"Dhavalambari": 10,
|
| 317 |
+
"Gowrivelaavali": 9,
|
| 318 |
+
"Namanarayani": 9,
|
| 319 |
+
"Shyamakalyani": 9,
|
| 320 |
+
"Bhoopali": 9,
|
| 321 |
+
"Vishnupriya": 8,
|
| 322 |
+
"Jyothi": 8,
|
| 323 |
+
"Gopikathilakam": 8,
|
| 324 |
+
"Chayanata": 8,
|
| 325 |
+
"Shruthiranjani": 8,
|
| 326 |
+
"Santhanamanjari": 8,
|
| 327 |
+
"Ardhradesi": 8,
|
| 328 |
+
"Sumukham": 8,
|
| 329 |
+
"Madhukowns": 8,
|
| 330 |
+
"Dhivyamani": 8,
|
| 331 |
+
"Shivakambhoji": 7,
|
| 332 |
+
"Devaranji": 7,
|
| 333 |
+
"Chayagowla": 7,
|
| 334 |
+
"Narthaki": 7,
|
| 335 |
+
"Rasamanjari": 7,
|
| 336 |
+
"Vardhani": 7,
|
| 337 |
+
"Dhurvanki": 7,
|
| 338 |
+
"Phenadyuthi": 7,
|
| 339 |
+
"Gavathi": 7,
|
| 340 |
+
"Poornalalitha": 7,
|
| 341 |
+
"Kunthalam": 7,
|
| 342 |
+
"Poorvakamodari": 7,
|
| 343 |
+
"Mahuri": 7,
|
| 344 |
+
"Chandrika": 7,
|
| 345 |
+
"Rohini": 7,
|
| 346 |
+
"Mishrakhamaj": 7,
|
| 347 |
+
"Salagam": 7,
|
| 348 |
+
"Sarvashree": 7,
|
| 349 |
+
"Bhooshavathi": 7,
|
| 350 |
+
"Shreemani": 6,
|
| 351 |
+
"Sumanapriya": 6,
|
| 352 |
+
"Sushama": 6,
|
| 353 |
+
"Latantapriya": 6,
|
| 354 |
+
"Gurjarithodi": 6,
|
| 355 |
+
"Guharanjani": 6,
|
| 356 |
+
"Namadesi": 6,
|
| 357 |
+
"Kunjari": 6,
|
| 358 |
+
"Maargadesi": 6,
|
| 359 |
+
"Ragapanjaram": 6,
|
| 360 |
+
"Harinarayani": 6,
|
| 361 |
+
"Ravikriya": 6,
|
| 362 |
+
"Gopriya": 5,
|
| 363 |
+
"Chayaranjani": 5,
|
| 364 |
+
"Hamsadeepika": 5,
|
| 365 |
+
"Nadavarangini": 5,
|
| 366 |
+
"Hamsalatha": 5,
|
| 367 |
+
"Shuddhavalaji": 5,
|
| 368 |
+
"Bhogachayanata": 5,
|
| 369 |
+
"Phalaranjani": 5,
|
| 370 |
+
"Bhogavasantha": 5,
|
| 371 |
+
"Senavathi": 5,
|
| 372 |
+
"Thandavam": 5,
|
| 373 |
+
"Swaravali": 5,
|
| 374 |
+
"Geethapriya": 5,
|
| 375 |
+
"Jayanarayani": 5,
|
| 376 |
+
"Vishwambhari": 5,
|
| 377 |
+
"Nagavarali": 5,
|
| 378 |
+
"Karthyayani": 5,
|
| 379 |
+
"Sharavathi": 5,
|
| 380 |
+
"Ganasamavarali": 5,
|
| 381 |
+
"Jaganmohinam": 5,
|
| 382 |
+
"Shyamalangi": 5,
|
| 383 |
+
"Nagavalli": 5,
|
| 384 |
+
"Bhinnashadjam": 5,
|
| 385 |
+
"Suraranjani": 5,
|
| 386 |
+
"Nabhomani": 4,
|
| 387 |
+
"Maruva": 4,
|
| 388 |
+
"Komalangi": 4,
|
| 389 |
+
"Dhavalangam": 4,
|
| 390 |
+
"Shivapriya": 4,
|
| 391 |
+
"Bhavabharanam": 4,
|
| 392 |
+
"Nagabhushani": 4,
|
| 393 |
+
"Swarasammodhini": 4,
|
| 394 |
+
"Gomedhikapriya": 4,
|
| 395 |
+
"Dhatuvardhani": 4,
|
| 396 |
+
"Hamsakalyani": 4,
|
| 397 |
+
"Samakadambari": 4,
|
| 398 |
+
"Sauvira": 4,
|
| 399 |
+
"Thanukeerthi": 4,
|
| 400 |
+
"Ragachoodaamani": 4,
|
| 401 |
+
"Rasikaranjani": 3,
|
| 402 |
+
"Moharanjani": 3,
|
| 403 |
+
"Hamsabhramari": 3,
|
| 404 |
+
"Dundubi": 3,
|
| 405 |
+
"Miyanmalhar": 3,
|
| 406 |
+
"Murali": 3,
|
| 407 |
+
"Shekharachandrikaa": 3,
|
| 408 |
+
"Madhulika": 3,
|
| 409 |
+
"Krisnaveni": 3,
|
| 410 |
+
"Nagadhwani": 3,
|
| 411 |
+
"Dakshayani": 3,
|
| 412 |
+
"Kalasaveri": 3,
|
| 413 |
+
"Sthavarajam": 3,
|
| 414 |
+
"Mishrajog": 3,
|
| 415 |
+
"Srothaswani": 3,
|
| 416 |
+
"Sowrasenaa": 3,
|
| 417 |
+
"Samudrapriya": 3,
|
| 418 |
+
"Bhuvanagaandhaari": 3,
|
| 419 |
+
"Hamsanantini": 3,
|
| 420 |
+
"Bhanuchandrika": 3,
|
| 421 |
+
"Balachandrika": 3,
|
| 422 |
+
"Shuddhamukhari": 3,
|
| 423 |
+
"Viswapriya": 3,
|
| 424 |
+
"Gangatharangini": 3,
|
| 425 |
+
"Sharadapriya": 3,
|
| 426 |
+
"Chathurangini": 3,
|
| 427 |
+
"Venkatadri": 3,
|
| 428 |
+
"Narayanadri": 3,
|
| 429 |
+
"Puriyakalyan": 3,
|
| 430 |
+
"Hindusthanitodi": 3,
|
| 431 |
+
"Dayavathi": 3,
|
| 432 |
+
"Kokila": 3,
|
| 433 |
+
"Madhyamaravali": 3,
|
| 434 |
+
"Vivahapriya": 3,
|
| 435 |
+
"Vijayavasantha": 3,
|
| 436 |
+
"Kuvalayabharanam": 3,
|
| 437 |
+
"Swararanjani": 3,
|
| 438 |
+
"Mishrabilahari": 3,
|
| 439 |
+
"Salanganata": 3,
|
| 440 |
+
"Malavapanchamam": 3,
|
| 441 |
+
"Siddhasena": 3,
|
| 442 |
+
"Jalavarali": 3,
|
| 443 |
+
"Haricharan": 3,
|
| 444 |
+
"Karnatakahindolam": 3,
|
| 445 |
+
"Senagrani": 3,
|
| 446 |
+
"Nadhabrahma": 3,
|
| 447 |
+
"Poorvagowla": 3,
|
| 448 |
+
"Veenadhaari": 3,
|
| 449 |
+
"Geyahejjajji": 3,
|
| 450 |
+
"Tarani": 3,
|
| 451 |
+
"Navarathnavilaasam": 3,
|
| 452 |
+
"Bhanudhanyasi": 3,
|
| 453 |
+
"Kolahalam": 3,
|
| 454 |
+
"Panchamam": 2,
|
| 455 |
+
"Bhoopalapanchamam": 2,
|
| 456 |
+
"Rasavinodhini": 2,
|
| 457 |
+
"Seshadri": 2,
|
| 458 |
+
"Tarakagowla": 2,
|
| 459 |
+
"Poornapanchamam": 2,
|
| 460 |
+
"Kesari": 2,
|
| 461 |
+
"Shailadeshakshi": 2,
|
| 462 |
+
"Jujahuli": 2,
|
| 463 |
+
"Churnikavinodhini": 2,
|
| 464 |
+
"Rukmambari": 2,
|
| 465 |
+
"Kalakanti": 2,
|
| 466 |
+
"Omkaari": 2,
|
| 467 |
+
"Hindoladarbar": 2,
|
| 468 |
+
"Mukthidayini": 2,
|
| 469 |
+
"Hradini": 2,
|
| 470 |
+
"Velaavali": 2,
|
| 471 |
+
"Kowshikadhwani": 2,
|
| 472 |
+
"Hamsavahini": 2,
|
| 473 |
+
"Srikara": 2,
|
| 474 |
+
"Mishramanolayam": 2,
|
| 475 |
+
"Bhanupriya": 2,
|
| 476 |
+
"Manolayam": 2,
|
| 477 |
+
"Mukundamalini": 2,
|
| 478 |
+
"Siddhi": 2,
|
| 479 |
+
"Ramakriya": 2,
|
| 480 |
+
"Suranandini": 2,
|
| 481 |
+
"Chandrapriya": 2,
|
| 482 |
+
"Shuddhasaranga": 2,
|
| 483 |
+
"Bhatiyaar": 2,
|
| 484 |
+
"Malava": 2,
|
| 485 |
+
"Kusumakaram": 2,
|
| 486 |
+
"Kannadamaruva": 2,
|
| 487 |
+
"Sowgandhini": 2,
|
| 488 |
+
"Shuddhathodi": 2,
|
| 489 |
+
"Thivravahini": 1,
|
| 490 |
+
"Savithri": 1,
|
| 491 |
+
"Shreekaanti": 1,
|
| 492 |
+
"Kumarapriya": 1,
|
| 493 |
+
"Sugunabhooshani": 1,
|
| 494 |
+
"Kalyanakesari": 1,
|
| 495 |
+
"Nagabharanam": 1,
|
| 496 |
+
"Jadathari": 1,
|
| 497 |
+
"Guhamanohari": 1,
|
| 498 |
+
"Agnikopa": 1,
|
| 499 |
+
"Pranavapriya": 1,
|
| 500 |
+
"Karpoorabharani": 1,
|
| 501 |
+
"Rojakadambari": 1,
|
| 502 |
+
"Shuddhakambhoji": 1,
|
| 503 |
+
"Hindhusthanibehag": 1,
|
| 504 |
+
"Jnanachinthamani": 1,
|
| 505 |
+
"Chandrahasitham": 1,
|
| 506 |
+
"Triveni": 1,
|
| 507 |
+
"Dravidakalavati": 1,
|
| 508 |
+
"Kumbhini": 1,
|
| 509 |
+
"Vasukari": 1,
|
| 510 |
+
"Vitapi": 1,
|
| 511 |
+
"Saranganata": 1,
|
| 512 |
+
"Visharada": 1,
|
| 513 |
+
"Vinodhini": 1,
|
| 514 |
+
"Sarasaanana": 1,
|
| 515 |
+
"Shuddhalalitha": 1,
|
| 516 |
+
"Rishipriya": 1,
|
| 517 |
+
"Dhipaka": 1,
|
| 518 |
+
"Shuddhasalavi": 1,
|
| 519 |
+
"Sutradhari": 1,
|
| 520 |
+
"Natabharanam": 1,
|
| 521 |
+
"Mishrabhairavi": 1,
|
| 522 |
+
"Priyadarshani": 1,
|
| 523 |
+
"Hamsagamini": 1,
|
| 524 |
+
"Shyamalam": 1,
|
| 525 |
+
"Jeevanthika": 1,
|
| 526 |
+
"Alankari": 1,
|
| 527 |
+
"Jayashuddhamaalavi": 1,
|
| 528 |
+
"Mechabowli": 1,
|
| 529 |
+
"Garudapriya": 1
|
| 530 |
+
}
|
models/RagaNet.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torchaudio.models import wav2vec2_model
|
| 4 |
+
|
| 5 |
+
def count_parameters(model):
|
| 6 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 7 |
+
|
| 8 |
+
#basic conv block
|
| 9 |
+
def conv_block(n_input, n_output, stride=1, kernel_size=80):
|
| 10 |
+
layers = []
|
| 11 |
+
if stride ==1:
|
| 12 |
+
layers.append(nn.Conv1d(n_input, n_output, kernel_size=kernel_size, stride=stride, padding='same')) #Conv
|
| 13 |
+
else:
|
| 14 |
+
layers.append(nn.Conv1d(n_input, n_output, kernel_size=kernel_size, stride=stride)) #Conv
|
| 15 |
+
layers.append(nn.BatchNorm1d(n_output))
|
| 16 |
+
layers.append(nn.ReLU())
|
| 17 |
+
return nn.Sequential(*layers)
|
| 18 |
+
|
| 19 |
+
#basic 2-conv residual block
|
| 20 |
+
class ResidualBlock(nn.Module):
|
| 21 |
+
def __init__(self, n_channels, kernel_size):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
self.conv_block1 = conv_block(n_channels, n_channels, stride = 1, kernel_size=kernel_size)
|
| 25 |
+
self.conv_block2 = conv_block(n_channels, n_channels, stride= 1, kernel_size=3)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
|
| 29 |
+
identity = x
|
| 30 |
+
x = self.conv_block1(x)
|
| 31 |
+
x = self.conv_block2(x)
|
| 32 |
+
x = x + identity
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ResNetRagaClassifier(nn.Module):
|
| 37 |
+
def __init__(self, params):
|
| 38 |
+
super().__init__()
|
| 39 |
+
n_input = params.n_input
|
| 40 |
+
n_channel = params.n_channel
|
| 41 |
+
stride = params.stride
|
| 42 |
+
self.n_blocks = params.n_blocks
|
| 43 |
+
self.conv_first = conv_block(n_input, n_channel, stride=stride, kernel_size = 80)
|
| 44 |
+
self.max_pool_every = params.max_pool_every
|
| 45 |
+
|
| 46 |
+
self.res_blocks = nn.ModuleList() #Residual Blocks
|
| 47 |
+
for i in range(self.n_blocks):
|
| 48 |
+
self.res_blocks.append(ResidualBlock(n_channel, kernel_size=3))
|
| 49 |
+
|
| 50 |
+
#linear classification head
|
| 51 |
+
self.fc1 = nn.Linear(n_channel, params.num_classes)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
#initial conv
|
| 56 |
+
x = self.conv_first(x)
|
| 57 |
+
|
| 58 |
+
#residual blocks
|
| 59 |
+
for i, block in enumerate(self.res_blocks):
|
| 60 |
+
x = block(x)
|
| 61 |
+
if i % self.max_pool_every == 0:
|
| 62 |
+
x = F.max_pool1d(x, 2)
|
| 63 |
+
|
| 64 |
+
#classification head
|
| 65 |
+
x = F.avg_pool1d(x, x.shape[-1])
|
| 66 |
+
x = x.permute(0, 2, 1)
|
| 67 |
+
x = self.fc1(x)
|
| 68 |
+
x = F.log_softmax(x, dim=-1)
|
| 69 |
+
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class BaseRagaClassifier(nn.Module):
|
| 74 |
+
def __init__(self, params):
|
| 75 |
+
super().__init__()
|
| 76 |
+
n_input = params.n_input
|
| 77 |
+
n_channel = params.n_channel
|
| 78 |
+
stride = params.stride
|
| 79 |
+
self.conv_blocks = []
|
| 80 |
+
|
| 81 |
+
self.conv_block1 = conv_block(n_input, n_channel, stride=stride, kernel_size=80)
|
| 82 |
+
self.conv_block2 = conv_block(n_channel, n_channel, stride=1, kernel_size=3)
|
| 83 |
+
self.conv_block3 = conv_block(n_channel, 2*n_channel, stride=1, kernel_size=3)
|
| 84 |
+
self.conv_block4 = conv_block(2*n_channel, 2*n_channel, stride=1, kernel_size=3)
|
| 85 |
+
self.fc1 = nn.Linear(2 * n_channel, params.num_classes)
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
x = self.conv_block1(x)
|
| 89 |
+
x = F.max_pool1d(x, 4)
|
| 90 |
+
x = self.conv_block2(x)
|
| 91 |
+
x = F.max_pool1d(x, 4)
|
| 92 |
+
x = self.conv_block3(x)
|
| 93 |
+
x = F.max_pool1d(x, 4)
|
| 94 |
+
x = self.conv_block4(x)
|
| 95 |
+
x = F.avg_pool1d(x, x.shape[-1])
|
| 96 |
+
x = x.permute(0, 2, 1)
|
| 97 |
+
x = self.fc1(x)
|
| 98 |
+
x = F.log_softmax(x, dim=-1)
|
| 99 |
+
return x
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class Wav2VecTransformer(nn.Module):
|
| 103 |
+
def __init__(self, params):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.params = params
|
| 106 |
+
self.extractor_mode = params.extractor_mode
|
| 107 |
+
self.extractor_conv_layer_config = params.extractor_conv_layer_config
|
| 108 |
+
self.extractor_conv_bias = params.extractor_conv_bias
|
| 109 |
+
self.encoder_embed_dim = params.encoder_embed_dim
|
| 110 |
+
self.encoder_projection_dropout = params.encoder_projection_dropout
|
| 111 |
+
self.encoder_pos_conv_kernel = params.encoder_pos_conv_kernel
|
| 112 |
+
self.encoder_pos_conv_groups = params.encoder_pos_conv_groups
|
| 113 |
+
self.encoder_num_layers = params.encoder_num_layers
|
| 114 |
+
self.encoder_num_heads = params.encoder_num_heads
|
| 115 |
+
self.encoder_attention_dropout = params.encoder_attention_dropout
|
| 116 |
+
self.encoder_ff_interm_features = params.encoder_ff_interm_features
|
| 117 |
+
self.encoder_ff_interm_dropout = params.encoder_ff_interm_dropout
|
| 118 |
+
self.encoder_dropout = params.encoder_dropout
|
| 119 |
+
self.encoder_layer_norm_first = params.encoder_layer_norm_first
|
| 120 |
+
self.encoder_layer_drop = params.encoder_layer_drop
|
| 121 |
+
self.aux_num_out = params.num_classes
|
| 122 |
+
|
| 123 |
+
self.extractor_conv_layer_config = [
|
| 124 |
+
(32, 80, 16),
|
| 125 |
+
(64, 5, 4),
|
| 126 |
+
(128, 5, 4),
|
| 127 |
+
(256, 5, 4),
|
| 128 |
+
(512, 3, 2),
|
| 129 |
+
(512, 2, 2),
|
| 130 |
+
(512, 2, 2),
|
| 131 |
+
]
|
| 132 |
+
self.encoder = wav2vec2_model(self.extractor_mode, \
|
| 133 |
+
self.extractor_conv_layer_config, \
|
| 134 |
+
self.extractor_conv_bias, \
|
| 135 |
+
self.encoder_embed_dim, \
|
| 136 |
+
self.encoder_projection_dropout,\
|
| 137 |
+
self.encoder_pos_conv_kernel,\
|
| 138 |
+
self.encoder_pos_conv_groups,\
|
| 139 |
+
self.encoder_num_layers,
|
| 140 |
+
self.encoder_num_heads,
|
| 141 |
+
self.encoder_attention_dropout,
|
| 142 |
+
self.encoder_ff_interm_features,
|
| 143 |
+
self.encoder_ff_interm_dropout,
|
| 144 |
+
self.encoder_dropout,\
|
| 145 |
+
self.encoder_layer_norm_first,\
|
| 146 |
+
self.encoder_layer_drop,
|
| 147 |
+
aux_num_out = None)
|
| 148 |
+
|
| 149 |
+
self.audio_length = params.sample_rate*params.clip_length
|
| 150 |
+
self.classification_head = nn.Linear(int(self.audio_length/(16*4*4*4*2*2*2))*params.encoder_embed_dim, params.num_classes)
|
| 151 |
+
|
| 152 |
+
def forward(self, x):
|
| 153 |
+
x = self.encoder(x)[0]
|
| 154 |
+
x = x.reshape(x.shape[0], -1) # flatten
|
| 155 |
+
x = self.classification_head(x)
|
| 156 |
+
x = F.log_softmax(x, dim=-1)
|
| 157 |
+
return x
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchaudio
|
| 3 |
+
numpy
|
site/tsne.jpeg
ADDED
|
utils/YParams.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ruamel.yaml import YAML
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
class YParams():
|
| 5 |
+
""" Yaml file parser """
|
| 6 |
+
def __init__(self, yaml_filename, config_name, print_params=False):
|
| 7 |
+
self._yaml_filename = yaml_filename
|
| 8 |
+
self._config_name = config_name
|
| 9 |
+
self.params = {}
|
| 10 |
+
|
| 11 |
+
if print_params:
|
| 12 |
+
print("------------------ Configuration ------------------")
|
| 13 |
+
|
| 14 |
+
with open(yaml_filename) as _file:
|
| 15 |
+
|
| 16 |
+
for key, val in YAML().load(_file)[config_name].items():
|
| 17 |
+
if print_params: print(key, val)
|
| 18 |
+
if val =='None': val = None
|
| 19 |
+
|
| 20 |
+
self.params[key] = val
|
| 21 |
+
self.__setattr__(key, val)
|
| 22 |
+
|
| 23 |
+
if print_params:
|
| 24 |
+
print("---------------------------------------------------")
|
| 25 |
+
|
| 26 |
+
def __getitem__(self, key):
|
| 27 |
+
return self.params[key]
|
| 28 |
+
|
| 29 |
+
def __setitem__(self, key, val):
|
| 30 |
+
self.params[key] = val
|
| 31 |
+
self.__setattr__(key, val)
|
| 32 |
+
|
| 33 |
+
def __contains__(self, key):
|
| 34 |
+
return (key in self.params)
|
| 35 |
+
|
| 36 |
+
def update_params(self, config):
|
| 37 |
+
for key, val in config.items():
|
| 38 |
+
self.params[key] = val
|
| 39 |
+
self.__setattr__(key, val)
|
| 40 |
+
|
| 41 |
+
def log(self):
|
| 42 |
+
logging.info("------------------ Configuration ------------------")
|
| 43 |
+
logging.info("Configuration file: "+str(self._yaml_filename))
|
| 44 |
+
logging.info("Configuration name: "+str(self._config_name))
|
| 45 |
+
for key, val in self.params.items():
|
| 46 |
+
logging.info(str(key) + ' ' + str(val))
|
| 47 |
+
logging.info("---------------------------------------------------")
|
utils/logging_utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 5 |
+
|
| 6 |
+
def config_logger(log_level=logging.INFO):
|
| 7 |
+
logging.basicConfig(format=_format, level=log_level)
|
| 8 |
+
|
| 9 |
+
def log_to_file(logger_name=None, log_level=logging.INFO, log_filename='tensorflow.log'):
|
| 10 |
+
|
| 11 |
+
if not os.path.exists(os.path.dirname(log_filename)):
|
| 12 |
+
os.makedirs(os.path.dirname(log_filename))
|
| 13 |
+
|
| 14 |
+
if logger_name is not None:
|
| 15 |
+
log = logging.getLogger(logger_name)
|
| 16 |
+
else:
|
| 17 |
+
log = logging.getLogger()
|
| 18 |
+
|
| 19 |
+
fh = logging.FileHandler(log_filename)
|
| 20 |
+
fh.setLevel(log_level)
|
| 21 |
+
fh.setFormatter(logging.Formatter(_format))
|
| 22 |
+
log.addHandler(fh)
|
| 23 |
+
|
| 24 |
+
def log_versions():
|
| 25 |
+
import torch
|
| 26 |
+
import subprocess
|
| 27 |
+
|
| 28 |
+
logging.info('--------------- Versions ---------------')
|
| 29 |
+
logging.info('git branch: ' + str(subprocess.check_output(['git', 'branch']).strip()))
|
| 30 |
+
logging.info('git hash: ' + str(subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip()))
|
| 31 |
+
logging.info('Torch: ' + str(torch.__version__))
|
| 32 |
+
logging.info('----------------------------------------')
|