Spaces:
Running
Running
import os | |
import json | |
import numpy as np | |
from typing import Any, Tuple | |
import soundfile as sf | |
import torch | |
from pytorch_lightning import LightningDataModule | |
# from pytorch_lightning.core.mixins import HyperparametersMixin | |
from torch.utils.data import ConcatDataset, DataLoader, Dataset | |
from typing import Dict, Iterable, List, Iterator | |
from rich import print | |
from pytorch_lightning.utilities import rank_zero_only | |
def print_(message: str): | |
print(message) | |
def normalize_tensor_wav(wav_tensor, eps=1e-8, std=None): | |
mean = wav_tensor.mean(-1, keepdim=True) | |
if std is None: | |
std = wav_tensor.std(-1, keepdim=True) | |
return (wav_tensor - mean) / (std + eps) | |
class MP3DDataset(Dataset): | |
def __init__( | |
self, | |
json_dir: str = "", | |
n_src: int = 2, | |
sample_rate: int = 8000, | |
segment: float = 4.0, | |
normalize_audio: bool = False, | |
) -> None: | |
super().__init__() | |
self.EPS = 1e-8 | |
if json_dir == None: | |
raise ValueError("JSON DIR is None!") | |
if n_src not in [1, 2]: | |
raise ValueError("{} is not in [1, 2]".format(n_src)) | |
self.json_dir = json_dir | |
self.sample_rate = sample_rate | |
self.normalize_audio = normalize_audio | |
if segment is None: | |
self.seg_len = None | |
self.fps_len = None | |
else: | |
self.seg_len = int(segment * sample_rate) | |
self.n_src = n_src | |
self.test = self.seg_len is None | |
mix_json = os.path.join(json_dir, "mix.json") | |
sources_json = [ | |
os.path.join(json_dir, source + ".json") for source in ["s1", "s2"] | |
] | |
with open(mix_json, "r") as f: | |
mix_infos = json.load(f) | |
sources_infos = [] | |
for src_json in sources_json: | |
with open(src_json, "r") as f: | |
sources_infos.append(json.load(f)) | |
self.mix = [] | |
self.sources = [] | |
if self.n_src == 1: | |
orig_len = len(mix_infos) * 2 | |
drop_utt, drop_len = 0, 0 | |
if not self.test: | |
for i in range(len(mix_infos) - 1, -1, -1): | |
if mix_infos[i][1] < self.seg_len: | |
drop_utt = drop_utt + 1 | |
drop_len = drop_len + mix_infos[i][1] | |
del mix_infos[i] | |
for src_inf in sources_infos: | |
del src_inf[i] | |
else: | |
for src_inf in sources_infos: | |
self.mix.append(mix_infos[i]) | |
self.sources.append(src_inf[i]) | |
else: | |
for i in range(len(mix_infos)): | |
for src_inf in sources_infos: | |
self.mix.append(mix_infos[i]) | |
self.sources.append(src_inf[i]) | |
print_( | |
"Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format( | |
drop_utt, drop_len / sample_rate / 3600, orig_len, self.seg_len | |
) | |
) | |
self.length = len(self.mix) | |
elif self.n_src == 2: | |
orig_len = len(mix_infos) | |
drop_utt, drop_len = 0, 0 | |
if not self.test: | |
for i in range(len(mix_infos) - 1, -1, -1): # Go backward | |
if mix_infos[i][1] < self.seg_len: | |
drop_utt = drop_utt + 1 | |
drop_len = drop_len + mix_infos[i][1] | |
del mix_infos[i] | |
for src_inf in sources_infos: | |
del src_inf[i] | |
print_( | |
"Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format( | |
drop_utt, drop_len / sample_rate / 36000, orig_len, self.seg_len | |
) | |
) | |
self.mix = mix_infos | |
self.sources = sources_infos | |
self.length = len(self.mix) | |
def __len__(self): | |
return self.length | |
def preprocess_audio_only(self, idx: int): | |
if self.n_src == 1: | |
if self.mix[idx][1] == self.seg_len or self.test: | |
rand_start = 0 | |
else: | |
rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len) | |
if self.test: | |
stop = None | |
else: | |
stop = rand_start + self.seg_len | |
# Load mixture | |
x, _ = sf.read( | |
self.mix[idx][0], start=rand_start, stop=stop, dtype="float32" | |
) | |
# Load sources | |
s, _ = sf.read( | |
self.sources[idx][0], start=rand_start, stop=stop, dtype="float32" | |
) | |
# torch from numpy | |
target = torch.from_numpy(s) | |
mixture = torch.from_numpy(x) | |
if self.normalize_audio: | |
m_std = mixture.std(-1, keepdim=True) | |
mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std) | |
target = normalize_tensor_wav(target, eps=self.EPS, std=m_std) | |
# return mixture, target.unsqueeze(0), self.mix[idx][0].split("/")[-1] | |
return mixture, target.unsqueeze(0), self.mix[idx][0] | |
# import pdb; pdb.set_trace() | |
if self.n_src == 2: | |
if self.mix[idx][1] == self.seg_len or self.test: | |
rand_start = 0 | |
else: | |
rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len) | |
if self.test: | |
stop = None | |
else: | |
stop = rand_start + self.seg_len | |
# Load mixture | |
x, _ = sf.read( | |
self.mix[idx][0], start=rand_start, stop=stop, dtype="float32" | |
) | |
# Load sources | |
source_arrays = [] | |
for src in self.sources: | |
s, _ = sf.read( | |
src[idx][0], start=rand_start, stop=stop, dtype="float32" | |
) | |
source_arrays.append(s) | |
sources = torch.from_numpy(np.vstack(source_arrays)) | |
mixture = torch.from_numpy(x) | |
if self.normalize_audio: | |
m_std = mixture.std(-1, keepdim=True) | |
mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std) | |
sources = normalize_tensor_wav(sources, eps=self.EPS, std=m_std) | |
# return mixture, sources, self.mix[idx][0].split("/")[-1] | |
return mixture, sources, self.mix[idx][0] | |
def __getitem__(self, index: int): | |
return self.preprocess_audio_only(index) | |
class EchoSetDataModule(object): | |
def __init__( | |
self, | |
train_dir: str, | |
valid_dir: str, | |
test_dir: str, | |
n_src: int = 2, | |
sample_rate: int = 8000, | |
segment: float = 4.0, | |
normalize_audio: bool = False, | |
batch_size: int = 64, | |
num_workers: int = 0, | |
pin_memory: bool = False, | |
persistent_workers: bool = False, | |
) -> None: | |
super().__init__() | |
if train_dir == None or valid_dir == None or test_dir == None: | |
raise ValueError("JSON DIR is None!") | |
if n_src not in [1, 2]: | |
raise ValueError("{} is not in [1, 2]".format(n_src)) | |
# this line allows to access init params with 'self.hparams' attribute | |
self.train_dir = train_dir | |
self.valid_dir = valid_dir | |
self.test_dir = test_dir | |
self.n_src = n_src | |
self.sample_rate = sample_rate | |
self.segment = segment | |
self.normalize_audio = normalize_audio | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.pin_memory = pin_memory | |
self.persistent_workers = persistent_workers | |
self.data_train: Dataset = None | |
self.data_val: Dataset = None | |
self.data_test: Dataset = None | |
def setup(self) -> None: | |
self.data_train = MP3DDataset( | |
json_dir=self.train_dir, | |
n_src=self.n_src, | |
sample_rate=self.sample_rate, | |
segment=self.segment, | |
normalize_audio=self.normalize_audio, | |
) | |
self.data_val = MP3DDataset( | |
json_dir=self.valid_dir, | |
n_src=self.n_src, | |
sample_rate=self.sample_rate, | |
segment=None, | |
normalize_audio=self.normalize_audio, | |
) | |
self.data_test = MP3DDataset( | |
json_dir=self.test_dir, | |
n_src=self.n_src, | |
sample_rate=self.sample_rate, | |
segment=None, | |
normalize_audio=self.normalize_audio, | |
) | |
def train_dataloader(self) -> DataLoader: | |
return DataLoader( | |
dataset=self.data_train, | |
batch_size=self.batch_size, | |
shuffle=True, | |
num_workers=self.num_workers, | |
persistent_workers=self.persistent_workers, | |
pin_memory=self.pin_memory, | |
drop_last=True, | |
) | |
def val_dataloader(self) -> DataLoader: | |
return DataLoader( | |
dataset=self.data_val, | |
shuffle=False, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
persistent_workers=self.persistent_workers, | |
pin_memory=self.pin_memory, | |
drop_last=False, | |
) | |
def test_dataloader(self) -> DataLoader: | |
return DataLoader( | |
dataset=self.data_test, | |
shuffle=False, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
persistent_workers=self.persistent_workers, | |
pin_memory=self.pin_memory, | |
drop_last=False, | |
) | |
def make_loader(self): | |
return self.train_dataloader(), self.val_dataloader(), self.test_dataloader() | |
def make_sets(self): | |
return self.data_train, self.data_val, self.data_test | |