Spaces:
Running
on
Zero
Running
on
Zero
import lightning as L | |
from .dataset import LatentDataset, SampleDataset, VideoDataset, AudioDataset, MultiModalDataset, LocalDatasetConfig, collation_fn | |
import importlib | |
from torch.utils.data import DataLoader | |
def get_configs(audio_configs): | |
configs = [] | |
for config in audio_configs: | |
data_dir_path = config.get("path", None) | |
audio_dir_path = config.get("audio_dir", None) | |
split_path = config.get("split_path", None) | |
assert data_dir_path is not None, "Path must be set for local audio directory configuration" | |
custom_metadata_fn = None | |
custom_metadata_module_path = config.get("custom_metadata_module", None) | |
if custom_metadata_module_path: | |
spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) | |
metadata_module = importlib.util.module_from_spec(spec) | |
spec.loader.exec_module(metadata_module) | |
custom_metadata_fn = metadata_module.get_custom_metadata | |
configs.append( | |
LocalDatasetConfig( | |
id=config["id"], | |
path=data_dir_path, | |
split_path=split_path, | |
custom_metadata_fn=custom_metadata_fn, | |
audio_dir=audio_dir_path | |
) | |
) | |
return configs | |
class DataModule(L.LightningDataModule): | |
def __init__(self, dataset_config, batch_size, test_batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4,repeat_num=5): | |
super().__init__() | |
dataset_type = dataset_config.get("dataset_type", None) | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.test_batch_size = test_batch_size | |
self.repeat_num = repeat_num | |
assert dataset_type is not None, "Dataset type must be specified in dataset config" | |
if audio_channels == 1: | |
force_channels = "mono" | |
elif audio_channels == 2: | |
force_channels = "stereo" | |
else: | |
force_channels = "foa" | |
val_dir_configs = dataset_config.get("val_datasets", None) | |
test_dir_configs = dataset_config.get("test_datasets", None) | |
configs = [] | |
val_configs = [] | |
test_configs = [] | |
if dataset_type == "audio_dir": | |
audio_dir_configs = dataset_config.get("datasets", None) | |
assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" | |
configs = get_configs(audio_dir_configs) | |
val_configs = get_configs(val_dir_configs) | |
test_configs = get_configs(test_dir_configs) | |
elif dataset_type == "latent_dir" or dataset_type == "video_dataset": | |
audio_dir_configs = dataset_config.get("datasets", None) | |
assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" | |
for i, dataset in enumerate((audio_dir_configs, val_dir_configs, test_dir_configs)): | |
for config in dataset: | |
data_dir_path = config.get("path", None) | |
audio_dir_path = config.get("audio_dir", None) | |
split_path = config.get("split_path", None) | |
assert data_dir_path is not None, "Path must be set for local audio directory configuration" | |
content = LocalDatasetConfig( | |
id=config["id"], | |
path=data_dir_path, | |
split_path=split_path, | |
audio_dir=audio_dir_path, | |
extra_cot=config.get("extra_cot", None) | |
) | |
if i == 0: | |
configs.append(content) | |
elif i == 1: | |
val_configs.append(content) | |
else: | |
test_configs.append(content) | |
elif dataset_type == "multimodal_dir": | |
self.audio_configs = [] | |
self.video_configs = [] | |
audio_dir_configs = dataset_config.get("audio_datasets", None) | |
video_dir_configs = dataset_config.get("video_datasets", None) | |
assert audio_dir_configs is not None and video_dir_configs is not None, "Directory configuration must be specified in video_datasets and audio_datasets" | |
for i, dataset in enumerate((audio_dir_configs, video_dir_configs, val_dir_configs, test_dir_configs)): | |
for config in dataset: | |
data_dir_path = config.get("path", None) | |
audio_dir_path = config.get("audio_dir", None) | |
split_path = config.get("split_path", None) | |
assert data_dir_path is not None, "Path must be set for local audio directory configuration" | |
print(f'extra cot: {config.get("extra_cot", None)}') | |
content = LocalDatasetConfig( | |
id=config["id"], | |
path=data_dir_path, | |
split_path=split_path, | |
audio_dir=audio_dir_path, | |
extra_cot=config.get("extra_cot", None) | |
) | |
if i == 0: | |
self.audio_configs.append(content) | |
elif i == 1: | |
self.video_configs.append(content) | |
elif i == 2: | |
val_configs.append(content) | |
else: | |
test_configs.append(content) | |
self.dataset_type = dataset_type | |
self.configs = configs | |
self.val_configs = val_configs | |
self.test_configs = test_configs | |
self.sample_rate = sample_rate | |
self.sample_size = sample_size | |
self.random_crop = dataset_config.get("random_crop", True) | |
self.input_type = dataset_config.get("input_type", "video") | |
self.fps = dataset_config.get("fps", 4) | |
self.force_channels = force_channels | |
def setup(self, stage: str): | |
if self.dataset_type == 'audio_dir': | |
dataset_class = SampleDataset | |
elif self.dataset_type == 'latent_dir': | |
dataset_class = LatentDataset | |
elif self.dataset_type == 'video_dataset': | |
dataset_class = VideoDataset | |
elif self.dataset_type == 'multimodal_dir': | |
dataset_class = VideoDataset | |
def create_dataset(configs, random_crop): | |
return dataset_class( | |
configs, | |
sample_rate=self.sample_rate, | |
sample_size=self.sample_size, | |
random_crop=random_crop, | |
input_type=self.input_type, | |
fps=self.input_type, | |
force_channels=self.force_channels | |
) | |
if stage == 'fit': | |
if self.dataset_type != 'multimodal_dir': | |
self.train_set = create_dataset(self.configs, random_crop=self.random_crop) | |
else: | |
self.video_set = VideoDataset( | |
self.video_configs, | |
sample_rate=self.sample_rate, | |
sample_size=self.sample_size, | |
random_crop=self.random_crop, | |
input_type=self.input_type, | |
fps=self.input_type, | |
force_channels=self.force_channels | |
) | |
self.audio_set = AudioDataset( | |
self.audio_configs, | |
sample_rate=self.sample_rate, | |
sample_size=self.sample_size, | |
random_crop=self.random_crop, | |
input_type=self.input_type, | |
fps=self.input_type, | |
force_channels=self.force_channels | |
) | |
self.train_set = MultiModalDataset([self.video_set]*self.repeat_num, [self.audio_set]) | |
self.val_set = create_dataset(self.val_configs, random_crop=False) | |
elif stage == 'validate': | |
self.val_set = create_dataset(self.val_configs, random_crop=False) | |
elif stage == 'predict': | |
self.test_set = create_dataset(self.test_configs, random_crop=False) | |
def train_dataloader(self): | |
return DataLoader(self.train_set, self.batch_size, shuffle=True, | |
num_workers=self.num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn) | |
def val_dataloader(self): | |
return DataLoader(self.val_set, self.batch_size, shuffle=False, | |
num_workers=self.num_workers, persistent_workers=False, pin_memory=False, drop_last=False, collate_fn=collation_fn) | |
def predict_dataloader(self): | |
return DataLoader(self.test_set, batch_size=self.test_batch_size, shuffle=False, | |
num_workers=self.num_workers, persistent_workers=False, pin_memory=False, drop_last=False, collate_fn=collation_fn) | |
# def predict_dataloader(self): | |
# return DataLoader(self.mnist_predict, batch_size=self.batch_size) | |
# def teardown(self, stage: str): | |
# # Used to clean-up when the run is finished | |
# ... |