### # Author: Kai Li # Date: 2021-06-22 12:41:36 # LastEditors: Please set LastEditors # LastEditTime: 2022-06-05 14:48:00 ### import csv from sympy import im import torch import numpy as np import logging import os import librosa from torch_mir_eval.separation import bss_eval_sources import fast_bss_eval from visqol import visqol_lib_py from visqol.pb2 import visqol_config_pb2 from visqol.pb2 import similarity_result_pb2 logger = logging.getLogger(__name__) def is_silent(wav, threshold=1e-4): return torch.sum(wav ** 2) / wav.numel() < threshold class MetricsTracker: def __init__(self, save_file: str = ""): self.all_sdrs = [] self.all_sisnrs = [] self.all_visqols = [] csv_columns = ["snt_id", "sdr", "si-snr", "visqol"] self.visqol_config = visqol_config_pb2.VisqolConfig() self.visqol_config.audio.sample_rate = 48000 self.visqol_config.options.use_speech_scoring = False svr_model_path = "libsvm_nu_svr_model.txt" self.visqol_config.options.svr_model_path = os.path.join(os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path) self.visqol_api = visqol_lib_py.VisqolApi() self.visqol_api.Create(self.visqol_config) self.results_csv = open(save_file, "w") self.writer = csv.DictWriter(self.results_csv, fieldnames=csv_columns) self.writer.writeheader() def __call__(self, clean, estimate, key): sisnr = fast_bss_eval.si_sdr(clean.unsqueeze(0), estimate.unsqueeze(0), zero_mean=True).mean() sdr = fast_bss_eval.sdr(clean.unsqueeze(0), estimate.unsqueeze(0), zero_mean=True).mean() clean = librosa.resample(clean.squeeze(0).mean(0).cpu().numpy(), orig_sr=44100, target_sr=48000).astype(np.float64) estimate = librosa.resample(estimate.squeeze(0).mean(0).cpu().numpy(), orig_sr=44100, target_sr=48000).astype(np.float64) visqol = self.visqol_api.Measure(clean, estimate).moslqo # import pdb; pdb.set_trace() row = { "snt_id": key, "sdr": sdr.item(), "si-snr": sisnr.item(), "visqol": visqol } self.writer.writerow(row) # Metric Accumulation self.all_sdrs.append(sdr.item()) self.all_sisnrs.append(sisnr.item()) self.all_visqols.append(visqol) def update(self, ): return {"sdr": np.array(self.all_sdrs).mean(), "si-snr": np.array(self.all_sisnrs).mean(), "visqol": np.array(self.all_visqols).mean()} def final(self,): row = { "snt_id": "avg", "sdr": np.array(self.all_sdrs).mean(), "si-snr": np.array(self.all_sisnrs).mean(), "visqol": np.array(self.all_visqols).mean() } self.writer.writerow(row) row = { "snt_id": "std", "sdr": np.array(self.all_sdrs).std(), "si-snr": np.array(self.all_sisnrs).std(), "visqol": np.array(self.all_visqols).std() } self.writer.writerow(row) self.results_csv.close()