libokj commited on
Commit
9d2a161
·
1 Parent(s): 57941d6

Delete deepscreen/data/dti_datamodule.py

Browse files
Files changed (1) hide show
  1. deepscreen/data/dti_datamodule.py +0 -314
deepscreen/data/dti_datamodule.py DELETED
@@ -1,314 +0,0 @@
1
- # from itertools import product
2
- from collections import namedtuple
3
- from numbers import Number
4
- from typing import Any, Dict, Optional, Sequence, Union, Literal
5
-
6
- # import numpy as np
7
- import pandas as pd
8
- from lightning import LightningDataModule
9
- from torch.utils.data import Dataset, DataLoader, random_split
10
-
11
- from deepscreen.data.utils.label import label_transform
12
- from deepscreen.data.utils.collator import collate_fn
13
- from deepscreen.data.utils.sampler import SafeBatchSampler
14
-
15
-
16
- class DTIDataset(Dataset):
17
- def __init__(
18
- self,
19
- task: Literal['regression', 'binary', 'multiclass'],
20
- n_classes: Optional[int],
21
- data_dir: str,
22
- dataset_name: str,
23
- drug_featurizer: callable,
24
- protein_featurizer: callable,
25
- thresholds: Optional[Union[Number, Sequence[Number]]] = None,
26
- discard_intermediate: Optional[bool] = False,
27
- ):
28
- df = pd.read_csv(
29
- f'{data_dir}{dataset_name}.csv',
30
- header=0, sep=',',
31
- usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'],
32
- dtype={'X1': 'str', 'ID1': 'str',
33
- 'X2': 'str', 'ID2': 'str',
34
- 'Y': 'float32', 'U': 'str'}
35
- )
36
- # if 'ID1' in df:
37
- # self.x1_to_id1 = dict(zip(df['X1'], df['ID1']))
38
- # if 'ID2' in df:
39
- # self.x2_to_id2 = dict(zip(df['X2'], df['ID2']))
40
- # self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2']))))
41
- # self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2']))))
42
-
43
- # # train and eval mode data processing (fully labelled)
44
- # if 'Y' in df.columns and df['Y'].notnull().all():
45
-
46
- # Forward-fill all non-label columns
47
- df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0)
48
-
49
- if 'Y' in df:
50
- # Transform labels
51
- df['Y'] = df['Y'].apply(label_transform, units=df.get('U', None), thresholds=thresholds,
52
- discard_intermediate=discard_intermediate).astype('float32')
53
-
54
- # Filter out rows with a NaN in Y (missing values)
55
- df.dropna(subset=['Y'], inplace=True)
56
-
57
- # Validate target labels for training/testing
58
- # TODO: check sklearn.utils.multiclass.check_classification_targets
59
- match task:
60
- case 'regression':
61
- assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \
62
- f"Y for task `regression` must be numeric; got {set(df['Y'].apply(type))}."
63
- case 'binary':
64
- assert all(df['Y'].isin([0, 1])), \
65
- f"Y for task `binary` (classification) must be 0 or 1, but Y got {pd.unique(df['Y'])}." \
66
- "\nYou may set `thresholds` to discretize continuous labels."
67
- case 'multiclass':
68
- assert n_classes >= 3, f'n_classes for task `multiclass` (classification) must be at least 3.'
69
- assert all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)), \
70
- f"Y for task `multiclass` (classification) must be non-negative integers, " \
71
- f"but Y got {pd.unique(df['Y'])}." \
72
- "\nYou may set `thresholds` to discretize continuous labels."
73
- target_n_unique = df['Y'].nunique()
74
- assert target_n_unique == n_classes, \
75
- f"You have set n_classes for task `multiclass` (classification) task to {n_classes}, " \
76
- f"but Y has {target_n_unique} unique labels."
77
-
78
- # # Predict mode data processing
79
- # else:
80
- # df = pd.DataFrame(product(df['X1'].dropna(), df['X2'].dropna()), columns=['X1', 'X2'])
81
- # if hasattr(self, "x1_to_id1"):
82
- # df['ID1'] = df['X1'].map(self.x1_to_id1)
83
- # if hasattr(self, "x1_to_id2"):
84
- # df['ID2'] = df['X2'].map(self.x2_to_id2)
85
-
86
- # self.smiles = df['X1']
87
- # self.fasta = df['X2']
88
- # self.smiles_ids = df.get('ID1', df['X1'])
89
- # self.fasta_ids = df.get('ID2', df['X2'])
90
- # self.labels = df.get('Y', None)
91
-
92
- self.df = df
93
- self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x)
94
- self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x)
95
- self.n_classes = df['Y'].nunique()
96
- # self.train = train
97
-
98
- self.Data = namedtuple('Data', ['FT1', 'ID1', 'FT2', 'ID2', 'Y'])
99
-
100
- def __len__(self):
101
- return len(self.df.index)
102
-
103
- def __getitem__(self, idx):
104
- sample = self.df.loc[idx]
105
- return self.Data(
106
- FT1=self.drug_featurizer(sample['X1']),
107
- ID1=sample.get('ID1', sample['X1']),
108
- FT2=self.protein_featurizer(sample['X2']),
109
- ID2=sample.get('ID2', sample['X2']),
110
- Y=sample.get('Y')
111
- )
112
- # {
113
- # 'FT1': self.drug_featurizer(sample['X1']),
114
- # 'ID1': sample.get('ID1', sample['X1']),
115
- # 'FT2': self.protein_featurizer(sample['X2']),
116
- # 'ID2': sample.get('ID2', sample['X2']),
117
- # 'Y': sample.get('Y')
118
- # }
119
- # if self.train:
120
- # sample = self.drug_featurizer(self.smiles[idx]), self.protein_featurizer(self.fasta[idx]), self.labels[idx]
121
- # sample = {
122
- # 'FT1': self.drug_featurizer(self.smiles[idx]),
123
- # 'FT2': self.protein_featurizer(self.fasta[idx]),
124
- # 'ID2': self.smiles_ids[idx],
125
- # }
126
- # else:
127
- # # sample = self.drug_featurizer(self.smiles[idx]), self.protein_featurizer(self.fasta[idx])
128
- # sample = {
129
- # 'FT1': self.drug_featurizer(self.smiles[idx]),
130
- # 'FT2': self.protein_featurizer(self.fasta[idx]),
131
- # }
132
- #
133
- # if all([True if n is not None else False for n in sample.values()]):
134
- # return sample # | {
135
- # # 'ID1': self.smiles_ids[idx],
136
- # # 'X1': self.drug_featurizer(self.smiles[idx]),
137
- # # 'ID2': self.fasta_ids[idx],
138
- # # 'X2': self.protein_featurizer(self.fasta[idx]),
139
- # # }
140
- # else:
141
- # return self.__getitem__(np.random.randint(0, self.size))
142
-
143
-
144
- class DTIdatamodule(LightningDataModule):
145
- """
146
- DTI DataModule
147
-
148
- A DataModule implements 5 key methods:
149
-
150
- def prepare_data(self):
151
- # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
152
- # download data, pre-process, split, save to disk, etc.
153
- def setup(self, stage):
154
- # things to do on every process in DDP
155
- # load data, set variables, etc.
156
- def train_dataloader(self):
157
- # return train dataloader
158
- def val_dataloader(self):
159
- # return validation dataloader
160
- def test_dataloader(self):
161
- # return test dataloader
162
- def teardown(self):
163
- # called on every process in DDP
164
- # clean up after fit or test
165
-
166
- This allows you to share a full dataset without explaining how to download,
167
- split, transform and process the data.
168
-
169
- Read the docs:
170
- https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
171
- """
172
-
173
- def __init__(
174
- self,
175
- task: Literal['regression', 'binary', 'multiclass'],
176
- n_classes: Optional[int],
177
- train: bool,
178
- drug_featurizer: callable,
179
- protein_featurizer: callable,
180
- batch_size: int,
181
- train_val_test_split: Optional[Sequence[Number]],
182
- num_workers: int = 0,
183
- thresholds: Optional[Union[Number, Sequence[Number]]] = None,
184
- pin_memory: bool = False,
185
- data_dir: str = "data/",
186
- dataset_name: Optional[str] = None,
187
- split: Optional[callable] = random_split,
188
- ):
189
- super().__init__()
190
-
191
- # this line allows to access init params with 'self.hparams' attribute
192
- # also ensures init params will be stored in ckpt
193
- self.save_hyperparameters(logger=False)
194
-
195
- # data processing
196
- self.data_split = split
197
-
198
- self.data_train: Optional[Dataset] = None
199
- self.data_val: Optional[Dataset] = None
200
- self.data_test: Optional[Dataset] = None
201
- self.data_predict: Optional[Dataset] = None
202
-
203
- def prepare_data(self):
204
- """
205
- Download data if needed.
206
- Do not use it to assign state (e.g., self.x = x).
207
- """
208
-
209
- def setup(self, stage: Optional[str] = None, encoding: str = None):
210
- """
211
- Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
212
- This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
213
- careful not to execute data splitting twice.
214
- """
215
- # TODO test SafeBatchSampler (which skips samples with any None without introducing variable batch size)
216
- # load and split datasets only if not loaded in initialization
217
- if not any([self.data_train, self.data_val, self.data_test, self.data_predict]):
218
- dataset = DTIDataset(
219
- task=self.hparams.task,
220
- n_classes=self.hparams.n_classes,
221
- data_dir=self.hparams.data_dir,
222
- drug_featurizer=self.hparams.drug_featurizer,
223
- protein_featurizer=self.hparams.protein_featurizer,
224
- dataset_name=self.hparams.dataset_name,
225
- thresholds=self.hparams.thresholds,
226
- )
227
-
228
- if self.hparams.train:
229
- self.data_train, self.data_val, self.data_test = self.data_split(
230
- dataset=dataset,
231
- lengths=self.hparams.train_val_test_split
232
- )
233
- else:
234
- self.data_test = self.data_predict = dataset
235
-
236
- def train_dataloader(self):
237
- return DataLoader(
238
- dataset=self.data_train,
239
- batch_sampler=SafeBatchSampler(
240
- data_source=self.data_train,
241
- batch_size=self.hparams.batch_size,
242
- drop_last=True,
243
- shuffle=True,
244
- ),
245
- # batch_size=self.hparams.batch_size,
246
- # shuffle=True,
247
- num_workers=self.hparams.num_workers,
248
- pin_memory=self.hparams.pin_memory,
249
- collate_fn=collate_fn,
250
- persistent_workers=True if self.hparams.num_workers > 0 else False
251
- )
252
-
253
- def val_dataloader(self):
254
- return DataLoader(
255
- dataset=self.data_val,
256
- batch_sampler=SafeBatchSampler(
257
- data_source=self.data_val,
258
- batch_size=self.hparams.batch_size,
259
- drop_last=False,
260
- shuffle=False,
261
- ),
262
- # batch_size=self.hparams.batch_size,
263
- # shuffle=False,
264
- num_workers=self.hparams.num_workers,
265
- pin_memory=self.hparams.pin_memory,
266
- collate_fn=collate_fn,
267
- persistent_workers=True if self.hparams.num_workers > 0 else False
268
- )
269
-
270
- def test_dataloader(self):
271
- return DataLoader(
272
- dataset=self.data_test,
273
- batch_sampler=SafeBatchSampler(
274
- data_source=self.data_test,
275
- batch_size=self.hparams.batch_size,
276
- drop_last=False,
277
- shuffle=False,
278
- ),
279
- # batch_size=self.hparams.batch_size,
280
- # shuffle=False,
281
- num_workers=self.hparams.num_workers,
282
- pin_memory=self.hparams.pin_memory,
283
- collate_fn=collate_fn,
284
- persistent_workers=True if self.hparams.num_workers > 0 else False
285
- )
286
-
287
- def predict_dataloader(self):
288
- return DataLoader(
289
- dataset=self.data_predict,
290
- batch_sampler=SafeBatchSampler(
291
- data_source=self.data_predict,
292
- batch_size=self.hparams.batch_size,
293
- drop_last=False,
294
- shuffle=False,
295
- ),
296
- # batch_size=self.hparams.batch_size,
297
- # shuffle=False,
298
- num_workers=self.hparams.num_workers,
299
- pin_memory=self.hparams.pin_memory,
300
- collate_fn=collate_fn,
301
- persistent_workers=True if self.hparams.num_workers > 0 else False
302
- )
303
-
304
- def teardown(self, stage: Optional[str] = None):
305
- """Clean up after fit or test."""
306
- pass
307
-
308
- def state_dict(self):
309
- """Extra things to save to checkpoint."""
310
- return {}
311
-
312
- def load_state_dict(self, state_dict: Dict[str, Any]):
313
- """Things to do when loading checkpoint."""
314
- pass