fffiloni's picture
Migrated from GitHub
406f22d verified
import os
import json
from tkinter.tix import Tree
import numpy as np
from typing import Any, Tuple
import soundfile as sf
import torch
from pytorch_lightning import LightningDataModule
from pytorch_lightning.core.mixins import HyperparametersMixin
from torch.utils.data import ConcatDataset, DataLoader, Dataset
from typing import Dict, Iterable, List, Iterator
from rich import print
from pytorch_lightning.utilities import rank_zero_only
@rank_zero_only
def print_(message: str):
print(message)
def normalize_tensor_wav(wav_tensor, eps=1e-8, std=None):
mean = wav_tensor.mean(-1, keepdim=True)
if std is None:
std = wav_tensor.std(-1, keepdim=True)
return (wav_tensor - mean) / (std + eps)
class LRS2Dataset(Dataset):
def __init__(
self,
json_dir: str = "",
n_src: int = 2,
sample_rate: int = 8000,
fps: int = 25,
segment: float = 4.0,
normalize_audio: bool = False,
audio_only: bool = True,
) -> None:
super().__init__()
self.EPS = 1e-8
if json_dir == None:
raise ValueError("JSON DIR is None!")
if n_src not in [1, 2]:
raise ValueError("{} is not in [1, 2]".format(n_src))
self.json_dir = json_dir
self.sample_rate = sample_rate
self.normalize_audio = normalize_audio
self.audio_only = audio_only
if segment is None:
self.seg_len = None
self.fps_len = None
else:
self.seg_len = int(segment * sample_rate)
self.fps_len = int(segment * fps)
self.n_src = n_src
self.test = self.seg_len is None
mix_json = os.path.join(json_dir, "mix.json")
sources_json = [
os.path.join(json_dir, source + ".json") for source in ["s1", "s2"]
]
with open(mix_json, "r") as f:
mix_infos = json.load(f)
sources_infos = []
for src_json in sources_json:
with open(src_json, "r") as f:
sources_infos.append(json.load(f))
self.mix = []
self.sources = []
if self.n_src == 1:
orig_len = len(mix_infos) * 2
drop_utt, drop_len = 0, 0
if not self.test:
for i in range(len(mix_infos) - 1, -1, -1):
if mix_infos[i][1] < self.seg_len:
drop_utt = drop_utt + 1
drop_len = drop_len + mix_infos[i][1]
del mix_infos[i]
for src_inf in sources_infos:
del src_inf[i]
else:
for src_inf in sources_infos:
self.mix.append(mix_infos[i])
self.sources.append(src_inf[i])
else:
for i in range(len(mix_infos)):
for src_inf in sources_infos:
self.mix.append(mix_infos[i])
self.sources.append(src_inf[i])
print_(
"Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format(
drop_utt, drop_len / sample_rate / 3600, orig_len, self.seg_len
)
)
self.length = len(self.mix)
elif self.n_src == 2:
orig_len = len(mix_infos)
drop_utt, drop_len = 0, 0
if not self.test:
for i in range(len(mix_infos) - 1, -1, -1): # Go backward
if mix_infos[i][1] < self.seg_len:
drop_utt = drop_utt + 1
drop_len = drop_len + mix_infos[i][1]
del mix_infos[i]
for src_inf in sources_infos:
del src_inf[i]
print_(
"Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format(
drop_utt, drop_len / sample_rate / 36000, orig_len, self.seg_len
)
)
self.mix = mix_infos
self.sources = sources_infos
self.length = len(self.mix)
def __len__(self):
return self.length
def preprocess_audio_only(self, idx: int):
if self.n_src == 1:
if self.mix[idx][1] == self.seg_len or self.test:
rand_start = 0
else:
rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len)
if self.test:
stop = None
else:
stop = rand_start + self.seg_len
# Load mixture
x, _ = sf.read(
self.mix[idx][0], start=rand_start, stop=stop, dtype="float32"
)
# Load sources
s, _ = sf.read(
self.sources[idx][0], start=rand_start, stop=stop, dtype="float32"
)
# torch from numpy
target = torch.from_numpy(s)
mixture = torch.from_numpy(x)
if self.normalize_audio:
m_std = mixture.std(-1, keepdim=True)
mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std)
target = normalize_tensor_wav(target, eps=self.EPS, std=m_std)
return mixture, target.unsqueeze(0), self.mix[idx][0].split("/")[-1]
# import pdb; pdb.set_trace()
if self.n_src == 2:
if self.mix[idx][1] == self.seg_len or self.test:
rand_start = 0
else:
rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len)
if self.test:
stop = None
else:
stop = rand_start + self.seg_len
# Load mixture
x, _ = sf.read(
self.mix[idx][0], start=rand_start, stop=stop, dtype="float32"
)
# Load sources
source_arrays = []
for src in self.sources:
s, _ = sf.read(
src[idx][0], start=rand_start, stop=stop, dtype="float32"
)
source_arrays.append(s)
sources = torch.from_numpy(np.vstack(source_arrays))
mixture = torch.from_numpy(x)
if self.normalize_audio:
m_std = mixture.std(-1, keepdim=True)
mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std)
sources = normalize_tensor_wav(sources, eps=self.EPS, std=m_std)
return mixture, sources, self.mix[idx][0].split("/")[-1]
def preprocess_audio_visual(self, idx: int):
if self.n_src == 1:
if self.mix[idx][1] == self.seg_len or self.test:
rand_start = 0
else:
rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len)
if self.test:
stop = None
else:
stop = rand_start + self.seg_len
mix_source, _ = sf.read(
self.mix[idx][0], start=rand_start, stop=stop, dtype="float32"
)
source = sf.read(
self.sources[idx][0], start=rand_start, stop=stop, dtype="float32"
)[0]
source_mouth = None
source = torch.from_numpy(source)
mixture = torch.from_numpy(mix_source)
if self.normalize_audio:
m_std = mixture.std(-1, keepdim=True)
mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std)
source = normalize_tensor_wav(source, eps=self.EPS, std=m_std)
return mixture, source, source_mouth, self.mix[idx][0].split("/")[-1]
if self.n_src == 2:
if self.mix[idx][1] == self.seg_len or self.test:
rand_start = 0
else:
rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len)
if self.test:
stop = None
else:
stop = rand_start + self.seg_len
mix_source, _ = sf.read(
self.mix[idx][0], start=rand_start, stop=stop, dtype="float32"
)
sources = []
for src in self.sources[idx]:
# import pdb; pdb.set_trace()
sources.append(
sf.read(src[0], start=rand_start, stop=stop, dtype="float32")[0]
)
# import pdb; pdb.set_trace()
sources_mouths = None
# import pdb; pdb.set_trace()
sources = torch.stack([torch.from_numpy(source) for source in sources])
mixture = torch.from_numpy(mix_source)
if self.normalize_audio:
m_std = mixture.std(-1, keepdim=True)
mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std)
sources = normalize_tensor_wav(sources, eps=self.EPS, std=m_std)
return mixture, sources, sources_mouths, self.mix[idx][0].split("/")[-1]
def __getitem__(self, index: int):
if self.audio_only:
return self.preprocess_audio_only(index)
else:
return self.preprocess_audio_visual(index)
class LRS2DataModule(object):
def __init__(
self,
train_dir: str,
valid_dir: str,
test_dir: str,
n_src: int = 2,
sample_rate: int = 8000,
fps: int = 25,
segment: float = 4.0,
normalize_audio: bool = False,
batch_size: int = 64,
num_workers: int = 0,
pin_memory: bool = False,
persistent_workers: bool = False,
audio_only: bool = True,
) -> None:
super().__init__()
if train_dir == None or valid_dir == None or test_dir == None:
raise ValueError("JSON DIR is None!")
if n_src not in [1, 2]:
raise ValueError("{} is not in [1, 2]".format(n_src))
# this line allows to access init params with 'self.hparams' attribute
self.train_dir = train_dir
self.valid_dir = valid_dir
self.test_dir = test_dir
self.n_src = n_src
self.sample_rate = sample_rate
self.fps = fps
self.segment = segment
self.normalize_audio = normalize_audio
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.audio_only = audio_only
self.data_train: Dataset = None
self.data_val: Dataset = None
self.data_test: Dataset = None
def setup(self) -> None:
self.data_train = LRS2Dataset(
json_dir=self.train_dir,
n_src=self.n_src,
sample_rate=self.sample_rate,
fps=self.fps,
segment=self.segment,
normalize_audio=self.normalize_audio,
audio_only=self.audio_only,
)
self.data_val = LRS2Dataset(
json_dir=self.valid_dir,
n_src=self.n_src,
sample_rate=self.sample_rate,
fps=self.fps,
segment=self.segment,
normalize_audio=self.normalize_audio,
audio_only=self.audio_only,
)
self.data_test = LRS2Dataset(
json_dir=self.test_dir,
n_src=self.n_src,
sample_rate=self.sample_rate,
fps=self.fps,
segment=self.segment,
normalize_audio=self.normalize_audio,
audio_only=self.audio_only,
)
def train_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.data_train,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
persistent_workers=self.persistent_workers,
pin_memory=self.pin_memory,
drop_last=True,
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.data_val,
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=self.persistent_workers,
pin_memory=self.pin_memory,
drop_last=True,
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.data_test,
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=self.persistent_workers,
pin_memory=self.pin_memory,
drop_last=True,
)
@property
def make_loader(self):
return self.train_dataloader(), self.val_dataloader(), self.test_dataloader()
@property
def make_sets(self):
return self.data_train, self.data_val, self.data_test