libokj commited on
Commit
4537a38
·
1 Parent(s): 03dbdce

Delete deepscreen/data/single_entity.py

Browse files
Files changed (1) hide show
  1. deepscreen/data/single_entity.py +0 -195
deepscreen/data/single_entity.py DELETED
@@ -1,195 +0,0 @@
1
- # from itertools import product
2
- from numbers import Number
3
- from pathlib import Path
4
- from typing import Any, Dict, Optional, Sequence, Union, Literal
5
-
6
- # import numpy as np
7
- import pandas as pd
8
- from lightning import LightningDataModule
9
- from sklearn.base import TransformerMixin
10
- from torch.utils.data import Dataset, DataLoader, random_split
11
-
12
- from deepscreen.data.utils.dataset import SingleEntitySingleTargetDataset, BaseEntityDataset
13
- from deepscreen.data.utils.label import label_transform
14
- from deepscreen.data.utils.collator import collate_fn
15
- from deepscreen.data.utils.sampler import SafeBatchSampler
16
-
17
-
18
- class EntityDataModule(LightningDataModule):
19
- """
20
- DTI DataModule
21
-
22
- A DataModule implements 5 key methods:
23
-
24
- def prepare_data(self):
25
- # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
26
- # download data, pre-process, split, save to disk, etc.
27
- def setup(self, stage):
28
- # things to do on every process in DDP
29
- # load data, set variables, etc.
30
- def train_dataloader(self):
31
- # return train dataloader
32
- def val_dataloader(self):
33
- # return validation dataloader
34
- def test_dataloader(self):
35
- # return test dataloader
36
- def teardown(self):
37
- # called on every process in DDP
38
- # clean up after fit or test
39
-
40
- This allows you to share a full dataset without explaining how to download,
41
- split, transform and process the data.
42
-
43
- Read the docs:
44
- https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
45
- """
46
-
47
- def __init__(
48
- self,
49
- dataset: type[BaseEntityDataset],
50
- task: Literal['regression', 'binary', 'multiclass'],
51
- n_classes: Optional[int],
52
- train: bool,
53
- batch_size: int,
54
- num_workers: int = 0,
55
- thresholds: Optional[Union[Number, Sequence[Number]]] = None,
56
- pin_memory: bool = False,
57
- data_dir: str = "data/",
58
- data_file: Optional[str] = None,
59
- train_val_test_split: Optional[Sequence[Number], Sequence[str]] = None,
60
- split: Optional[callable] = random_split,
61
- ):
62
- super().__init__()
63
- data_path = Path(data_dir) / data_file
64
- # this line allows to access init params with 'self.hparams' attribute
65
- # also ensures init params will be stored in ckpt
66
- self.save_hyperparameters(logger=False)
67
-
68
- # data processing
69
- self.split = split
70
-
71
- if train:
72
- if all([data_file, split]):
73
- if all(isinstance(split, Number) for split in train_val_test_split):
74
- pass
75
- else:
76
- raise ValueError('`train_val_test_split` must be a sequence of 3 numbers '
77
- '(float for percentages and int for sample numbers) if '
78
- '`data_file` and `split` have been specified.')
79
- elif all(isinstance(split, str) for split in train_val_test_split) and not any([data_file, split]):
80
- self.train_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[0]))
81
- self.val_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[1]))
82
- self.test_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[2]))
83
- else:
84
- raise ValueError('For training (train=True), you must specify either '
85
- '`dataset_name` and `split` with `train_val_test_split` of 3 numbers or '
86
- 'solely `train_val_test_split` of 3 data file names.')
87
- else:
88
- if data_file and not any([split, train_val_test_split]):
89
- self.test_data = self.predict_data = dataset(dataset_path=str(Path(data_dir) / data_file))
90
- else:
91
- raise ValueError("For testing/predicting (train=False), you must specify only `data_file` without "
92
- "`train_val_test_split` or `split`")
93
-
94
- def prepare_data(self):
95
- """
96
- Download data if needed.
97
- Do not use it to assign state (e.g., self.x = x).
98
- """
99
-
100
- def setup(self, stage: Optional[str] = None, encoding: str = None):
101
- """
102
- Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
103
- This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
104
- careful not to execute data splitting twice.
105
- """
106
- # load and split datasets only if not loaded in initialization
107
- if not any([self.data_train, self.data_val, self.data_test, self.data_predict]):
108
- dataset = SingleEntitySingleTargetDataset(
109
- task=self.hparams.task,
110
- n_classes=self.hparams.n_classes,
111
- dataset_path=Path(self.hparams.data_dir) / self.hparams.dataset_name,
112
- transformer=self.hparams.transformer,
113
- featurizer=self.hparams.featurizer,
114
- thresholds=self.hparams.thresholds,
115
- )
116
-
117
- if self.hparams.train:
118
- self.data_train, self.data_val, self.data_test = self.split(
119
- dataset=dataset,
120
- lengths=self.hparams.train_val_test_split
121
- )
122
- else:
123
- self.data_test = self.data_predict = dataset
124
-
125
- def train_dataloader(self):
126
- return DataLoader(
127
- dataset=self.data_train,
128
- batch_sampler=SafeBatchSampler(
129
- data_source=self.data_train,
130
- batch_size=self.hparams.batch_size,
131
- shuffle=True),
132
- # batch_size=self.hparams.batch_size,
133
- # shuffle=True,
134
- num_workers=self.hparams.num_workers,
135
- pin_memory=self.hparams.pin_memory,
136
- collate_fn=collate_fn,
137
- persistent_workers=True if self.hparams.num_workers > 0 else False
138
- )
139
-
140
- def val_dataloader(self):
141
- return DataLoader(
142
- dataset=self.data_val,
143
- batch_sampler=SafeBatchSampler(
144
- data_source=self.data_val,
145
- batch_size=self.hparams.batch_size,
146
- shuffle=False),
147
- # batch_size=self.hparams.batch_size,
148
- # shuffle=False,
149
- num_workers=self.hparams.num_workers,
150
- pin_memory=self.hparams.pin_memory,
151
- collate_fn=collate_fn,
152
- persistent_workers=True if self.hparams.num_workers > 0 else False
153
- )
154
-
155
- def test_dataloader(self):
156
- return DataLoader(
157
- dataset=self.data_test,
158
- batch_sampler=SafeBatchSampler(
159
- data_source=self.data_test,
160
- batch_size=self.hparams.batch_size,
161
- shuffle=False),
162
- # batch_size=self.hparams.batch_size,
163
- # shuffle=False,
164
- num_workers=self.hparams.num_workers,
165
- pin_memory=self.hparams.pin_memory,
166
- collate_fn=collate_fn,
167
- persistent_workers=True if self.hparams.num_workers > 0 else False
168
- )
169
-
170
- def predict_dataloader(self):
171
- return DataLoader(
172
- dataset=self.data_predict,
173
- batch_sampler=SafeBatchSampler(
174
- data_source=self.data_predict,
175
- batch_size=self.hparams.batch_size,
176
- shuffle=False),
177
- # batch_size=self.hparams.batch_size,
178
- # shuffle=False,
179
- num_workers=self.hparams.num_workers,
180
- pin_memory=self.hparams.pin_memory,
181
- collate_fn=collate_fn,
182
- persistent_workers=True if self.hparams.num_workers > 0 else False
183
- )
184
-
185
- def teardown(self, stage: Optional[str] = None):
186
- """Clean up after fit or test."""
187
- pass
188
-
189
- def state_dict(self):
190
- """Extra things to save to checkpoint."""
191
- return {}
192
-
193
- def load_state_dict(self, state_dict: Dict[str, Any]):
194
- """Things to do when loading checkpoint."""
195
- pass