File size: 4,739 Bytes
64094d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import logging
import glob
from multiprocessing.sharedctypes import Value
import torch
import random
import numpy as np
from torch.utils.data import Dataset
import torchaudio
import json
import os
from math import floor
import sys
import time
import copy
import math
import logging
np.random.seed(123)
random.seed(123)
#load used subset of metadata
def extract_data(params):
with open(params.metadata_labeled_path) as f:
metadata_labeled = json.load(f)
if params.use_unlabeled_data:
with open(params.metadata_unlabeled_path) as f:
metadata_unlabeled = json.load(f)
n_unlabeled_samples = len(metadata_unlabeled)
#shuffle data and removed unused files
random.shuffle(metadata_labeled)
keep = math.ceil(params.use_frac*len(metadata_labeled))
metadata_labeled = metadata_labeled[0:keep]
#construct raga label lookup table.
raga2label = get_raga2label(params)
#remove ragas unused ragas from metadata dictionary
metadata_labeled_final = remove_unused_ragas(metadata_labeled, raga2label)
return metadata_labeled_final, raga2label
def get_raga2label(params):
with open(params.num_files_per_raga_path) as f:
num_files_per_raga = json.load(f)
raga2label = {}
for i, raga in enumerate(num_files_per_raga.keys()):
raga2label[raga] = i #assign every raga to a unique number from 0 to self.num_classes
if i == params.num_classes-1:
break
return raga2label
def remove_unused_ragas(metadata_labeled, raga2label):
temp = copy.deepcopy(metadata_labeled)
for i, entry in enumerate(metadata_labeled):
raga = entry['filename'].split("/")[0]
if raga not in raga2label.keys(): #this raga is not in the top self.params.num_classes ragas
temp.remove(entry)
return temp
class RagaDataset(Dataset):
def __init__(self, params, metadata_labeled, raga2label):
self.params = params
self.metadata_labeled = metadata_labeled
self.raga2label = raga2label
self.n_labeled_samples = len(self.metadata_labeled)
self.transform_dict = {}
self.count=0
if params.local_rank ==0:
print("Begin training using ", self.__len__(), " audio samples of ", self.params.clip_length, " seconds each.")
print("Total number of ragas specified: ", self.params.num_classes)
def construct_label(self, raga, label_smoothing=False):
#construct one hot encoding vector for raga
raga_index = self.raga2label[raga]
label = torch.zeros((self.params.num_classes,), dtype = torch.float32)
label[raga_index] = 1
return label
def normalize(self, audio):
return (audio - torch.mean(audio, dim=1, keepdim=True))/(torch.std(audio, dim=1, keepdim=True) + 1e-5)
def pad_audio(self, audio):
pad = (0, self.params.sample_rate*self.params.clip_length - audio.shape[1])
return torch.nn.functional.pad(audio, pad = pad, value=0)
def __len__(self):
return len(self.metadata_labeled)
def __getitem__(self, idx):
#get metadata
file_info = self.metadata_labeled[idx]
#sample offset uniformly
rng = max(0,file_info['duration'] - self.params.clip_length)
if rng == 0:
rng = file_info['duration']
seconds_offset = np.random.randint(floor(rng))
#open audio file
audio_clip, sample_rate = torchaudio.load(filepath = os.path.join(self.params.labeled_data_dir, file_info['filename']), \
frame_offset = seconds_offset * file_info['sample_rate'], \
num_frames=self.params.clip_length*file_info['sample_rate'], normalize=True)
if audio_clip.shape[0] !=2:
audio_clip = audio_clip.repeat(2, 1)
#keep stereo
#audio_clip = audio_clip.mean(dim=0, keepdim=True)
#add transform to dictionary
if sample_rate not in self.transform_dict.keys():
self.transform_dict[sample_rate] = torchaudio.transforms.Resample(orig_freq = sample_rate, new_freq = self.params.sample_rate)
#load cached transform
resample = self.transform_dict[sample_rate]
audio_clip = resample(audio_clip)
if self.params.normalize:
audio_clip = self.normalize(audio_clip)
if audio_clip.size()[1] < self.params.sample_rate*self.params.clip_length:
#pad audio with zeros if it's not long enough
audio_clip = self.pad_audio(audio_clip)
raga = file_info['filename'].split("/")[0]
#construct label
label = self.construct_label(raga)
assert not torch.any(torch.isnan(audio_clip))
assert not torch.any(torch.isnan(label))
assert audio_clip.shape[1] == self.params.sample_rate*self.params.clip_length
return audio_clip, label
|