File size: 7,785 Bytes
2568013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import random
from dataclasses import dataclass
from typing import Callable

import numpy as np
import torch
from lightning.pytorch import LightningDataModule
from torch import Generator, nn
from torch.utils.data import DataLoader, Dataset, IterableDataset

from src.dataset import *
from src.global_cfg import get_cfg


from ..misc.step_tracker import StepTracker
from ..misc.utils import get_world_size, get_rank
from . import DatasetCfgWrapper, get_dataset
from .types import DataShim, Stage
from .data_sampler import BatchedRandomSampler, MixedBatchSampler, custom_collate_fn
from .validation_wrapper import ValidationWrapper

def get_data_shim(encoder: nn.Module) -> DataShim:
    """Get functions that modify the batch. It's sometimes necessary to modify batches
    outside the data loader because GPU computations are required to modify the batch or
    because the modification depends on something outside the data loader.
    """

    shims: list[DataShim] = []
    if hasattr(encoder, "get_data_shim"):
        shims.append(encoder.get_data_shim())

    def combined_shim(batch):
        for shim in shims:
            batch = shim(batch)
        return batch

    return combined_shim

# the training ratio of datasets (example)
prob_mapping = {DatasetScannetpp: 0.5, 
                DatasetDL3DV: 0.5,
                DatasetCo3d: 0.5}

@dataclass
class DataLoaderStageCfg:
    batch_size: int
    num_workers: int
    persistent_workers: bool
    seed: int | None


@dataclass
class DataLoaderCfg:
    train: DataLoaderStageCfg
    test: DataLoaderStageCfg
    val: DataLoaderStageCfg


DatasetShim = Callable[[Dataset, Stage], Dataset]


def worker_init_fn(worker_id: int) -> None:
    random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1))
    np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1))


class DataModule(LightningDataModule):
    dataset_cfgs: list[DatasetCfgWrapper]
    data_loader_cfg: DataLoaderCfg
    step_tracker: StepTracker | None
    dataset_shim: DatasetShim
    global_rank: int
    
    def __init__(
        self,
        dataset_cfgs: list[DatasetCfgWrapper],
        data_loader_cfg: DataLoaderCfg,
        step_tracker: StepTracker | None = None,
        dataset_shim: DatasetShim = lambda dataset, _: dataset,
        global_rank: int = 0,
    ) -> None:
        super().__init__()
        self.dataset_cfgs = dataset_cfgs
        self.data_loader_cfg = data_loader_cfg
        self.step_tracker = step_tracker
        self.dataset_shim = dataset_shim
        self.global_rank = global_rank
        
    def get_persistent(self, loader_cfg: DataLoaderStageCfg) -> bool | None:
        return None if loader_cfg.num_workers == 0 else loader_cfg.persistent_workers

    def get_generator(self, loader_cfg: DataLoaderStageCfg) -> torch.Generator | None:
        if loader_cfg.seed is None:
            return None
        generator = Generator()
        generator.manual_seed(loader_cfg.seed + self.global_rank)
        self.generator = generator
        return self.generator
        
    def train_dataloader(self):
        dataset, datasets_ls = get_dataset(self.dataset_cfgs, "train", self.step_tracker, self.dataset_shim)
        world_size = get_world_size()
        rank = get_rank()
        # breakpoint()
        prob_ls = [prob_mapping[type(dataset)] for dataset in datasets_ls]
        # we assume all the dataset share the same num_context_views
        
        if len(datasets_ls) > 1:
            prob = prob_ls
            context_num_views = [dataset.cfg.view_sampler.num_context_views for dataset in datasets_ls]
        else:
            prob = None
            dataset_key = next(iter(get_cfg()["dataset"]))
            dataset_cfg = get_cfg()["dataset"][dataset_key]
            context_num_views = dataset_cfg['view_sampler']['num_context_views']
            
        sampler = MixedBatchSampler(datasets_ls, 
                                    batch_size=self.data_loader_cfg.train.batch_size, # Not used here!
                                    num_context_views=context_num_views, 
                                    world_size=world_size, 
                                    rank=rank,
                                    prob=prob,
                                    generator=self.get_generator(self.data_loader_cfg.train))
        sampler.set_epoch(0)
        self.train_loader = DataLoader(
            dataset,
            # self.data_loader_cfg.train.batch_size,
            # shuffle=not isinstance(dataset, IterableDataset),
            batch_sampler=sampler,
            num_workers=self.data_loader_cfg.train.num_workers,
            generator=self.generator,
            worker_init_fn=worker_init_fn,
            # collate_fn=custom_collate_fn,
            persistent_workers=self.get_persistent(self.data_loader_cfg.train),
        )
        # breakpoint()
        # Set epoch for train and validation loaders (if applicable)
        if hasattr(self.train_loader, "dataset") and hasattr(self.train_loader.dataset, "set_epoch"):
            print("Training: Set Epoch in DataModule")
            self.train_loader.dataset.set_epoch(0)
        if hasattr(self.train_loader, "sampler") and hasattr(self.train_loader.sampler, "set_epoch"):
            print("Training: Set Epoch in DataModule")
            self.train_loader.sampler.set_epoch(0)
        
        return self.train_loader

    def val_dataloader(self):
        dataset, datasets_ls = get_dataset(self.dataset_cfgs, "val", self.step_tracker, self.dataset_shim)
        world_size = get_world_size()
        rank = get_rank()
        # here, we random select one dataset for val
        dataset_key = next(iter(get_cfg()["dataset"]))
        dataset_cfg = get_cfg()["dataset"][dataset_key]
        if len(datasets_ls) > 1:
             prob = [0.5] * len(datasets_ls)
        else:
            prob = None
        sampler = MixedBatchSampler(datasets_ls, 
                                    batch_size=self.data_loader_cfg.train.batch_size, 
                                    num_context_views=dataset_cfg['view_sampler']['num_context_views'], 
                                    world_size=world_size, 
                                    rank=rank,
                                    prob=prob,
                                    generator=self.get_generator(self.data_loader_cfg.train))
        sampler.set_epoch(0)
        self.val_loader = DataLoader(
            dataset,
            self.data_loader_cfg.val.batch_size,
            num_workers=self.data_loader_cfg.val.num_workers,
            sampler=sampler,
            generator=self.get_generator(self.data_loader_cfg.val),
            worker_init_fn=worker_init_fn,
            persistent_workers=self.get_persistent(self.data_loader_cfg.val),
        )
        if hasattr(self.val_loader, "dataset") and hasattr(self.val_loader.dataset, "set_epoch"):
            print("Validation: Set Epoch in DataModule")
            self.val_loader.dataset.set_epoch(0)
        if hasattr(self.val_loader, "sampler") and hasattr(self.val_loader.sampler, "set_epoch"):
            print("Validation: Set Epoch in DataModule")
            self.val_loader.sampler.set_epoch(0)
        return self.val_loader

    def test_dataloader(self):
        dataset = get_dataset(self.dataset_cfgs, "test", self.step_tracker, self.dataset_shim)
        data_loader = DataLoader(
            dataset,
            self.data_loader_cfg.test.batch_size,
            num_workers=self.data_loader_cfg.test.num_workers,
            generator=self.get_generator(self.data_loader_cfg.test),
            worker_init_fn=worker_init_fn,
            persistent_workers=self.get_persistent(self.data_loader_cfg.test),
        )
            
        return data_loader