libokj commited on
Commit
57941d6
·
1 Parent(s): 6b01fcc

Delete deepscreen/data/dti.py.bak

Browse files
Files changed (1) hide show
  1. deepscreen/data/dti.py.bak +0 -369
deepscreen/data/dti.py.bak DELETED
@@ -1,369 +0,0 @@
1
- from functools import partial
2
- from numbers import Number
3
- from pathlib import Path
4
- from typing import Any, Dict, Optional, Sequence, Union, Literal
5
-
6
- from lightning import LightningDataModule
7
- import pandas as pd
8
- from sklearn.preprocessing import LabelEncoder
9
- from torch.utils.data import Dataset, DataLoader
10
-
11
- from deepscreen.data.utils import label_transform, collate_fn, SafeBatchSampler
12
- from deepscreen.utils import get_logger
13
-
14
- log = get_logger(__name__)
15
-
16
-
17
- # TODO: save a list of corrupted records
18
-
19
-
20
- class DTIDataset(Dataset):
21
- def __init__(
22
- self,
23
- task: Literal['regression', 'binary', 'multiclass'],
24
- n_class: Optional[int],
25
- data_path: str | Path,
26
- drug_featurizer: callable,
27
- protein_featurizer: callable,
28
- thresholds: Optional[Union[Number, Sequence[Number]]] = None,
29
- discard_intermediate: Optional[bool] = False,
30
- ):
31
- df = pd.read_csv(
32
- data_path,
33
- engine='python',
34
- header=0,
35
- usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'],
36
- dtype={
37
- 'X1': 'str',
38
- 'ID1': 'str',
39
- 'X2': 'str',
40
- 'ID2': 'str',
41
- 'Y': 'float32',
42
- 'U': 'str',
43
- },
44
- )
45
- # Read the whole data table
46
-
47
- # if 'ID1' in df:
48
- # self.x1_to_id1 = dict(zip(df['X1'], df['ID1']))
49
- # if 'ID2' in df:
50
- # self.x2_to_id2 = dict(zip(df['X2'], df['ID2']))
51
- # self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2']))))
52
- # self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2']))))
53
-
54
- # # train and eval mode data processing (fully labelled)
55
- # if 'Y' in df.columns and df['Y'].notnull().all():
56
- log.info(f"Processing data file: {data_path}")
57
-
58
- # Forward-fill all non-label columns
59
- df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0)
60
-
61
- if 'Y' in df:
62
- log.info(f"Performing pre-transformation target validation.")
63
- # TODO: check sklearn.utils.multiclass.check_classification_targets
64
- match task:
65
- case 'regression':
66
- assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \
67
- f"""`Y` must be numeric for `regression` task,
68
- but it has {set(df['Y'].apply(type))}."""
69
-
70
- case 'binary':
71
- if all(df['Y'].isin([0, 1])):
72
- assert not thresholds, \
73
- f"""`Y` is already 0 or 1 for `binary` (classification) `task`,
74
- but still got `thresholds` {thresholds}.
75
- Double check your choices of `task` and `thresholds` and records in the `Y` column."""
76
- else:
77
- assert thresholds, \
78
- f"""`Y` must be 0 or 1 for `binary` (classification) `task`,
79
- but it has {pd.unique(df['Y'])}.
80
- You must set `thresholds` to discretize continuous labels."""
81
-
82
- case 'multiclass':
83
- assert n_class >= 3, f'`n_class` for `multiclass` (classification) `task` must be at least 3.'
84
-
85
- if all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)):
86
- assert not thresholds, \
87
- f"""`Y` is already non-negative integers for
88
- `multiclass` (classification) `task`, but still got `thresholds` {thresholds}.
89
- Double check your choice of `task`, `thresholds` and records in the `Y` column."""
90
- else:
91
- assert thresholds, \
92
- f"""`Y` must be non-negative integers for
93
- `multiclass` (classification) 'task',but it has {pd.unique(df['Y'])}.
94
- You must set `thresholds` to discretize continuous labels."""
95
-
96
- if 'U' in df.columns:
97
- units = df['U']
98
- else:
99
- units = None
100
- log.warning("Units ('U') not in the data table. "
101
- "Assuming all labels to be discrete or in p-scale (-log10[M]).")
102
-
103
- # Transform labels
104
- df['Y'] = label_transform(labels=df['Y'], units=units, thresholds=thresholds,
105
- discard_intermediate=discard_intermediate)
106
-
107
- # Filter out rows with a NaN in Y (missing values)
108
- df.dropna(subset=['Y'], inplace=True)
109
-
110
- log.info(f"Performing post-transformation target validation.")
111
- match task:
112
- case 'regression':
113
- df['Y'] = df['Y'].astype('float32')
114
- assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \
115
- f"""`Y` must be numeric for `regression` task,
116
- but after transformation it still has {set(df['Y'].apply(type))}.
117
- Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
118
-
119
- case 'binary':
120
- df['Y'] = df['Y'].astype('int')
121
- assert all(df['Y'].isin([0, 1])), \
122
- f"""`Y` must be 0 or 1 for `binary` (classification) `task`, "
123
- but after transformation it still has {pd.unique(df['Y'])}.
124
- Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
125
-
126
- case 'multiclass':
127
- df['Y'] = df['Y'].astype('int')
128
- assert all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)), \
129
- f"""Y must be non-negative integers for task `multiclass` (classification)
130
- but after transformation it still has {pd.unique(df['Y'])}.
131
- Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
132
-
133
- target_n_unique = df['Y'].nunique()
134
- assert target_n_unique == n_class, \
135
- f"""You have set `n_class` for `multiclass` (classification) `task` to {n_class},
136
- but after transformation Y still has {target_n_unique} unique labels.
137
- Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
138
-
139
- # Indexed protein/FASTA for retrieval metrics
140
- df['IDX'] = LabelEncoder().fit_transform(df['X2'])
141
-
142
- self.df = df
143
- self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x)
144
- self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x)
145
-
146
- def __len__(self):
147
- return len(self.df.index)
148
-
149
- def __getitem__(self, i):
150
- sample = self.df.loc[i]
151
- return {
152
- 'N': i,
153
- 'X1': self.drug_featurizer(sample['X1']),
154
- 'ID1': sample.get('ID1', sample['X1']),
155
- 'X2': self.protein_featurizer(sample['X2']),
156
- 'ID2': sample.get('ID2', sample['X2']),
157
- 'Y': sample.get('Y'),
158
- 'IDX': sample['IDX'],
159
- }
160
-
161
-
162
- class DTIDataModule(LightningDataModule):
163
- """
164
- DTI DataModule
165
-
166
- A DataModule implements 5 key methods:
167
-
168
- def prepare_data(self):
169
- # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
170
- # download data, pre-process, split, save to disk, etc.
171
- def setup(self, stage):
172
- # things to do on every process in DDP
173
- # load data, set variables, etc.
174
- def train_dataloader(self):
175
- # return train dataloader
176
- def val_dataloader(self):
177
- # return validation dataloader
178
- def test_dataloader(self):
179
- # return test dataloader
180
- def teardown(self):
181
- # called on every process in DDP
182
- # clean up after fit or test
183
-
184
- This allows you to share a full dataset without explaining how to download,
185
- split, transform and process the data.
186
-
187
- Read the docs:
188
- https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
189
- """
190
-
191
- def __init__(
192
- self,
193
- task: Literal['regression', 'binary', 'multiclass'],
194
- n_class: Optional[int],
195
- batch_size: int,
196
- # train: bool,
197
- drug_featurizer: callable,
198
- protein_featurizer: callable,
199
- collator: callable = collate_fn,
200
- data_dir: str = "data/",
201
- data_file: Optional[str] = None,
202
- train_val_test_split: Optional[Union[Sequence[Number | str]]] = None,
203
- split: Optional[callable] = None,
204
- thresholds: Optional[Union[Number, Sequence[Number]]] = None,
205
- discard_intermediate: Optional[bool] = False,
206
- num_workers: int = 0,
207
- pin_memory: bool = False,
208
- ):
209
- super().__init__()
210
-
211
- self.train_data: Optional[Dataset] = None
212
- self.val_data: Optional[Dataset] = None
213
- self.test_data: Optional[Dataset] = None
214
- self.predict_data: Optional[Dataset] = None
215
- self.split = split
216
- self.collator = collator
217
- self.dataset = partial(
218
- DTIDataset,
219
- task=task,
220
- n_class=n_class,
221
- drug_featurizer=drug_featurizer,
222
- protein_featurizer=protein_featurizer,
223
- thresholds=thresholds,
224
- discard_intermediate=discard_intermediate
225
- )
226
-
227
- if train_val_test_split:
228
- # TODO test behavior for trainer.test and predict when this is passed
229
- if len(train_val_test_split) not in [2, 3]:
230
- raise ValueError('Length of `train_val_test_split` must be 2 (for training without testing) or 3.')
231
- if all([data_file, split]):
232
- if all(isinstance(split, Number) for split in train_val_test_split):
233
- pass
234
- else:
235
- raise ValueError('`train_val_test_split` must be a sequence numbers '
236
- '(float for percentages and int for sample numbers) '
237
- 'if both `data_file` and `split` have been specified.')
238
- elif all(isinstance(split, str) for split in train_val_test_split) and not any([data_file, split]):
239
- split_paths = []
240
- for split in train_val_test_split:
241
- split = Path(split)
242
- if not split.is_absolute():
243
- split = Path(data_dir, split)
244
- split_paths.append(split)
245
-
246
- self.train_data = self.dataset(data_path=split_paths[0])
247
- self.val_data = self.dataset(data_path=split_paths[1])
248
- if len(train_val_test_split) == 3:
249
- self.test_data = self.dataset(data_path=split_paths[2])
250
- else:
251
- raise ValueError('For training, you must specify either `data_file`, `split`, '
252
- 'and `train_val_test_split` as a sequence of numbers or '
253
- 'solely `train_val_test_split` as a sequence of data file paths.')
254
-
255
- elif data_file and not any([split, train_val_test_split]):
256
- data_file = Path(data_file)
257
- if not data_file.is_absolute():
258
- data_file = Path(data_dir, data_file)
259
- self.test_data = self.predict_data = self.dataset(data_path=data_file)
260
- else:
261
- raise ValueError("For training, you must specify `train_val_test_split`. "
262
- "For testing/predicting, you must specify only `data_file` without "
263
- "`train_val_test_split` or `split`.")
264
-
265
- # this line allows to access init params with 'self.hparams' attribute
266
- # also ensures init params will be stored in ckpt
267
- self.save_hyperparameters(logger=False) # ignore=['split']
268
-
269
- def prepare_data(self):
270
- """
271
- Download data if needed.
272
- Do not use it to assign state (e.g., self.x = x).
273
- """
274
-
275
- def setup(self, stage: Optional[str] = None, encoding: str = None):
276
- """
277
- Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
278
- This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
279
- careful not to execute data splitting twice.
280
- """
281
- # TODO test SafeBatchSampler (which skips samples with any None without introducing variable batch size)
282
- # load and split datasets only if not loaded in initialization
283
- if not any([self.train_data, self.test_data, self.val_data, self.predict_data]):
284
- self.train_data, self.val_data, self.test_data = self.split(
285
- dataset=self.dataset(data_path=Path(self.hparams.data_dir, self.hparams.data_file)),
286
- lengths=self.hparams.train_val_test_split
287
- )
288
-
289
- def train_dataloader(self):
290
- return DataLoader(
291
- dataset=self.train_data,
292
- batch_sampler=SafeBatchSampler(
293
- data_source=self.train_data,
294
- batch_size=self.hparams.batch_size,
295
- # Dropping the last batch prevents problems caused by variable batch sizes in training, e.g.,
296
- # batch_size=1 in BatchNorm, and shuffling ensures the model be trained on all samples over epochs.
297
- drop_last=True,
298
- shuffle=True,
299
- ),
300
- # batch_size=self.hparams.batch_size,
301
- # shuffle=True,
302
- num_workers=self.hparams.num_workers,
303
- pin_memory=self.hparams.pin_memory,
304
- collate_fn=self.collator,
305
- persistent_workers=True if self.hparams.num_workers > 0 else False
306
- )
307
-
308
- def val_dataloader(self):
309
- return DataLoader(
310
- dataset=self.val_data,
311
- batch_sampler=SafeBatchSampler(
312
- data_source=self.val_data,
313
- batch_size=self.hparams.batch_size,
314
- drop_last=False,
315
- shuffle=False
316
- ),
317
- # batch_size=self.hparams.batch_size,
318
- # shuffle=False,
319
- num_workers=self.hparams.num_workers,
320
- pin_memory=self.hparams.pin_memory,
321
- collate_fn=self.collator,
322
- persistent_workers=True if self.hparams.num_workers > 0 else False
323
- )
324
-
325
- def test_dataloader(self):
326
- return DataLoader(
327
- dataset=self.test_data,
328
- batch_sampler=SafeBatchSampler(
329
- data_source=self.test_data,
330
- batch_size=self.hparams.batch_size,
331
- drop_last=False,
332
- shuffle=False
333
- ),
334
- # batch_size=self.hparams.batch_size,
335
- # shuffle=False,
336
- num_workers=self.hparams.num_workers,
337
- pin_memory=self.hparams.pin_memory,
338
- collate_fn=self.collator,
339
- persistent_workers=True if self.hparams.num_workers > 0 else False
340
- )
341
-
342
- def predict_dataloader(self):
343
- return DataLoader(
344
- dataset=self.predict_data,
345
- batch_sampler=SafeBatchSampler(
346
- data_source=self.predict_data,
347
- batch_size=self.hparams.batch_size,
348
- drop_last=False,
349
- shuffle=False
350
- ),
351
- # batch_size=self.hparams.batch_size,
352
- # shuffle=False,
353
- num_workers=self.hparams.num_workers,
354
- pin_memory=self.hparams.pin_memory,
355
- collate_fn=self.collator,
356
- persistent_workers=True if self.hparams.num_workers > 0 else False
357
- )
358
-
359
- def teardown(self, stage: Optional[str] = None):
360
- """Clean up after fit or test."""
361
- pass
362
-
363
- def state_dict(self):
364
- """Extra things to save to checkpoint."""
365
- return {}
366
-
367
- def load_state_dict(self, state_dict: Dict[str, Any]):
368
- """Things to do when loading checkpoint."""
369
- pass