libokj commited on
Commit
03dbdce
·
1 Parent(s): 9d2a161

Delete deepscreen/data/entity_datamodule.py

Browse files
Files changed (1) hide show
  1. deepscreen/data/entity_datamodule.py +0 -167
deepscreen/data/entity_datamodule.py DELETED
@@ -1,167 +0,0 @@
1
- from numbers import Number
2
- from pathlib import Path
3
- from typing import Any, Dict, Optional, Sequence, Type
4
-
5
- from lightning import LightningDataModule
6
- from sklearn.base import TransformerMixin
7
- from torch.utils.data import Dataset, DataLoader
8
-
9
- from deepscreen.data.utils import collate_fn, SafeBatchSampler
10
- from deepscreen.data.utils.dataset import BaseEntityDataset
11
-
12
-
13
- class EntityDataModule(LightningDataModule):
14
- """
15
- def prepare_data(self):
16
- # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
17
- # download data, pre-process, split, save to disk, etc.
18
- def setup(self, stage):
19
- # things to do on every process in DDP
20
- # load data, set variables, etc.
21
- def train_dataloader(self):
22
- # return train dataloader
23
- def val_dataloader(self):
24
- # return validation dataloader
25
- def test_dataloader(self):
26
- # return test dataloader
27
- def teardown(self):
28
- # called on every process in DDP
29
- # clean up after fit or test
30
- """
31
- def __init__(
32
- self,
33
- dataset: type[BaseEntityDataset],
34
- transformer: type[TransformerMixin],
35
- train: bool,
36
- batch_size: int,
37
- data_dir: str = "data/",
38
- data_file: Optional[str] = None,
39
- train_val_test_split: Optional[Sequence[Number], Sequence[str]] = None,
40
- split: Optional[callable] = None,
41
- num_workers: int = 0,
42
- pin_memory: bool = False,
43
- ):
44
- super().__init__()
45
-
46
- # data processing
47
- self.split = split
48
-
49
- if train:
50
- if all([data_file, split]):
51
- if all(isinstance(split, Number) for split in train_val_test_split):
52
- pass
53
- else:
54
- raise ValueError('`train_val_test_split` must be a sequence of 3 numbers '
55
- '(float for percentages and int for sample numbers) if '
56
- '`data_file` and `split` have been specified.')
57
- elif all(isinstance(split, str) for split in train_val_test_split) and not any([data_file, split]):
58
- self.train_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[0]))
59
- self.val_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[1]))
60
- self.test_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[2]))
61
- else:
62
- raise ValueError('For training (train=True), you must specify either '
63
- '`dataset_name` and `split` with `train_val_test_split` of 3 numbers or '
64
- 'solely `train_val_test_split` of 3 data file names.')
65
- else:
66
- if data_file and not any([split, train_val_test_split]):
67
- self.test_data = self.predict_data = dataset(dataset_path=str(Path(data_dir) / data_file))
68
- else:
69
- raise ValueError("For testing/predicting (train=False), you must specify only `data_file` without "
70
- "`train_val_test_split` or `split`")
71
-
72
- # this line allows to access init params with 'self.hparams' attribute
73
- # also ensures init params will be stored in ckpt
74
- self.save_hyperparameters(logger=False)
75
- def prepare_data(self):
76
- """
77
- Download data if needed.
78
- Do not use it to assign state (e.g., self.x = x).
79
- """
80
-
81
- def setup(self, stage: Optional[str] = None, encoding: str = None):
82
- """
83
- Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
84
- This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
85
- careful not to execute data splitting twice.
86
- """
87
- # TODO test SafeBatchSampler (which skips samples with any None without introducing variable batch size)
88
- # TODO: find a way to apply transformer.fit_transform only to train and transformer.transform only to val, test
89
- # load and split datasets only if not loaded in initialization
90
- if not any([self.train_data, self.test_data, self.val_data, self.predict_data]):
91
- self.train_data, self.val_data, self.test_data = self.split(
92
- dataset=self.hparams.dataset(data_dir=self.hparams.data_dir,
93
- dataset_name=self.hparams.train_dataset_name),
94
- lengths=self.hparams.train_val_test_split
95
- )
96
-
97
- def train_dataloader(self):
98
- return DataLoader(
99
- dataset=self.train_data,
100
- batch_sampler=SafeBatchSampler(
101
- data_source=self.train_data,
102
- batch_size=self.hparams.batch_size,
103
- shuffle=True),
104
- # batch_size=self.hparams.batch_size,
105
- # shuffle=True,
106
- num_workers=self.hparams.num_workers,
107
- pin_memory=self.hparams.pin_memory,
108
- collate_fn=collate_fn,
109
- persistent_workers=True if self.hparams.num_workers > 0 else False
110
- )
111
-
112
- def val_dataloader(self):
113
- return DataLoader(
114
- dataset=self.val_data,
115
- batch_sampler=SafeBatchSampler(
116
- data_source=self.val_data,
117
- batch_size=self.hparams.batch_size,
118
- shuffle=False),
119
- # batch_size=self.hparams.batch_size,
120
- # shuffle=False,
121
- num_workers=self.hparams.num_workers,
122
- pin_memory=self.hparams.pin_memory,
123
- collate_fn=collate_fn,
124
- persistent_workers=True if self.hparams.num_workers > 0 else False
125
- )
126
-
127
- def test_dataloader(self):
128
- return DataLoader(
129
- dataset=self.test_data,
130
- batch_sampler=SafeBatchSampler(
131
- data_source=self.test_data,
132
- batch_size=self.hparams.batch_size,
133
- shuffle=False),
134
- # batch_size=self.hparams.batch_size,
135
- # shuffle=False,
136
- num_workers=self.hparams.num_workers,
137
- pin_memory=self.hparams.pin_memory,
138
- collate_fn=collate_fn,
139
- persistent_workers=True if self.hparams.num_workers > 0 else False
140
- )
141
-
142
- def predict_dataloader(self):
143
- return DataLoader(
144
- dataset=self.predict_data,
145
- batch_sampler=SafeBatchSampler(
146
- data_source=self.predict_data,
147
- batch_size=self.hparams.batch_size,
148
- shuffle=False),
149
- # batch_size=self.hparams.batch_size,
150
- # shuffle=False,
151
- num_workers=self.hparams.num_workers,
152
- pin_memory=self.hparams.pin_memory,
153
- collate_fn=collate_fn,
154
- persistent_workers=True if self.hparams.num_workers > 0 else False
155
- )
156
-
157
- def teardown(self, stage: Optional[str] = None):
158
- """Clean up after fit or test."""
159
- pass
160
-
161
- def state_dict(self):
162
- """Extra things to save to checkpoint."""
163
- return {}
164
-
165
- def load_state_dict(self, state_dict: Dict[str, Any]):
166
- """Things to do when loading checkpoint."""
167
- pass