File size: 4,506 Bytes
64094d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import logging
from utils import logging_utils
logging_utils.config_logger()
import torch
import random
import numpy as np
from data.dataloader import extract_data
import torchaudio
from models.RagaNet import BaseRagaClassifier, ResNetRagaClassifier, Wav2VecTransformer, count_parameters
from collections import OrderedDict

np.random.seed(123)
random.seed(123)


class Evaluator():

    def __init__(self, params):
        self.params = params
        self.device = self.params.device
        #get raga to label mapping
        _, self.raga2label = extract_data(self.params)
        self.raga_list = list(self.raga2label.keys())
        self.label_list = list(self.raga2label.values())

        #initialize model
        if params.model == 'base':
            self.model = BaseRagaClassifier(params).to(self.device)
        elif params.model == 'resnet':
            self.model = ResNetRagaClassifier(params).to(self.device)
        elif params.model == 'wav2vec':
            self.model = Wav2VecTransformer(params).to(self.device)
        else:
            logging.error("Model must be either 'base', 'resnet', or 'wav2vec'")
        
        #load best model
        logging.info("Loading checkpoint %s"%params.best_checkpoint_path)
        self.restore_checkpoint('ckpts/resnet_0.7/150classes_alldata_cliplength30/training_checkpoints/best_ckpt.tar')#params.best_checkpoint_path)
        self.model.eval()


    def normalize(self, audio):
        return (audio - torch.mean(audio, dim=1, keepdim=True))/(torch.std(audio, dim=1, keepdim=True) + 1e-5)

    def pad_audio(self, audio):
        pad = (0, self.params.sample_rate*self.params.clip_length - audio.shape[1])
        return torch.nn.functional.pad(audio, pad = pad, value=0)

    def inference(self, k, audio):
        #open audio file
        sample_rate, audio_clip = audio

        #repeat mono channel to get stereo if necessary
        if len(audio_clip.shape) == 1:
            audio_clip = torch.tensor(audio_clip).unsqueeze(0).repeat(2,1).to(torch.float32)
        else:
            audio_clip = torch.tensor(audio_clip).T.to(torch.float32)
        
        #resample audio clip
        resample = torchaudio.transforms.Resample(orig_freq = sample_rate, new_freq = self.params.sample_rate)
        audio_clip = resample(audio_clip)

        #normalize the audio clip
        if self.params.normalize:
            audio_clip = self.normalize(audio_clip)

        #pad audio with zeros if it's not long enough
        if audio_clip.size()[1] < self.params.sample_rate*self.params.clip_length:
            audio_clip = self.pad_audio(audio_clip)
        
        assert not torch.any(torch.isnan(audio_clip))
        audio_clip = audio_clip.to(self.device)

        with torch.no_grad():
            length = audio_clip.shape[1]
            train_length = self.params.sample_rate*self.params.clip_length

            pred_probs = torch.zeros((self.params.num_classes,)).to(self.device)

            #loop over clip_length segments and perform inference
            num_clips = int(np.floor(length/train_length))
            for i in range(num_clips):
                
                clip = audio_clip[:, i*train_length:(i+1)*train_length].unsqueeze(0)
        
                #perform forward pass through model
                pred_distribution = self.model(clip).reshape(-1, self.params.num_classes)
                pred_probs +=  1 / num_clips * (torch.exp(pred_distribution)/torch.exp(pred_distribution).sum(axis = 1, keepdim=True))[0]

                
        pred_probs, labels = pred_probs.sort(descending=True)
        pred_probs_topk = pred_probs[:k]
        pred_ragas_topk = [self.raga_list[self.label_list.index(label)] for label in labels[:k]]
        d = dict(zip(pred_ragas_topk, pred_probs_topk)) 
        return {k: v.item() for k, v in d.items()}

    def restore_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        try:
            self.model.load_state_dict(checkpoint['model_state'])
        except:
            #loading DDP checkpoint into non-DDP model
            new_state_dict = OrderedDict()
            for k, v in checkpoint['model_state'].items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
            # load params
            self.model.load_state_dict(new_state_dict)
            
        self.iters = checkpoint['iters']
        self.startEpoch = checkpoint['epoch']