libokj commited on
Commit
3cadfde
·
1 Parent(s): 6534610

Update deepscreen/data/dti.py

Browse files
Files changed (1) hide show
  1. deepscreen/data/dti.py +21 -17
deepscreen/data/dti.py CHANGED
@@ -6,7 +6,9 @@ from typing import Any, Dict, Optional, Sequence, Union, Literal
6
 
7
  from lightning import LightningDataModule
8
  import pandas as pd
9
- import swifter
 
 
10
  from sklearn.preprocessing import LabelEncoder
11
  from torch.utils.data import Dataset, DataLoader
12
 
@@ -14,6 +16,7 @@ from deepscreen.data.utils import label_transform, collate_fn, SafeBatchSampler
14
  from deepscreen.utils import get_logger
15
 
16
  log = get_logger(__name__)
 
17
 
18
  SMILES_PAT = r"[^A-Za-z0-9=#:+\-\[\]<>()/\\@%,.*]"
19
  FASTA_PAT = r"[^A-Z*\-]"
@@ -33,14 +36,12 @@ def validate_seq_str(seq, regex):
33
  # TODO: save a list of corrupted records
34
 
35
  def rdkit_canonicalize(smiles):
36
- from rdkit import Chem
37
  try:
38
  mol = Chem.MolFromSmiles(smiles)
39
- cano_smiles = Chem.MolToSmiles(mol)
40
- return cano_smiles
41
  except Exception as e:
42
  log.warning(f'Failed to canonicalize SMILES using RDKIT due to {str(e)}. Returning original SMILES: {smiles}')
43
- return smiles
44
 
45
 
46
  class DTIDataset(Dataset):
@@ -85,6 +86,12 @@ class DTIDataset(Dataset):
85
  # Forward-fill all non-label columns
86
  df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0)
87
 
 
 
 
 
 
 
88
  # TODO potentially allow running through the whole data validation process
89
  # error = False
90
 
@@ -93,9 +100,9 @@ class DTIDataset(Dataset):
93
  # TODO: check sklearn.utils.multiclass.check_classification_targets
94
  match task:
95
  case 'regression':
96
- assert all(df['Y'].swifter.apply(lambda x: isinstance(x, Number))), \
97
  f"""`Y` must be numeric for `regression` task,
98
- but it has {set(df['Y'].swifter.apply(type))}."""
99
 
100
  case 'binary':
101
  if all(df['Y'].isin([0, 1])):
@@ -112,7 +119,7 @@ class DTIDataset(Dataset):
112
  case 'multiclass':
113
  assert num_classes >= 3, f'`num_classes` for `task=multiclass` must be at least 3.'
114
 
115
- if all(df['Y'].swifter.apply(lambda x: x.is_integer() and x >= 0)):
116
  assert not thresholds, \
117
  f"""`Y` is already non-negative integers for
118
  `multiclass` (classification) `task`, but still got `thresholds` ({thresholds}).
@@ -140,9 +147,9 @@ class DTIDataset(Dataset):
140
  match task:
141
  case 'regression':
142
  df['Y'] = df['Y'].astype('float32')
143
- assert all(df['Y'].swifter.apply(lambda x: isinstance(x, Number))), \
144
  f"""`Y` must be numeric for `regression` task,
145
- but after transformation it still has {set(df['Y'].swifter.apply(type))}.
146
  Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
147
  # TODO print err idx instead
148
  case 'binary':
@@ -154,7 +161,7 @@ class DTIDataset(Dataset):
154
  # TODO print err idx instead
155
  case 'multiclass':
156
  df['Y'] = df['Y'].astype('int')
