ThinkSound / think_sound /data /datamodule.py
UncleWang233's picture
init
08f69f6
raw
history blame
9.21 kB
import lightning as L
from .dataset import LatentDataset, SampleDataset, VideoDataset, AudioDataset, MultiModalDataset, LocalDatasetConfig, collation_fn
import importlib
from torch.utils.data import DataLoader
def get_configs(audio_configs):
configs = []
for config in audio_configs:
data_dir_path = config.get("path", None)
audio_dir_path = config.get("audio_dir", None)
split_path = config.get("split_path", None)
assert data_dir_path is not None, "Path must be set for local audio directory configuration"
custom_metadata_fn = None
custom_metadata_module_path = config.get("custom_metadata_module", None)
if custom_metadata_module_path:
spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
metadata_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(metadata_module)
custom_metadata_fn = metadata_module.get_custom_metadata
configs.append(
LocalDatasetConfig(
id=config["id"],
path=data_dir_path,
split_path=split_path,
custom_metadata_fn=custom_metadata_fn,
audio_dir=audio_dir_path
)
)
return configs
class DataModule(L.LightningDataModule):
def __init__(self, dataset_config, batch_size, test_batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4,repeat_num=5):
super().__init__()
dataset_type = dataset_config.get("dataset_type", None)
self.batch_size = batch_size
self.num_workers = num_workers
self.test_batch_size = test_batch_size
self.repeat_num = repeat_num
assert dataset_type is not None, "Dataset type must be specified in dataset config"
if audio_channels == 1:
force_channels = "mono"
elif audio_channels == 2:
force_channels = "stereo"
else:
force_channels = "foa"
val_dir_configs = dataset_config.get("val_datasets", None)
test_dir_configs = dataset_config.get("test_datasets", None)
configs = []
val_configs = []
test_configs = []
if dataset_type == "audio_dir":
audio_dir_configs = dataset_config.get("datasets", None)
assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
configs = get_configs(audio_dir_configs)
val_configs = get_configs(val_dir_configs)
test_configs = get_configs(test_dir_configs)
elif dataset_type == "latent_dir" or dataset_type == "video_dataset":
audio_dir_configs = dataset_config.get("datasets", None)
assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
for i, dataset in enumerate((audio_dir_configs, val_dir_configs, test_dir_configs)):
for config in dataset:
data_dir_path = config.get("path", None)
audio_dir_path = config.get("audio_dir", None)
split_path = config.get("split_path", None)
assert data_dir_path is not None, "Path must be set for local audio directory configuration"
content = LocalDatasetConfig(
id=config["id"],
path=data_dir_path,
split_path=split_path,
audio_dir=audio_dir_path,
extra_cot=config.get("extra_cot", None)
)
if i == 0:
configs.append(content)
elif i == 1:
val_configs.append(content)
else:
test_configs.append(content)
elif dataset_type == "multimodal_dir":
self.audio_configs = []
self.video_configs = []
audio_dir_configs = dataset_config.get("audio_datasets", None)
video_dir_configs = dataset_config.get("video_datasets", None)
assert audio_dir_configs is not None and video_dir_configs is not None, "Directory configuration must be specified in video_datasets and audio_datasets"
for i, dataset in enumerate((audio_dir_configs, video_dir_configs, val_dir_configs, test_dir_configs)):
for config in dataset:
data_dir_path = config.get("path", None)
audio_dir_path = config.get("audio_dir", None)
split_path = config.get("split_path", None)
assert data_dir_path is not None, "Path must be set for local audio directory configuration"
print(f'extra cot: {config.get("extra_cot", None)}')
content = LocalDatasetConfig(
id=config["id"],
path=data_dir_path,
split_path=split_path,
audio_dir=audio_dir_path,
extra_cot=config.get("extra_cot", None)
)
if i == 0:
self.audio_configs.append(content)
elif i == 1:
self.video_configs.append(content)
elif i == 2:
val_configs.append(content)
else:
test_configs.append(content)
self.dataset_type = dataset_type
self.configs = configs
self.val_configs = val_configs
self.test_configs = test_configs
self.sample_rate = sample_rate
self.sample_size = sample_size
self.random_crop = dataset_config.get("random_crop", True)
self.input_type = dataset_config.get("input_type", "video")
self.fps = dataset_config.get("fps", 4)
self.force_channels = force_channels
def setup(self, stage: str):
if self.dataset_type == 'audio_dir':
dataset_class = SampleDataset
elif self.dataset_type == 'latent_dir':
dataset_class = LatentDataset
elif self.dataset_type == 'video_dataset':
dataset_class = VideoDataset
elif self.dataset_type == 'multimodal_dir':
dataset_class = VideoDataset
def create_dataset(configs, random_crop):
return dataset_class(
configs,
sample_rate=self.sample_rate,
sample_size=self.sample_size,
random_crop=random_crop,
input_type=self.input_type,
fps=self.input_type,
force_channels=self.force_channels
)
if stage == 'fit':
if self.dataset_type != 'multimodal_dir':
self.train_set = create_dataset(self.configs, random_crop=self.random_crop)
else:
self.video_set = VideoDataset(
self.video_configs,
sample_rate=self.sample_rate,
sample_size=self.sample_size,
random_crop=self.random_crop,
input_type=self.input_type,
fps=self.input_type,
force_channels=self.force_channels
)
self.audio_set = AudioDataset(
self.audio_configs,
sample_rate=self.sample_rate,
sample_size=self.sample_size,
random_crop=self.random_crop,
input_type=self.input_type,
fps=self.input_type,
force_channels=self.force_channels
)
self.train_set = MultiModalDataset([self.video_set]*self.repeat_num, [self.audio_set])
self.val_set = create_dataset(self.val_configs, random_crop=False)
elif stage == 'validate':
self.val_set = create_dataset(self.val_configs, random_crop=False)
elif stage == 'predict':
self.test_set = create_dataset(self.test_configs, random_crop=False)
def train_dataloader(self):
return DataLoader(self.train_set, self.batch_size, shuffle=True,
num_workers=self.num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)
def val_dataloader(self):
return DataLoader(self.val_set, self.batch_size, shuffle=False,
num_workers=self.num_workers, persistent_workers=False, pin_memory=False, drop_last=False, collate_fn=collation_fn)
def predict_dataloader(self):
return DataLoader(self.test_set, batch_size=self.test_batch_size, shuffle=False,
num_workers=self.num_workers, persistent_workers=False, pin_memory=False, drop_last=False, collate_fn=collation_fn)
# def predict_dataloader(self):
# return DataLoader(self.mnist_predict, batch_size=self.batch_size)
# def teardown(self, stage: str):
# # Used to clean-up when the run is finished
# ...