# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Test time evaluation, either using the original SDR from [Vincent et al. 2006] or the newest SDR definition from the MDX 2021 competition (this one will be reported as `nsdr` for `new sdr`). """ from concurrent import futures import logging from dora.log import LogProgress import numpy as np import musdb import museval import torch as th from .apply import apply_model from .audio import convert_audio, save_audio from . import distrib from .utils import DummyPoolExecutor logger = logging.getLogger(__name__) def new_sdr(references, estimates): """ Compute the SDR according to the MDX challenge definition. Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license) """ assert references.dim() == 4 assert estimates.dim() == 4 delta = 1e-7 # avoid numerical errors num = th.sum(th.square(references), dim=(2, 3)) den = th.sum(th.square(references - estimates), dim=(2, 3)) num += delta den += delta scores = 10 * th.log10(num / den) return scores def eval_track(references, estimates, win, hop, compute_sdr=True): references = references.transpose(1, 2).double() estimates = estimates.transpose(1, 2).double() new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0] if not compute_sdr: return None, new_scores else: references = references.numpy() estimates = estimates.numpy() scores = museval.metrics.bss_eval( references, estimates, compute_permutation=False, window=win, hop=hop, framewise_filters=False, bsseval_sources_version=False)[:-1] return scores, new_scores def evaluate(solver, compute_sdr=False): """ Evaluate model using museval. compute_sdr=False means using only the MDX definition of the SDR, which is much faster to evaluate. """ args = solver.args output_dir = solver.folder / "results" output_dir.mkdir(exist_ok=True, parents=True) json_folder = solver.folder / "results/test" json_folder.mkdir(exist_ok=True, parents=True) # we load tracks from the original musdb set if args.test.nonhq is None: test_set = musdb.DB(args.dset.musdb, subsets=["test"], is_wav=True) else: test_set = musdb.DB(args.test.nonhq, subsets=["test"], is_wav=False) src_rate = args.dset.musdb_samplerate eval_device = 'cpu' model = solver.model win = int(1. * model.samplerate) hop = int(1. * model.samplerate) indexes = range(distrib.rank, len(test_set), distrib.world_size) indexes = LogProgress(logger, indexes, updates=args.misc.num_prints, name='Eval') pendings = [] pool = futures.ProcessPoolExecutor if args.test.workers else DummyPoolExecutor with pool(args.test.workers) as pool: for index in indexes: track = test_set.tracks[index] mix = th.from_numpy(track.audio).t().float() if mix.dim() == 1: mix = mix[None] mix = mix.to(solver.device) ref = mix.mean(dim=0) # mono mixture mix = (mix - ref.mean()) / ref.std() mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels) estimates = apply_model(model, mix[None], shifts=args.test.shifts, split=args.test.split, overlap=args.test.overlap)[0] estimates = estimates * ref.std() + ref.mean() estimates = estimates.to(eval_device) references = th.stack( [th.from_numpy(track.targets[name].audio).t() for name in model.sources]) if references.dim() == 2: references = references[:, None] references = references.to(eval_device) references = convert_audio(references, src_rate, model.samplerate, model.audio_channels) if args.test.save: folder = solver.folder / "wav" / track.name folder.mkdir(exist_ok=True, parents=True) for name, estimate in zip(model.sources, estimates): save_audio(estimate.cpu(), folder / (name + ".mp3"), model.samplerate) pendings.append((track.name, pool.submit( eval_track, references, estimates, win=win, hop=hop, compute_sdr=compute_sdr))) pendings = LogProgress(logger, pendings, updates=args.misc.num_prints, name='Eval (BSS)') tracks = {} for track_name, pending in pendings: pending = pending.result() scores, nsdrs = pending tracks[track_name] = {} for idx, target in enumerate(model.sources): tracks[track_name][target] = {'nsdr': [float(nsdrs[idx])]} if scores is not None: (sdr, isr, sir, sar) = scores for idx, target in enumerate(model.sources): values = { "SDR": sdr[idx].tolist(), "SIR": sir[idx].tolist(), "ISR": isr[idx].tolist(), "SAR": sar[idx].tolist() } tracks[track_name][target].update(values) all_tracks = {} for src in range(distrib.world_size): all_tracks.update(distrib.share(tracks, src)) result = {} metric_names = next(iter(all_tracks.values()))[model.sources[0]] for metric_name in metric_names: avg = 0 avg_of_medians = 0 for source in model.sources: medians = [ np.nanmedian(all_tracks[track][source][metric_name]) for track in all_tracks.keys()] mean = np.mean(medians) median = np.median(medians) result[metric_name.lower() + "_" + source] = mean result[metric_name.lower() + "_med" + "_" + source] = median avg += mean / len(model.sources) avg_of_medians += median / len(model.sources) result[metric_name.lower()] = avg result[metric_name.lower() + "_med"] = avg_of_medians return result