File size: 6,304 Bytes
6cfcfea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import os
import math
import random
import torchaudio
from io import IOBase
from torch.nn.functional import pad
def get_torchaudio_info(file, backend = None):
if not backend:
backends = (torchaudio.list_audio_backends())
backend = "soundfile" if "soundfile" in backends else backends[0]
info = torchaudio.info(file["audio"], backend=backend)
if isinstance(file["audio"], IOBase): file["audio"].seek(0)
return info
class Audio:
@staticmethod
def power_normalize(waveform):
return waveform / (waveform.square().mean(dim=-1, keepdim=True).sqrt() + 1e-8)
@staticmethod
def validate_file(file):
if isinstance(file, (str, os.PathLike)): file = {"audio": str(file), "uri": os.path.splitext(os.path.basename(file))[0]}
elif isinstance(file, IOBase): return {"audio": file, "uri": "stream"}
else: raise ValueError
if "waveform" in file:
waveform = file["waveform"]
if len(waveform.shape) != 2 or waveform.shape[0] > waveform.shape[1]: raise ValueError
sample_rate: int = file.get("sample_rate", None)
if sample_rate is None: raise ValueError
file.setdefault("uri", "waveform")
elif "audio" in file:
if isinstance(file["audio"], IOBase): return file
path = os.path.abspath(file["audio"])
file.setdefault("uri", os.path.splitext(os.path.basename(path))[0])
else: raise ValueError
return file
def __init__(self, sample_rate: int = None, mono=None, backend: str = None):
super().__init__()
self.sample_rate = sample_rate
self.mono = mono
if not backend:
backends = (torchaudio.list_audio_backends())
backend = "soundfile" if "soundfile" in backends else backends[0]
self.backend = backend
def downmix_and_resample(self, waveform, sample_rate):
num_channels = waveform.shape[0]
if num_channels > 1:
if self.mono == "random":
channel = random.randint(0, num_channels - 1)
waveform = waveform[channel : channel + 1]
elif self.mono == "downmix": waveform = waveform.mean(dim=0, keepdim=True)
if (self.sample_rate is not None) and (self.sample_rate != sample_rate):
waveform = torchaudio.functional.resample(waveform, sample_rate, self.sample_rate)
sample_rate = self.sample_rate
return waveform, sample_rate
def get_duration(self, file):
file = self.validate_file(file)
if "waveform" in file:
frames = len(file["waveform"].T)
sample_rate = file["sample_rate"]
else:
info = file["torchaudio.info"] if "torchaudio.info" in file else get_torchaudio_info(file, backend=self.backend)
frames = info.num_frames
sample_rate = info.sample_rate
return frames / sample_rate
def get_num_samples(self, duration, sample_rate = None):
sample_rate = sample_rate or self.sample_rate
if sample_rate is None: raise ValueError
return math.floor(duration * sample_rate)
def __call__(self, file):
file = self.validate_file(file)
if "waveform" in file:
waveform = file["waveform"]
sample_rate = file["sample_rate"]
elif "audio" in file:
waveform, sample_rate = torchaudio.load(file["audio"], backend=self.backend)
if isinstance(file["audio"], IOBase): file["audio"].seek(0)
channel = file.get("channel", None)
if channel is not None: waveform = waveform[channel : channel + 1]
return self.downmix_and_resample(waveform, sample_rate)
def crop(self, file, segment, duration = None, mode="raise"):
file = self.validate_file(file)
if "waveform" in file:
waveform = file["waveform"]
frames = waveform.shape[1]
sample_rate = file["sample_rate"]
elif "torchaudio.info" in file:
info = file["torchaudio.info"]
frames = info.num_frames
sample_rate = info.sample_rate
else:
info = get_torchaudio_info(file, backend=self.backend)
frames = info.num_frames
sample_rate = info.sample_rate
channel = file.get("channel", None)
start_frame = math.floor(segment.start * sample_rate)
if duration:
num_frames = math.floor(duration * sample_rate)
end_frame = start_frame + num_frames
else:
end_frame = math.floor(segment.end * sample_rate)
num_frames = end_frame - start_frame
if mode == "raise":
if num_frames > frames: raise ValueError
if end_frame > frames + math.ceil(0.001 * sample_rate): raise ValueError
else:
end_frame = min(end_frame, frames)
start_frame = end_frame - num_frames
if start_frame < 0: raise ValueError
elif mode == "pad":
pad_start = -min(0, start_frame)
pad_end = max(end_frame, frames) - frames
start_frame = max(0, start_frame)
end_frame = min(end_frame, frames)
num_frames = end_frame - start_frame
if "waveform" in file: data = file["waveform"][:, start_frame:end_frame]
else:
try:
data, _ = torchaudio.load(file["audio"], frame_offset=start_frame, num_frames=num_frames, backend=self.backend)
if isinstance(file["audio"], IOBase): file["audio"].seek(0)
except RuntimeError:
if isinstance(file["audio"], IOBase): raise RuntimeError
waveform, sample_rate = self.__call__(file)
data = waveform[:, start_frame:end_frame]
file["waveform"] = waveform
file["sample_rate"] = sample_rate
if channel is not None: data = data[channel : channel + 1, :]
if mode == "pad": data = pad(data, (pad_start, pad_end))
return self.downmix_and_resample(data, sample_rate) |