jeevster
commited on
Commit
·
64094d4
1
Parent(s):
08771d5
huggingface space main commit
Browse files- .gitignore +1 -0
- about.md +12 -0
- app.py +57 -0
- ckpts/resnet_0.7/150classes_alldata_cliplength30/training_checkpoints/best_ckpt.tar +3 -0
- config.yaml +130 -0
- data/dataloader.py +143 -0
- inference.py +117 -0
- labeled_0.7_wav_metadata.json +0 -0
- metadata_0.7.json +530 -0
- models/RagaNet.py +164 -0
- requirements.txt +3 -0
- site/tsne.jpeg +0 -0
- utils/YParams.py +47 -0
- utils/logging_utils.py +32 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/
|
about.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### About the Classifier
|
2 |
+
The classifier is a [convolutional neural network](https://en.wikipedia.org/wiki/Convolutional_neural_network) trained on over 10,000 hours of Carnatic audio sourced from this incredible [YouTube collection](https://ramanarunachalam.github.io/Music/Carnatic/carnatic.html).
|
3 |
+
### Key Features:
|
4 |
+
- Can identify **150 ragas**
|
5 |
+
- Does not require any information about the **shruthi (tonic pitch)** of the recording.
|
6 |
+
- **Compatible** with male/female vocal or instrumental recordings.
|
7 |
+
|
8 |
+
### Interpreting the Classifier:
|
9 |
+
We can gain an intuitive sense for what the classifier has learned. Here is a [t-SNE](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding) projection of the hidden activations averaged per ragam. Each point is a ragam, and relative distances between the points indicate the degree to which the classifier thinks the ragas are similar. Each ragam is color coded by the [melakartha chakra](https://en.wikipedia.org/wiki/Melakarta#Chakras) it belongs to. We observe that the classifier has learned to a representation that roughly corresponds to these chakras!
|
10 |
+
|
11 |
+
|
12 |
+
|
app.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from inference import Evaluator
|
3 |
+
import argparse
|
4 |
+
from utils.YParams import YParams
|
5 |
+
import torch
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
def read_markdown_file(path):
|
9 |
+
with open(path, 'r', encoding='utf-8') as file:
|
10 |
+
return file.read()
|
11 |
+
|
12 |
+
|
13 |
+
if __name__ == '__main__':
|
14 |
+
|
15 |
+
#parse args
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument("--yaml_config", default='config.yaml', type=str)
|
18 |
+
parser.add_argument("--config", default='resnet_0.7', type=str)
|
19 |
+
|
20 |
+
args = parser.parse_args()
|
21 |
+
params = YParams(os.path.abspath(args.yaml_config), args.config)
|
22 |
+
|
23 |
+
#GPU stuff
|
24 |
+
try:
|
25 |
+
params.device = torch.device(torch.cuda.current_device())
|
26 |
+
except:
|
27 |
+
params.device = "cpu"
|
28 |
+
|
29 |
+
#checkpoint stuff
|
30 |
+
expDir = "ckpts/resnet_0.7/150classes_alldata_cliplength30"
|
31 |
+
params['checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/ckpt.tar')
|
32 |
+
params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
|
33 |
+
|
34 |
+
evaluator = Evaluator(params)
|
35 |
+
|
36 |
+
with gr.Blocks() as demo:
|
37 |
+
with gr.Tab("Classifier"):
|
38 |
+
gr.Interface(
|
39 |
+
title="Carnatic Raga Classifier",
|
40 |
+
description="**Welcome!** This is a deep-learning based raga classifier. Simply upload or record an audio clip to test it out. \n",
|
41 |
+
article = "**Get in Touch:** Feel free to reach out to [me](https://sanjeevraja.com/) via email (sanjeevr AT berkeley DOT edu) with any questions or feedback! ",
|
42 |
+
fn=evaluator.inference,
|
43 |
+
inputs=[
|
44 |
+
gr.Slider(minimum = 1, maximum = 150, value = 5, label = "Number of displayed ragas", info = "Choose number of top predictions to display"),
|
45 |
+
gr.Audio()
|
46 |
+
],
|
47 |
+
outputs="label",
|
48 |
+
allow_flagging = False
|
49 |
+
)
|
50 |
+
|
51 |
+
with gr.Tab("About"):
|
52 |
+
gr.Markdown(read_markdown_file('about.md'))
|
53 |
+
gr.Image('site/tsne.jpeg', height = 800, width=800)
|
54 |
+
|
55 |
+
demo.launch()
|
56 |
+
|
57 |
+
|
ckpts/resnet_0.7/150classes_alldata_cliplength30/training_checkpoints/best_ckpt.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:09844cbcfc6c98af632671ca14dd4878fe6656811cdb77097a847c8b324362ca
|
3 |
+
size 66309523
|
config.yaml
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default: &DEFAULT
|
2 |
+
|
3 |
+
#data
|
4 |
+
exp_dir: carnatic/ckpts
|
5 |
+
metadata_labeled_path: labeled_small_wav_metadata.json
|
6 |
+
metadata_unlabeled_path: unlabeled_data/unlabeled_mp3_metadata.json
|
7 |
+
num_files_per_raga_path: metadata_small.json
|
8 |
+
num_classes: 150
|
9 |
+
use_frac: 1
|
10 |
+
use_unlabeled_data: !!bool False
|
11 |
+
labeled_data_dir: labeled_data_small
|
12 |
+
clip_length: 30
|
13 |
+
sample_rate: 8000
|
14 |
+
normalize: !!bool True
|
15 |
+
|
16 |
+
#training
|
17 |
+
batch_size: 16
|
18 |
+
num_data_workers: 16
|
19 |
+
n_epochs: 100
|
20 |
+
lr: 0.001
|
21 |
+
class_imbalance_weights: !!bool False
|
22 |
+
patience: 10
|
23 |
+
train_frac: 0.8
|
24 |
+
|
25 |
+
#model
|
26 |
+
model: 'base'
|
27 |
+
n_input: 2 #stereo
|
28 |
+
stride: 16
|
29 |
+
n_channel: 32
|
30 |
+
max_pool_every: 1
|
31 |
+
|
32 |
+
#logging
|
33 |
+
save_checkpoint: !!bool False
|
34 |
+
wandb_api_key: f7892f37dd96b5f1da5c85a410300bb661f3c4de
|
35 |
+
log_to_wandb: !!bool False
|
36 |
+
|
37 |
+
|
38 |
+
default_0.7: &DEFAULT_0.7
|
39 |
+
|
40 |
+
<<: *DEFAULT
|
41 |
+
metadata_labeled_path: labeled_0.7_wav_metadata.json
|
42 |
+
num_files_per_raga_path: metadata_0.7.json
|
43 |
+
labeled_data_dir: labeled_data_0.7
|
44 |
+
|
45 |
+
default_0.9: &DEFAULT_0.9
|
46 |
+
|
47 |
+
<<: *DEFAULT
|
48 |
+
metadata_labeled_path: labeled_0.9_wav_metadata.json
|
49 |
+
num_files_per_raga_path: metadata_0.9.json
|
50 |
+
labeled_data_dir: labeled_data_0.9
|
51 |
+
train_frac: 0.85
|
52 |
+
num_classes: 200
|
53 |
+
|
54 |
+
|
55 |
+
resnet: &RESNET
|
56 |
+
|
57 |
+
<<: *DEFAULT
|
58 |
+
model: 'resnet'
|
59 |
+
n_blocks: 5 #for resnet
|
60 |
+
n_channel: 128
|
61 |
+
|
62 |
+
resnet_0.7: &RESNET_0.7
|
63 |
+
|
64 |
+
<<: *DEFAULT_0.7
|
65 |
+
model: 'resnet'
|
66 |
+
n_blocks: 10 #for resnet
|
67 |
+
n_channel: 300
|
68 |
+
num_classes: 150
|
69 |
+
|
70 |
+
resnet_0.9: &RESNET_0.9
|
71 |
+
|
72 |
+
<<: *DEFAULT_0.9
|
73 |
+
model: 'resnet'
|
74 |
+
n_blocks: 10 #for resnet
|
75 |
+
n_channel: 350
|
76 |
+
max_pool_every: 1 #downsample every other res block
|
77 |
+
|
78 |
+
|
79 |
+
wav2vec_0.7: &WAV2VEC_0.7
|
80 |
+
<<: *DEFAULT_0.7
|
81 |
+
|
82 |
+
model: 'wav2vec'
|
83 |
+
n_input: 1 #mono
|
84 |
+
|
85 |
+
#transformer parameters (this config leads to around 29M params)
|
86 |
+
extractor_mode: "layer_norm"
|
87 |
+
extractor_conv_layer_config: None #harcoded for now, fix this at some point
|
88 |
+
extractor_conv_bias: !!bool True
|
89 |
+
encoder_embed_dim: 512
|
90 |
+
encoder_projection_dropout: 0
|
91 |
+
encoder_pos_conv_kernel: 3
|
92 |
+
encoder_pos_conv_groups: 32
|
93 |
+
encoder_num_layers: 12
|
94 |
+
encoder_num_heads: 16
|
95 |
+
encoder_attention_dropout: 0
|
96 |
+
encoder_ff_interm_features: 1024
|
97 |
+
encoder_ff_interm_dropout: 0
|
98 |
+
encoder_dropout: 0
|
99 |
+
encoder_layer_norm_first: !!bool True
|
100 |
+
encoder_layer_drop: 0
|
101 |
+
|
102 |
+
|
103 |
+
wav2vec_0.9: &WAV2VEC_0.9
|
104 |
+
<<: *DEFAULT_0.9
|
105 |
+
|
106 |
+
model: 'wav2vec'
|
107 |
+
n_input: 1 #mono
|
108 |
+
|
109 |
+
#transformer parameters (this config leads to around 29M params)
|
110 |
+
extractor_mode: "layer_norm"
|
111 |
+
extractor_conv_layer_config: None #harcoded for now, fix this at some point
|
112 |
+
extractor_conv_bias: !!bool True
|
113 |
+
encoder_embed_dim: 512
|
114 |
+
encoder_projection_dropout: 0
|
115 |
+
encoder_pos_conv_kernel: 3
|
116 |
+
encoder_pos_conv_groups: 32
|
117 |
+
encoder_num_layers: 12
|
118 |
+
encoder_num_heads: 16
|
119 |
+
encoder_attention_dropout: 0
|
120 |
+
encoder_ff_interm_features: 1024
|
121 |
+
encoder_ff_interm_dropout: 0
|
122 |
+
encoder_dropout: 0
|
123 |
+
encoder_layer_norm_first: !!bool True
|
124 |
+
encoder_layer_drop: 0
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
data/dataloader.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import glob
|
3 |
+
from multiprocessing.sharedctypes import Value
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
import torchaudio
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
from math import floor
|
12 |
+
import sys
|
13 |
+
import time
|
14 |
+
import copy
|
15 |
+
import math
|
16 |
+
import logging
|
17 |
+
|
18 |
+
|
19 |
+
np.random.seed(123)
|
20 |
+
random.seed(123)
|
21 |
+
#load used subset of metadata
|
22 |
+
def extract_data(params):
|
23 |
+
with open(params.metadata_labeled_path) as f:
|
24 |
+
metadata_labeled = json.load(f)
|
25 |
+
|
26 |
+
if params.use_unlabeled_data:
|
27 |
+
with open(params.metadata_unlabeled_path) as f:
|
28 |
+
metadata_unlabeled = json.load(f)
|
29 |
+
n_unlabeled_samples = len(metadata_unlabeled)
|
30 |
+
|
31 |
+
#shuffle data and removed unused files
|
32 |
+
random.shuffle(metadata_labeled)
|
33 |
+
keep = math.ceil(params.use_frac*len(metadata_labeled))
|
34 |
+
metadata_labeled = metadata_labeled[0:keep]
|
35 |
+
|
36 |
+
#construct raga label lookup table.
|
37 |
+
raga2label = get_raga2label(params)
|
38 |
+
|
39 |
+
#remove ragas unused ragas from metadata dictionary
|
40 |
+
metadata_labeled_final = remove_unused_ragas(metadata_labeled, raga2label)
|
41 |
+
|
42 |
+
return metadata_labeled_final, raga2label
|
43 |
+
|
44 |
+
def get_raga2label(params):
|
45 |
+
with open(params.num_files_per_raga_path) as f:
|
46 |
+
num_files_per_raga = json.load(f)
|
47 |
+
raga2label = {}
|
48 |
+
for i, raga in enumerate(num_files_per_raga.keys()):
|
49 |
+
raga2label[raga] = i #assign every raga to a unique number from 0 to self.num_classes
|
50 |
+
if i == params.num_classes-1:
|
51 |
+
break
|
52 |
+
return raga2label
|
53 |
+
|
54 |
+
def remove_unused_ragas(metadata_labeled, raga2label):
|
55 |
+
temp = copy.deepcopy(metadata_labeled)
|
56 |
+
for i, entry in enumerate(metadata_labeled):
|
57 |
+
raga = entry['filename'].split("/")[0]
|
58 |
+
if raga not in raga2label.keys(): #this raga is not in the top self.params.num_classes ragas
|
59 |
+
temp.remove(entry)
|
60 |
+
|
61 |
+
return temp
|
62 |
+
|
63 |
+
class RagaDataset(Dataset):
|
64 |
+
def __init__(self, params, metadata_labeled, raga2label):
|
65 |
+
self.params = params
|
66 |
+
self.metadata_labeled = metadata_labeled
|
67 |
+
self.raga2label = raga2label
|
68 |
+
self.n_labeled_samples = len(self.metadata_labeled)
|
69 |
+
self.transform_dict = {}
|
70 |
+
self.count=0
|
71 |
+
|
72 |
+
if params.local_rank ==0:
|
73 |
+
print("Begin training using ", self.__len__(), " audio samples of ", self.params.clip_length, " seconds each.")
|
74 |
+
print("Total number of ragas specified: ", self.params.num_classes)
|
75 |
+
|
76 |
+
|
77 |
+
def construct_label(self, raga, label_smoothing=False):
|
78 |
+
#construct one hot encoding vector for raga
|
79 |
+
raga_index = self.raga2label[raga]
|
80 |
+
label = torch.zeros((self.params.num_classes,), dtype = torch.float32)
|
81 |
+
label[raga_index] = 1
|
82 |
+
return label
|
83 |
+
|
84 |
+
def normalize(self, audio):
|
85 |
+
|
86 |
+
return (audio - torch.mean(audio, dim=1, keepdim=True))/(torch.std(audio, dim=1, keepdim=True) + 1e-5)
|
87 |
+
|
88 |
+
def pad_audio(self, audio):
|
89 |
+
pad = (0, self.params.sample_rate*self.params.clip_length - audio.shape[1])
|
90 |
+
return torch.nn.functional.pad(audio, pad = pad, value=0)
|
91 |
+
|
92 |
+
def __len__(self):
|
93 |
+
return len(self.metadata_labeled)
|
94 |
+
|
95 |
+
|
96 |
+
def __getitem__(self, idx):
|
97 |
+
#get metadata
|
98 |
+
file_info = self.metadata_labeled[idx]
|
99 |
+
|
100 |
+
#sample offset uniformly
|
101 |
+
rng = max(0,file_info['duration'] - self.params.clip_length)
|
102 |
+
if rng == 0:
|
103 |
+
rng = file_info['duration']
|
104 |
+
|
105 |
+
seconds_offset = np.random.randint(floor(rng))
|
106 |
+
|
107 |
+
#open audio file
|
108 |
+
audio_clip, sample_rate = torchaudio.load(filepath = os.path.join(self.params.labeled_data_dir, file_info['filename']), \
|
109 |
+
frame_offset = seconds_offset * file_info['sample_rate'], \
|
110 |
+
num_frames=self.params.clip_length*file_info['sample_rate'], normalize=True)
|
111 |
+
|
112 |
+
|
113 |
+
if audio_clip.shape[0] !=2:
|
114 |
+
audio_clip = audio_clip.repeat(2, 1)
|
115 |
+
|
116 |
+
#keep stereo
|
117 |
+
#audio_clip = audio_clip.mean(dim=0, keepdim=True)
|
118 |
+
#add transform to dictionary
|
119 |
+
if sample_rate not in self.transform_dict.keys():
|
120 |
+
self.transform_dict[sample_rate] = torchaudio.transforms.Resample(orig_freq = sample_rate, new_freq = self.params.sample_rate)
|
121 |
+
|
122 |
+
#load cached transform
|
123 |
+
resample = self.transform_dict[sample_rate]
|
124 |
+
audio_clip = resample(audio_clip)
|
125 |
+
|
126 |
+
if self.params.normalize:
|
127 |
+
audio_clip = self.normalize(audio_clip)
|
128 |
+
|
129 |
+
if audio_clip.size()[1] < self.params.sample_rate*self.params.clip_length:
|
130 |
+
#pad audio with zeros if it's not long enough
|
131 |
+
audio_clip = self.pad_audio(audio_clip)
|
132 |
+
|
133 |
+
|
134 |
+
raga = file_info['filename'].split("/")[0]
|
135 |
+
|
136 |
+
#construct label
|
137 |
+
label = self.construct_label(raga)
|
138 |
+
|
139 |
+
assert not torch.any(torch.isnan(audio_clip))
|
140 |
+
assert not torch.any(torch.isnan(label))
|
141 |
+
assert audio_clip.shape[1] == self.params.sample_rate*self.params.clip_length
|
142 |
+
return audio_clip, label
|
143 |
+
|
inference.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from utils import logging_utils
|
3 |
+
logging_utils.config_logger()
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
from data.dataloader import extract_data
|
8 |
+
import torchaudio
|
9 |
+
from models.RagaNet import BaseRagaClassifier, ResNetRagaClassifier, Wav2VecTransformer, count_parameters
|
10 |
+
from collections import OrderedDict
|
11 |
+
|
12 |
+
np.random.seed(123)
|
13 |
+
random.seed(123)
|
14 |
+
|
15 |
+
|
16 |
+
class Evaluator():
|
17 |
+
|
18 |
+
def __init__(self, params):
|
19 |
+
self.params = params
|
20 |
+
self.device = self.params.device
|
21 |
+
#get raga to label mapping
|
22 |
+
_, self.raga2label = extract_data(self.params)
|
23 |
+
self.raga_list = list(self.raga2label.keys())
|
24 |
+
self.label_list = list(self.raga2label.values())
|
25 |
+
|
26 |
+
#initialize model
|
27 |
+
if params.model == 'base':
|
28 |
+
self.model = BaseRagaClassifier(params).to(self.device)
|
29 |
+
elif params.model == 'resnet':
|
30 |
+
self.model = ResNetRagaClassifier(params).to(self.device)
|
31 |
+
elif params.model == 'wav2vec':
|
32 |
+
self.model = Wav2VecTransformer(params).to(self.device)
|
33 |
+
else:
|
34 |
+
logging.error("Model must be either 'base', 'resnet', or 'wav2vec'")
|
35 |
+
|
36 |
+
#load best model
|
37 |
+
logging.info("Loading checkpoint %s"%params.best_checkpoint_path)
|
38 |
+
self.restore_checkpoint('ckpts/resnet_0.7/150classes_alldata_cliplength30/training_checkpoints/best_ckpt.tar')#params.best_checkpoint_path)
|
39 |
+
self.model.eval()
|
40 |
+
|
41 |
+
|
42 |
+
def normalize(self, audio):
|
43 |
+
return (audio - torch.mean(audio, dim=1, keepdim=True))/(torch.std(audio, dim=1, keepdim=True) + 1e-5)
|
44 |
+
|
45 |
+
def pad_audio(self, audio):
|
46 |
+
pad = (0, self.params.sample_rate*self.params.clip_length - audio.shape[1])
|
47 |
+
return torch.nn.functional.pad(audio, pad = pad, value=0)
|
48 |
+
|
49 |
+
def inference(self, k, audio):
|
50 |
+
#open audio file
|
51 |
+
sample_rate, audio_clip = audio
|
52 |
+
|
53 |
+
#repeat mono channel to get stereo if necessary
|
54 |
+
if len(audio_clip.shape) == 1:
|
55 |
+
audio_clip = torch.tensor(audio_clip).unsqueeze(0).repeat(2,1).to(torch.float32)
|
56 |
+
else:
|
57 |
+
audio_clip = torch.tensor(audio_clip).T.to(torch.float32)
|
58 |
+
|
59 |
+
#resample audio clip
|
60 |
+
resample = torchaudio.transforms.Resample(orig_freq = sample_rate, new_freq = self.params.sample_rate)
|
61 |
+
audio_clip = resample(audio_clip)
|
62 |
+
|
63 |
+
#normalize the audio clip
|
64 |
+
if self.params.normalize:
|
65 |
+
audio_clip = self.normalize(audio_clip)
|
66 |
+
|
67 |
+
#pad audio with zeros if it's not long enough
|
68 |
+
if audio_clip.size()[1] < self.params.sample_rate*self.params.clip_length:
|
69 |
+
audio_clip = self.pad_audio(audio_clip)
|
70 |
+
|
71 |
+
assert not torch.any(torch.isnan(audio_clip))
|
72 |
+
audio_clip = audio_clip.to(self.device)
|
73 |
+
|
74 |
+
with torch.no_grad():
|
75 |
+
length = audio_clip.shape[1]
|
76 |
+
train_length = self.params.sample_rate*self.params.clip_length
|
77 |
+
|
78 |
+
pred_probs = torch.zeros((self.params.num_classes,)).to(self.device)
|
79 |
+
|
80 |
+
#loop over clip_length segments and perform inference
|
81 |
+
num_clips = int(np.floor(length/train_length))
|
82 |
+
for i in range(num_clips):
|
83 |
+
|
84 |
+
clip = audio_clip[:, i*train_length:(i+1)*train_length].unsqueeze(0)
|
85 |
+
|
86 |
+
#perform forward pass through model
|
87 |
+
pred_distribution = self.model(clip).reshape(-1, self.params.num_classes)
|
88 |
+
pred_probs += 1 / num_clips * (torch.exp(pred_distribution)/torch.exp(pred_distribution).sum(axis = 1, keepdim=True))[0]
|
89 |
+
|
90 |
+
|
91 |
+
pred_probs, labels = pred_probs.sort(descending=True)
|
92 |
+
pred_probs_topk = pred_probs[:k]
|
93 |
+
pred_ragas_topk = [self.raga_list[self.label_list.index(label)] for label in labels[:k]]
|
94 |
+
d = dict(zip(pred_ragas_topk, pred_probs_topk))
|
95 |
+
return {k: v.item() for k, v in d.items()}
|
96 |
+
|
97 |
+
def restore_checkpoint(self, checkpoint_path):
|
98 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
99 |
+
try:
|
100 |
+
self.model.load_state_dict(checkpoint['model_state'])
|
101 |
+
except:
|
102 |
+
#loading DDP checkpoint into non-DDP model
|
103 |
+
new_state_dict = OrderedDict()
|
104 |
+
for k, v in checkpoint['model_state'].items():
|
105 |
+
name = k[7:] # remove `module.`
|
106 |
+
new_state_dict[name] = v
|
107 |
+
# load params
|
108 |
+
self.model.load_state_dict(new_state_dict)
|
109 |
+
|
110 |
+
self.iters = checkpoint['iters']
|
111 |
+
self.startEpoch = checkpoint['epoch']
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
|
labeled_0.7_wav_metadata.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
metadata_0.7.json
ADDED
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"Mohanam": 1085,
|
3 |
+
"Sankarabharanam": 1084,
|
4 |
+
"Panthuvarali": 1036,
|
5 |
+
"Kalyani": 1032,
|
6 |
+
"Bhairavi": 1032,
|
7 |
+
"Hamsadhwani": 1005,
|
8 |
+
"Sindhubhairavi": 985,
|
9 |
+
"Kambhoji": 958,
|
10 |
+
"Thodi": 936,
|
11 |
+
"Khamas": 926,
|
12 |
+
"Poorvikalyani": 920,
|
13 |
+
"Madhyamavathi": 914,
|
14 |
+
"Kapi": 907,
|
15 |
+
"Karaharapriya": 881,
|
16 |
+
"Hindolam": 868,
|
17 |
+
"Anandabhairavi": 847,
|
18 |
+
"Saveri": 801,
|
19 |
+
"Nata": 788,
|
20 |
+
"Shanmukapriya": 779,
|
21 |
+
"Reethigowla": 743,
|
22 |
+
"Begada": 732,
|
23 |
+
"Behag": 714,
|
24 |
+
"Abheri": 692,
|
25 |
+
"Arabhi": 691,
|
26 |
+
"Bilahari": 651,
|
27 |
+
"Atana": 648,
|
28 |
+
"Kaanada": 613,
|
29 |
+
"Sri": 601,
|
30 |
+
"Varali": 577,
|
31 |
+
"Sahana": 572,
|
32 |
+
"Yamunakalyani": 560,
|
33 |
+
"Suruti": 544,
|
34 |
+
"Mayamalavagowla": 539,
|
35 |
+
"Yadukulakambhoji": 536,
|
36 |
+
"Mukhari": 493,
|
37 |
+
"Sriranjani": 477,
|
38 |
+
"Abhogi": 464,
|
39 |
+
"Natakurinji": 463,
|
40 |
+
"Harikambhoji": 458,
|
41 |
+
"Vasantha": 447,
|
42 |
+
"Dhanyasi": 442,
|
43 |
+
"Neelambari": 434,
|
44 |
+
"Keeravani": 421,
|
45 |
+
"Kedaragowla": 413,
|
46 |
+
"Chenchurutti": 392,
|
47 |
+
"Sowrashtram": 387,
|
48 |
+
"Hamsanandi": 378,
|
49 |
+
"Shuddhadhanyasi": 366,
|
50 |
+
"Sama": 351,
|
51 |
+
"Brindavanasaranga": 340,
|
52 |
+
"Shuddhasaveri": 331,
|
53 |
+
"Hamirkalyani": 323,
|
54 |
+
"Devagandhari": 322,
|
55 |
+
"Gowla": 313,
|
56 |
+
"Charukesi": 310,
|
57 |
+
"Darbar": 306,
|
58 |
+
"Revathi": 302,
|
59 |
+
"Saranga": 299,
|
60 |
+
"Simhendramadhyamam": 290,
|
61 |
+
"Dwijavanthi": 283,
|
62 |
+
"Kunthalavarali": 269,
|
63 |
+
"Jonpuri": 269,
|
64 |
+
"Bahudhari": 269,
|
65 |
+
"Huseni": 267,
|
66 |
+
"Ranjani": 263,
|
67 |
+
"Maand": 259,
|
68 |
+
"Nalinakanthi": 257,
|
69 |
+
"Asaveri": 257,
|
70 |
+
"Kedaram": 253,
|
71 |
+
"Hamsanadam": 252,
|
72 |
+
"Kurinji": 250,
|
73 |
+
"Kadanakuthuhalam": 236,
|
74 |
+
"Desh": 234,
|
75 |
+
"Bowli": 232,
|
76 |
+
"Amrithavarshini": 232,
|
77 |
+
"Saramathi": 232,
|
78 |
+
"Punnagavarali": 231,
|
79 |
+
"Mohanakalyani": 231,
|
80 |
+
"Subhapanthuvarali": 224,
|
81 |
+
"Nadanamakriya": 223,
|
82 |
+
"Lathangi": 222,
|
83 |
+
"Gambheeranata": 219,
|
84 |
+
"Darbarikaanada": 218,
|
85 |
+
"Malayamarutham": 208,
|
86 |
+
"Chakravakam": 204,
|
87 |
+
"Bageshree": 201,
|
88 |
+
"Ahiri": 200,
|
89 |
+
"Saraswathi": 200,
|
90 |
+
"Kannada": 197,
|
91 |
+
"Manirangu": 196,
|
92 |
+
"Kalyanavasantham": 195,
|
93 |
+
"Jaganmohini": 192,
|
94 |
+
"Andolika": 190,
|
95 |
+
"Poornachandrika": 184,
|
96 |
+
"Devamanohari": 183,
|
97 |
+
"Lalitha": 179,
|
98 |
+
"Gowrimanohari": 178,
|
99 |
+
"Valaji": 171,
|
100 |
+
"Vachaspathi": 171,
|
101 |
+
"Dharmavathi": 168,
|
102 |
+
"Paras": 161,
|
103 |
+
"Varamu": 159,
|
104 |
+
"Janaranjani": 159,
|
105 |
+
"Hemavathi": 154,
|
106 |
+
"Ravichandrika": 150,
|
107 |
+
"Nayaki": 143,
|
108 |
+
"Navarasakannada": 143,
|
109 |
+
"Gowlipanthu": 126,
|
110 |
+
"Madhuvanthi": 124,
|
111 |
+
"Manji": 124,
|
112 |
+
"Thilang": 124,
|
113 |
+
"Kapinarayani": 122,
|
114 |
+
"Kamalamanohari": 121,
|
115 |
+
"Ahirbhairavi": 119,
|
116 |
+
"Revagupthi": 118,
|
117 |
+
"Jayanthashree": 117,
|
118 |
+
"Chandrajyothi": 117,
|
119 |
+
"Nagaswaravali": 116,
|
120 |
+
"Kannadagowla": 116,
|
121 |
+
"Bindumalini": 112,
|
122 |
+
"Durga": 111,
|
123 |
+
"Vagadheeswari": 110,
|
124 |
+
"Karnaranjani": 106,
|
125 |
+
"Nasikabhushani": 106,
|
126 |
+
"Dhenuka": 103,
|
127 |
+
"Malavi": 102,
|
128 |
+
"Natabhairavi": 101,
|
129 |
+
"Ganamoorthi": 99,
|
130 |
+
"Vegavahini": 97,
|
131 |
+
"Sarasangi": 94,
|
132 |
+
"Saraswathimanohari": 93,
|
133 |
+
"Ramapriya": 87,
|
134 |
+
"Rudrapriya": 87,
|
135 |
+
"Chittaranjani": 86,
|
136 |
+
"Mandari": 86,
|
137 |
+
"Malahari": 81,
|
138 |
+
"Shuddhabangala": 80,
|
139 |
+
"Sunadhavinodhini": 79,
|
140 |
+
"Salagabhairavi": 75,
|
141 |
+
"Chinthamani": 74,
|
142 |
+
"Devamruthavarshini": 73,
|
143 |
+
"Brindavani": 72,
|
144 |
+
"Balahamsa": 70,
|
145 |
+
"Mishrayaman": 70,
|
146 |
+
"Vakulabharanam": 70,
|
147 |
+
"Bhavapriya": 70,
|
148 |
+
"Vijayashree": 69,
|
149 |
+
"Bhoopalam": 69,
|
150 |
+
"Dhanashree": 68,
|
151 |
+
"Jayamanohari": 67,
|
152 |
+
"Peeloo": 66,
|
153 |
+
"Shivaranjani": 66,
|
154 |
+
"Rasikapriya": 63,
|
155 |
+
"Vasanthabhairavi": 63,
|
156 |
+
"Rishabhapriya": 62,
|
157 |
+
"Garudadhwani": 61,
|
158 |
+
"Natakapriya": 61,
|
159 |
+
"Chalanata": 60,
|
160 |
+
"Ratipatipriya": 60,
|
161 |
+
"Gowdamalhar": 58,
|
162 |
+
"Navaroj": 58,
|
163 |
+
"Karnatakakapi": 58,
|
164 |
+
"Kumudakriya": 58,
|
165 |
+
"Bangala": 54,
|
166 |
+
"Chenchukambhoji": 54,
|
167 |
+
"Shuddhasarang": 53,
|
168 |
+
"Kalavathi": 53,
|
169 |
+
"Bhairavam": 53,
|
170 |
+
"Narayanagowla": 52,
|
171 |
+
"Shuddhaseemanthini": 52,
|
172 |
+
"Hamsavinodhini": 49,
|
173 |
+
"Manoranjani": 49,
|
174 |
+
"Chayatharangini": 49,
|
175 |
+
"Kiranavali": 49,
|
176 |
+
"Amrithavahini": 49,
|
177 |
+
"Devagandharam": 48,
|
178 |
+
"Jayanthasena": 47,
|
179 |
+
"Ramamanohari": 47,
|
180 |
+
"Mishramaand": 47,
|
181 |
+
"Nagagandhari": 47,
|
182 |
+
"Kanthamani": 46,
|
183 |
+
"Neethimathi": 46,
|
184 |
+
"Lalithapanchamam": 45,
|
185 |
+
"Vijayanagari": 45,
|
186 |
+
"Devakriya": 45,
|
187 |
+
"Rageshree": 45,
|
188 |
+
"Suryakantam": 45,
|
189 |
+
"Chandrakowns": 45,
|
190 |
+
"Vasanthi": 45,
|
191 |
+
"Pushpalatika": 44,
|
192 |
+
"Umabharanam": 44,
|
193 |
+
"Kanakangi": 43,
|
194 |
+
"Sindhuramakriya": 42,
|
195 |
+
"Gangeyabhushani": 41,
|
196 |
+
"Poornashadjam": 41,
|
197 |
+
"Tharangini": 39,
|
198 |
+
"Lavangi": 39,
|
199 |
+
"Karnatakabehag": 38,
|
200 |
+
"Sucharithra": 38,
|
201 |
+
"Neelamani": 38,
|
202 |
+
"Hindolavasantham": 37,
|
203 |
+
"Margahindolam": 37,
|
204 |
+
"Mishrashivaranjani": 36,
|
205 |
+
"Patdeep": 36,
|
206 |
+
"Jyothiswaroopini": 35,
|
207 |
+
"Phalamanjari": 35,
|
208 |
+
"Urmika": 35,
|
209 |
+
"Yaman": 35,
|
210 |
+
"Vanaspathi": 35,
|
211 |
+
"Veeravasantham": 35,
|
212 |
+
"Gamanashrama": 34,
|
213 |
+
"Kokilavarali": 33,
|
214 |
+
"Pahadi": 33,
|
215 |
+
"Kalanidhi": 33,
|
216 |
+
"Sindhukannada": 32,
|
217 |
+
"Narayani": 32,
|
218 |
+
"Manjari": 31,
|
219 |
+
"Hindusthanigandhari": 29,
|
220 |
+
"Bhavani": 29,
|
221 |
+
"Purvi": 29,
|
222 |
+
"Prathapavarali": 29,
|
223 |
+
"Vallabhi": 28,
|
224 |
+
"Karnatakashuddhasaveri": 27,
|
225 |
+
"Megharanjani": 27,
|
226 |
+
"Isamanohari": 27,
|
227 |
+
"Soorya": 27,
|
228 |
+
"Pasupathipriya": 26,
|
229 |
+
"Gambheeravani": 26,
|
230 |
+
"Simhavahini": 26,
|
231 |
+
"Madhavamanohari": 26,
|
232 |
+
"Niroshta": 26,
|
233 |
+
"Varunapriya": 26,
|
234 |
+
"Manavathi": 26,
|
235 |
+
"Naganandhini": 25,
|
236 |
+
"Ghanta": 25,
|
237 |
+
"Shivashakti": 25,
|
238 |
+
"Kaikavasi": 24,
|
239 |
+
"Kokilapriya": 24,
|
240 |
+
"Mishrapahadi": 24,
|
241 |
+
"Saindhavi": 24,
|
242 |
+
"Jingala": 24,
|
243 |
+
"Kosalam": 24,
|
244 |
+
"Kalgada": 23,
|
245 |
+
"Sumanesaranjani": 23,
|
246 |
+
"Bhushavali": 23,
|
247 |
+
"Narireethigowla": 22,
|
248 |
+
"Rasali": 22,
|
249 |
+
"Vivardhini": 22,
|
250 |
+
"Jankaradhvani": 22,
|
251 |
+
"Sarangatharangini": 21,
|
252 |
+
"Paadi": 21,
|
253 |
+
"Gowri": 21,
|
254 |
+
"Bhujangini": 21,
|
255 |
+
"Gurjari": 21,
|
256 |
+
"Suposhini": 21,
|
257 |
+
"Udhayaravichandrika": 20,
|
258 |
+
"Rathnangi": 20,
|
259 |
+
"Deepali": 20,
|
260 |
+
"Ragavardhani": 20,
|
261 |
+
"Supradeepam": 19,
|
262 |
+
"Mangalakaishiki": 19,
|
263 |
+
"Kannadabangala": 19,
|
264 |
+
"Shadhvidhamargini": 19,
|
265 |
+
"Kokiladhwani": 18,
|
266 |
+
"Mahathi": 18,
|
267 |
+
"Janasammodhini": 17,
|
268 |
+
"Takka": 17,
|
269 |
+
"Vasanthavarali": 17,
|
270 |
+
"Thanarupi": 17,
|
271 |
+
"Amrithabehag": 17,
|
272 |
+
"Rupavathi": 16,
|
273 |
+
"Ganavaridhi": 16,
|
274 |
+
"Natanarayani": 15,
|
275 |
+
"Jog": 15,
|
276 |
+
"Manorama": 15,
|
277 |
+
"Swarabhushani": 15,
|
278 |
+
"Gopikavasantham": 15,
|
279 |
+
"Shuddhadesi": 14,
|
280 |
+
"Mararanjani": 14,
|
281 |
+
"Deshakshi": 14,
|
282 |
+
"Gundakriya": 14,
|
283 |
+
"Gayakapriya": 14,
|
284 |
+
"Chithrambari": 14,
|
285 |
+
"Maruvabehag": 14,
|
286 |
+
"Ramkali": 13,
|
287 |
+
"Shulini": 13,
|
288 |
+
"Desiyathodi": 13,
|
289 |
+
"Vandanadharini": 13,
|
290 |
+
"Malavashree": 13,
|
291 |
+
"Deepakam": 12,
|
292 |
+
"Kokilaravam": 12,
|
293 |
+
"Vijayasaraswathi": 12,
|
294 |
+
"Mishraharikambhoji": 12,
|
295 |
+
"Maruvadhanyasi": 12,
|
296 |
+
"Mohanangi": 12,
|
297 |
+
"Hatakambari": 12,
|
298 |
+
"Yagapriya": 12,
|
299 |
+
"Suvarnangi": 12,
|
300 |
+
"Samanta": 12,
|
301 |
+
"Basantbahar": 12,
|
302 |
+
"Mishrapeeloo": 11,
|
303 |
+
"Raghupriya": 11,
|
304 |
+
"Pavani": 11,
|
305 |
+
"Navanitham": 11,
|
306 |
+
"Sindhumandari": 11,
|
307 |
+
"Buddhamanohari": 11,
|
308 |
+
"Shuddhavasantha": 10,
|
309 |
+
"Bhanumathi": 10,
|
310 |
+
"Andhali": 10,
|
311 |
+
"Jalarnavam": 10,
|
312 |
+
"Puriyadhanashree": 10,
|
313 |
+
"Gavambodhi": 10,
|
314 |
+
"Vamsavathi": 10,
|
315 |
+
"Sarangamalhar": 10,
|
316 |
+
"Dhavalambari": 10,
|
317 |
+
"Gowrivelaavali": 9,
|
318 |
+
"Namanarayani": 9,
|
319 |
+
"Shyamakalyani": 9,
|
320 |
+
"Bhoopali": 9,
|
321 |
+
"Vishnupriya": 8,
|
322 |
+
"Jyothi": 8,
|
323 |
+
"Gopikathilakam": 8,
|
324 |
+
"Chayanata": 8,
|
325 |
+
"Shruthiranjani": 8,
|
326 |
+
"Santhanamanjari": 8,
|
327 |
+
"Ardhradesi": 8,
|
328 |
+
"Sumukham": 8,
|
329 |
+
"Madhukowns": 8,
|
330 |
+
"Dhivyamani": 8,
|
331 |
+
"Shivakambhoji": 7,
|
332 |
+
"Devaranji": 7,
|
333 |
+
"Chayagowla": 7,
|
334 |
+
"Narthaki": 7,
|
335 |
+
"Rasamanjari": 7,
|
336 |
+
"Vardhani": 7,
|
337 |
+
"Dhurvanki": 7,
|
338 |
+
"Phenadyuthi": 7,
|
339 |
+
"Gavathi": 7,
|
340 |
+
"Poornalalitha": 7,
|
341 |
+
"Kunthalam": 7,
|
342 |
+
"Poorvakamodari": 7,
|
343 |
+
"Mahuri": 7,
|
344 |
+
"Chandrika": 7,
|
345 |
+
"Rohini": 7,
|
346 |
+
"Mishrakhamaj": 7,
|
347 |
+
"Salagam": 7,
|
348 |
+
"Sarvashree": 7,
|
349 |
+
"Bhooshavathi": 7,
|
350 |
+
"Shreemani": 6,
|
351 |
+
"Sumanapriya": 6,
|
352 |
+
"Sushama": 6,
|
353 |
+
"Latantapriya": 6,
|
354 |
+
"Gurjarithodi": 6,
|
355 |
+
"Guharanjani": 6,
|
356 |
+
"Namadesi": 6,
|
357 |
+
"Kunjari": 6,
|
358 |
+
"Maargadesi": 6,
|
359 |
+
"Ragapanjaram": 6,
|
360 |
+
"Harinarayani": 6,
|
361 |
+
"Ravikriya": 6,
|
362 |
+
"Gopriya": 5,
|
363 |
+
"Chayaranjani": 5,
|
364 |
+
"Hamsadeepika": 5,
|
365 |
+
"Nadavarangini": 5,
|
366 |
+
"Hamsalatha": 5,
|
367 |
+
"Shuddhavalaji": 5,
|
368 |
+
"Bhogachayanata": 5,
|
369 |
+
"Phalaranjani": 5,
|
370 |
+
"Bhogavasantha": 5,
|
371 |
+
"Senavathi": 5,
|
372 |
+
"Thandavam": 5,
|
373 |
+
"Swaravali": 5,
|
374 |
+
"Geethapriya": 5,
|
375 |
+
"Jayanarayani": 5,
|
376 |
+
"Vishwambhari": 5,
|
377 |
+
"Nagavarali": 5,
|
378 |
+
"Karthyayani": 5,
|
379 |
+
"Sharavathi": 5,
|
380 |
+
"Ganasamavarali": 5,
|
381 |
+
"Jaganmohinam": 5,
|
382 |
+
"Shyamalangi": 5,
|
383 |
+
"Nagavalli": 5,
|
384 |
+
"Bhinnashadjam": 5,
|
385 |
+
"Suraranjani": 5,
|
386 |
+
"Nabhomani": 4,
|
387 |
+
"Maruva": 4,
|
388 |
+
"Komalangi": 4,
|
389 |
+
"Dhavalangam": 4,
|
390 |
+
"Shivapriya": 4,
|
391 |
+
"Bhavabharanam": 4,
|
392 |
+
"Nagabhushani": 4,
|
393 |
+
"Swarasammodhini": 4,
|
394 |
+
"Gomedhikapriya": 4,
|
395 |
+
"Dhatuvardhani": 4,
|
396 |
+
"Hamsakalyani": 4,
|
397 |
+
"Samakadambari": 4,
|
398 |
+
"Sauvira": 4,
|
399 |
+
"Thanukeerthi": 4,
|
400 |
+
"Ragachoodaamani": 4,
|
401 |
+
"Rasikaranjani": 3,
|
402 |
+
"Moharanjani": 3,
|
403 |
+
"Hamsabhramari": 3,
|
404 |
+
"Dundubi": 3,
|
405 |
+
"Miyanmalhar": 3,
|
406 |
+
"Murali": 3,
|
407 |
+
"Shekharachandrikaa": 3,
|
408 |
+
"Madhulika": 3,
|
409 |
+
"Krisnaveni": 3,
|
410 |
+
"Nagadhwani": 3,
|
411 |
+
"Dakshayani": 3,
|
412 |
+
"Kalasaveri": 3,
|
413 |
+
"Sthavarajam": 3,
|
414 |
+
"Mishrajog": 3,
|
415 |
+
"Srothaswani": 3,
|
416 |
+
"Sowrasenaa": 3,
|
417 |
+
"Samudrapriya": 3,
|
418 |
+
"Bhuvanagaandhaari": 3,
|
419 |
+
"Hamsanantini": 3,
|
420 |
+
"Bhanuchandrika": 3,
|
421 |
+
"Balachandrika": 3,
|
422 |
+
"Shuddhamukhari": 3,
|
423 |
+
"Viswapriya": 3,
|
424 |
+
"Gangatharangini": 3,
|
425 |
+
"Sharadapriya": 3,
|
426 |
+
"Chathurangini": 3,
|
427 |
+
"Venkatadri": 3,
|
428 |
+
"Narayanadri": 3,
|
429 |
+
"Puriyakalyan": 3,
|
430 |
+
"Hindusthanitodi": 3,
|
431 |
+
"Dayavathi": 3,
|
432 |
+
"Kokila": 3,
|
433 |
+
"Madhyamaravali": 3,
|
434 |
+
"Vivahapriya": 3,
|
435 |
+
"Vijayavasantha": 3,
|
436 |
+
"Kuvalayabharanam": 3,
|
437 |
+
"Swararanjani": 3,
|
438 |
+
"Mishrabilahari": 3,
|
439 |
+
"Salanganata": 3,
|
440 |
+
"Malavapanchamam": 3,
|
441 |
+
"Siddhasena": 3,
|
442 |
+
"Jalavarali": 3,
|
443 |
+
"Haricharan": 3,
|
444 |
+
"Karnatakahindolam": 3,
|
445 |
+
"Senagrani": 3,
|
446 |
+
"Nadhabrahma": 3,
|
447 |
+
"Poorvagowla": 3,
|
448 |
+
"Veenadhaari": 3,
|
449 |
+
"Geyahejjajji": 3,
|
450 |
+
"Tarani": 3,
|
451 |
+
"Navarathnavilaasam": 3,
|
452 |
+
"Bhanudhanyasi": 3,
|
453 |
+
"Kolahalam": 3,
|
454 |
+
"Panchamam": 2,
|
455 |
+
"Bhoopalapanchamam": 2,
|
456 |
+
"Rasavinodhini": 2,
|
457 |
+
"Seshadri": 2,
|
458 |
+
"Tarakagowla": 2,
|
459 |
+
"Poornapanchamam": 2,
|
460 |
+
"Kesari": 2,
|
461 |
+
"Shailadeshakshi": 2,
|
462 |
+
"Jujahuli": 2,
|
463 |
+
"Churnikavinodhini": 2,
|
464 |
+
"Rukmambari": 2,
|
465 |
+
"Kalakanti": 2,
|
466 |
+
"Omkaari": 2,
|
467 |
+
"Hindoladarbar": 2,
|
468 |
+
"Mukthidayini": 2,
|
469 |
+
"Hradini": 2,
|
470 |
+
"Velaavali": 2,
|
471 |
+
"Kowshikadhwani": 2,
|
472 |
+
"Hamsavahini": 2,
|
473 |
+
"Srikara": 2,
|
474 |
+
"Mishramanolayam": 2,
|
475 |
+
"Bhanupriya": 2,
|
476 |
+
"Manolayam": 2,
|
477 |
+
"Mukundamalini": 2,
|
478 |
+
"Siddhi": 2,
|
479 |
+
"Ramakriya": 2,
|
480 |
+
"Suranandini": 2,
|
481 |
+
"Chandrapriya": 2,
|
482 |
+
"Shuddhasaranga": 2,
|
483 |
+
"Bhatiyaar": 2,
|
484 |
+
"Malava": 2,
|
485 |
+
"Kusumakaram": 2,
|
486 |
+
"Kannadamaruva": 2,
|
487 |
+
"Sowgandhini": 2,
|
488 |
+
"Shuddhathodi": 2,
|
489 |
+
"Thivravahini": 1,
|
490 |
+
"Savithri": 1,
|
491 |
+
"Shreekaanti": 1,
|
492 |
+
"Kumarapriya": 1,
|
493 |
+
"Sugunabhooshani": 1,
|
494 |
+
"Kalyanakesari": 1,
|
495 |
+
"Nagabharanam": 1,
|
496 |
+
"Jadathari": 1,
|
497 |
+
"Guhamanohari": 1,
|
498 |
+
"Agnikopa": 1,
|
499 |
+
"Pranavapriya": 1,
|
500 |
+
"Karpoorabharani": 1,
|
501 |
+
"Rojakadambari": 1,
|
502 |
+
"Shuddhakambhoji": 1,
|
503 |
+
"Hindhusthanibehag": 1,
|
504 |
+
"Jnanachinthamani": 1,
|
505 |
+
"Chandrahasitham": 1,
|
506 |
+
"Triveni": 1,
|
507 |
+
"Dravidakalavati": 1,
|
508 |
+
"Kumbhini": 1,
|
509 |
+
"Vasukari": 1,
|
510 |
+
"Vitapi": 1,
|
511 |
+
"Saranganata": 1,
|
512 |
+
"Visharada": 1,
|
513 |
+
"Vinodhini": 1,
|
514 |
+
"Sarasaanana": 1,
|
515 |
+
"Shuddhalalitha": 1,
|
516 |
+
"Rishipriya": 1,
|
517 |
+
"Dhipaka": 1,
|
518 |
+
"Shuddhasalavi": 1,
|
519 |
+
"Sutradhari": 1,
|
520 |
+
"Natabharanam": 1,
|
521 |
+
"Mishrabhairavi": 1,
|
522 |
+
"Priyadarshani": 1,
|
523 |
+
"Hamsagamini": 1,
|
524 |
+
"Shyamalam": 1,
|
525 |
+
"Jeevanthika": 1,
|
526 |
+
"Alankari": 1,
|
527 |
+
"Jayashuddhamaalavi": 1,
|
528 |
+
"Mechabowli": 1,
|
529 |
+
"Garudapriya": 1
|
530 |
+
}
|
models/RagaNet.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torchaudio.models import wav2vec2_model
|
4 |
+
|
5 |
+
def count_parameters(model):
|
6 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
7 |
+
|
8 |
+
#basic conv block
|
9 |
+
def conv_block(n_input, n_output, stride=1, kernel_size=80):
|
10 |
+
layers = []
|
11 |
+
if stride ==1:
|
12 |
+
layers.append(nn.Conv1d(n_input, n_output, kernel_size=kernel_size, stride=stride, padding='same')) #Conv
|
13 |
+
else:
|
14 |
+
layers.append(nn.Conv1d(n_input, n_output, kernel_size=kernel_size, stride=stride)) #Conv
|
15 |
+
layers.append(nn.BatchNorm1d(n_output))
|
16 |
+
layers.append(nn.ReLU())
|
17 |
+
return nn.Sequential(*layers)
|
18 |
+
|
19 |
+
#basic 2-conv residual block
|
20 |
+
class ResidualBlock(nn.Module):
|
21 |
+
def __init__(self, n_channels, kernel_size):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
self.conv_block1 = conv_block(n_channels, n_channels, stride = 1, kernel_size=kernel_size)
|
25 |
+
self.conv_block2 = conv_block(n_channels, n_channels, stride= 1, kernel_size=3)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
|
29 |
+
identity = x
|
30 |
+
x = self.conv_block1(x)
|
31 |
+
x = self.conv_block2(x)
|
32 |
+
x = x + identity
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
class ResNetRagaClassifier(nn.Module):
|
37 |
+
def __init__(self, params):
|
38 |
+
super().__init__()
|
39 |
+
n_input = params.n_input
|
40 |
+
n_channel = params.n_channel
|
41 |
+
stride = params.stride
|
42 |
+
self.n_blocks = params.n_blocks
|
43 |
+
self.conv_first = conv_block(n_input, n_channel, stride=stride, kernel_size = 80)
|
44 |
+
self.max_pool_every = params.max_pool_every
|
45 |
+
|
46 |
+
self.res_blocks = nn.ModuleList() #Residual Blocks
|
47 |
+
for i in range(self.n_blocks):
|
48 |
+
self.res_blocks.append(ResidualBlock(n_channel, kernel_size=3))
|
49 |
+
|
50 |
+
#linear classification head
|
51 |
+
self.fc1 = nn.Linear(n_channel, params.num_classes)
|
52 |
+
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
#initial conv
|
56 |
+
x = self.conv_first(x)
|
57 |
+
|
58 |
+
#residual blocks
|
59 |
+
for i, block in enumerate(self.res_blocks):
|
60 |
+
x = block(x)
|
61 |
+
if i % self.max_pool_every == 0:
|
62 |
+
x = F.max_pool1d(x, 2)
|
63 |
+
|
64 |
+
#classification head
|
65 |
+
x = F.avg_pool1d(x, x.shape[-1])
|
66 |
+
x = x.permute(0, 2, 1)
|
67 |
+
x = self.fc1(x)
|
68 |
+
x = F.log_softmax(x, dim=-1)
|
69 |
+
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
class BaseRagaClassifier(nn.Module):
|
74 |
+
def __init__(self, params):
|
75 |
+
super().__init__()
|
76 |
+
n_input = params.n_input
|
77 |
+
n_channel = params.n_channel
|
78 |
+
stride = params.stride
|
79 |
+
self.conv_blocks = []
|
80 |
+
|
81 |
+
self.conv_block1 = conv_block(n_input, n_channel, stride=stride, kernel_size=80)
|
82 |
+
self.conv_block2 = conv_block(n_channel, n_channel, stride=1, kernel_size=3)
|
83 |
+
self.conv_block3 = conv_block(n_channel, 2*n_channel, stride=1, kernel_size=3)
|
84 |
+
self.conv_block4 = conv_block(2*n_channel, 2*n_channel, stride=1, kernel_size=3)
|
85 |
+
self.fc1 = nn.Linear(2 * n_channel, params.num_classes)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
x = self.conv_block1(x)
|
89 |
+
x = F.max_pool1d(x, 4)
|
90 |
+
x = self.conv_block2(x)
|
91 |
+
x = F.max_pool1d(x, 4)
|
92 |
+
x = self.conv_block3(x)
|
93 |
+
x = F.max_pool1d(x, 4)
|
94 |
+
x = self.conv_block4(x)
|
95 |
+
x = F.avg_pool1d(x, x.shape[-1])
|
96 |
+
x = x.permute(0, 2, 1)
|
97 |
+
x = self.fc1(x)
|
98 |
+
x = F.log_softmax(x, dim=-1)
|
99 |
+
return x
|
100 |
+
|
101 |
+
|
102 |
+
class Wav2VecTransformer(nn.Module):
|
103 |
+
def __init__(self, params):
|
104 |
+
super().__init__()
|
105 |
+
self.params = params
|
106 |
+
self.extractor_mode = params.extractor_mode
|
107 |
+
self.extractor_conv_layer_config = params.extractor_conv_layer_config
|
108 |
+
self.extractor_conv_bias = params.extractor_conv_bias
|
109 |
+
self.encoder_embed_dim = params.encoder_embed_dim
|
110 |
+
self.encoder_projection_dropout = params.encoder_projection_dropout
|
111 |
+
self.encoder_pos_conv_kernel = params.encoder_pos_conv_kernel
|
112 |
+
self.encoder_pos_conv_groups = params.encoder_pos_conv_groups
|
113 |
+
self.encoder_num_layers = params.encoder_num_layers
|
114 |
+
self.encoder_num_heads = params.encoder_num_heads
|
115 |
+
self.encoder_attention_dropout = params.encoder_attention_dropout
|
116 |
+
self.encoder_ff_interm_features = params.encoder_ff_interm_features
|
117 |
+
self.encoder_ff_interm_dropout = params.encoder_ff_interm_dropout
|
118 |
+
self.encoder_dropout = params.encoder_dropout
|
119 |
+
self.encoder_layer_norm_first = params.encoder_layer_norm_first
|
120 |
+
self.encoder_layer_drop = params.encoder_layer_drop
|
121 |
+
self.aux_num_out = params.num_classes
|
122 |
+
|
123 |
+
self.extractor_conv_layer_config = [
|
124 |
+
(32, 80, 16),
|
125 |
+
(64, 5, 4),
|
126 |
+
(128, 5, 4),
|
127 |
+
(256, 5, 4),
|
128 |
+
(512, 3, 2),
|
129 |
+
(512, 2, 2),
|
130 |
+
(512, 2, 2),
|
131 |
+
]
|
132 |
+
self.encoder = wav2vec2_model(self.extractor_mode, \
|
133 |
+
self.extractor_conv_layer_config, \
|
134 |
+
self.extractor_conv_bias, \
|
135 |
+
self.encoder_embed_dim, \
|
136 |
+
self.encoder_projection_dropout,\
|
137 |
+
self.encoder_pos_conv_kernel,\
|
138 |
+
self.encoder_pos_conv_groups,\
|
139 |
+
self.encoder_num_layers,
|
140 |
+
self.encoder_num_heads,
|
141 |
+
self.encoder_attention_dropout,
|
142 |
+
self.encoder_ff_interm_features,
|
143 |
+
self.encoder_ff_interm_dropout,
|
144 |
+
self.encoder_dropout,\
|
145 |
+
self.encoder_layer_norm_first,\
|
146 |
+
self.encoder_layer_drop,
|
147 |
+
aux_num_out = None)
|
148 |
+
|
149 |
+
self.audio_length = params.sample_rate*params.clip_length
|
150 |
+
self.classification_head = nn.Linear(int(self.audio_length/(16*4*4*4*2*2*2))*params.encoder_embed_dim, params.num_classes)
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
x = self.encoder(x)[0]
|
154 |
+
x = x.reshape(x.shape[0], -1) # flatten
|
155 |
+
x = self.classification_head(x)
|
156 |
+
x = F.log_softmax(x, dim=-1)
|
157 |
+
return x
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchaudio
|
3 |
+
numpy
|
site/tsne.jpeg
ADDED
![]() |
utils/YParams.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ruamel.yaml import YAML
|
2 |
+
import logging
|
3 |
+
|
4 |
+
class YParams():
|
5 |
+
""" Yaml file parser """
|
6 |
+
def __init__(self, yaml_filename, config_name, print_params=False):
|
7 |
+
self._yaml_filename = yaml_filename
|
8 |
+
self._config_name = config_name
|
9 |
+
self.params = {}
|
10 |
+
|
11 |
+
if print_params:
|
12 |
+
print("------------------ Configuration ------------------")
|
13 |
+
|
14 |
+
with open(yaml_filename) as _file:
|
15 |
+
|
16 |
+
for key, val in YAML().load(_file)[config_name].items():
|
17 |
+
if print_params: print(key, val)
|
18 |
+
if val =='None': val = None
|
19 |
+
|
20 |
+
self.params[key] = val
|
21 |
+
self.__setattr__(key, val)
|
22 |
+
|
23 |
+
if print_params:
|
24 |
+
print("---------------------------------------------------")
|
25 |
+
|
26 |
+
def __getitem__(self, key):
|
27 |
+
return self.params[key]
|
28 |
+
|
29 |
+
def __setitem__(self, key, val):
|
30 |
+
self.params[key] = val
|
31 |
+
self.__setattr__(key, val)
|
32 |
+
|
33 |
+
def __contains__(self, key):
|
34 |
+
return (key in self.params)
|
35 |
+
|
36 |
+
def update_params(self, config):
|
37 |
+
for key, val in config.items():
|
38 |
+
self.params[key] = val
|
39 |
+
self.__setattr__(key, val)
|
40 |
+
|
41 |
+
def log(self):
|
42 |
+
logging.info("------------------ Configuration ------------------")
|
43 |
+
logging.info("Configuration file: "+str(self._yaml_filename))
|
44 |
+
logging.info("Configuration name: "+str(self._config_name))
|
45 |
+
for key, val in self.params.items():
|
46 |
+
logging.info(str(key) + ' ' + str(val))
|
47 |
+
logging.info("---------------------------------------------------")
|
utils/logging_utils.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
|
4 |
+
_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
5 |
+
|
6 |
+
def config_logger(log_level=logging.INFO):
|
7 |
+
logging.basicConfig(format=_format, level=log_level)
|
8 |
+
|
9 |
+
def log_to_file(logger_name=None, log_level=logging.INFO, log_filename='tensorflow.log'):
|
10 |
+
|
11 |
+
if not os.path.exists(os.path.dirname(log_filename)):
|
12 |
+
os.makedirs(os.path.dirname(log_filename))
|
13 |
+
|
14 |
+
if logger_name is not None:
|
15 |
+
log = logging.getLogger(logger_name)
|
16 |
+
else:
|
17 |
+
log = logging.getLogger()
|
18 |
+
|
19 |
+
fh = logging.FileHandler(log_filename)
|
20 |
+
fh.setLevel(log_level)
|
21 |
+
fh.setFormatter(logging.Formatter(_format))
|
22 |
+
log.addHandler(fh)
|
23 |
+
|
24 |
+
def log_versions():
|
25 |
+
import torch
|
26 |
+
import subprocess
|
27 |
+
|
28 |
+
logging.info('--------------- Versions ---------------')
|
29 |
+
logging.info('git branch: ' + str(subprocess.check_output(['git', 'branch']).strip()))
|
30 |
+
logging.info('git hash: ' + str(subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip()))
|
31 |
+
logging.info('Torch: ' + str(torch.__version__))
|
32 |
+
logging.info('----------------------------------------')
|