|
import logging |
|
from utils import logging_utils |
|
logging_utils.config_logger() |
|
import torch |
|
import random |
|
import numpy as np |
|
from data.dataloader import extract_data |
|
import torchaudio |
|
from models.RagaNet import BaseRagaClassifier, ResNetRagaClassifier, Wav2VecTransformer, count_parameters |
|
from collections import OrderedDict |
|
|
|
np.random.seed(123) |
|
random.seed(123) |
|
|
|
|
|
class Evaluator(): |
|
|
|
def __init__(self, params): |
|
self.params = params |
|
self.device = self.params.device |
|
|
|
_, self.raga2label = extract_data(self.params) |
|
self.raga_list = list(self.raga2label.keys()) |
|
self.label_list = list(self.raga2label.values()) |
|
|
|
|
|
if params.model == 'base': |
|
self.model = BaseRagaClassifier(params).to(self.device) |
|
elif params.model == 'resnet': |
|
self.model = ResNetRagaClassifier(params).to(self.device) |
|
elif params.model == 'wav2vec': |
|
self.model = Wav2VecTransformer(params).to(self.device) |
|
else: |
|
logging.error("Model must be either 'base', 'resnet', or 'wav2vec'") |
|
|
|
|
|
logging.info("Loading checkpoint %s"%params.best_checkpoint_path) |
|
self.restore_checkpoint('ckpts/resnet_0.7/150classes_alldata_cliplength30/training_checkpoints/best_ckpt.tar') |
|
self.model.eval() |
|
|
|
|
|
def normalize(self, audio): |
|
return (audio - torch.mean(audio, dim=1, keepdim=True))/(torch.std(audio, dim=1, keepdim=True) + 1e-5) |
|
|
|
def pad_audio(self, audio): |
|
pad = (0, self.params.sample_rate*self.params.clip_length - audio.shape[1]) |
|
return torch.nn.functional.pad(audio, pad = pad, value=0) |
|
|
|
def inference(self, k, audio): |
|
|
|
sample_rate, audio_clip = audio |
|
|
|
|
|
if len(audio_clip.shape) == 1: |
|
audio_clip = torch.tensor(audio_clip).unsqueeze(0).repeat(2,1).to(torch.float32) |
|
else: |
|
audio_clip = torch.tensor(audio_clip).T.to(torch.float32) |
|
|
|
|
|
resample = torchaudio.transforms.Resample(orig_freq = sample_rate, new_freq = self.params.sample_rate) |
|
audio_clip = resample(audio_clip) |
|
|
|
|
|
if self.params.normalize: |
|
audio_clip = self.normalize(audio_clip) |
|
|
|
|
|
if audio_clip.size()[1] < self.params.sample_rate*self.params.clip_length: |
|
audio_clip = self.pad_audio(audio_clip) |
|
|
|
assert not torch.any(torch.isnan(audio_clip)) |
|
audio_clip = audio_clip.to(self.device) |
|
|
|
with torch.no_grad(): |
|
length = audio_clip.shape[1] |
|
train_length = self.params.sample_rate*self.params.clip_length |
|
|
|
pred_probs = torch.zeros((self.params.num_classes,)).to(self.device) |
|
|
|
|
|
num_clips = int(np.floor(length/train_length)) |
|
for i in range(num_clips): |
|
|
|
clip = audio_clip[:, i*train_length:(i+1)*train_length].unsqueeze(0) |
|
|
|
|
|
pred_distribution = self.model(clip).reshape(-1, self.params.num_classes) |
|
pred_probs += 1 / num_clips * (torch.exp(pred_distribution)/torch.exp(pred_distribution).sum(axis = 1, keepdim=True))[0] |
|
|
|
|
|
pred_probs, labels = pred_probs.sort(descending=True) |
|
pred_probs_topk = pred_probs[:k] |
|
pred_ragas_topk = [self.raga_list[self.label_list.index(label)] for label in labels[:k]] |
|
d = dict(zip(pred_ragas_topk, pred_probs_topk)) |
|
return {k: v.item() for k, v in d.items()} |
|
|
|
def restore_checkpoint(self, checkpoint_path): |
|
checkpoint = torch.load(checkpoint_path, map_location=self.device) |
|
try: |
|
self.model.load_state_dict(checkpoint['model_state']) |
|
except: |
|
|
|
new_state_dict = OrderedDict() |
|
for k, v in checkpoint['model_state'].items(): |
|
name = k[7:] |
|
new_state_dict[name] = v |
|
|
|
self.model.load_state_dict(new_state_dict) |
|
|
|
self.iters = checkpoint['iters'] |
|
self.startEpoch = checkpoint['epoch'] |
|
|
|
|
|
|
|
|
|
|
|
|