Spaces:
Configuration error
Configuration error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Utility for on the fly pitch/tempo change for data augmentation.""" | |
import random | |
import subprocess as sp | |
import tempfile | |
import torch | |
import torchaudio as ta | |
from .audio import save_audio | |
class RepitchedWrapper: | |
""" | |
Wrap a dataset to apply online change of pitch / tempo. | |
""" | |
def __init__(self, dataset, proba=0.2, max_pitch=2, max_tempo=12, | |
tempo_std=5, vocals=[3], same=True): | |
self.dataset = dataset | |
self.proba = proba | |
self.max_pitch = max_pitch | |
self.max_tempo = max_tempo | |
self.tempo_std = tempo_std | |
self.same = same | |
self.vocals = vocals | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, index): | |
streams = self.dataset[index] | |
in_length = streams.shape[-1] | |
out_length = int((1 - 0.01 * self.max_tempo) * in_length) | |
if random.random() < self.proba: | |
outs = [] | |
for idx, stream in enumerate(streams): | |
if idx == 0 or not self.same: | |
delta_pitch = random.randint(-self.max_pitch, self.max_pitch) | |
delta_tempo = random.gauss(0, self.tempo_std) | |
delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo) | |
stream = repitch( | |
stream, | |
delta_pitch, | |
delta_tempo, | |
voice=idx in self.vocals) | |
outs.append(stream[:, :out_length]) | |
streams = torch.stack(outs) | |
else: | |
streams = streams[..., :out_length] | |
return streams | |
def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100): | |
""" | |
tempo is a relative delta in percentage, so tempo=10 means tempo at 110%! | |
pitch is in semi tones. | |
Requires `soundstretch` to be installed, see | |
https://www.surina.net/soundtouch/soundstretch.html | |
""" | |
infile = tempfile.NamedTemporaryFile(suffix=".wav") | |
outfile = tempfile.NamedTemporaryFile(suffix=".wav") | |
save_audio(wav, infile.name, samplerate, clip='clamp') | |
command = [ | |
"soundstretch", | |
infile.name, | |
outfile.name, | |
f"-pitch={pitch}", | |
f"-tempo={tempo:.6f}", | |
] | |
if quick: | |
command += ["-quick"] | |
if voice: | |
command += ["-speech"] | |
try: | |
sp.run(command, capture_output=True, check=True) | |
except sp.CalledProcessError as error: | |
raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}") | |
wav, sr = ta.load(outfile.name) | |
assert sr == samplerate | |
return wav | |