Spaces:
Runtime error
Runtime error
Commit
·
248f682
1
Parent(s):
ec63e8e
updated resample pipeline
Browse files- preprocessing/dataset.py +5 -3
- preprocessing/pipelines.py +5 -10
preprocessing/dataset.py
CHANGED
|
@@ -29,6 +29,7 @@ class SongDataset(Dataset):
|
|
| 29 |
audio_window_duration=6, # seconds
|
| 30 |
audio_window_jitter=1.0, # seconds
|
| 31 |
audio_durations=None,
|
|
|
|
| 32 |
):
|
| 33 |
assert (
|
| 34 |
audio_window_duration > audio_window_jitter
|
|
@@ -54,6 +55,7 @@ class SongDataset(Dataset):
|
|
| 54 |
self.audio_window_duration = int(audio_window_duration)
|
| 55 |
self.audio_start_offset = audio_start_offset
|
| 56 |
self.audio_window_jitter = audio_window_jitter
|
|
|
|
| 57 |
|
| 58 |
def __len__(self):
|
| 59 |
return int(
|
|
@@ -125,9 +127,9 @@ class SongDataset(Dataset):
|
|
| 125 |
waveform, sample_rate = ta.load(
|
| 126 |
audio_filepath, frame_offset=frame_offset, num_frames=num_frames
|
| 127 |
)
|
| 128 |
-
|
| 129 |
-
sample_rate
|
| 130 |
-
)
|
| 131 |
return waveform
|
| 132 |
|
| 133 |
def _label_from_index(self, idx: int) -> torch.Tensor:
|
|
|
|
| 29 |
audio_window_duration=6, # seconds
|
| 30 |
audio_window_jitter=1.0, # seconds
|
| 31 |
audio_durations=None,
|
| 32 |
+
target_sample_rate=16000,
|
| 33 |
):
|
| 34 |
assert (
|
| 35 |
audio_window_duration > audio_window_jitter
|
|
|
|
| 55 |
self.audio_window_duration = int(audio_window_duration)
|
| 56 |
self.audio_start_offset = audio_start_offset
|
| 57 |
self.audio_window_jitter = audio_window_jitter
|
| 58 |
+
self.target_sample_rate = target_sample_rate
|
| 59 |
|
| 60 |
def __len__(self):
|
| 61 |
return int(
|
|
|
|
| 127 |
waveform, sample_rate = ta.load(
|
| 128 |
audio_filepath, frame_offset=frame_offset, num_frames=num_frames
|
| 129 |
)
|
| 130 |
+
waveform = ta.functional.resample(
|
| 131 |
+
waveform, orig_freq=sample_rate, new_freq=self.target_sample_rate
|
| 132 |
+
)
|
| 133 |
return waveform
|
| 134 |
|
| 135 |
def _label_from_index(self, idx: int) -> torch.Tensor:
|
preprocessing/pipelines.py
CHANGED
|
@@ -7,21 +7,17 @@ import torch.nn as nn
|
|
| 7 |
class WaveformTrainingPipeline(torch.nn.Module):
|
| 8 |
def __init__(
|
| 9 |
self,
|
| 10 |
-
input_freq=16000,
|
| 11 |
-
resample_freq=16000,
|
| 12 |
expected_duration=6,
|
| 13 |
snr_mean=6.0,
|
| 14 |
noise_path=None,
|
| 15 |
):
|
| 16 |
super().__init__()
|
| 17 |
-
self.input_freq = input_freq
|
| 18 |
self.snr_mean = snr_mean
|
| 19 |
self.noise = self.get_noise(noise_path)
|
| 20 |
-
self.
|
| 21 |
-
self.resample = taT.Resample(input_freq, resample_freq)
|
| 22 |
|
| 23 |
self.preprocess_waveform = WaveformPreprocessing(
|
| 24 |
-
|
| 25 |
)
|
| 26 |
|
| 27 |
def get_noise(self, path) -> torch.Tensor:
|
|
@@ -30,8 +26,8 @@ class WaveformTrainingPipeline(torch.nn.Module):
|
|
| 30 |
noise, sr = torchaudio.load(path)
|
| 31 |
if noise.shape[0] > 1:
|
| 32 |
noise = noise.mean(0, keepdim=True)
|
| 33 |
-
if sr != self.
|
| 34 |
-
noise = taF.resample(noise, sr, self.
|
| 35 |
return noise
|
| 36 |
|
| 37 |
def add_noise(self, waveform: torch.Tensor) -> torch.Tensor:
|
|
@@ -49,7 +45,6 @@ class WaveformTrainingPipeline(torch.nn.Module):
|
|
| 49 |
return noisy_waveform
|
| 50 |
|
| 51 |
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
| 52 |
-
waveform = self.resample(waveform)
|
| 53 |
waveform = self.preprocess_waveform(waveform)
|
| 54 |
if self.noise is not None:
|
| 55 |
waveform = self.add_noise(waveform)
|
|
@@ -63,7 +58,7 @@ class SpectrogramTrainingPipeline(WaveformTrainingPipeline):
|
|
| 63 |
super().__init__(*args, **kwargs)
|
| 64 |
self.mask_count = mask_count
|
| 65 |
self.audio_to_spectrogram = AudioToSpectrogram(
|
| 66 |
-
sample_rate=self.
|
| 67 |
)
|
| 68 |
self.freq_mask = taT.FrequencyMasking(freq_mask_size)
|
| 69 |
self.time_mask = taT.TimeMasking(time_mask_size)
|
|
|
|
| 7 |
class WaveformTrainingPipeline(torch.nn.Module):
|
| 8 |
def __init__(
|
| 9 |
self,
|
|
|
|
|
|
|
| 10 |
expected_duration=6,
|
| 11 |
snr_mean=6.0,
|
| 12 |
noise_path=None,
|
| 13 |
):
|
| 14 |
super().__init__()
|
|
|
|
| 15 |
self.snr_mean = snr_mean
|
| 16 |
self.noise = self.get_noise(noise_path)
|
| 17 |
+
self.sample_rate = 16000
|
|
|
|
| 18 |
|
| 19 |
self.preprocess_waveform = WaveformPreprocessing(
|
| 20 |
+
self.sample_rate * expected_duration
|
| 21 |
)
|
| 22 |
|
| 23 |
def get_noise(self, path) -> torch.Tensor:
|
|
|
|
| 26 |
noise, sr = torchaudio.load(path)
|
| 27 |
if noise.shape[0] > 1:
|
| 28 |
noise = noise.mean(0, keepdim=True)
|
| 29 |
+
if sr != self.sample_rate:
|
| 30 |
+
noise = taF.resample(noise, sr, self.sample_rate)
|
| 31 |
return noise
|
| 32 |
|
| 33 |
def add_noise(self, waveform: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 45 |
return noisy_waveform
|
| 46 |
|
| 47 |
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 48 |
waveform = self.preprocess_waveform(waveform)
|
| 49 |
if self.noise is not None:
|
| 50 |
waveform = self.add_noise(waveform)
|
|
|
|
| 58 |
super().__init__(*args, **kwargs)
|
| 59 |
self.mask_count = mask_count
|
| 60 |
self.audio_to_spectrogram = AudioToSpectrogram(
|
| 61 |
+
sample_rate=self.sample_rate,
|
| 62 |
)
|
| 63 |
self.freq_mask = taT.FrequencyMasking(freq_mask_size)
|
| 64 |
self.time_mask = taT.TimeMasking(time_mask_size)
|