File size: 4,739 Bytes
64094d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import logging
import glob
from multiprocessing.sharedctypes import Value
import torch
import random
import numpy as np
from torch.utils.data import Dataset
import torchaudio
import json
import os
from math import floor
import sys
import time
import copy
import math
import logging


np.random.seed(123)
random.seed(123)
#load used subset of metadata
def extract_data(params):
    with open(params.metadata_labeled_path) as f:
        metadata_labeled = json.load(f)
        
    if params.use_unlabeled_data:
        with open(params.metadata_unlabeled_path) as f:
            metadata_unlabeled = json.load(f)
        n_unlabeled_samples = len(metadata_unlabeled)
    
    #shuffle data and removed unused files
    random.shuffle(metadata_labeled)
    keep = math.ceil(params.use_frac*len(metadata_labeled))
    metadata_labeled = metadata_labeled[0:keep]

    #construct raga label lookup table.
    raga2label = get_raga2label(params)

    #remove ragas unused ragas from metadata dictionary
    metadata_labeled_final = remove_unused_ragas(metadata_labeled, raga2label)

    return metadata_labeled_final, raga2label

def get_raga2label(params):
    with open(params.num_files_per_raga_path) as f:
        num_files_per_raga = json.load(f)
    raga2label = {}
    for i, raga in enumerate(num_files_per_raga.keys()):
        raga2label[raga] = i #assign every raga to a unique number from 0 to self.num_classes
        if i == params.num_classes-1:
          break
    return raga2label

def remove_unused_ragas(metadata_labeled, raga2label):
    temp = copy.deepcopy(metadata_labeled)
    for i, entry in enumerate(metadata_labeled):
      raga = entry['filename'].split("/")[0]
      if raga not in raga2label.keys(): #this raga is not in the top self.params.num_classes ragas
        temp.remove(entry)
    
    return temp

class RagaDataset(Dataset):
  def __init__(self, params, metadata_labeled, raga2label):
    self.params = params
    self.metadata_labeled = metadata_labeled
    self.raga2label = raga2label
    self.n_labeled_samples = len(self.metadata_labeled)
    self.transform_dict = {}
    self.count=0

    if params.local_rank ==0:
      print("Begin training using ", self.__len__(), " audio samples of ", self.params.clip_length, " seconds each.")
      print("Total number of ragas specified: ", self.params.num_classes)
    
        
  def construct_label(self, raga, label_smoothing=False):
    #construct one hot encoding vector for raga
    raga_index = self.raga2label[raga]
    label = torch.zeros((self.params.num_classes,), dtype = torch.float32)
    label[raga_index] = 1
    return label 

  def normalize(self, audio):

    return (audio - torch.mean(audio, dim=1, keepdim=True))/(torch.std(audio, dim=1, keepdim=True) + 1e-5)

  def pad_audio(self, audio):
    pad = (0, self.params.sample_rate*self.params.clip_length - audio.shape[1])
    return torch.nn.functional.pad(audio, pad = pad, value=0)

  def __len__(self):
    return len(self.metadata_labeled)
    

  def __getitem__(self, idx):
    #get metadata
    file_info = self.metadata_labeled[idx]

    #sample offset uniformly
    rng = max(0,file_info['duration'] - self.params.clip_length)
    if rng == 0:
      rng = file_info['duration']

    seconds_offset = np.random.randint(floor(rng))

    #open audio file
    audio_clip, sample_rate = torchaudio.load(filepath = os.path.join(self.params.labeled_data_dir, file_info['filename']), \
                                 frame_offset = seconds_offset * file_info['sample_rate'], \
                                num_frames=self.params.clip_length*file_info['sample_rate'], normalize=True)
    
    
    if audio_clip.shape[0] !=2:
      audio_clip = audio_clip.repeat(2, 1)
    
    #keep stereo
    #audio_clip = audio_clip.mean(dim=0, keepdim=True)
    #add transform to dictionary          
    if sample_rate not in self.transform_dict.keys():
      self.transform_dict[sample_rate] = torchaudio.transforms.Resample(orig_freq = sample_rate, new_freq = self.params.sample_rate)
    
    #load cached transform
    resample = self.transform_dict[sample_rate]
    audio_clip = resample(audio_clip)

    if self.params.normalize:
      audio_clip = self.normalize(audio_clip)

    if audio_clip.size()[1] < self.params.sample_rate*self.params.clip_length:
      #pad audio with zeros if it's not long enough
      audio_clip = self.pad_audio(audio_clip)

    
    raga = file_info['filename'].split("/")[0]
    
    #construct label
    label = self.construct_label(raga)

    assert not torch.any(torch.isnan(audio_clip))
    assert not torch.any(torch.isnan(label))
    assert audio_clip.shape[1] == self.params.sample_rate*self.params.clip_length
    return audio_clip, label