import csv import torch import numpy as np import logging from torch_mir_eval.separation import bss_eval_sources import fast_bss_eval from ..losses import ( PITLossWrapper, pairwise_neg_sisdr, pairwise_neg_snr, singlesrc_neg_sisdr, PairwiseNegSDR, ) logger = logging.getLogger(__name__) class MetricsTracker: def __init__(self, save_file: str = ""): self.all_sdrs = [] self.all_sdrs_i = [] self.all_sisnrs = [] self.all_sisnrs_i = [] csv_columns = ["snt_id", "sdr", "sdr_i", "si-snr", "si-snr_i"] self.results_csv = open(save_file, "w") self.writer = csv.DictWriter(self.results_csv, fieldnames=csv_columns) self.writer.writeheader() self.pit_sisnr = PITLossWrapper( PairwiseNegSDR("sisdr", zero_mean=False), pit_from="pw_mtx" ) self.pit_snr = PITLossWrapper( PairwiseNegSDR("snr", zero_mean=False), pit_from="pw_mtx" ) def __call__(self, mix, clean, estimate, key): # sisnr sisnr = self.pit_sisnr(estimate.unsqueeze(0), clean.unsqueeze(0)) mix = torch.stack([mix] * clean.shape[0], dim=0) sisnr_baseline = self.pit_sisnr(mix.unsqueeze(0), clean.unsqueeze(0)) sisnr_i = sisnr - sisnr_baseline # sdr sdr = -fast_bss_eval.sdr_pit_loss(estimate, clean).mean() sdr_baseline = -fast_bss_eval.sdr_pit_loss(mix, clean).mean() sdr_i = sdr - sdr_baseline # import pdb; pdb.set_trace() row = { "snt_id": key, "sdr": sdr.item(), "sdr_i": sdr_i.item(), "si-snr": -sisnr.item(), "si-snr_i": -sisnr_i.item(), } self.writer.writerow(row) # Metric Accumulation self.all_sdrs.append(sdr.item()) self.all_sdrs_i.append(sdr_i.item()) self.all_sisnrs.append(-sisnr.item()) self.all_sisnrs_i.append(-sisnr_i.item()) def update(self, ): return {"sdr_i": np.array(self.all_sdrs_i).mean(), "si-snr_i": np.array(self.all_sisnrs_i).mean() } def final(self,): row = { "snt_id": "avg", "sdr": np.array(self.all_sdrs).mean(), "sdr_i": np.array(self.all_sdrs_i).mean(), "si-snr": np.array(self.all_sisnrs).mean(), "si-snr_i": np.array(self.all_sisnrs_i).mean(), } self.writer.writerow(row) row = { "snt_id": "std", "sdr": np.array(self.all_sdrs).std(), "sdr_i": np.array(self.all_sdrs_i).std(), "si-snr": np.array(self.all_sisnrs).std(), "si-snr_i": np.array(self.all_sisnrs_i).std(), } self.writer.writerow(row) self.results_csv.close()