jeevster commited on
Commit
64094d4
·
1 Parent(s): 08771d5

huggingface space main commit

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
about.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### About the Classifier
2
+ The classifier is a [convolutional neural network](https://en.wikipedia.org/wiki/Convolutional_neural_network) trained on over 10,000 hours of Carnatic audio sourced from this incredible [YouTube collection](https://ramanarunachalam.github.io/Music/Carnatic/carnatic.html).
3
+ ### Key Features:
4
+ - Can identify **150 ragas**
5
+ - Does not require any information about the **shruthi (tonic pitch)** of the recording.
6
+ - **Compatible** with male/female vocal or instrumental recordings.
7
+
8
+ ### Interpreting the Classifier:
9
+ We can gain an intuitive sense for what the classifier has learned. Here is a [t-SNE](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding) projection of the hidden activations averaged per ragam. Each point is a ragam, and relative distances between the points indicate the degree to which the classifier thinks the ragas are similar. Each ragam is color coded by the [melakartha chakra](https://en.wikipedia.org/wiki/Melakarta#Chakras) it belongs to. We observe that the classifier has learned to a representation that roughly corresponds to these chakras!
10
+
11
+
12
+
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from inference import Evaluator
3
+ import argparse
4
+ from utils.YParams import YParams
5
+ import torch
6
+ import gradio as gr
7
+
8
+ def read_markdown_file(path):
9
+ with open(path, 'r', encoding='utf-8') as file:
10
+ return file.read()
11
+
12
+
13
+ if __name__ == '__main__':
14
+
15
+ #parse args
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("--yaml_config", default='config.yaml', type=str)
18
+ parser.add_argument("--config", default='resnet_0.7', type=str)
19
+
20
+ args = parser.parse_args()
21
+ params = YParams(os.path.abspath(args.yaml_config), args.config)
22
+
23
+ #GPU stuff
24
+ try:
25
+ params.device = torch.device(torch.cuda.current_device())
26
+ except:
27
+ params.device = "cpu"
28
+
29
+ #checkpoint stuff
30
+ expDir = "ckpts/resnet_0.7/150classes_alldata_cliplength30"
31
+ params['checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/ckpt.tar')
32
+ params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
33
+
34
+ evaluator = Evaluator(params)
35
+
36
+ with gr.Blocks() as demo:
37
+ with gr.Tab("Classifier"):
38
+ gr.Interface(
39
+ title="Carnatic Raga Classifier",
40
+ description="**Welcome!** This is a deep-learning based raga classifier. Simply upload or record an audio clip to test it out. \n",
41
+ article = "**Get in Touch:** Feel free to reach out to [me](https://sanjeevraja.com/) via email (sanjeevr AT berkeley DOT edu) with any questions or feedback! ",
42
+ fn=evaluator.inference,
43
+ inputs=[
44
+ gr.Slider(minimum = 1, maximum = 150, value = 5, label = "Number of displayed ragas", info = "Choose number of top predictions to display"),
45
+ gr.Audio()
46
+ ],
47
+ outputs="label",
48
+ allow_flagging = False
49
+ )
50
+
51
+ with gr.Tab("About"):
52
+ gr.Markdown(read_markdown_file('about.md'))
53
+ gr.Image('site/tsne.jpeg', height = 800, width=800)
54
+
55
+ demo.launch()
56
+
57
+
ckpts/resnet_0.7/150classes_alldata_cliplength30/training_checkpoints/best_ckpt.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09844cbcfc6c98af632671ca14dd4878fe6656811cdb77097a847c8b324362ca
3
+ size 66309523
config.yaml ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default: &DEFAULT
2
+
3
+ #data
4
+ exp_dir: carnatic/ckpts
5
+ metadata_labeled_path: labeled_small_wav_metadata.json
6
+ metadata_unlabeled_path: unlabeled_data/unlabeled_mp3_metadata.json
7
+ num_files_per_raga_path: metadata_small.json
8
+ num_classes: 150
9
+ use_frac: 1
10
+ use_unlabeled_data: !!bool False
11
+ labeled_data_dir: labeled_data_small
12
+ clip_length: 30
13
+ sample_rate: 8000
14
+ normalize: !!bool True
15
+
16
+ #training
17
+ batch_size: 16
18
+ num_data_workers: 16
19
+ n_epochs: 100
20
+ lr: 0.001
21
+ class_imbalance_weights: !!bool False
22
+ patience: 10
23
+ train_frac: 0.8
24
+
25
+ #model
26
+ model: 'base'
27
+ n_input: 2 #stereo
28
+ stride: 16
29
+ n_channel: 32
30
+ max_pool_every: 1
31
+
32
+ #logging
33
+ save_checkpoint: !!bool False
34
+ wandb_api_key: f7892f37dd96b5f1da5c85a410300bb661f3c4de
35
+ log_to_wandb: !!bool False
36
+
37
+
38
+ default_0.7: &DEFAULT_0.7
39
+
40
+ <<: *DEFAULT
41
+ metadata_labeled_path: labeled_0.7_wav_metadata.json
42
+ num_files_per_raga_path: metadata_0.7.json
43
+ labeled_data_dir: labeled_data_0.7
44
+
45
+ default_0.9: &DEFAULT_0.9
46
+
47
+ <<: *DEFAULT
48
+ metadata_labeled_path: labeled_0.9_wav_metadata.json
49
+ num_files_per_raga_path: metadata_0.9.json
50
+ labeled_data_dir: labeled_data_0.9
51
+ train_frac: 0.85
52
+ num_classes: 200
53
+
54
+
55
+ resnet: &RESNET
56
+
57
+ <<: *DEFAULT
58
+ model: 'resnet'
59
+ n_blocks: 5 #for resnet
60
+ n_channel: 128
61
+
62
+ resnet_0.7: &RESNET_0.7
63
+
64
+ <<: *DEFAULT_0.7
65
+ model: 'resnet'
66
+ n_blocks: 10 #for resnet
67
+ n_channel: 300
68
+ num_classes: 150
69
+
70
+ resnet_0.9: &RESNET_0.9
71
+
72
+ <<: *DEFAULT_0.9
73
+ model: 'resnet'
74
+ n_blocks: 10 #for resnet
75
+ n_channel: 350
76
+ max_pool_every: 1 #downsample every other res block
77
+
78
+
79
+ wav2vec_0.7: &WAV2VEC_0.7
80
+ <<: *DEFAULT_0.7
81
+
82
+ model: 'wav2vec'
83
+ n_input: 1 #mono
84
+
85
+ #transformer parameters (this config leads to around 29M params)
86
+ extractor_mode: "layer_norm"
87
+ extractor_conv_layer_config: None #harcoded for now, fix this at some point
88
+ extractor_conv_bias: !!bool True
89
+ encoder_embed_dim: 512
90
+ encoder_projection_dropout: 0
91
+ encoder_pos_conv_kernel: 3
92
+ encoder_pos_conv_groups: 32
93
+ encoder_num_layers: 12
94
+ encoder_num_heads: 16
95
+ encoder_attention_dropout: 0
96
+ encoder_ff_interm_features: 1024
97
+ encoder_ff_interm_dropout: 0
98
+ encoder_dropout: 0
99
+ encoder_layer_norm_first: !!bool True
100
+ encoder_layer_drop: 0
101
+
102
+
103
+ wav2vec_0.9: &WAV2VEC_0.9
104
+ <<: *DEFAULT_0.9
105
+
106
+ model: 'wav2vec'
107
+ n_input: 1 #mono
108
+
109
+ #transformer parameters (this config leads to around 29M params)
110
+ extractor_mode: "layer_norm"
111
+ extractor_conv_layer_config: None #harcoded for now, fix this at some point
112
+ extractor_conv_bias: !!bool True
113
+ encoder_embed_dim: 512
114
+ encoder_projection_dropout: 0
115
+ encoder_pos_conv_kernel: 3
116
+ encoder_pos_conv_groups: 32
117
+ encoder_num_layers: 12
118
+ encoder_num_heads: 16
119
+ encoder_attention_dropout: 0
120
+ encoder_ff_interm_features: 1024
121
+ encoder_ff_interm_dropout: 0
122
+ encoder_dropout: 0
123
+ encoder_layer_norm_first: !!bool True
124
+ encoder_layer_drop: 0
125
+
126
+
127
+
128
+
129
+
130
+
data/dataloader.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import glob
3
+ from multiprocessing.sharedctypes import Value
4
+ import torch
5
+ import random
6
+ import numpy as np
7
+ from torch.utils.data import Dataset
8
+ import torchaudio
9
+ import json
10
+ import os
11
+ from math import floor
12
+ import sys
13
+ import time
14
+ import copy
15
+ import math
16
+ import logging
17
+
18
+
19
+ np.random.seed(123)
20
+ random.seed(123)
21
+ #load used subset of metadata
22
+ def extract_data(params):
23
+ with open(params.metadata_labeled_path) as f:
24
+ metadata_labeled = json.load(f)
25
+
26
+ if params.use_unlabeled_data:
27
+ with open(params.metadata_unlabeled_path) as f:
28
+ metadata_unlabeled = json.load(f)
29
+ n_unlabeled_samples = len(metadata_unlabeled)
30
+
31
+ #shuffle data and removed unused files
32
+ random.shuffle(metadata_labeled)
33
+ keep = math.ceil(params.use_frac*len(metadata_labeled))
34
+ metadata_labeled = metadata_labeled[0:keep]
35
+
36
+ #construct raga label lookup table.
37
+ raga2label = get_raga2label(params)
38
+
39
+ #remove ragas unused ragas from metadata dictionary
40
+ metadata_labeled_final = remove_unused_ragas(metadata_labeled, raga2label)
41
+
42
+ return metadata_labeled_final, raga2label
43
+
44
+ def get_raga2label(params):
45
+ with open(params.num_files_per_raga_path) as f:
46
+ num_files_per_raga = json.load(f)
47
+ raga2label = {}
48
+ for i, raga in enumerate(num_files_per_raga.keys()):
49
+ raga2label[raga] = i #assign every raga to a unique number from 0 to self.num_classes
50
+ if i == params.num_classes-1:
51
+ break
52
+ return raga2label
53
+
54
+ def remove_unused_ragas(metadata_labeled, raga2label):
55
+ temp = copy.deepcopy(metadata_labeled)
56
+ for i, entry in enumerate(metadata_labeled):
57
+ raga = entry['filename'].split("/")[0]
58
+ if raga not in raga2label.keys(): #this raga is not in the top self.params.num_classes ragas
59
+ temp.remove(entry)
60
+
61
+ return temp
62
+
63
+ class RagaDataset(Dataset):
64
+ def __init__(self, params, metadata_labeled, raga2label):
65
+ self.params = params
66
+ self.metadata_labeled = metadata_labeled
67
+ self.raga2label = raga2label
68
+ self.n_labeled_samples = len(self.metadata_labeled)
69
+ self.transform_dict = {}
70
+ self.count=0
71
+
72
+ if params.local_rank ==0:
73
+ print("Begin training using ", self.__len__(), " audio samples of ", self.params.clip_length, " seconds each.")
74
+ print("Total number of ragas specified: ", self.params.num_classes)
75
+
76
+
77
+ def construct_label(self, raga, label_smoothing=False):
78
+ #construct one hot encoding vector for raga
79
+ raga_index = self.raga2label[raga]
80
+ label = torch.zeros((self.params.num_classes,), dtype = torch.float32)
81
+ label[raga_index] = 1
82
+ return label
83
+
84
+ def normalize(self, audio):
85
+
86
+ return (audio - torch.mean(audio, dim=1, keepdim=True))/(torch.std(audio, dim=1, keepdim=True) + 1e-5)
87
+
88
+ def pad_audio(self, audio):
89
+ pad = (0, self.params.sample_rate*self.params.clip_length - audio.shape[1])
90
+ return torch.nn.functional.pad(audio, pad = pad, value=0)
91
+
92
+ def __len__(self):
93
+ return len(self.metadata_labeled)
94
+
95
+
96
+ def __getitem__(self, idx):
97
+ #get metadata
98
+ file_info = self.metadata_labeled[idx]
99
+
100
+ #sample offset uniformly
101
+ rng = max(0,file_info['duration'] - self.params.clip_length)
102
+ if rng == 0:
103
+ rng = file_info['duration']
104
+
105
+ seconds_offset = np.random.randint(floor(rng))
106
+
107
+ #open audio file
108
+ audio_clip, sample_rate = torchaudio.load(filepath = os.path.join(self.params.labeled_data_dir, file_info['filename']), \
109
+ frame_offset = seconds_offset * file_info['sample_rate'], \
110
+ num_frames=self.params.clip_length*file_info['sample_rate'], normalize=True)
111
+
112
+
113
+ if audio_clip.shape[0] !=2:
114
+ audio_clip = audio_clip.repeat(2, 1)
115
+
116
+ #keep stereo
117
+ #audio_clip = audio_clip.mean(dim=0, keepdim=True)
118
+ #add transform to dictionary
119
+ if sample_rate not in self.transform_dict.keys():
120
+ self.transform_dict[sample_rate] = torchaudio.transforms.Resample(orig_freq = sample_rate, new_freq = self.params.sample_rate)
121
+
122
+ #load cached transform
123
+ resample = self.transform_dict[sample_rate]
124
+ audio_clip = resample(audio_clip)
125
+
126
+ if self.params.normalize:
127
+ audio_clip = self.normalize(audio_clip)
128
+
129
+ if audio_clip.size()[1] < self.params.sample_rate*self.params.clip_length:
130
+ #pad audio with zeros if it's not long enough
131
+ audio_clip = self.pad_audio(audio_clip)
132
+
133
+
134
+ raga = file_info['filename'].split("/")[0]
135
+
136
+ #construct label
137
+ label = self.construct_label(raga)
138
+
139
+ assert not torch.any(torch.isnan(audio_clip))
140
+ assert not torch.any(torch.isnan(label))
141
+ assert audio_clip.shape[1] == self.params.sample_rate*self.params.clip_length
142
+ return audio_clip, label
143
+
inference.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from utils import logging_utils
3
+ logging_utils.config_logger()
4
+ import torch
5
+ import random
6
+ import numpy as np
7
+ from data.dataloader import extract_data
8
+ import torchaudio
9
+ from models.RagaNet import BaseRagaClassifier, ResNetRagaClassifier, Wav2VecTransformer, count_parameters
10
+ from collections import OrderedDict
11
+
12
+ np.random.seed(123)
13
+ random.seed(123)
14
+
15
+
16
+ class Evaluator():
17
+
18
+ def __init__(self, params):
19
+ self.params = params
20
+ self.device = self.params.device
21
+ #get raga to label mapping
22
+ _, self.raga2label = extract_data(self.params)
23
+ self.raga_list = list(self.raga2label.keys())
24
+ self.label_list = list(self.raga2label.values())
25
+
26
+ #initialize model
27
+ if params.model == 'base':
28
+ self.model = BaseRagaClassifier(params).to(self.device)
29
+ elif params.model == 'resnet':
30
+ self.model = ResNetRagaClassifier(params).to(self.device)
31
+ elif params.model == 'wav2vec':
32
+ self.model = Wav2VecTransformer(params).to(self.device)
33
+ else:
34
+ logging.error("Model must be either 'base', 'resnet', or 'wav2vec'")
35
+
36
+ #load best model
37
+ logging.info("Loading checkpoint %s"%params.best_checkpoint_path)
38
+ self.restore_checkpoint('ckpts/resnet_0.7/150classes_alldata_cliplength30/training_checkpoints/best_ckpt.tar')#params.best_checkpoint_path)
39
+ self.model.eval()
40
+
41
+
42
+ def normalize(self, audio):
43
+ return (audio - torch.mean(audio, dim=1, keepdim=True))/(torch.std(audio, dim=1, keepdim=True) + 1e-5)
44
+
45
+ def pad_audio(self, audio):
46
+ pad = (0, self.params.sample_rate*self.params.clip_length - audio.shape[1])
47
+ return torch.nn.functional.pad(audio, pad = pad, value=0)
48
+
49
+ def inference(self, k, audio):
50
+ #open audio file
51
+ sample_rate, audio_clip = audio
52
+
53
+ #repeat mono channel to get stereo if necessary
54
+ if len(audio_clip.shape) == 1:
55
+ audio_clip = torch.tensor(audio_clip).unsqueeze(0).repeat(2,1).to(torch.float32)
56
+ else:
57
+ audio_clip = torch.tensor(audio_clip).T.to(torch.float32)
58
+
59
+ #resample audio clip
60
+ resample = torchaudio.transforms.Resample(orig_freq = sample_rate, new_freq = self.params.sample_rate)
61
+ audio_clip = resample(audio_clip)
62
+
63
+ #normalize the audio clip
64
+ if self.params.normalize:
65
+ audio_clip = self.normalize(audio_clip)
66
+
67
+ #pad audio with zeros if it's not long enough
68
+ if audio_clip.size()[1] < self.params.sample_rate*self.params.clip_length:
69
+ audio_clip = self.pad_audio(audio_clip)
70
+
71
+ assert not torch.any(torch.isnan(audio_clip))
72
+ audio_clip = audio_clip.to(self.device)
73
+
74
+ with torch.no_grad():
75
+ length = audio_clip.shape[1]
76
+ train_length = self.params.sample_rate*self.params.clip_length
77
+
78
+ pred_probs = torch.zeros((self.params.num_classes,)).to(self.device)
79
+
80
+ #loop over clip_length segments and perform inference
81
+ num_clips = int(np.floor(length/train_length))
82
+ for i in range(num_clips):
83
+
84
+ clip = audio_clip[:, i*train_length:(i+1)*train_length].unsqueeze(0)
85
+
86
+ #perform forward pass through model
87
+ pred_distribution = self.model(clip).reshape(-1, self.params.num_classes)
88
+ pred_probs += 1 / num_clips * (torch.exp(pred_distribution)/torch.exp(pred_distribution).sum(axis = 1, keepdim=True))[0]
89
+
90
+
91
+ pred_probs, labels = pred_probs.sort(descending=True)
92
+ pred_probs_topk = pred_probs[:k]
93
+ pred_ragas_topk = [self.raga_list[self.label_list.index(label)] for label in labels[:k]]
94
+ d = dict(zip(pred_ragas_topk, pred_probs_topk))
95
+ return {k: v.item() for k, v in d.items()}
96
+
97
+ def restore_checkpoint(self, checkpoint_path):
98
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
99
+ try:
100
+ self.model.load_state_dict(checkpoint['model_state'])
101
+ except:
102
+ #loading DDP checkpoint into non-DDP model
103
+ new_state_dict = OrderedDict()
104
+ for k, v in checkpoint['model_state'].items():
105
+ name = k[7:] # remove `module.`
106
+ new_state_dict[name] = v
107
+ # load params
108
+ self.model.load_state_dict(new_state_dict)
109
+
110
+ self.iters = checkpoint['iters']
111
+ self.startEpoch = checkpoint['epoch']
112
+
113
+
114
+
115
+
116
+
117
+
labeled_0.7_wav_metadata.json ADDED
The diff for this file is too large to render. See raw diff
 
metadata_0.7.json ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Mohanam": 1085,
3
+ "Sankarabharanam": 1084,
4
+ "Panthuvarali": 1036,
5
+ "Kalyani": 1032,
6
+ "Bhairavi": 1032,
7
+ "Hamsadhwani": 1005,
8
+ "Sindhubhairavi": 985,
9
+ "Kambhoji": 958,
10
+ "Thodi": 936,
11
+ "Khamas": 926,
12
+ "Poorvikalyani": 920,
13
+ "Madhyamavathi": 914,
14
+ "Kapi": 907,
15
+ "Karaharapriya": 881,
16
+ "Hindolam": 868,
17
+ "Anandabhairavi": 847,
18
+ "Saveri": 801,
19
+ "Nata": 788,
20
+ "Shanmukapriya": 779,
21
+ "Reethigowla": 743,
22
+ "Begada": 732,
23
+ "Behag": 714,
24
+ "Abheri": 692,
25
+ "Arabhi": 691,
26
+ "Bilahari": 651,
27
+ "Atana": 648,
28
+ "Kaanada": 613,
29
+ "Sri": 601,
30
+ "Varali": 577,
31
+ "Sahana": 572,
32
+ "Yamunakalyani": 560,
33
+ "Suruti": 544,
34
+ "Mayamalavagowla": 539,
35
+ "Yadukulakambhoji": 536,
36
+ "Mukhari": 493,
37
+ "Sriranjani": 477,
38
+ "Abhogi": 464,
39
+ "Natakurinji": 463,
40
+ "Harikambhoji": 458,
41
+ "Vasantha": 447,
42
+ "Dhanyasi": 442,
43
+ "Neelambari": 434,
44
+ "Keeravani": 421,
45
+ "Kedaragowla": 413,
46
+ "Chenchurutti": 392,
47
+ "Sowrashtram": 387,
48
+ "Hamsanandi": 378,
49
+ "Shuddhadhanyasi": 366,
50
+ "Sama": 351,
51
+ "Brindavanasaranga": 340,
52
+ "Shuddhasaveri": 331,
53
+ "Hamirkalyani": 323,
54
+ "Devagandhari": 322,
55
+ "Gowla": 313,
56
+ "Charukesi": 310,
57
+ "Darbar": 306,
58
+ "Revathi": 302,
59
+ "Saranga": 299,
60
+ "Simhendramadhyamam": 290,
61
+ "Dwijavanthi": 283,
62
+ "Kunthalavarali": 269,
63
+ "Jonpuri": 269,
64
+ "Bahudhari": 269,
65
+ "Huseni": 267,
66
+ "Ranjani": 263,
67
+ "Maand": 259,
68
+ "Nalinakanthi": 257,
69
+ "Asaveri": 257,
70
+ "Kedaram": 253,
71
+ "Hamsanadam": 252,
72
+ "Kurinji": 250,
73
+ "Kadanakuthuhalam": 236,
74
+ "Desh": 234,
75
+ "Bowli": 232,
76
+ "Amrithavarshini": 232,
77
+ "Saramathi": 232,
78
+ "Punnagavarali": 231,
79
+ "Mohanakalyani": 231,
80
+ "Subhapanthuvarali": 224,
81
+ "Nadanamakriya": 223,
82
+ "Lathangi": 222,
83
+ "Gambheeranata": 219,
84
+ "Darbarikaanada": 218,
85
+ "Malayamarutham": 208,
86
+ "Chakravakam": 204,
87
+ "Bageshree": 201,
88
+ "Ahiri": 200,
89
+ "Saraswathi": 200,
90
+ "Kannada": 197,
91
+ "Manirangu": 196,
92
+ "Kalyanavasantham": 195,
93
+ "Jaganmohini": 192,
94
+ "Andolika": 190,
95
+ "Poornachandrika": 184,
96
+ "Devamanohari": 183,
97
+ "Lalitha": 179,
98
+ "Gowrimanohari": 178,
99
+ "Valaji": 171,
100
+ "Vachaspathi": 171,
101
+ "Dharmavathi": 168,
102
+ "Paras": 161,
103
+ "Varamu": 159,
104
+ "Janaranjani": 159,
105
+ "Hemavathi": 154,
106
+ "Ravichandrika": 150,
107
+ "Nayaki": 143,
108
+ "Navarasakannada": 143,
109
+ "Gowlipanthu": 126,
110
+ "Madhuvanthi": 124,
111
+ "Manji": 124,
112
+ "Thilang": 124,
113
+ "Kapinarayani": 122,
114
+ "Kamalamanohari": 121,
115
+ "Ahirbhairavi": 119,
116
+ "Revagupthi": 118,
117
+ "Jayanthashree": 117,
118
+ "Chandrajyothi": 117,
119
+ "Nagaswaravali": 116,
120
+ "Kannadagowla": 116,
121
+ "Bindumalini": 112,
122
+ "Durga": 111,
123
+ "Vagadheeswari": 110,
124
+ "Karnaranjani": 106,
125
+ "Nasikabhushani": 106,
126
+ "Dhenuka": 103,
127
+ "Malavi": 102,
128
+ "Natabhairavi": 101,
129
+ "Ganamoorthi": 99,
130
+ "Vegavahini": 97,
131
+ "Sarasangi": 94,
132
+ "Saraswathimanohari": 93,
133
+ "Ramapriya": 87,
134
+ "Rudrapriya": 87,
135
+ "Chittaranjani": 86,
136
+ "Mandari": 86,
137
+ "Malahari": 81,
138
+ "Shuddhabangala": 80,
139
+ "Sunadhavinodhini": 79,
140
+ "Salagabhairavi": 75,
141
+ "Chinthamani": 74,
142
+ "Devamruthavarshini": 73,
143
+ "Brindavani": 72,
144
+ "Balahamsa": 70,
145
+ "Mishrayaman": 70,
146
+ "Vakulabharanam": 70,
147
+ "Bhavapriya": 70,
148
+ "Vijayashree": 69,
149
+ "Bhoopalam": 69,
150
+ "Dhanashree": 68,
151
+ "Jayamanohari": 67,
152
+ "Peeloo": 66,
153
+ "Shivaranjani": 66,
154
+ "Rasikapriya": 63,
155
+ "Vasanthabhairavi": 63,
156
+ "Rishabhapriya": 62,
157
+ "Garudadhwani": 61,
158
+ "Natakapriya": 61,
159
+ "Chalanata": 60,
160
+ "Ratipatipriya": 60,
161
+ "Gowdamalhar": 58,
162
+ "Navaroj": 58,
163
+ "Karnatakakapi": 58,
164
+ "Kumudakriya": 58,
165
+ "Bangala": 54,
166
+ "Chenchukambhoji": 54,
167
+ "Shuddhasarang": 53,
168
+ "Kalavathi": 53,
169
+ "Bhairavam": 53,
170
+ "Narayanagowla": 52,
171
+ "Shuddhaseemanthini": 52,
172
+ "Hamsavinodhini": 49,
173
+ "Manoranjani": 49,
174
+ "Chayatharangini": 49,
175
+ "Kiranavali": 49,
176
+ "Amrithavahini": 49,
177
+ "Devagandharam": 48,
178
+ "Jayanthasena": 47,
179
+ "Ramamanohari": 47,
180
+ "Mishramaand": 47,
181
+ "Nagagandhari": 47,
182
+ "Kanthamani": 46,
183
+ "Neethimathi": 46,
184
+ "Lalithapanchamam": 45,
185
+ "Vijayanagari": 45,
186
+ "Devakriya": 45,
187
+ "Rageshree": 45,
188
+ "Suryakantam": 45,
189
+ "Chandrakowns": 45,
190
+ "Vasanthi": 45,
191
+ "Pushpalatika": 44,
192
+ "Umabharanam": 44,
193
+ "Kanakangi": 43,
194
+ "Sindhuramakriya": 42,
195
+ "Gangeyabhushani": 41,
196
+ "Poornashadjam": 41,
197
+ "Tharangini": 39,
198
+ "Lavangi": 39,
199
+ "Karnatakabehag": 38,
200
+ "Sucharithra": 38,
201
+ "Neelamani": 38,
202
+ "Hindolavasantham": 37,
203
+ "Margahindolam": 37,
204
+ "Mishrashivaranjani": 36,
205
+ "Patdeep": 36,
206
+ "Jyothiswaroopini": 35,
207
+ "Phalamanjari": 35,
208
+ "Urmika": 35,
209
+ "Yaman": 35,
210
+ "Vanaspathi": 35,
211
+ "Veeravasantham": 35,
212
+ "Gamanashrama": 34,
213
+ "Kokilavarali": 33,
214
+ "Pahadi": 33,
215
+ "Kalanidhi": 33,
216
+ "Sindhukannada": 32,
217
+ "Narayani": 32,
218
+ "Manjari": 31,
219
+ "Hindusthanigandhari": 29,
220
+ "Bhavani": 29,
221
+ "Purvi": 29,
222
+ "Prathapavarali": 29,
223
+ "Vallabhi": 28,
224
+ "Karnatakashuddhasaveri": 27,
225
+ "Megharanjani": 27,
226
+ "Isamanohari": 27,
227
+ "Soorya": 27,
228
+ "Pasupathipriya": 26,
229
+ "Gambheeravani": 26,
230
+ "Simhavahini": 26,
231
+ "Madhavamanohari": 26,
232
+ "Niroshta": 26,
233
+ "Varunapriya": 26,
234
+ "Manavathi": 26,
235
+ "Naganandhini": 25,
236
+ "Ghanta": 25,
237
+ "Shivashakti": 25,
238
+ "Kaikavasi": 24,
239
+ "Kokilapriya": 24,
240
+ "Mishrapahadi": 24,
241
+ "Saindhavi": 24,
242
+ "Jingala": 24,
243
+ "Kosalam": 24,
244
+ "Kalgada": 23,
245
+ "Sumanesaranjani": 23,
246
+ "Bhushavali": 23,
247
+ "Narireethigowla": 22,
248
+ "Rasali": 22,
249
+ "Vivardhini": 22,
250
+ "Jankaradhvani": 22,
251
+ "Sarangatharangini": 21,
252
+ "Paadi": 21,
253
+ "Gowri": 21,
254
+ "Bhujangini": 21,
255
+ "Gurjari": 21,
256
+ "Suposhini": 21,
257
+ "Udhayaravichandrika": 20,
258
+ "Rathnangi": 20,
259
+ "Deepali": 20,
260
+ "Ragavardhani": 20,
261
+ "Supradeepam": 19,
262
+ "Mangalakaishiki": 19,
263
+ "Kannadabangala": 19,
264
+ "Shadhvidhamargini": 19,
265
+ "Kokiladhwani": 18,
266
+ "Mahathi": 18,
267
+ "Janasammodhini": 17,
268
+ "Takka": 17,
269
+ "Vasanthavarali": 17,
270
+ "Thanarupi": 17,
271
+ "Amrithabehag": 17,
272
+ "Rupavathi": 16,
273
+ "Ganavaridhi": 16,
274
+ "Natanarayani": 15,
275
+ "Jog": 15,
276
+ "Manorama": 15,
277
+ "Swarabhushani": 15,
278
+ "Gopikavasantham": 15,
279
+ "Shuddhadesi": 14,
280
+ "Mararanjani": 14,
281
+ "Deshakshi": 14,
282
+ "Gundakriya": 14,
283
+ "Gayakapriya": 14,
284
+ "Chithrambari": 14,
285
+ "Maruvabehag": 14,
286
+ "Ramkali": 13,
287
+ "Shulini": 13,
288
+ "Desiyathodi": 13,
289
+ "Vandanadharini": 13,
290
+ "Malavashree": 13,
291
+ "Deepakam": 12,
292
+ "Kokilaravam": 12,
293
+ "Vijayasaraswathi": 12,
294
+ "Mishraharikambhoji": 12,
295
+ "Maruvadhanyasi": 12,
296
+ "Mohanangi": 12,
297
+ "Hatakambari": 12,
298
+ "Yagapriya": 12,
299
+ "Suvarnangi": 12,
300
+ "Samanta": 12,
301
+ "Basantbahar": 12,
302
+ "Mishrapeeloo": 11,
303
+ "Raghupriya": 11,
304
+ "Pavani": 11,
305
+ "Navanitham": 11,
306
+ "Sindhumandari": 11,
307
+ "Buddhamanohari": 11,
308
+ "Shuddhavasantha": 10,
309
+ "Bhanumathi": 10,
310
+ "Andhali": 10,
311
+ "Jalarnavam": 10,
312
+ "Puriyadhanashree": 10,
313
+ "Gavambodhi": 10,
314
+ "Vamsavathi": 10,
315
+ "Sarangamalhar": 10,
316
+ "Dhavalambari": 10,
317
+ "Gowrivelaavali": 9,
318
+ "Namanarayani": 9,
319
+ "Shyamakalyani": 9,
320
+ "Bhoopali": 9,
321
+ "Vishnupriya": 8,
322
+ "Jyothi": 8,
323
+ "Gopikathilakam": 8,
324
+ "Chayanata": 8,
325
+ "Shruthiranjani": 8,
326
+ "Santhanamanjari": 8,
327
+ "Ardhradesi": 8,
328
+ "Sumukham": 8,
329
+ "Madhukowns": 8,
330
+ "Dhivyamani": 8,
331
+ "Shivakambhoji": 7,
332
+ "Devaranji": 7,
333
+ "Chayagowla": 7,
334
+ "Narthaki": 7,
335
+ "Rasamanjari": 7,
336
+ "Vardhani": 7,
337
+ "Dhurvanki": 7,
338
+ "Phenadyuthi": 7,
339
+ "Gavathi": 7,
340
+ "Poornalalitha": 7,
341
+ "Kunthalam": 7,
342
+ "Poorvakamodari": 7,
343
+ "Mahuri": 7,
344
+ "Chandrika": 7,
345
+ "Rohini": 7,
346
+ "Mishrakhamaj": 7,
347
+ "Salagam": 7,
348
+ "Sarvashree": 7,
349
+ "Bhooshavathi": 7,
350
+ "Shreemani": 6,
351
+ "Sumanapriya": 6,
352
+ "Sushama": 6,
353
+ "Latantapriya": 6,
354
+ "Gurjarithodi": 6,
355
+ "Guharanjani": 6,
356
+ "Namadesi": 6,
357
+ "Kunjari": 6,
358
+ "Maargadesi": 6,
359
+ "Ragapanjaram": 6,
360
+ "Harinarayani": 6,
361
+ "Ravikriya": 6,
362
+ "Gopriya": 5,
363
+ "Chayaranjani": 5,
364
+ "Hamsadeepika": 5,
365
+ "Nadavarangini": 5,
366
+ "Hamsalatha": 5,
367
+ "Shuddhavalaji": 5,
368
+ "Bhogachayanata": 5,
369
+ "Phalaranjani": 5,
370
+ "Bhogavasantha": 5,
371
+ "Senavathi": 5,
372
+ "Thandavam": 5,
373
+ "Swaravali": 5,
374
+ "Geethapriya": 5,
375
+ "Jayanarayani": 5,
376
+ "Vishwambhari": 5,
377
+ "Nagavarali": 5,
378
+ "Karthyayani": 5,
379
+ "Sharavathi": 5,
380
+ "Ganasamavarali": 5,
381
+ "Jaganmohinam": 5,
382
+ "Shyamalangi": 5,
383
+ "Nagavalli": 5,
384
+ "Bhinnashadjam": 5,
385
+ "Suraranjani": 5,
386
+ "Nabhomani": 4,
387
+ "Maruva": 4,
388
+ "Komalangi": 4,
389
+ "Dhavalangam": 4,
390
+ "Shivapriya": 4,
391
+ "Bhavabharanam": 4,
392
+ "Nagabhushani": 4,
393
+ "Swarasammodhini": 4,
394
+ "Gomedhikapriya": 4,
395
+ "Dhatuvardhani": 4,
396
+ "Hamsakalyani": 4,
397
+ "Samakadambari": 4,
398
+ "Sauvira": 4,
399
+ "Thanukeerthi": 4,
400
+ "Ragachoodaamani": 4,
401
+ "Rasikaranjani": 3,
402
+ "Moharanjani": 3,
403
+ "Hamsabhramari": 3,
404
+ "Dundubi": 3,
405
+ "Miyanmalhar": 3,
406
+ "Murali": 3,
407
+ "Shekharachandrikaa": 3,
408
+ "Madhulika": 3,
409
+ "Krisnaveni": 3,
410
+ "Nagadhwani": 3,
411
+ "Dakshayani": 3,
412
+ "Kalasaveri": 3,
413
+ "Sthavarajam": 3,
414
+ "Mishrajog": 3,
415
+ "Srothaswani": 3,
416
+ "Sowrasenaa": 3,
417
+ "Samudrapriya": 3,
418
+ "Bhuvanagaandhaari": 3,
419
+ "Hamsanantini": 3,
420
+ "Bhanuchandrika": 3,
421
+ "Balachandrika": 3,
422
+ "Shuddhamukhari": 3,
423
+ "Viswapriya": 3,
424
+ "Gangatharangini": 3,
425
+ "Sharadapriya": 3,
426
+ "Chathurangini": 3,
427
+ "Venkatadri": 3,
428
+ "Narayanadri": 3,
429
+ "Puriyakalyan": 3,
430
+ "Hindusthanitodi": 3,
431
+ "Dayavathi": 3,
432
+ "Kokila": 3,
433
+ "Madhyamaravali": 3,
434
+ "Vivahapriya": 3,
435
+ "Vijayavasantha": 3,
436
+ "Kuvalayabharanam": 3,
437
+ "Swararanjani": 3,
438
+ "Mishrabilahari": 3,
439
+ "Salanganata": 3,
440
+ "Malavapanchamam": 3,
441
+ "Siddhasena": 3,
442
+ "Jalavarali": 3,
443
+ "Haricharan": 3,
444
+ "Karnatakahindolam": 3,
445
+ "Senagrani": 3,
446
+ "Nadhabrahma": 3,
447
+ "Poorvagowla": 3,
448
+ "Veenadhaari": 3,
449
+ "Geyahejjajji": 3,
450
+ "Tarani": 3,
451
+ "Navarathnavilaasam": 3,
452
+ "Bhanudhanyasi": 3,
453
+ "Kolahalam": 3,
454
+ "Panchamam": 2,
455
+ "Bhoopalapanchamam": 2,
456
+ "Rasavinodhini": 2,
457
+ "Seshadri": 2,
458
+ "Tarakagowla": 2,
459
+ "Poornapanchamam": 2,
460
+ "Kesari": 2,
461
+ "Shailadeshakshi": 2,
462
+ "Jujahuli": 2,
463
+ "Churnikavinodhini": 2,
464
+ "Rukmambari": 2,
465
+ "Kalakanti": 2,
466
+ "Omkaari": 2,
467
+ "Hindoladarbar": 2,
468
+ "Mukthidayini": 2,
469
+ "Hradini": 2,
470
+ "Velaavali": 2,
471
+ "Kowshikadhwani": 2,
472
+ "Hamsavahini": 2,
473
+ "Srikara": 2,
474
+ "Mishramanolayam": 2,
475
+ "Bhanupriya": 2,
476
+ "Manolayam": 2,
477
+ "Mukundamalini": 2,
478
+ "Siddhi": 2,
479
+ "Ramakriya": 2,
480
+ "Suranandini": 2,
481
+ "Chandrapriya": 2,
482
+ "Shuddhasaranga": 2,
483
+ "Bhatiyaar": 2,
484
+ "Malava": 2,
485
+ "Kusumakaram": 2,
486
+ "Kannadamaruva": 2,
487
+ "Sowgandhini": 2,
488
+ "Shuddhathodi": 2,
489
+ "Thivravahini": 1,
490
+ "Savithri": 1,
491
+ "Shreekaanti": 1,
492
+ "Kumarapriya": 1,
493
+ "Sugunabhooshani": 1,
494
+ "Kalyanakesari": 1,
495
+ "Nagabharanam": 1,
496
+ "Jadathari": 1,
497
+ "Guhamanohari": 1,
498
+ "Agnikopa": 1,
499
+ "Pranavapriya": 1,
500
+ "Karpoorabharani": 1,
501
+ "Rojakadambari": 1,
502
+ "Shuddhakambhoji": 1,
503
+ "Hindhusthanibehag": 1,
504
+ "Jnanachinthamani": 1,
505
+ "Chandrahasitham": 1,
506
+ "Triveni": 1,
507
+ "Dravidakalavati": 1,
508
+ "Kumbhini": 1,
509
+ "Vasukari": 1,
510
+ "Vitapi": 1,
511
+ "Saranganata": 1,
512
+ "Visharada": 1,
513
+ "Vinodhini": 1,
514
+ "Sarasaanana": 1,
515
+ "Shuddhalalitha": 1,
516
+ "Rishipriya": 1,
517
+ "Dhipaka": 1,
518
+ "Shuddhasalavi": 1,
519
+ "Sutradhari": 1,
520
+ "Natabharanam": 1,
521
+ "Mishrabhairavi": 1,
522
+ "Priyadarshani": 1,
523
+ "Hamsagamini": 1,
524
+ "Shyamalam": 1,
525
+ "Jeevanthika": 1,
526
+ "Alankari": 1,
527
+ "Jayashuddhamaalavi": 1,
528
+ "Mechabowli": 1,
529
+ "Garudapriya": 1
530
+ }
models/RagaNet.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ from torchaudio.models import wav2vec2_model
4
+
5
+ def count_parameters(model):
6
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
7
+
8
+ #basic conv block
9
+ def conv_block(n_input, n_output, stride=1, kernel_size=80):
10
+ layers = []
11
+ if stride ==1:
12
+ layers.append(nn.Conv1d(n_input, n_output, kernel_size=kernel_size, stride=stride, padding='same')) #Conv
13
+ else:
14
+ layers.append(nn.Conv1d(n_input, n_output, kernel_size=kernel_size, stride=stride)) #Conv
15
+ layers.append(nn.BatchNorm1d(n_output))
16
+ layers.append(nn.ReLU())
17
+ return nn.Sequential(*layers)
18
+
19
+ #basic 2-conv residual block
20
+ class ResidualBlock(nn.Module):
21
+ def __init__(self, n_channels, kernel_size):
22
+ super().__init__()
23
+
24
+ self.conv_block1 = conv_block(n_channels, n_channels, stride = 1, kernel_size=kernel_size)
25
+ self.conv_block2 = conv_block(n_channels, n_channels, stride= 1, kernel_size=3)
26
+
27
+ def forward(self, x):
28
+
29
+ identity = x
30
+ x = self.conv_block1(x)
31
+ x = self.conv_block2(x)
32
+ x = x + identity
33
+ return x
34
+
35
+
36
+ class ResNetRagaClassifier(nn.Module):
37
+ def __init__(self, params):
38
+ super().__init__()
39
+ n_input = params.n_input
40
+ n_channel = params.n_channel
41
+ stride = params.stride
42
+ self.n_blocks = params.n_blocks
43
+ self.conv_first = conv_block(n_input, n_channel, stride=stride, kernel_size = 80)
44
+ self.max_pool_every = params.max_pool_every
45
+
46
+ self.res_blocks = nn.ModuleList() #Residual Blocks
47
+ for i in range(self.n_blocks):
48
+ self.res_blocks.append(ResidualBlock(n_channel, kernel_size=3))
49
+
50
+ #linear classification head
51
+ self.fc1 = nn.Linear(n_channel, params.num_classes)
52
+
53
+
54
+ def forward(self, x):
55
+ #initial conv
56
+ x = self.conv_first(x)
57
+
58
+ #residual blocks
59
+ for i, block in enumerate(self.res_blocks):
60
+ x = block(x)
61
+ if i % self.max_pool_every == 0:
62
+ x = F.max_pool1d(x, 2)
63
+
64
+ #classification head
65
+ x = F.avg_pool1d(x, x.shape[-1])
66
+ x = x.permute(0, 2, 1)
67
+ x = self.fc1(x)
68
+ x = F.log_softmax(x, dim=-1)
69
+
70
+ return x
71
+
72
+
73
+ class BaseRagaClassifier(nn.Module):
74
+ def __init__(self, params):
75
+ super().__init__()
76
+ n_input = params.n_input
77
+ n_channel = params.n_channel
78
+ stride = params.stride
79
+ self.conv_blocks = []
80
+
81
+ self.conv_block1 = conv_block(n_input, n_channel, stride=stride, kernel_size=80)
82
+ self.conv_block2 = conv_block(n_channel, n_channel, stride=1, kernel_size=3)
83
+ self.conv_block3 = conv_block(n_channel, 2*n_channel, stride=1, kernel_size=3)
84
+ self.conv_block4 = conv_block(2*n_channel, 2*n_channel, stride=1, kernel_size=3)
85
+ self.fc1 = nn.Linear(2 * n_channel, params.num_classes)
86
+
87
+ def forward(self, x):
88
+ x = self.conv_block1(x)
89
+ x = F.max_pool1d(x, 4)
90
+ x = self.conv_block2(x)
91
+ x = F.max_pool1d(x, 4)
92
+ x = self.conv_block3(x)
93
+ x = F.max_pool1d(x, 4)
94
+ x = self.conv_block4(x)
95
+ x = F.avg_pool1d(x, x.shape[-1])
96
+ x = x.permute(0, 2, 1)
97
+ x = self.fc1(x)
98
+ x = F.log_softmax(x, dim=-1)
99
+ return x
100
+
101
+
102
+ class Wav2VecTransformer(nn.Module):
103
+ def __init__(self, params):
104
+ super().__init__()
105
+ self.params = params
106
+ self.extractor_mode = params.extractor_mode
107
+ self.extractor_conv_layer_config = params.extractor_conv_layer_config
108
+ self.extractor_conv_bias = params.extractor_conv_bias
109
+ self.encoder_embed_dim = params.encoder_embed_dim
110
+ self.encoder_projection_dropout = params.encoder_projection_dropout
111
+ self.encoder_pos_conv_kernel = params.encoder_pos_conv_kernel
112
+ self.encoder_pos_conv_groups = params.encoder_pos_conv_groups
113
+ self.encoder_num_layers = params.encoder_num_layers
114
+ self.encoder_num_heads = params.encoder_num_heads
115
+ self.encoder_attention_dropout = params.encoder_attention_dropout
116
+ self.encoder_ff_interm_features = params.encoder_ff_interm_features
117
+ self.encoder_ff_interm_dropout = params.encoder_ff_interm_dropout
118
+ self.encoder_dropout = params.encoder_dropout
119
+ self.encoder_layer_norm_first = params.encoder_layer_norm_first
120
+ self.encoder_layer_drop = params.encoder_layer_drop
121
+ self.aux_num_out = params.num_classes
122
+
123
+ self.extractor_conv_layer_config = [
124
+ (32, 80, 16),
125
+ (64, 5, 4),
126
+ (128, 5, 4),
127
+ (256, 5, 4),
128
+ (512, 3, 2),
129
+ (512, 2, 2),
130
+ (512, 2, 2),
131
+ ]
132
+ self.encoder = wav2vec2_model(self.extractor_mode, \
133
+ self.extractor_conv_layer_config, \
134
+ self.extractor_conv_bias, \
135
+ self.encoder_embed_dim, \
136
+ self.encoder_projection_dropout,\
137
+ self.encoder_pos_conv_kernel,\
138
+ self.encoder_pos_conv_groups,\
139
+ self.encoder_num_layers,
140
+ self.encoder_num_heads,
141
+ self.encoder_attention_dropout,
142
+ self.encoder_ff_interm_features,
143
+ self.encoder_ff_interm_dropout,
144
+ self.encoder_dropout,\
145
+ self.encoder_layer_norm_first,\
146
+ self.encoder_layer_drop,
147
+ aux_num_out = None)
148
+
149
+ self.audio_length = params.sample_rate*params.clip_length
150
+ self.classification_head = nn.Linear(int(self.audio_length/(16*4*4*4*2*2*2))*params.encoder_embed_dim, params.num_classes)
151
+
152
+ def forward(self, x):
153
+ x = self.encoder(x)[0]
154
+ x = x.reshape(x.shape[0], -1) # flatten
155
+ x = self.classification_head(x)
156
+ x = F.log_softmax(x, dim=-1)
157
+ return x
158
+
159
+
160
+
161
+
162
+
163
+
164
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ numpy
site/tsne.jpeg ADDED
utils/YParams.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ruamel.yaml import YAML
2
+ import logging
3
+
4
+ class YParams():
5
+ """ Yaml file parser """
6
+ def __init__(self, yaml_filename, config_name, print_params=False):
7
+ self._yaml_filename = yaml_filename
8
+ self._config_name = config_name
9
+ self.params = {}
10
+
11
+ if print_params:
12
+ print("------------------ Configuration ------------------")
13
+
14
+ with open(yaml_filename) as _file:
15
+
16
+ for key, val in YAML().load(_file)[config_name].items():
17
+ if print_params: print(key, val)
18
+ if val =='None': val = None
19
+
20
+ self.params[key] = val
21
+ self.__setattr__(key, val)
22
+
23
+ if print_params:
24
+ print("---------------------------------------------------")
25
+
26
+ def __getitem__(self, key):
27
+ return self.params[key]
28
+
29
+ def __setitem__(self, key, val):
30
+ self.params[key] = val
31
+ self.__setattr__(key, val)
32
+
33
+ def __contains__(self, key):
34
+ return (key in self.params)
35
+
36
+ def update_params(self, config):
37
+ for key, val in config.items():
38
+ self.params[key] = val
39
+ self.__setattr__(key, val)
40
+
41
+ def log(self):
42
+ logging.info("------------------ Configuration ------------------")
43
+ logging.info("Configuration file: "+str(self._yaml_filename))
44
+ logging.info("Configuration name: "+str(self._config_name))
45
+ for key, val in self.params.items():
46
+ logging.info(str(key) + ' ' + str(val))
47
+ logging.info("---------------------------------------------------")
utils/logging_utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ _format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
5
+
6
+ def config_logger(log_level=logging.INFO):
7
+ logging.basicConfig(format=_format, level=log_level)
8
+
9
+ def log_to_file(logger_name=None, log_level=logging.INFO, log_filename='tensorflow.log'):
10
+
11
+ if not os.path.exists(os.path.dirname(log_filename)):
12
+ os.makedirs(os.path.dirname(log_filename))
13
+
14
+ if logger_name is not None:
15
+ log = logging.getLogger(logger_name)
16
+ else:
17
+ log = logging.getLogger()
18
+
19
+ fh = logging.FileHandler(log_filename)
20
+ fh.setLevel(log_level)
21
+ fh.setFormatter(logging.Formatter(_format))
22
+ log.addHandler(fh)
23
+
24
+ def log_versions():
25
+ import torch
26
+ import subprocess
27
+
28
+ logging.info('--------------- Versions ---------------')
29
+ logging.info('git branch: ' + str(subprocess.check_output(['git', 'branch']).strip()))
30
+ logging.info('git hash: ' + str(subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip()))
31
+ logging.info('Torch: ' + str(torch.__version__))
32
+ logging.info('----------------------------------------')