import os import json from tkinter.tix import Tree import numpy as np from typing import Any, Tuple import soundfile as sf import torch from pytorch_lightning import LightningDataModule from pytorch_lightning.core.mixins import HyperparametersMixin from torch.utils.data import ConcatDataset, DataLoader, Dataset from typing import Dict, Iterable, List, Iterator from rich import print from pytorch_lightning.utilities import rank_zero_only @rank_zero_only def print_(message: str): print(message) def normalize_tensor_wav(wav_tensor, eps=1e-8, std=None): mean = wav_tensor.mean(-1, keepdim=True) if std is None: std = wav_tensor.std(-1, keepdim=True) return (wav_tensor - mean) / (std + eps) class LRS2Dataset(Dataset): def __init__( self, json_dir: str = "", n_src: int = 2, sample_rate: int = 8000, fps: int = 25, segment: float = 4.0, normalize_audio: bool = False, audio_only: bool = True, ) -> None: super().__init__() self.EPS = 1e-8 if json_dir == None: raise ValueError("JSON DIR is None!") if n_src not in [1, 2]: raise ValueError("{} is not in [1, 2]".format(n_src)) self.json_dir = json_dir self.sample_rate = sample_rate self.normalize_audio = normalize_audio self.audio_only = audio_only if segment is None: self.seg_len = None self.fps_len = None else: self.seg_len = int(segment * sample_rate) self.fps_len = int(segment * fps) self.n_src = n_src self.test = self.seg_len is None mix_json = os.path.join(json_dir, "mix.json") sources_json = [ os.path.join(json_dir, source + ".json") for source in ["s1", "s2"] ] with open(mix_json, "r") as f: mix_infos = json.load(f) sources_infos = [] for src_json in sources_json: with open(src_json, "r") as f: sources_infos.append(json.load(f)) self.mix = [] self.sources = [] if self.n_src == 1: orig_len = len(mix_infos) * 2 drop_utt, drop_len = 0, 0 if not self.test: for i in range(len(mix_infos) - 1, -1, -1): if mix_infos[i][1] < self.seg_len: drop_utt = drop_utt + 1 drop_len = drop_len + mix_infos[i][1] del mix_infos[i] for src_inf in sources_infos: del src_inf[i] else: for src_inf in sources_infos: self.mix.append(mix_infos[i]) self.sources.append(src_inf[i]) else: for i in range(len(mix_infos)): for src_inf in sources_infos: self.mix.append(mix_infos[i]) self.sources.append(src_inf[i]) print_( "Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format( drop_utt, drop_len / sample_rate / 3600, orig_len, self.seg_len ) ) self.length = len(self.mix) elif self.n_src == 2: orig_len = len(mix_infos) drop_utt, drop_len = 0, 0 if not self.test: for i in range(len(mix_infos) - 1, -1, -1): # Go backward if mix_infos[i][1] < self.seg_len: drop_utt = drop_utt + 1 drop_len = drop_len + mix_infos[i][1] del mix_infos[i] for src_inf in sources_infos: del src_inf[i] print_( "Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format( drop_utt, drop_len / sample_rate / 36000, orig_len, self.seg_len ) ) self.mix = mix_infos self.sources = sources_infos self.length = len(self.mix) def __len__(self): return self.length def preprocess_audio_only(self, idx: int): if self.n_src == 1: if self.mix[idx][1] == self.seg_len or self.test: rand_start = 0 else: rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len) if self.test: stop = None else: stop = rand_start + self.seg_len # Load mixture x, _ = sf.read( self.mix[idx][0], start=rand_start, stop=stop, dtype="float32" ) # Load sources s, _ = sf.read( self.sources[idx][0], start=rand_start, stop=stop, dtype="float32" ) # torch from numpy target = torch.from_numpy(s) mixture = torch.from_numpy(x) if self.normalize_audio: m_std = mixture.std(-1, keepdim=True) mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std) target = normalize_tensor_wav(target, eps=self.EPS, std=m_std) return mixture, target.unsqueeze(0), self.mix[idx][0].split("/")[-1] # import pdb; pdb.set_trace() if self.n_src == 2: if self.mix[idx][1] == self.seg_len or self.test: rand_start = 0 else: rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len) if self.test: stop = None else: stop = rand_start + self.seg_len # Load mixture x, _ = sf.read( self.mix[idx][0], start=rand_start, stop=stop, dtype="float32" ) # Load sources source_arrays = [] for src in self.sources: s, _ = sf.read( src[idx][0], start=rand_start, stop=stop, dtype="float32" ) source_arrays.append(s) sources = torch.from_numpy(np.vstack(source_arrays)) mixture = torch.from_numpy(x) if self.normalize_audio: m_std = mixture.std(-1, keepdim=True) mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std) sources = normalize_tensor_wav(sources, eps=self.EPS, std=m_std) return mixture, sources, self.mix[idx][0].split("/")[-1] def preprocess_audio_visual(self, idx: int): if self.n_src == 1: if self.mix[idx][1] == self.seg_len or self.test: rand_start = 0 else: rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len) if self.test: stop = None else: stop = rand_start + self.seg_len mix_source, _ = sf.read( self.mix[idx][0], start=rand_start, stop=stop, dtype="float32" ) source = sf.read( self.sources[idx][0], start=rand_start, stop=stop, dtype="float32" )[0] source_mouth = None source = torch.from_numpy(source) mixture = torch.from_numpy(mix_source) if self.normalize_audio: m_std = mixture.std(-1, keepdim=True) mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std) source = normalize_tensor_wav(source, eps=self.EPS, std=m_std) return mixture, source, source_mouth, self.mix[idx][0].split("/")[-1] if self.n_src == 2: if self.mix[idx][1] == self.seg_len or self.test: rand_start = 0 else: rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len) if self.test: stop = None else: stop = rand_start + self.seg_len mix_source, _ = sf.read( self.mix[idx][0], start=rand_start, stop=stop, dtype="float32" ) sources = [] for src in self.sources[idx]: # import pdb; pdb.set_trace() sources.append( sf.read(src[0], start=rand_start, stop=stop, dtype="float32")[0] ) # import pdb; pdb.set_trace() sources_mouths = None # import pdb; pdb.set_trace() sources = torch.stack([torch.from_numpy(source) for source in sources]) mixture = torch.from_numpy(mix_source) if self.normalize_audio: m_std = mixture.std(-1, keepdim=True) mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std) sources = normalize_tensor_wav(sources, eps=self.EPS, std=m_std) return mixture, sources, sources_mouths, self.mix[idx][0].split("/")[-1] def __getitem__(self, index: int): if self.audio_only: return self.preprocess_audio_only(index) else: return self.preprocess_audio_visual(index) class LRS2DataModule(object): def __init__( self, train_dir: str, valid_dir: str, test_dir: str, n_src: int = 2, sample_rate: int = 8000, fps: int = 25, segment: float = 4.0, normalize_audio: bool = False, batch_size: int = 64, num_workers: int = 0, pin_memory: bool = False, persistent_workers: bool = False, audio_only: bool = True, ) -> None: super().__init__() if train_dir == None or valid_dir == None or test_dir == None: raise ValueError("JSON DIR is None!") if n_src not in [1, 2]: raise ValueError("{} is not in [1, 2]".format(n_src)) # this line allows to access init params with 'self.hparams' attribute self.train_dir = train_dir self.valid_dir = valid_dir self.test_dir = test_dir self.n_src = n_src self.sample_rate = sample_rate self.fps = fps self.segment = segment self.normalize_audio = normalize_audio self.batch_size = batch_size self.num_workers = num_workers self.pin_memory = pin_memory self.persistent_workers = persistent_workers self.audio_only = audio_only self.data_train: Dataset = None self.data_val: Dataset = None self.data_test: Dataset = None def setup(self) -> None: self.data_train = LRS2Dataset( json_dir=self.train_dir, n_src=self.n_src, sample_rate=self.sample_rate, fps=self.fps, segment=self.segment, normalize_audio=self.normalize_audio, audio_only=self.audio_only, ) self.data_val = LRS2Dataset( json_dir=self.valid_dir, n_src=self.n_src, sample_rate=self.sample_rate, fps=self.fps, segment=self.segment, normalize_audio=self.normalize_audio, audio_only=self.audio_only, ) self.data_test = LRS2Dataset( json_dir=self.test_dir, n_src=self.n_src, sample_rate=self.sample_rate, fps=self.fps, segment=self.segment, normalize_audio=self.normalize_audio, audio_only=self.audio_only, ) def train_dataloader(self) -> DataLoader: return DataLoader( dataset=self.data_train, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, persistent_workers=self.persistent_workers, pin_memory=self.pin_memory, drop_last=True, ) def val_dataloader(self) -> DataLoader: return DataLoader( dataset=self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=self.persistent_workers, pin_memory=self.pin_memory, drop_last=True, ) def test_dataloader(self) -> DataLoader: return DataLoader( dataset=self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=self.persistent_workers, pin_memory=self.pin_memory, drop_last=True, ) @property def make_loader(self): return self.train_dataloader(), self.val_dataloader(), self.test_dataloader() @property def make_sets(self): return self.data_train, self.data_val, self.data_test