|
import os
|
|
import math
|
|
import random
|
|
import torchaudio
|
|
|
|
from io import IOBase
|
|
from torch.nn.functional import pad
|
|
|
|
def get_torchaudio_info(file, backend = None):
|
|
if not backend:
|
|
backends = (torchaudio.list_audio_backends())
|
|
backend = "soundfile" if "soundfile" in backends else backends[0]
|
|
|
|
info = torchaudio.info(file["audio"], backend=backend)
|
|
if isinstance(file["audio"], IOBase): file["audio"].seek(0)
|
|
|
|
return info
|
|
|
|
class Audio:
|
|
@staticmethod
|
|
def power_normalize(waveform):
|
|
return waveform / (waveform.square().mean(dim=-1, keepdim=True).sqrt() + 1e-8)
|
|
|
|
@staticmethod
|
|
def validate_file(file):
|
|
if isinstance(file, (str, os.PathLike)): file = {"audio": str(file), "uri": os.path.splitext(os.path.basename(file))[0]}
|
|
elif isinstance(file, IOBase): return {"audio": file, "uri": "stream"}
|
|
else: raise ValueError
|
|
|
|
if "waveform" in file:
|
|
waveform = file["waveform"]
|
|
if len(waveform.shape) != 2 or waveform.shape[0] > waveform.shape[1]: raise ValueError
|
|
|
|
sample_rate: int = file.get("sample_rate", None)
|
|
if sample_rate is None: raise ValueError
|
|
|
|
file.setdefault("uri", "waveform")
|
|
|
|
elif "audio" in file:
|
|
if isinstance(file["audio"], IOBase): return file
|
|
|
|
path = os.path.abspath(file["audio"])
|
|
file.setdefault("uri", os.path.splitext(os.path.basename(path))[0])
|
|
|
|
else: raise ValueError
|
|
|
|
return file
|
|
|
|
def __init__(self, sample_rate: int = None, mono=None, backend: str = None):
|
|
super().__init__()
|
|
self.sample_rate = sample_rate
|
|
self.mono = mono
|
|
|
|
if not backend:
|
|
backends = (torchaudio.list_audio_backends())
|
|
backend = "soundfile" if "soundfile" in backends else backends[0]
|
|
|
|
self.backend = backend
|
|
|
|
def downmix_and_resample(self, waveform, sample_rate):
|
|
num_channels = waveform.shape[0]
|
|
|
|
if num_channels > 1:
|
|
if self.mono == "random":
|
|
channel = random.randint(0, num_channels - 1)
|
|
waveform = waveform[channel : channel + 1]
|
|
elif self.mono == "downmix": waveform = waveform.mean(dim=0, keepdim=True)
|
|
|
|
if (self.sample_rate is not None) and (self.sample_rate != sample_rate):
|
|
waveform = torchaudio.functional.resample(waveform, sample_rate, self.sample_rate)
|
|
sample_rate = self.sample_rate
|
|
|
|
return waveform, sample_rate
|
|
|
|
def get_duration(self, file):
|
|
file = self.validate_file(file)
|
|
|
|
if "waveform" in file:
|
|
frames = len(file["waveform"].T)
|
|
sample_rate = file["sample_rate"]
|
|
else:
|
|
info = file["torchaudio.info"] if "torchaudio.info" in file else get_torchaudio_info(file, backend=self.backend)
|
|
frames = info.num_frames
|
|
sample_rate = info.sample_rate
|
|
|
|
return frames / sample_rate
|
|
|
|
def get_num_samples(self, duration, sample_rate = None):
|
|
sample_rate = sample_rate or self.sample_rate
|
|
if sample_rate is None: raise ValueError
|
|
|
|
return math.floor(duration * sample_rate)
|
|
|
|
def __call__(self, file):
|
|
file = self.validate_file(file)
|
|
|
|
if "waveform" in file:
|
|
waveform = file["waveform"]
|
|
sample_rate = file["sample_rate"]
|
|
elif "audio" in file:
|
|
waveform, sample_rate = torchaudio.load(file["audio"], backend=self.backend)
|
|
if isinstance(file["audio"], IOBase): file["audio"].seek(0)
|
|
|
|
channel = file.get("channel", None)
|
|
if channel is not None: waveform = waveform[channel : channel + 1]
|
|
|
|
return self.downmix_and_resample(waveform, sample_rate)
|
|
|
|
def crop(self, file, segment, duration = None, mode="raise"):
|
|
file = self.validate_file(file)
|
|
|
|
if "waveform" in file:
|
|
waveform = file["waveform"]
|
|
frames = waveform.shape[1]
|
|
sample_rate = file["sample_rate"]
|
|
elif "torchaudio.info" in file:
|
|
info = file["torchaudio.info"]
|
|
frames = info.num_frames
|
|
sample_rate = info.sample_rate
|
|
else:
|
|
info = get_torchaudio_info(file, backend=self.backend)
|
|
frames = info.num_frames
|
|
sample_rate = info.sample_rate
|
|
|
|
channel = file.get("channel", None)
|
|
start_frame = math.floor(segment.start * sample_rate)
|
|
|
|
if duration:
|
|
num_frames = math.floor(duration * sample_rate)
|
|
end_frame = start_frame + num_frames
|
|
else:
|
|
end_frame = math.floor(segment.end * sample_rate)
|
|
num_frames = end_frame - start_frame
|
|
|
|
if mode == "raise":
|
|
if num_frames > frames: raise ValueError
|
|
|
|
if end_frame > frames + math.ceil(0.001 * sample_rate): raise ValueError
|
|
else:
|
|
end_frame = min(end_frame, frames)
|
|
start_frame = end_frame - num_frames
|
|
|
|
if start_frame < 0: raise ValueError
|
|
elif mode == "pad":
|
|
pad_start = -min(0, start_frame)
|
|
pad_end = max(end_frame, frames) - frames
|
|
|
|
start_frame = max(0, start_frame)
|
|
end_frame = min(end_frame, frames)
|
|
|
|
num_frames = end_frame - start_frame
|
|
|
|
if "waveform" in file: data = file["waveform"][:, start_frame:end_frame]
|
|
else:
|
|
try:
|
|
data, _ = torchaudio.load(file["audio"], frame_offset=start_frame, num_frames=num_frames, backend=self.backend)
|
|
if isinstance(file["audio"], IOBase): file["audio"].seek(0)
|
|
except RuntimeError:
|
|
if isinstance(file["audio"], IOBase): raise RuntimeError
|
|
|
|
waveform, sample_rate = self.__call__(file)
|
|
data = waveform[:, start_frame:end_frame]
|
|
|
|
file["waveform"] = waveform
|
|
file["sample_rate"] = sample_rate
|
|
|
|
if channel is not None: data = data[channel : channel + 1, :]
|
|
if mode == "pad": data = pad(data, (pad_start, pad_end))
|
|
|
|
return self.downmix_and_resample(data, sample_rate) |