157
- assert all(df['Y'].swifter.apply(lambda x: x.is_integer() and x >= 0)), \
158
  f"""Y must be non-negative integers for `task=multiclass`
159
  but after transformation it still has {pd.unique(df['Y'])}.
160
  Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
@@ -166,16 +173,14 @@ class DTIDataset(Dataset):
166
  Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
167
 
168
  log.info("Validating SMILES (`X1`)...")
169
- df['X1_ERR'] = df['X1'].swifter.progress_bar(
170
- desc="Validating SMILES...").apply(validate_seq_str, regex=SMILES_PAT)
171
  if not df['X1_ERR'].isna().all():
172
  raise Exception(f"Encountered invalid SMILES:\n{df[~df['X1_ERR'].isna()][['X1', 'X1_ERR']]}")
173
- df['X1^'] = df['X1'].swifter.apply(rdkit_canonicalize) # swifter
174
 
175
  log.info("Validating FASTA (`X2`)...")
176
  df['X2'] = df['X2'].str.upper()
177
- df['X2_ERR'] = df['X2'].swifter.progress_bar(
178
- desc="Validating FASTA...").apply(validate_seq_str, regex=FASTA_PAT)
179
  if not df['X2_ERR'].isna().all():
180
  raise Exception(f"Encountered invalid FASTA:\n{df[~df['X2_ERR'].isna()][['X2', 'X2_ERR']]}")
181
 
@@ -425,4 +430,3 @@ class DTIDataModule(LightningDataModule):
425
  def load_state_dict(self, state_dict: Dict[str, Any]):
426
  """Things to do when loading checkpoint."""
427
  pass
428
-
 
6
 
7
  from lightning import LightningDataModule
8
  import pandas as pd
9
+ from pandarallel import pandarallel
10
+ from rdkit import Chem
11
+ #import swifter
12
  from sklearn.preprocessing import LabelEncoder
13
  from torch.utils.data import Dataset, DataLoader
14
 
 
16
  from deepscreen.utils import get_logger
17
 
18
  log = get_logger(__name__)
19
+ pandarallel.initialize(progress_bar=True)
20
 
21
  SMILES_PAT = r"[^A-Za-z0-9=#:+\-\[\]<>()/\\@%,.*]"
22
  FASTA_PAT = r"[^A-Z*\-]"
 
36
  # TODO: save a list of corrupted records
37
 
38
  def rdkit_canonicalize(smiles):
 
39
  try:
40
  mol = Chem.MolFromSmiles(smiles)
41
+ smiles = Chem.MolToSmiles(mol)
 
42
  except Exception as e:
43
  log.warning(f'Failed to canonicalize SMILES using RDKIT due to {str(e)}. Returning original SMILES: {smiles}')
44
+ return smiles
45
 
46
 
47
  class DTIDataset(Dataset):
 
86
  # Forward-fill all non-label columns
87
  df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0)
88
 
89
+ # Fill NAs in string cols with an empty string to prevent wrong type inference by pytorch collator
90
+ for col in df.columns:
91
+ if df[col].dtype == 'object':
92
+ df[col] = df[col].fillna('')
93
+
94
+
95
  # TODO potentially allow running through the whole data validation process
96
  # error = False
97
 
 
100
  # TODO: check sklearn.utils.multiclass.check_classification_targets
101
  match task:
102
  case 'regression':
103
+ assert all(df['Y'].parallel_apply(lambda x: isinstance(x, Number))), \
104
  f"""`Y` must be numeric for `regression` task,
105
+ but it has {set(df['Y'].parallel_apply(type))}."""
106
 
107
  case 'binary':
108
  if all(df['Y'].isin([0, 1])):
 
119
  case 'multiclass':
120
  assert num_classes >= 3, f'`num_classes` for `task=multiclass` must be at least 3.'
121
 
122
+ if all(df['Y'].parallel_apply(lambda x: x.is_integer() and x >= 0)):
123
  assert not thresholds, \
124
  f"""`Y` is already non-negative integers for
125
  `multiclass` (classification) `task`, but still got `thresholds` ({thresholds}).
 
147
  match task:
148
  case 'regression':
149
  df['Y'] = df['Y'].astype('float32')
150
+ assert all(df['Y'].parallel_apply(lambda x: isinstance(x, Number))), \
151
  f"""`Y` must be numeric for `regression` task,
152
+ but after transformation it still has {set(df['Y'].parallel_apply(type))}.
153
  Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
154
  # TODO print err idx instead
155
  case 'binary':
 
161
  # TODO print err idx instead
162
  case 'multiclass':
163
  df['Y'] = df['Y'].astype('int')
164
+ assert all(df['Y'].parallel_apply(lambda x: x.is_integer() and x >= 0)), \
165
  f"""Y must be non-negative integers for `task=multiclass`
166
  but after transformation it still has {pd.unique(df['Y'])}.
167
  Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
 
173
  Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
174
 
175
  log.info("Validating SMILES (`X1`)...")
176
+ df['X1_ERR'] = df['X1'].parallel_apply(validate_seq_str, regex=SMILES_PAT)
 
177
  if not df['X1_ERR'].isna().all():
178
  raise Exception(f"Encountered invalid SMILES:\n{df[~df['X1_ERR'].isna()][['X1', 'X1_ERR']]}")
179
+ df['X1^'] = df['X1'].parallel_apply(rdkit_canonicalize)
180
 
181
  log.info("Validating FASTA (`X2`)...")
182
  df['X2'] = df['X2'].str.upper()
183
+ df['X2_ERR'] = df['X2'].parallel_apply(validate_seq_str, regex=FASTA_PAT)
 
184
  if not df['X2_ERR'].isna().all():
185
  raise Exception(f"Encountered invalid FASTA:\n{df[~df['X2_ERR'].isna()][['X2', 'X2_ERR']]}")
186
 
 
430
  def load_state_dict(self, state_dict: Dict[str, Any]):
431
  """Things to do when loading checkpoint."""
432
  pass