|
import os |
|
import math |
|
import random |
|
import torchaudio |
|
|
|
from io import IOBase |
|
from torch.nn.functional import pad |
|
|
|
def get_torchaudio_info(file, backend = None): |
|
if not backend: |
|
backends = (torchaudio.list_audio_backends()) |
|
backend = "soundfile" if "soundfile" in backends else backends[0] |
|
|
|
info = torchaudio.info(file["audio"], backend=backend) |
|
if isinstance(file["audio"], IOBase): file["audio"].seek(0) |
|
|
|
return info |
|
|
|
class Audio: |
|
@staticmethod |
|
def power_normalize(waveform): |
|
return waveform / (waveform.square().mean(dim=-1, keepdim=True).sqrt() + 1e-8) |
|
|
|
@staticmethod |
|
def validate_file(file): |
|
if isinstance(file, (str, os.PathLike)): file = {"audio": str(file), "uri": os.path.splitext(os.path.basename(file))[0]} |
|
elif isinstance(file, IOBase): return {"audio": file, "uri": "stream"} |
|
else: raise ValueError |
|
|
|
if "waveform" in file: |
|
waveform = file["waveform"] |
|
if len(waveform.shape) != 2 or waveform.shape[0] > waveform.shape[1]: raise ValueError |
|
|
|
sample_rate: int = file.get("sample_rate", None) |
|
if sample_rate is None: raise ValueError |
|
|
|
file.setdefault("uri", "waveform") |
|
|
|
elif "audio" in file: |
|
if isinstance(file["audio"], IOBase): return file |
|
|
|
path = os.path.abspath(file["audio"]) |
|
file.setdefault("uri", os.path.splitext(os.path.basename(path))[0]) |
|
|
|
else: raise ValueError |
|
|
|
return file |
|
|
|
def __init__(self, sample_rate: int = None, mono=None, backend: str = None): |
|
super().__init__() |
|
self.sample_rate = sample_rate |
|
self.mono = mono |
|
|
|
if not backend: |
|
backends = (torchaudio.list_audio_backends()) |
|
backend = "soundfile" if "soundfile" in backends else backends[0] |
|
|
|
self.backend = backend |
|
|
|
def downmix_and_resample(self, waveform, sample_rate): |
|
num_channels = waveform.shape[0] |
|
|
|
if num_channels > 1: |
|
if self.mono == "random": |
|
channel = random.randint(0, num_channels - 1) |
|
waveform = waveform[channel : channel + 1] |
|
elif self.mono == "downmix": waveform = waveform.mean(dim=0, keepdim=True) |
|
|
|
if (self.sample_rate is not None) and (self.sample_rate != sample_rate): |
|
waveform = torchaudio.functional.resample(waveform, sample_rate, self.sample_rate) |
|
sample_rate = self.sample_rate |
|
|
|
return waveform, sample_rate |
|
|
|
def get_duration(self, file): |
|
file = self.validate_file(file) |
|
|
|
if "waveform" in file: |
|
frames = len(file["waveform"].T) |
|
sample_rate = file["sample_rate"] |
|
else: |
|
info = file["torchaudio.info"] if "torchaudio.info" in file else get_torchaudio_info(file, backend=self.backend) |
|
frames = info.num_frames |
|
sample_rate = info.sample_rate |
|
|
|
return frames / sample_rate |
|
|
|
def get_num_samples(self, duration, sample_rate = None): |
|
sample_rate = sample_rate or self.sample_rate |
|
if sample_rate is None: raise ValueError |
|
|
|
return math.floor(duration * sample_rate) |
|
|
|
def __call__(self, file): |
|
file = self.validate_file(file) |
|
|
|
if "waveform" in file: |
|
waveform = file["waveform"] |
|
sample_rate = file["sample_rate"] |
|
elif "audio" in file: |
|
waveform, sample_rate = torchaudio.load(file["audio"], backend=self.backend) |
|
if isinstance(file["audio"], IOBase): file["audio"].seek(0) |
|
|
|
channel = file.get("channel", None) |
|
if channel is not None: waveform = waveform[channel : channel + 1] |
|
|
|
return self.downmix_and_resample(waveform, sample_rate) |
|
|
|
def crop(self, file, segment, duration = None, mode="raise"): |
|
file = self.validate_file(file) |
|
|
|
if "waveform" in file: |
|
waveform = file["waveform"] |
|
frames = waveform.shape[1] |
|
sample_rate = file["sample_rate"] |
|
elif "torchaudio.info" in file: |
|
info = file["torchaudio.info"] |
|
frames = info.num_frames |
|
sample_rate = info.sample_rate |
|
else: |
|
info = get_torchaudio_info(file, backend=self.backend) |
|
frames = info.num_frames |
|
sample_rate = info.sample_rate |
|
|
|
channel = file.get("channel", None) |
|
start_frame = math.floor(segment.start * sample_rate) |
|
|
|
if duration: |
|
num_frames = math.floor(duration * sample_rate) |
|
end_frame = start_frame + num_frames |
|
else: |
|
end_frame = math.floor(segment.end * sample_rate) |
|
num_frames = end_frame - start_frame |
|
|
|
if mode == "raise": |
|
if num_frames > frames: raise ValueError |
|
|
|
if end_frame > frames + math.ceil(0.001 * sample_rate): raise ValueError |
|
else: |
|
end_frame = min(end_frame, frames) |
|
start_frame = end_frame - num_frames |
|
|
|
if start_frame < 0: raise ValueError |
|
elif mode == "pad": |
|
pad_start = -min(0, start_frame) |
|
pad_end = max(end_frame, frames) - frames |
|
|
|
start_frame = max(0, start_frame) |
|
end_frame = min(end_frame, frames) |
|
|
|
num_frames = end_frame - start_frame |
|
|
|
if "waveform" in file: data = file["waveform"][:, start_frame:end_frame] |
|
else: |
|
try: |
|
data, _ = torchaudio.load(file["audio"], frame_offset=start_frame, num_frames=num_frames, backend=self.backend) |
|
if isinstance(file["audio"], IOBase): file["audio"].seek(0) |
|
except RuntimeError: |
|
if isinstance(file["audio"], IOBase): raise RuntimeError |
|
|
|
waveform, sample_rate = self.__call__(file) |
|
data = waveform[:, start_frame:end_frame] |
|
|
|
file["waveform"] = waveform |
|
file["sample_rate"] = sample_rate |
|
|
|
if channel is not None: data = data[channel : channel + 1, :] |
|
if mode == "pad": data = pad(data, (pad_start, pad_end)) |
|
|
|
return self.downmix_and_resample(data, sample_rate) |