|  | import glob | 
					
						
						|  | import os | 
					
						
						|  | import random | 
					
						
						|  |  | 
					
						
						|  | import librosa | 
					
						
						|  | import numpy as np | 
					
						
						|  | import soundfile as sf | 
					
						
						|  | import torch | 
					
						
						|  | from numpy.random import default_rng | 
					
						
						|  | from pydtmc import MarkovChain | 
					
						
						|  | from sklearn.model_selection import train_test_split | 
					
						
						|  | from torch.utils.data import Dataset | 
					
						
						|  |  | 
					
						
						|  | from config import CONFIG | 
					
						
						|  |  | 
					
						
						|  | np.random.seed(0) | 
					
						
						|  | rng = default_rng() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_audio( | 
					
						
						|  | path, | 
					
						
						|  | sample_rate: int = 16000, | 
					
						
						|  | chunk_len=None, | 
					
						
						|  | ): | 
					
						
						|  | with sf.SoundFile(path) as f: | 
					
						
						|  | sr = f.samplerate | 
					
						
						|  | audio_len = f.frames | 
					
						
						|  |  | 
					
						
						|  | if chunk_len is not None and chunk_len < audio_len: | 
					
						
						|  | start_index = torch.randint(0, audio_len - chunk_len, (1,))[0] | 
					
						
						|  |  | 
					
						
						|  | frames = f._prepare_read(start_index, start_index + chunk_len, -1) | 
					
						
						|  | audio = f.read(frames, always_2d=True, dtype="float32") | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | audio = f.read(always_2d=True, dtype="float32") | 
					
						
						|  |  | 
					
						
						|  | if sr != sample_rate: | 
					
						
						|  | audio = librosa.resample(np.squeeze(audio), sr, sample_rate)[:, np.newaxis] | 
					
						
						|  |  | 
					
						
						|  | return audio.T | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def pad(sig, length): | 
					
						
						|  | if sig.shape[1] < length: | 
					
						
						|  | pad_len = length - sig.shape[1] | 
					
						
						|  | sig = torch.hstack((sig, torch.zeros((sig.shape[0], pad_len)))) | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | start = random.randint(0, sig.shape[1] - length) | 
					
						
						|  | sig = sig[:, start:start + length] | 
					
						
						|  | return sig | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MaskGenerator: | 
					
						
						|  | def __init__(self, is_train=True, probs=((0.9, 0.1), (0.5, 0.1), (0.5, 0.5))): | 
					
						
						|  | ''' | 
					
						
						|  | is_train: if True, mask generator for training otherwise for evaluation | 
					
						
						|  | probs: a list of transition probability (p_N, p_L) for Markov Chain. Only allow 1 tuple if 'is_train=False' | 
					
						
						|  | ''' | 
					
						
						|  | self.is_train = is_train | 
					
						
						|  | self.probs = probs | 
					
						
						|  | self.mcs = [] | 
					
						
						|  | if self.is_train: | 
					
						
						|  | for prob in probs: | 
					
						
						|  | self.mcs.append(MarkovChain([[prob[0], 1 - prob[0]], [1 - prob[1], prob[1]]], ['1', '0'])) | 
					
						
						|  | else: | 
					
						
						|  | assert len(probs) == 1 | 
					
						
						|  | prob = self.probs[0] | 
					
						
						|  | self.mcs.append(MarkovChain([[prob[0], 1 - prob[0]], [1 - prob[1], prob[1]]], ['1', '0'])) | 
					
						
						|  |  | 
					
						
						|  | def gen_mask(self, length, seed=0): | 
					
						
						|  | if self.is_train: | 
					
						
						|  | mc = random.choice(self.mcs) | 
					
						
						|  | else: | 
					
						
						|  | mc = self.mcs[0] | 
					
						
						|  | mask = mc.walk(length - 1, seed=seed) | 
					
						
						|  | mask = np.array(list(map(int, mask))) | 
					
						
						|  | return mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TestLoader(Dataset): | 
					
						
						|  | def __init__(self): | 
					
						
						|  | dataset_name = CONFIG.DATA.dataset | 
					
						
						|  | self.mask = CONFIG.DATA.EVAL.masking | 
					
						
						|  |  | 
					
						
						|  | self.target_root = CONFIG.DATA.data_dir[dataset_name]['root'] | 
					
						
						|  | txt_list = CONFIG.DATA.data_dir[dataset_name]['test'] | 
					
						
						|  | self.data_list = self.load_txt(txt_list) | 
					
						
						|  | if self.mask == 'real': | 
					
						
						|  | trace_txt = glob.glob(os.path.join(CONFIG.DATA.EVAL.trace_path, '*.txt')) | 
					
						
						|  | trace_txt.sort() | 
					
						
						|  | self.trace_list = [1 - np.array(list(map(int, open(txt, 'r').read().strip('\n').split('\n')))) for txt in | 
					
						
						|  | trace_txt] | 
					
						
						|  | else: | 
					
						
						|  | self.mask_generator = MaskGenerator(is_train=False, probs=CONFIG.DATA.EVAL.transition_probs) | 
					
						
						|  |  | 
					
						
						|  | self.sr = CONFIG.DATA.sr | 
					
						
						|  | self.stride = CONFIG.DATA.stride | 
					
						
						|  | self.window_size = CONFIG.DATA.window_size | 
					
						
						|  | self.audio_chunk_len = CONFIG.DATA.audio_chunk_len | 
					
						
						|  | self.p_size = CONFIG.DATA.EVAL.packet_size | 
					
						
						|  | self.hann = torch.sqrt(torch.hann_window(self.window_size)) | 
					
						
						|  |  | 
					
						
						|  | def __len__(self): | 
					
						
						|  | return len(self.data_list) | 
					
						
						|  |  | 
					
						
						|  | def load_txt(self, txt_list): | 
					
						
						|  | target = [] | 
					
						
						|  | with open(txt_list) as f: | 
					
						
						|  | for line in f: | 
					
						
						|  | target.append(os.path.join(self.target_root, line.strip('\n'))) | 
					
						
						|  | target = list(set(target)) | 
					
						
						|  | target.sort() | 
					
						
						|  | return target | 
					
						
						|  |  | 
					
						
						|  | def __getitem__(self, index): | 
					
						
						|  | target = load_audio(self.data_list[index], sample_rate=self.sr) | 
					
						
						|  | target = target[:, :(target.shape[1] // self.p_size) * self.p_size] | 
					
						
						|  |  | 
					
						
						|  | sig = np.reshape(target, (-1, self.p_size)).copy() | 
					
						
						|  | if self.mask == 'real': | 
					
						
						|  | mask = self.trace_list[index % len(self.trace_list)] | 
					
						
						|  | mask = np.repeat(mask, np.ceil(len(sig) / len(mask)), 0)[:len(sig)][:, np.newaxis] | 
					
						
						|  | else: | 
					
						
						|  | mask = self.mask_generator.gen_mask(len(sig), seed=index)[:, np.newaxis] | 
					
						
						|  | sig *= mask | 
					
						
						|  | sig = torch.tensor(sig).reshape(-1) | 
					
						
						|  |  | 
					
						
						|  | target = torch.tensor(target).squeeze(0) | 
					
						
						|  |  | 
					
						
						|  | sig_wav = sig.clone() | 
					
						
						|  | target_wav = target.clone() | 
					
						
						|  |  | 
					
						
						|  | target = torch.stft(target, self.window_size, self.stride, window=self.hann, | 
					
						
						|  | return_complex=False).permute(2, 0, 1) | 
					
						
						|  | sig = torch.stft(sig, self.window_size, self.stride, window=self.hann, return_complex=False).permute(2, 0, 1) | 
					
						
						|  | return sig.float(), target.float(), sig_wav, target_wav | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BlindTestLoader(Dataset): | 
					
						
						|  | def __init__(self, test_dir): | 
					
						
						|  | self.data_list = glob.glob(os.path.join(test_dir, '*.wav')) | 
					
						
						|  | self.sr = CONFIG.DATA.sr | 
					
						
						|  | self.stride = CONFIG.DATA.stride | 
					
						
						|  | self.chunk_len = CONFIG.DATA.window_size | 
					
						
						|  | self.hann = torch.sqrt(torch.hann_window(self.chunk_len)) | 
					
						
						|  |  | 
					
						
						|  | def __len__(self): | 
					
						
						|  | return len(self.data_list) | 
					
						
						|  |  | 
					
						
						|  | def __getitem__(self, index): | 
					
						
						|  | sig = load_audio(self.data_list[index], sample_rate=self.sr) | 
					
						
						|  | sig = torch.from_numpy(sig).squeeze(0) | 
					
						
						|  | sig = torch.stft(sig, self.chunk_len, self.stride, window=self.hann, return_complex=False).permute(2, 0, 1) | 
					
						
						|  | return sig.float() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TrainDataset(Dataset): | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, mode='train'): | 
					
						
						|  | dataset_name = CONFIG.DATA.dataset | 
					
						
						|  | self.target_root = CONFIG.DATA.data_dir[dataset_name]['root'] | 
					
						
						|  |  | 
					
						
						|  | txt_list = CONFIG.DATA.data_dir[dataset_name]['train'] | 
					
						
						|  | self.data_list = self.load_txt(txt_list) | 
					
						
						|  |  | 
					
						
						|  | if mode == 'train': | 
					
						
						|  | self.data_list, _ = train_test_split(self.data_list, test_size=CONFIG.TRAIN.val_split, random_state=0) | 
					
						
						|  |  | 
					
						
						|  | elif mode == 'val': | 
					
						
						|  | _, self.data_list = train_test_split(self.data_list, test_size=CONFIG.TRAIN.val_split, random_state=0) | 
					
						
						|  |  | 
					
						
						|  | self.p_sizes = CONFIG.DATA.TRAIN.packet_sizes | 
					
						
						|  | self.mode = mode | 
					
						
						|  | self.sr = CONFIG.DATA.sr | 
					
						
						|  | self.window = CONFIG.DATA.audio_chunk_len | 
					
						
						|  | self.stride = CONFIG.DATA.stride | 
					
						
						|  | self.chunk_len = CONFIG.DATA.window_size | 
					
						
						|  | self.hann = torch.sqrt(torch.hann_window(self.chunk_len)) | 
					
						
						|  | self.mask_generator = MaskGenerator(is_train=True, probs=CONFIG.DATA.TRAIN.transition_probs) | 
					
						
						|  |  | 
					
						
						|  | def __len__(self): | 
					
						
						|  | return len(self.data_list) | 
					
						
						|  |  | 
					
						
						|  | def load_txt(self, txt_list): | 
					
						
						|  | target = [] | 
					
						
						|  | with open(txt_list) as f: | 
					
						
						|  | for line in f: | 
					
						
						|  | target.append(os.path.join(self.target_root, line.strip('\n'))) | 
					
						
						|  | target = list(set(target)) | 
					
						
						|  | target.sort() | 
					
						
						|  | return target | 
					
						
						|  |  | 
					
						
						|  | def fetch_audio(self, index): | 
					
						
						|  | sig = load_audio(self.data_list[index], sample_rate=self.sr, chunk_len=self.window) | 
					
						
						|  | while sig.shape[1] < self.window: | 
					
						
						|  | idx = torch.randint(0, len(self.data_list), (1,))[0] | 
					
						
						|  | pad_len = self.window - sig.shape[1] | 
					
						
						|  | if pad_len < 0.02 * self.sr: | 
					
						
						|  | padding = np.zeros((1, pad_len), dtype=np.float) | 
					
						
						|  | else: | 
					
						
						|  | padding = load_audio(self.data_list[idx], sample_rate=self.sr, chunk_len=pad_len) | 
					
						
						|  | sig = np.hstack((sig, padding)) | 
					
						
						|  | return sig | 
					
						
						|  |  | 
					
						
						|  | def __getitem__(self, index): | 
					
						
						|  | sig = self.fetch_audio(index) | 
					
						
						|  |  | 
					
						
						|  | sig = sig.reshape(-1).astype(np.float32) | 
					
						
						|  |  | 
					
						
						|  | target = torch.tensor(sig.copy()) | 
					
						
						|  | p_size = random.choice(self.p_sizes) | 
					
						
						|  |  | 
					
						
						|  | sig = np.reshape(sig, (-1, p_size)) | 
					
						
						|  | mask = self.mask_generator.gen_mask(len(sig), seed=index)[:, np.newaxis] | 
					
						
						|  | sig *= mask | 
					
						
						|  | sig = torch.tensor(sig.copy()).reshape(-1) | 
					
						
						|  |  | 
					
						
						|  | target = torch.stft(target, self.chunk_len, self.stride, window=self.hann, | 
					
						
						|  | return_complex=False).permute(2, 0, 1).float() | 
					
						
						|  | sig = torch.stft(sig, self.chunk_len, self.stride, window=self.hann, return_complex=False) | 
					
						
						|  | sig = sig.permute(2, 0, 1).float() | 
					
						
						|  | return sig, target | 
					
						
						|  |  |