Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| import logging | |
| import random | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| import torchaudio.functional as AF | |
| from torch.nn.utils.rnn import pad_sequence | |
| from torch.utils.data import Dataset as DatasetBase | |
| from ..hparams import HParams | |
| from .distorter import Distorter | |
| from .utils import rglob_audio_files | |
| logger = logging.getLogger(__name__) | |
| def _normalize(x): | |
| return x / (np.abs(x).max() + 1e-7) | |
| def _collate(batch, key, tensor=True, pad=True): | |
| l = [d[key] for d in batch] | |
| if l[0] is None: | |
| return None | |
| if tensor: | |
| l = [torch.from_numpy(x) for x in l] | |
| if pad: | |
| assert tensor, "Can't pad non-tensor" | |
| l = pad_sequence(l, batch_first=True) | |
| return l | |
| def praat_augment(wav, sr): | |
| try: | |
| import parselmouth | |
| except ImportError: | |
| raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation") | |
| # "praat-parselmouth @ git+https://github.com/YannickJadoul/Parselmouth@0bbcca69705ed73322f3712b19d71bb3694b2540", | |
| # https://github.com/YannickJadoul/Parselmouth/issues/68 | |
| # note that this function may hang if the praat version is 0.4.3 | |
| assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}" | |
| sound = parselmouth.Sound(wav, sr) | |
| formant_shift_ratio = random.uniform(1.1, 1.5) | |
| pitch_range_factor = random.uniform(0.5, 2.0) | |
| sound = parselmouth.praat.call(sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0) | |
| wav = np.array(sound.values)[0].astype(np.float32) | |
| return wav | |
| class Dataset(DatasetBase): | |
| def __init__( | |
| self, | |
| fg_paths: list[Path], | |
| hp: HParams, | |
| training=True, | |
| max_retries=100, | |
| silent_fg_prob=0.01, | |
| mode=False, | |
| ): | |
| super().__init__() | |
| assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}" | |
| self.hp = hp | |
| self.fg_paths = fg_paths | |
| self.bg_paths = rglob_audio_files(hp.bg_dir) | |
| if len(self.fg_paths) == 0: | |
| raise ValueError(f"No foreground audio files found in {hp.fg_dir}") | |
| if len(self.bg_paths) == 0: | |
| raise ValueError(f"No background audio files found in {hp.bg_dir}") | |
| logger.info(f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files") | |
| self.training = training | |
| self.max_retries = max_retries | |
| self.silent_fg_prob = silent_fg_prob | |
| self.mode = mode | |
| self.distorter = Distorter(hp, training=training, mode=mode) | |
| def _load_wav(self, path, length=None, random_crop=True): | |
| wav, sr = torchaudio.load(path) | |
| wav = AF.resample( | |
| waveform=wav, | |
| orig_freq=sr, | |
| new_freq=self.hp.wav_rate, | |
| lowpass_filter_width=64, | |
| rolloff=0.9475937167399596, | |
| resampling_method="sinc_interp_kaiser", | |
| beta=14.769656459379492, | |
| ) | |
| wav = wav.float().numpy() | |
| if wav.ndim == 2: | |
| wav = np.mean(wav, axis=0) | |
| if length is None and self.training: | |
| length = int(self.hp.training_seconds * self.hp.wav_rate) | |
| if length is not None: | |
| if random_crop: | |
| start = random.randint(0, max(0, len(wav) - length)) | |
| wav = wav[start : start + length] | |
| else: | |
| wav = wav[:length] | |
| if length is not None and len(wav) < length: | |
| wav = np.pad(wav, (0, length - len(wav))) | |
| wav = _normalize(wav) | |
| return wav | |
| def _getitem_unsafe(self, index: int): | |
| fg_path = self.fg_paths[index] | |
| if self.training and random.random() < self.silent_fg_prob: | |
| fg_wav = np.zeros(int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32) | |
| else: | |
| fg_wav = self._load_wav(fg_path) | |
| if random.random() < self.hp.praat_augment_prob and self.training: | |
| fg_wav = praat_augment(fg_wav, self.hp.wav_rate) | |
| if self.hp.load_fg_only: | |
| bg_wav = None | |
| fg_dwav = None | |
| bg_dwav = None | |
| else: | |
| fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(np.float32) | |
| if self.training: | |
| bg_path = random.choice(self.bg_paths) | |
| else: | |
| # Deterministic for validation | |
| bg_path = self.bg_paths[index % len(self.bg_paths)] | |
| bg_wav = self._load_wav(bg_path, length=len(fg_wav), random_crop=self.training) | |
| bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(np.float32) | |
| return dict( | |
| fg_wav=fg_wav, | |
| bg_wav=bg_wav, | |
| fg_dwav=fg_dwav, | |
| bg_dwav=bg_dwav, | |
| ) | |
| def __getitem__(self, index: int): | |
| for i in range(self.max_retries): | |
| try: | |
| return self._getitem_unsafe(index) | |
| except Exception as e: | |
| if i == self.max_retries - 1: | |
| raise RuntimeError(f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries") from e | |
| logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping") | |
| index = np.random.randint(0, len(self)) | |
| def __len__(self): | |
| return len(self.fg_paths) | |
| def collate_fn(batch): | |
| return dict( | |
| fg_wavs=_collate(batch, "fg_wav"), | |
| bg_wavs=_collate(batch, "bg_wav"), | |
| fg_dwavs=_collate(batch, "fg_dwav"), | |
| bg_dwavs=_collate(batch, "bg_dwav"), | |
| ) | |