# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ This script creates realistic mixes with stems from different songs. In particular, it will align BPM, sync up the first beat and perform pitch shift to maximize pitches overlap. In order to limit artifacts, only parts that can be mixed with less than 15% tempo shift, and 3 semitones of pitch shift are mixed together. """ from collections import namedtuple from concurrent.futures import ProcessPoolExecutor import hashlib from pathlib import Path import random import shutil import tqdm import pickle from librosa.beat import beat_track from librosa.feature import chroma_cqt import numpy as np import torch from torch.nn import functional as F from dora.utils import try_load from demucs.audio import save_audio from demucs.repitch import repitch from demucs.pretrained import SOURCES from demucs.wav import build_metadata, Wavset, _get_musdb_valid MUSDB_PATH = '/checkpoint/defossez/datasets/musdbhq' EXTRA_WAV_PATH = "/checkpoint/defossez/datasets/allstems_44" # WARNING: OUTPATH will be completely erased. OUTPATH = Path.home() / 'tmp/demucs_mdx/automix_musdb/' CACHE = Path.home() / 'tmp/automix_cache' # cache BPM and pitch information. CHANNELS = 2 SR = 44100 MAX_PITCH = 3 # maximum allowable pitch shift in semi tones MAX_TEMPO = 0.15 # maximum allowable tempo shift Spec = namedtuple("Spec", "tempo onsets kr track index") def rms(wav, window=10000): """efficient rms computed for each time step over a given window.""" half = window // 2 window = 2 * half + 1 wav = F.pad(wav, (half, half)) tot = wav.pow(2).cumsum(dim=-1) return ((tot[..., window - 1:] - tot[..., :-window + 1]) / window).sqrt() def analyse_track(dset, index): """analyse track, extract bpm and distribution of notes from the bass line.""" track = dset[index] mix = track.sum(0).mean(0) ref = mix.std() starts = (abs(mix) >= 1e-2 * ref).float().argmax().item() track = track[..., starts:] cache = CACHE / dset.sig cache.mkdir(exist_ok=True, parents=True) cache_file = cache / f"{index}.pkl" cached = None if cache_file.exists(): cached = try_load(cache_file) if cached is not None: tempo, events, hist_kr = cached if cached is None: drums = track[0].mean(0) if drums.std() > 1e-2 * ref: tempo, events = beat_track(y=drums.numpy(), units='time', sr=SR) else: print("failed drums", drums.std(), ref) return None, track bass = track[1].mean(0) r = rms(bass) peak = r.max() mask = r >= 0.05 * peak bass = bass[mask] if bass.std() > 1e-2 * ref: kr = torch.from_numpy(chroma_cqt(y=bass.numpy(), sr=SR)) hist_kr = (kr.max(dim=0, keepdim=True)[0] == kr).float().mean(1) else: print("failed bass", bass.std(), ref) return None, track pickle.dump([tempo, events, hist_kr], open(cache_file, 'wb')) spec = Spec(tempo, events, hist_kr, track, index) return spec, None def best_pitch_shift(kr_a, kr_b): """find the best pitch shift between two chroma distributions.""" deltas = [] for p in range(12): deltas.append((kr_a - kr_b).abs().mean()) kr_b = kr_b.roll(1, 0) ps = np.argmin(deltas) if ps > 6: ps = ps - 12 return ps def align_stems(stems): """Align the first beats of the stems. This is a naive implementation. A grid with a time definition 10ms is defined and each beat onset is represented as a gaussian over this grid. Then, we try each possible time shift to make two grids align the best. We repeat for all sources. """ sources = len(stems) width = 5e-3 # grid of 10ms limit = 5 std = 2 x = torch.arange(-limit, limit + 1, 1).float() gauss = torch.exp(-x**2 / (2 * std**2)) grids = [] for wav, onsets in stems: le = wav.shape[-1] dur = le / SR grid = torch.zeros(int(le / width / SR)) for onset in onsets: pos = int(onset / width) if onset >= dur - 1: continue if onset < 1: continue grid[pos - limit:pos + limit + 1] += gauss grids.append(grid) shifts = [0] for s in range(1, sources): max_shift = int(4 / width) dots = [] for shift in range(-max_shift, max_shift): other = grids[s] ref = grids[0] if shift >= 0: other = other[shift:] else: ref = ref[shift:] le = min(len(other), len(ref)) dots.append((ref[:le].dot(other[:le]), int(shift * width * SR))) _, shift = max(dots) shifts.append(-shift) outs = [] new_zero = min(shifts) for (wav, _), shift in zip(stems, shifts): offset = shift - new_zero wav = F.pad(wav, (offset, 0)) outs.append(wav) le = min(x.shape[-1] for x in outs) outs = [w[..., :le] for w in outs] return torch.stack(outs) def find_candidate(spec_ref, catalog, pitch_match=True): """Given reference track, this finds a track in the catalog that is a potential match (pitch and tempo delta must be within the allowable limits). """ candidates = list(catalog) random.shuffle(candidates) for spec in candidates: ok = False for scale in [1/4, 1/2, 1, 2, 4]: tempo = spec.tempo * scale delta_tempo = spec_ref.tempo / tempo - 1 if abs(delta_tempo) < MAX_TEMPO: ok = True break if not ok: print(delta_tempo, spec_ref.tempo, spec.tempo, "FAILED TEMPO") # too much of a tempo difference continue spec = spec._replace(tempo=tempo) ps = 0 if pitch_match: ps = best_pitch_shift(spec_ref.kr, spec.kr) if abs(ps) > MAX_PITCH: print("Failed pitch", ps) # too much pitch difference continue return spec, delta_tempo, ps def get_part(spec, source, dt, dp): """Apply given delta of tempo and delta of pitch to a stem.""" wav = spec.track[source] if dt or dp: wav = repitch(wav, dp, dt * 100, samplerate=SR, voice=source == 3) spec = spec._replace(onsets=spec.onsets / (1 + dt)) return wav, spec def build_track(ref_index, catalog): """Given the reference track index and a catalog of track, builds a completely new track. One of the source at random from the ref track will be kept and other sources will be drawn from the catalog. """ order = list(range(len(SOURCES))) random.shuffle(order) stems = [None] * len(order) indexes = [None] * len(order) origs = [None] * len(order) dps = [None] * len(order) dts = [None] * len(order) first = order[0] spec_ref = catalog[ref_index] stems[first] = (spec_ref.track[first], spec_ref.onsets) indexes[first] = ref_index origs[first] = spec_ref.track[first] dps[first] = 0 dts[first] = 0 pitch_match = order != 0 for src in order[1:]: spec, dt, dp = find_candidate(spec_ref, catalog, pitch_match=pitch_match) if not pitch_match: spec_ref = spec_ref._replace(kr=spec.kr) pitch_match = True dps[src] = dp dts[src] = dt wav, spec = get_part(spec, src, dt, dp) stems[src] = (wav, spec.onsets) indexes[src] = spec.index origs.append(spec.track[src]) print("FINAL CHOICES", ref_index, indexes, dps, dts) stems = align_stems(stems) return stems, origs def get_musdb_dataset(part='train'): root = Path(MUSDB_PATH) / part ext = '.wav' metadata = build_metadata(root, SOURCES, ext=ext, normalize=False) valid_tracks = _get_musdb_valid() metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks} train_set = Wavset( root, metadata_train, SOURCES, samplerate=SR, channels=CHANNELS, normalize=False, ext=ext) sig = hashlib.sha1(str(root).encode()).hexdigest()[:8] train_set.sig = sig return train_set def get_wav_dataset(): root = Path(EXTRA_WAV_PATH) ext = '.wav' metadata = _build_metadata(root, SOURCES, ext=ext, normalize=False) train_set = Wavset( root, metadata, SOURCES, samplerate=SR, channels=CHANNELS, normalize=False, ext=ext) sig = hashlib.sha1(str(root).encode()).hexdigest()[:8] train_set.sig = sig return train_set def main(): random.seed(4321) if OUTPATH.exists(): shutil.rmtree(OUTPATH) OUTPATH.mkdir(exist_ok=True, parents=True) (OUTPATH / 'train').mkdir(exist_ok=True, parents=True) (OUTPATH / 'valid').mkdir(exist_ok=True, parents=True) out = OUTPATH / 'train' dset = get_musdb_dataset() # dset2 = get_wav_dataset() # dset3 = get_musdb_dataset('test') dset2 = None dset3 = None pendings = [] copies = 6 copies_rej = 2 with ProcessPoolExecutor(20) as pool: for index in range(len(dset)): pendings.append(pool.submit(analyse_track, dset, index)) if dset2: for index in range(len(dset2)): pendings.append(pool.submit(analyse_track, dset2, index)) if dset3: for index in range(len(dset3)): pendings.append(pool.submit(analyse_track, dset3, index)) catalog = [] rej = 0 for pending in tqdm.tqdm(pendings, ncols=120): spec, track = pending.result() if spec is not None: catalog.append(spec) else: mix = track.sum(0) for copy in range(copies_rej): folder = out / f'rej_{rej}_{copy}' folder.mkdir() save_audio(mix, folder / "mixture.wav", SR) for stem, source in zip(track, SOURCES): save_audio(stem, folder / f"{source}.wav", SR, clip='clamp') rej += 1 for copy in range(copies): for index in range(len(catalog)): track, origs = build_track(index, catalog) mix = track.sum(0) mx = mix.abs().max() scale = max(1, 1.01 * mx) mix = mix / scale track = track / scale folder = out / f'{copy}_{index}' folder.mkdir() save_audio(mix, folder / "mixture.wav", SR) for stem, source, orig in zip(track, SOURCES, origs): save_audio(stem, folder / f"{source}.wav", SR, clip='clamp') # save_audio(stem.std() * orig / (1e-6 + orig.std()), folder / f"{source}_orig.wav", # SR, clip='clamp') if __name__ == '__main__': main()