audio / demucs /repitch.py
PreciousMposa's picture
Upload 107 files
519d358 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Utility for on the fly pitch/tempo change for data augmentation."""
import random
import subprocess as sp
import tempfile
import torch
import torchaudio as ta
from .audio import save_audio
class RepitchedWrapper:
"""
Wrap a dataset to apply online change of pitch / tempo.
"""
def __init__(self, dataset, proba=0.2, max_pitch=2, max_tempo=12,
tempo_std=5, vocals=[3], same=True):
self.dataset = dataset
self.proba = proba
self.max_pitch = max_pitch
self.max_tempo = max_tempo
self.tempo_std = tempo_std
self.same = same
self.vocals = vocals
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
streams = self.dataset[index]
in_length = streams.shape[-1]
out_length = int((1 - 0.01 * self.max_tempo) * in_length)
if random.random() < self.proba:
outs = []
for idx, stream in enumerate(streams):
if idx == 0 or not self.same:
delta_pitch = random.randint(-self.max_pitch, self.max_pitch)
delta_tempo = random.gauss(0, self.tempo_std)
delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo)
stream = repitch(
stream,
delta_pitch,
delta_tempo,
voice=idx in self.vocals)
outs.append(stream[:, :out_length])
streams = torch.stack(outs)
else:
streams = streams[..., :out_length]
return streams
def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100):
"""
tempo is a relative delta in percentage, so tempo=10 means tempo at 110%!
pitch is in semi tones.
Requires `soundstretch` to be installed, see
https://www.surina.net/soundtouch/soundstretch.html
"""
infile = tempfile.NamedTemporaryFile(suffix=".wav")
outfile = tempfile.NamedTemporaryFile(suffix=".wav")
save_audio(wav, infile.name, samplerate, clip='clamp')
command = [
"soundstretch",
infile.name,
outfile.name,
f"-pitch={pitch}",
f"-tempo={tempo:.6f}",
]
if quick:
command += ["-quick"]
if voice:
command += ["-speech"]
try:
sp.run(command, capture_output=True, check=True)
except sp.CalledProcessError as error:
raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}")
wav, sr = ta.load(outfile.name)
assert sr == samplerate
return wav