File size: 9,205 Bytes
08f69f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import lightning as L
from .dataset import LatentDataset, SampleDataset, VideoDataset, AudioDataset, MultiModalDataset, LocalDatasetConfig, collation_fn
import importlib
from torch.utils.data import DataLoader


def get_configs(audio_configs):
    configs = []
    for config in audio_configs:
        data_dir_path = config.get("path", None)
        audio_dir_path = config.get("audio_dir", None)
        split_path = config.get("split_path", None)
        assert data_dir_path is not None, "Path must be set for local audio directory configuration"
        
        custom_metadata_fn = None
        custom_metadata_module_path = config.get("custom_metadata_module", None)
        
        if custom_metadata_module_path:
            spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
            metadata_module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(metadata_module)
            custom_metadata_fn = metadata_module.get_custom_metadata

        configs.append(
            LocalDatasetConfig(
                id=config["id"],
                path=data_dir_path,
                split_path=split_path,
                custom_metadata_fn=custom_metadata_fn,
                audio_dir=audio_dir_path
            )
        )
    return configs

class DataModule(L.LightningDataModule):
    def __init__(self, dataset_config, batch_size, test_batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4,repeat_num=5):
        super().__init__()
        dataset_type = dataset_config.get("dataset_type", None)
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.test_batch_size = test_batch_size
        self.repeat_num = repeat_num
        assert dataset_type is not None, "Dataset type must be specified in dataset config"

        if audio_channels == 1:
            force_channels = "mono"
        elif audio_channels == 2:
            force_channels = "stereo"
        else:
            force_channels = "foa"
        val_dir_configs = dataset_config.get("val_datasets", None)
        test_dir_configs = dataset_config.get("test_datasets", None)
        configs = []
        val_configs = []
        test_configs = []
        if dataset_type == "audio_dir":
            audio_dir_configs = dataset_config.get("datasets", None)
            assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
            configs = get_configs(audio_dir_configs)
            val_configs = get_configs(val_dir_configs)
            test_configs = get_configs(test_dir_configs)
        elif dataset_type == "latent_dir" or dataset_type == "video_dataset":
            audio_dir_configs = dataset_config.get("datasets", None)
            assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
            for i, dataset in enumerate((audio_dir_configs, val_dir_configs, test_dir_configs)):
                for config in dataset:
                    data_dir_path = config.get("path", None)
                    audio_dir_path = config.get("audio_dir", None)
                    split_path = config.get("split_path", None)
                    assert data_dir_path is not None, "Path must be set for local audio directory configuration"
                    
                    content = LocalDatasetConfig(
                        id=config["id"],
                        path=data_dir_path,
                        split_path=split_path,
                        audio_dir=audio_dir_path,
                        extra_cot=config.get("extra_cot", None)
                    )
                    if i == 0:
                        configs.append(content)
                    elif i == 1:
                        val_configs.append(content)
                    else:
                        test_configs.append(content)
        elif dataset_type == "multimodal_dir":
            self.audio_configs = []
            self.video_configs = []
            audio_dir_configs = dataset_config.get("audio_datasets", None)
            video_dir_configs = dataset_config.get("video_datasets", None)
            assert audio_dir_configs is not None and video_dir_configs is not None, "Directory configuration must be specified in video_datasets and audio_datasets"
            for i, dataset in enumerate((audio_dir_configs, video_dir_configs, val_dir_configs, test_dir_configs)):
                for config in dataset:
                    data_dir_path = config.get("path", None)
                    audio_dir_path = config.get("audio_dir", None)
                    split_path = config.get("split_path", None)
                    assert data_dir_path is not None, "Path must be set for local audio directory configuration"
                    print(f'extra cot: {config.get("extra_cot", None)}')
                    content = LocalDatasetConfig(
                        id=config["id"],
                        path=data_dir_path,
                        split_path=split_path,
                        audio_dir=audio_dir_path,
                        extra_cot=config.get("extra_cot", None)
                    )
                    if i == 0:
                        self.audio_configs.append(content)
                    elif i == 1:
                        self.video_configs.append(content)
                    elif i == 2:
                        val_configs.append(content)
                    else:
                        test_configs.append(content)
        self.dataset_type = dataset_type
        self.configs = configs
        self.val_configs = val_configs
        self.test_configs = test_configs
        self.sample_rate = sample_rate
        self.sample_size = sample_size
        self.random_crop = dataset_config.get("random_crop", True)
        self.input_type = dataset_config.get("input_type", "video")
        self.fps = dataset_config.get("fps", 4)
        self.force_channels = force_channels
        

    def setup(self, stage: str):
        if self.dataset_type == 'audio_dir':
            dataset_class = SampleDataset
        elif self.dataset_type == 'latent_dir':
            dataset_class = LatentDataset
        elif self.dataset_type == 'video_dataset':
            dataset_class = VideoDataset
        elif self.dataset_type == 'multimodal_dir':
            dataset_class = VideoDataset

        def create_dataset(configs, random_crop):
            return dataset_class(
                configs,
                sample_rate=self.sample_rate,
                sample_size=self.sample_size,
                random_crop=random_crop,
                input_type=self.input_type,
                fps=self.input_type,
                force_channels=self.force_channels
            )

        if stage == 'fit':
            if self.dataset_type != 'multimodal_dir':
                self.train_set = create_dataset(self.configs, random_crop=self.random_crop)
            else:
                self.video_set = VideoDataset(
                    self.video_configs,
                    sample_rate=self.sample_rate,
                    sample_size=self.sample_size,
                    random_crop=self.random_crop,
                    input_type=self.input_type,
                    fps=self.input_type,
                    force_channels=self.force_channels
                )
                self.audio_set = AudioDataset(
                    self.audio_configs,
                    sample_rate=self.sample_rate,
                    sample_size=self.sample_size,
                    random_crop=self.random_crop,
                    input_type=self.input_type,
                    fps=self.input_type,
                    force_channels=self.force_channels
                )
                self.train_set = MultiModalDataset([self.video_set]*self.repeat_num, [self.audio_set])
            self.val_set = create_dataset(self.val_configs, random_crop=False)
        elif stage == 'validate':
            self.val_set = create_dataset(self.val_configs, random_crop=False)
        elif stage == 'predict':
            self.test_set = create_dataset(self.test_configs, random_crop=False)

    def train_dataloader(self):
        return DataLoader(self.train_set, self.batch_size, shuffle=True,
                                num_workers=self.num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)

    def val_dataloader(self):
        return DataLoader(self.val_set, self.batch_size, shuffle=False,
                                num_workers=self.num_workers, persistent_workers=False, pin_memory=False, drop_last=False, collate_fn=collation_fn)

    def predict_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.test_batch_size, shuffle=False,
                                num_workers=self.num_workers, persistent_workers=False, pin_memory=False, drop_last=False, collate_fn=collation_fn)

    # def predict_dataloader(self):
    #     return DataLoader(self.mnist_predict, batch_size=self.batch_size)

    # def teardown(self, stage: str):
    #     # Used to clean-up when the run is finished
    #     ...