Spaces:
Sleeping
Sleeping
Delete deepscreen
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- deepscreen/__init__.py +0 -101
- deepscreen/__pycache__/__init__.cpython-311.pyc +0 -0
- deepscreen/__pycache__/__init__.cpython-39.pyc +0 -0
- deepscreen/__pycache__/predict.cpython-311.pyc +0 -0
- deepscreen/__pycache__/test.cpython-311.pyc +0 -0
- deepscreen/__pycache__/train.cpython-311.pyc +0 -0
- deepscreen/__pycache__/train.cpython-39.pyc +0 -0
- deepscreen/data/__init__.py +0 -0
- deepscreen/data/__pycache__/__init__.cpython-311.pyc +0 -0
- deepscreen/data/__pycache__/__init__.cpython-39.pyc +0 -0
- deepscreen/data/__pycache__/dti.cpython-311.pyc +0 -0
- deepscreen/data/__pycache__/dti_datamodule.cpython-311.pyc +0 -0
- deepscreen/data/dti.py +0 -422
- deepscreen/data/dti.py.bak +0 -369
- deepscreen/data/dti_datamodule.py +0 -314
- deepscreen/data/entity_datamodule.py +0 -167
- deepscreen/data/featurizers/__init__.py +0 -0
- deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/__pycache__/fcs.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/__pycache__/graph.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/__pycache__/token.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/categorical.py +0 -86
- deepscreen/data/featurizers/chem.py +0 -48
- deepscreen/data/featurizers/fcs.py +0 -67
- deepscreen/data/featurizers/fingerprint/__init__.py +0 -45
- deepscreen/data/featurizers/fingerprint/__pycache__/__init__.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/fingerprint/__pycache__/atompairs.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/fingerprint/__pycache__/avalonfp.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/fingerprint/__pycache__/estatefp.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/fingerprint/__pycache__/maccskeys.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/fingerprint/__pycache__/map4.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/fingerprint/__pycache__/mhfp6.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/fingerprint/__pycache__/morganfp.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/fingerprint/__pycache__/pharmErGfp.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/fingerprint/__pycache__/pharmPointfp.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/fingerprint/__pycache__/pubchemfp.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/fingerprint/__pycache__/rdkitfp.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/fingerprint/__pycache__/torsions.cpython-311.pyc +0 -0
- deepscreen/data/featurizers/fingerprint/atompairs.py +0 -18
- deepscreen/data/featurizers/fingerprint/avalonfp.py +0 -16
- deepscreen/data/featurizers/fingerprint/estatefp.py +0 -12
- deepscreen/data/featurizers/fingerprint/maccskeys.py +0 -25
- deepscreen/data/featurizers/fingerprint/maccskeys.xlsx +0 -0
- deepscreen/data/featurizers/fingerprint/map4.py +0 -130
- deepscreen/data/featurizers/fingerprint/mhfp6.py +0 -18
- deepscreen/data/featurizers/fingerprint/mnimalfatures.fdef +0 -53
- deepscreen/data/featurizers/fingerprint/morganfp.py +0 -18
- deepscreen/data/featurizers/fingerprint/pharmErGfp.py +0 -60
- deepscreen/data/featurizers/fingerprint/pharmPointfp.py +0 -59
deepscreen/__init__.py
DELETED
@@ -1,101 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
DeepScreen package initialization, registering custom objects and monkey patching for some libraries.
|
3 |
-
"""
|
4 |
-
import sys
|
5 |
-
from builtins import eval
|
6 |
-
|
7 |
-
import lightning.fabric.strategies.launchers.subprocess_script as subprocess_script
|
8 |
-
import torch
|
9 |
-
from omegaconf import OmegaConf
|
10 |
-
|
11 |
-
from deepscreen.utils import get_logger
|
12 |
-
|
13 |
-
log = get_logger(__name__)
|
14 |
-
|
15 |
-
# Allow basic Python operations in hydra interpolation; examples:
|
16 |
-
# `in_channels: ${eval:${model.drug_encoder.out_channels}+${model.protein_encoder.out_channels}}`
|
17 |
-
# `subdir: ${eval:${hydra.job.override_dirname}.replace('/', '.')}`
|
18 |
-
OmegaConf.register_new_resolver("eval", eval)
|
19 |
-
|
20 |
-
|
21 |
-
def sanitize_path(path_str: str):
|
22 |
-
"""
|
23 |
-
Sanitize a string for path creation by replacing unsafe characters and cutting length to 255 (OS limitation).
|
24 |
-
"""
|
25 |
-
return path_str.replace("/", ".").replace("\\", ".").replace(":", "-")[:255]
|
26 |
-
|
27 |
-
|
28 |
-
OmegaConf.register_new_resolver("sanitize_path", sanitize_path)
|
29 |
-
|
30 |
-
|
31 |
-
def _hydra_subprocess_cmd(local_rank: int):
|
32 |
-
"""
|
33 |
-
Monkey patching for lightning.fabric.strategies.launchers.subprocess_script._hydra_subprocess_cmd
|
34 |
-
Temporarily fixes the problem of unnecessarily creating log folders for DDP subprocesses in Hydra multirun/sweep.
|
35 |
-
"""
|
36 |
-
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
37 |
-
from hydra.core.hydra_config import HydraConfig
|
38 |
-
from hydra.utils import get_original_cwd, to_absolute_path
|
39 |
-
|
40 |
-
# when user is using hydra find the absolute path
|
41 |
-
if __main__.__spec__ is None: # pragma: no-cover
|
42 |
-
command = [sys.executable, to_absolute_path(sys.argv[0])]
|
43 |
-
else:
|
44 |
-
command = [sys.executable, "-m", __main__.__spec__.name]
|
45 |
-
|
46 |
-
command += sys.argv[1:]
|
47 |
-
|
48 |
-
cwd = get_original_cwd()
|
49 |
-
rundir = f'"{HydraConfig.get().runtime.output_dir}"'
|
50 |
-
# Set output_subdir null since we don't want different subprocesses trying to write to config.yaml
|
51 |
-
command += [f"hydra.job.name=train_ddp_process_{local_rank}",
|
52 |
-
"hydra.output_subdir=null,"
|
53 |
-
f"hydra.runtime.output_dir={rundir}"]
|
54 |
-
return command, cwd
|
55 |
-
|
56 |
-
|
57 |
-
subprocess_script._hydra_subprocess_cmd = _hydra_subprocess_cmd
|
58 |
-
|
59 |
-
# from torch import Tensor
|
60 |
-
# from lightning.fabric.utilities.distributed import _distributed_available
|
61 |
-
# from lightning.pytorch.utilities.rank_zero import WarningCache
|
62 |
-
# from lightning.pytorch.utilities.warnings import PossibleUserWarning
|
63 |
-
# from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
|
64 |
-
|
65 |
-
# warning_cache = WarningCache()
|
66 |
-
#
|
67 |
-
# @staticmethod
|
68 |
-
# def _get_cache(result_metric, on_step: bool):
|
69 |
-
# cache = None
|
70 |
-
# if on_step and result_metric.meta.on_step:
|
71 |
-
# cache = result_metric._forward_cache
|
72 |
-
# elif not on_step and result_metric.meta.on_epoch:
|
73 |
-
# if result_metric._computed is None:
|
74 |
-
# should = result_metric.meta.sync.should
|
75 |
-
# if not should and _distributed_available() and result_metric.is_tensor:
|
76 |
-
# warning_cache.warn(
|
77 |
-
# f"It is recommended to use `self.log({result_metric.meta.name!r}, ..., sync_dist=True)`"
|
78 |
-
# " when logging on epoch level in distributed setting to accumulate the metric across"
|
79 |
-
# " devices.",
|
80 |
-
# category=PossibleUserWarning,
|
81 |
-
# )
|
82 |
-
# result_metric.compute()
|
83 |
-
# result_metric.meta.sync.should = should
|
84 |
-
#
|
85 |
-
# cache = result_metric._computed
|
86 |
-
#
|
87 |
-
# if cache is not None:
|
88 |
-
# if isinstance(cache, Tensor):
|
89 |
-
# if not result_metric.meta.enable_graph:
|
90 |
-
# return cache.detach()
|
91 |
-
#
|
92 |
-
# return cache
|
93 |
-
#
|
94 |
-
#
|
95 |
-
# _ResultCollection._get_cache = _get_cache
|
96 |
-
|
97 |
-
if torch.cuda.is_available():
|
98 |
-
if torch.cuda.get_device_capability() >= (8, 0):
|
99 |
-
torch.set_float32_matmul_precision("high")
|
100 |
-
log.info("Your GPU supports tensor cores, "
|
101 |
-
"we will enable it automatically by setting `torch.set_float32_matmul_precision('high')`")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/__pycache__/__init__.cpython-311.pyc
DELETED
Binary file (3.28 kB)
|
|
deepscreen/__pycache__/__init__.cpython-39.pyc
DELETED
Binary file (247 Bytes)
|
|
deepscreen/__pycache__/predict.cpython-311.pyc
DELETED
Binary file (3.37 kB)
|
|
deepscreen/__pycache__/test.cpython-311.pyc
DELETED
Binary file (4.54 kB)
|
|
deepscreen/__pycache__/train.cpython-311.pyc
DELETED
Binary file (7.14 kB)
|
|
deepscreen/__pycache__/train.cpython-39.pyc
DELETED
Binary file (2.68 kB)
|
|
deepscreen/data/__init__.py
DELETED
File without changes
|
deepscreen/data/__pycache__/__init__.cpython-311.pyc
DELETED
Binary file (179 Bytes)
|
|
deepscreen/data/__pycache__/__init__.cpython-39.pyc
DELETED
Binary file (161 Bytes)
|
|
deepscreen/data/__pycache__/dti.cpython-311.pyc
DELETED
Binary file (23 kB)
|
|
deepscreen/data/__pycache__/dti_datamodule.cpython-311.pyc
DELETED
Binary file (13 kB)
|
|
deepscreen/data/dti.py
DELETED
@@ -1,422 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
from functools import partial
|
3 |
-
from numbers import Number
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Any, Dict, Optional, Sequence, Union, Literal
|
6 |
-
|
7 |
-
from lightning import LightningDataModule
|
8 |
-
import pandas as pd
|
9 |
-
import swifter
|
10 |
-
from sklearn.preprocessing import LabelEncoder
|
11 |
-
from torch.utils.data import Dataset, DataLoader
|
12 |
-
|
13 |
-
from deepscreen.data.utils import label_transform, collate_fn, SafeBatchSampler
|
14 |
-
from deepscreen.utils import get_logger
|
15 |
-
|
16 |
-
log = get_logger(__name__)
|
17 |
-
|
18 |
-
SMILES_PAT = r"[^A-Za-z0-9=#:+\-\[\]<>()/\\@%,.*]"
|
19 |
-
FASTA_PAT = r"[^A-Z*\-]"
|
20 |
-
|
21 |
-
|
22 |
-
def validate_seq_str(seq, regex):
|
23 |
-
if seq:
|
24 |
-
err_charset = set(re.findall(regex, seq))
|
25 |
-
if not err_charset:
|
26 |
-
return None
|
27 |
-
else:
|
28 |
-
return ', '.join(err_charset)
|
29 |
-
else:
|
30 |
-
return 'Empty string'
|
31 |
-
|
32 |
-
|
33 |
-
# TODO: save a list of corrupted records
|
34 |
-
|
35 |
-
def rdkit_canonicalize(smiles):
|
36 |
-
from rdkit import Chem
|
37 |
-
try:
|
38 |
-
mol = Chem.MolFromSmiles(smiles)
|
39 |
-
cano_smiles = Chem.MolToSmiles(mol)
|
40 |
-
return cano_smiles
|
41 |
-
except Exception as e:
|
42 |
-
log.warning(f'Failed to canonicalize SMILES using RDKIT due to {str(e)}. Returning original SMILES: {smiles}')
|
43 |
-
return smiles
|
44 |
-
|
45 |
-
|
46 |
-
class DTIDataset(Dataset):
|
47 |
-
def __init__(
|
48 |
-
self,
|
49 |
-
task: Literal['regression', 'binary', 'multiclass'],
|
50 |
-
num_classes: Optional[int],
|
51 |
-
data_path: str | Path,
|
52 |
-
drug_featurizer: callable,
|
53 |
-
protein_featurizer: callable,
|
54 |
-
thresholds: Optional[Union[Number, Sequence[Number]]] = None,
|
55 |
-
discard_intermediate: Optional[bool] = False,
|
56 |
-
query: Optional[str] = 'X2'
|
57 |
-
):
|
58 |
-
df = pd.read_csv(
|
59 |
-
data_path,
|
60 |
-
engine='python',
|
61 |
-
header=0,
|
62 |
-
usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'],
|
63 |
-
dtype={
|
64 |
-
'X1': 'str',
|
65 |
-
'ID1': 'str',
|
66 |
-
'X2': 'str',
|
67 |
-
'ID2': 'str',
|
68 |
-
'Y': 'float32',
|
69 |
-
'U': 'str',
|
70 |
-
},
|
71 |
-
)
|
72 |
-
# Read the whole data table
|
73 |
-
|
74 |
-
# if 'ID1' in df:
|
75 |
-
# self.x1_to_id1 = dict(zip(df['X1'], df['ID1']))
|
76 |
-
# if 'ID2' in df:
|
77 |
-
# self.x2_to_id2 = dict(zip(df['X2'], df['ID2']))
|
78 |
-
# self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2']))))
|
79 |
-
# self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2']))))
|
80 |
-
|
81 |
-
# # train and eval mode data processing (fully labelled)
|
82 |
-
# if 'Y' in df.columns and df['Y'].notnull().all():
|
83 |
-
log.info(f"Processing data file: {data_path}")
|
84 |
-
|
85 |
-
# Forward-fill all non-label columns
|
86 |
-
df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0)
|
87 |
-
|
88 |
-
# TODO potentially allow running through the whole data validation process
|
89 |
-
# error = False
|
90 |
-
|
91 |
-
if 'Y' in df:
|
92 |
-
log.info(f"Validating labels (`Y`)...")
|
93 |
-
# TODO: check sklearn.utils.multiclass.check_classification_targets
|
94 |
-
match task:
|
95 |
-
case 'regression':
|
96 |
-
assert all(df['Y'].swifter.apply(lambda x: isinstance(x, Number))), \
|
97 |
-
f"""`Y` must be numeric for `regression` task,
|
98 |
-
but it has {set(df['Y'].swifter.apply(type))}."""
|
99 |
-
|
100 |
-
case 'binary':
|
101 |
-
if all(df['Y'].isin([0, 1])):
|
102 |
-
assert not thresholds, \
|
103 |
-
f"""`Y` is already 0 or 1 for `binary` (classification) `task`,
|
104 |
-
but still got `thresholds` ({thresholds}).
|
105 |
-
Double check your choices of `task` and `thresholds`, and records in the `Y` column."""
|
106 |
-
else:
|
107 |
-
assert thresholds, \
|
108 |
-
f"""`Y` must be 0 or 1 for `binary` (classification) `task`,
|
109 |
-
but it has {pd.unique(df['Y'])}.
|
110 |
-
You may set `thresholds` to discretize continuous labels.""" # TODO print err idx instead
|
111 |
-
|
112 |
-
case 'multiclass':
|
113 |
-
assert num_classes >= 3, f'`num_classes` for `task=multiclass` must be at least 3.'
|
114 |
-
|
115 |
-
if all(df['Y'].swifter.apply(lambda x: x.is_integer() and x >= 0)):
|
116 |
-
assert not thresholds, \
|
117 |
-
f"""`Y` is already non-negative integers for
|
118 |
-
`multiclass` (classification) `task`, but still got `thresholds` ({thresholds}).
|
119 |
-
Double check your choice of `task`, `thresholds` and records in the `Y` column."""
|
120 |
-
else:
|
121 |
-
assert thresholds, \
|
122 |
-
f"""`Y` must be non-negative integers for
|
123 |
-
`multiclass` (classification) 'task',but it has {pd.unique(df['Y'])}.
|
124 |
-
You must set `thresholds` to discretize continuous labels.""" # TODO print err idx instead
|
125 |
-
|
126 |
-
if 'U' in df.columns:
|
127 |
-
units = df['U']
|
128 |
-
else:
|
129 |
-
units = None
|
130 |
-
log.warning("Units ('U') not in the data table. "
|
131 |
-
"Assuming all labels to be discrete or in p-scale (-log10[M]).")
|
132 |
-
|
133 |
-
# Transform labels
|
134 |
-
df['Y'] = label_transform(labels=df['Y'], units=units, thresholds=thresholds,
|
135 |
-
discard_intermediate=discard_intermediate)
|
136 |
-
|
137 |
-
# Filter out rows with a NaN in Y (missing values)
|
138 |
-
df.dropna(subset=['Y'], inplace=True)
|
139 |
-
|
140 |
-
match task:
|
141 |
-
case 'regression':
|
142 |
-
df['Y'] = df['Y'].astype('float32')
|
143 |
-
assert all(df['Y'].swifter.apply(lambda x: isinstance(x, Number))), \
|
144 |
-
f"""`Y` must be numeric for `regression` task,
|
145 |
-
but after transformation it still has {set(df['Y'].swifter.apply(type))}.
|
146 |
-
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
|
147 |
-
# TODO print err idx instead
|
148 |
-
case 'binary':
|
149 |
-
df['Y'] = df['Y'].astype('int')
|
150 |
-
assert all(df['Y'].isin([0, 1])), \
|
151 |
-
f"""`Y` must be 0 or 1 for `task=binary`, "
|
152 |
-
but after transformation it still has {pd.unique(df['Y'])}.
|
153 |
-
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
|
154 |
-
# TODO print err idx instead
|
155 |
-
case 'multiclass':
|
156 |
-
df['Y'] = df['Y'].astype('int')
|
157 |
-
assert all(df['Y'].swifter.apply(lambda x: x.is_integer() and x >= 0)), \
|
158 |
-
f"""Y must be non-negative integers for `task=multiclass`
|
159 |
-
but after transformation it still has {pd.unique(df['Y'])}.
|
160 |
-
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
|
161 |
-
# TODO print err idx instead
|
162 |
-
target_n_unique = df['Y'].nunique()
|
163 |
-
assert target_n_unique == num_classes, \
|
164 |
-
f"""You have set `num_classes` for `task=multiclass` to {num_classes},
|
165 |
-
but after transformation Y still has {target_n_unique} unique labels.
|
166 |
-
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
|
167 |
-
|
168 |
-
log.info("Validating SMILES (`X1`)...")
|
169 |
-
df['X1_ERR'] = df['X1'].swifter.progress_bar(
|
170 |
-
desc="Validating SMILES...").apply(validate_seq_str, regex=SMILES_PAT)
|
171 |
-
if not df['X1_ERR'].isna().all():
|
172 |
-
raise Exception(f"Encountered invalid SMILES:\n{df[~df['X1_ERR'].isna()][['X1', 'X1_ERR']]}")
|
173 |
-
df['X1^'] = df['X1'].apply(rdkit_canonicalize) # swifter
|
174 |
-
|
175 |
-
log.info("Validating FASTA (`X2`)...")
|
176 |
-
df['X2'] = df['X2'].str.upper()
|
177 |
-
df['X2_ERR'] = df['X2'].swifter.progress_bar(
|
178 |
-
desc="Validating FASTA...").apply(validate_seq_str, regex=FASTA_PAT)
|
179 |
-
if not df['X2_ERR'].isna().all():
|
180 |
-
raise Exception(f"Encountered invalid FASTA:\n{df[~df['X2_ERR'].isna()][['X2', 'X2_ERR']]}")
|
181 |
-
|
182 |
-
# FASTA/SMILES indices as query for retrieval metrics like enrichment factor and hit rate
|
183 |
-
if query:
|
184 |
-
df['ID^'] = LabelEncoder().fit_transform(df[query])
|
185 |
-
|
186 |
-
self.df = df
|
187 |
-
self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x)
|
188 |
-
self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x)
|
189 |
-
|
190 |
-
def __len__(self):
|
191 |
-
return len(self.df.index)
|
192 |
-
|
193 |
-
def __getitem__(self, i):
|
194 |
-
sample = self.df.loc[i]
|
195 |
-
return {
|
196 |
-
'N': i,
|
197 |
-
'X1': sample['X1'],
|
198 |
-
'X1^': self.drug_featurizer(sample['X1^']),
|
199 |
-
'ID1': sample.get('ID1'),
|
200 |
-
'X2': sample['X2'],
|
201 |
-
'X2^': self.protein_featurizer(sample['X2']),
|
202 |
-
'ID2': sample.get('ID2'),
|
203 |
-
'Y': sample.get('Y'),
|
204 |
-
'ID^': sample.get('ID^'),
|
205 |
-
}
|
206 |
-
|
207 |
-
|
208 |
-
class DTIDataModule(LightningDataModule):
|
209 |
-
"""
|
210 |
-
DTI DataModule
|
211 |
-
|
212 |
-
A DataModule implements 5 key methods:
|
213 |
-
|
214 |
-
def prepare_data(self):
|
215 |
-
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
|
216 |
-
# download data, pre-process, split, save to disk, etc.
|
217 |
-
def setup(self, stage):
|
218 |
-
# things to do on every process in DDP
|
219 |
-
# load data, set variables, etc.
|
220 |
-
def train_dataloader(self):
|
221 |
-
# return train dataloader
|
222 |
-
def val_dataloader(self):
|
223 |
-
# return validation dataloader
|
224 |
-
def test_dataloader(self):
|
225 |
-
# return test dataloader
|
226 |
-
def teardown(self):
|
227 |
-
# called on every process in DDP
|
228 |
-
# clean up after fit or test
|
229 |
-
|
230 |
-
This allows you to share a full dataset without explaining how to download,
|
231 |
-
split, transform and process the data.
|
232 |
-
|
233 |
-
Read the docs:
|
234 |
-
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
|
235 |
-
"""
|
236 |
-
|
237 |
-
def __init__(
|
238 |
-
self,
|
239 |
-
task: Literal['regression', 'binary', 'multiclass'],
|
240 |
-
num_classes: Optional[int],
|
241 |
-
batch_size: int,
|
242 |
-
# train: bool,
|
243 |
-
drug_featurizer: callable,
|
244 |
-
protein_featurizer: callable,
|
245 |
-
collator: callable = collate_fn,
|
246 |
-
data_dir: str = "data/",
|
247 |
-
data_file: Optional[str] = None,
|
248 |
-
train_val_test_split: Optional[Union[Sequence[Number | str]]] = None,
|
249 |
-
split: Optional[callable] = None,
|
250 |
-
thresholds: Optional[Union[Number, Sequence[Number]]] = None,
|
251 |
-
discard_intermediate: Optional[bool] = False,
|
252 |
-
num_workers: int = 0,
|
253 |
-
pin_memory: bool = False,
|
254 |
-
):
|
255 |
-
super().__init__()
|
256 |
-
|
257 |
-
self.train_data: Optional[Dataset] = None
|
258 |
-
self.val_data: Optional[Dataset] = None
|
259 |
-
self.test_data: Optional[Dataset] = None
|
260 |
-
self.predict_data: Optional[Dataset] = None
|
261 |
-
self.split = split
|
262 |
-
self.collator = collator
|
263 |
-
self.dataset = partial(
|
264 |
-
DTIDataset,
|
265 |
-
task=task,
|
266 |
-
num_classes=num_classes,
|
267 |
-
drug_featurizer=drug_featurizer,
|
268 |
-
protein_featurizer=protein_featurizer,
|
269 |
-
thresholds=thresholds,
|
270 |
-
discard_intermediate=discard_intermediate
|
271 |
-
)
|
272 |
-
|
273 |
-
# this line allows to access init params with 'self.hparams' ensures init params will be stored in ckpt
|
274 |
-
self.save_hyperparameters(logger=False) # ignore=['split']
|
275 |
-
|
276 |
-
def prepare_data(self):
|
277 |
-
"""
|
278 |
-
Download data if needed.
|
279 |
-
Do not use it to assign state (e.g., self.x = x).
|
280 |
-
"""
|
281 |
-
|
282 |
-
def setup(self, stage: Optional[str] = None, encoding: str = None):
|
283 |
-
"""
|
284 |
-
Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
|
285 |
-
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
|
286 |
-
careful not to execute data splitting twice.
|
287 |
-
"""
|
288 |
-
# load and split datasets only if not loaded in initialization
|
289 |
-
if not any([self.train_data, self.test_data, self.val_data, self.predict_data]):
|
290 |
-
if self.hparams.train_val_test_split:
|
291 |
-
if len(self.hparams.train_val_test_split) != 3:
|
292 |
-
raise ValueError('Length of `train_val_test_split` must be 3. '
|
293 |
-
'Set the second element to None for training without validation. '
|
294 |
-
'Set the third element to None for training without testing.')
|
295 |
-
|
296 |
-
self.train_data = self.hparams.train_val_test_split[0]
|
297 |
-
self.val_data = self.hparams.train_val_test_split[1]
|
298 |
-
self.test_data = self.hparams.train_val_test_split[2]
|
299 |
-
|
300 |
-
if all([self.hparams.data_file, self.split]):
|
301 |
-
if all(isinstance(split, Number) or split is None
|
302 |
-
for split in self.hparams.train_val_test_split):
|
303 |
-
split_data = self.split(
|
304 |
-
dataset=self.dataset(data_path=Path(self.hparams.data_dir, self.hparams.data_file)),
|
305 |
-
lengths=[split for split in self.hparams.train_val_test_split if split is not None]
|
306 |
-
)
|
307 |
-
for dataset in ['train_data', 'val_data', 'test_data']:
|
308 |
-
if getattr(self, dataset) is not None:
|
309 |
-
setattr(self, dataset, split_data.pop(0))
|
310 |
-
|
311 |
-
else:
|
312 |
-
raise ValueError('`train_val_test_split` must be a sequence numbers or None'
|
313 |
-
'(float for percentages and int for sample numbers) '
|
314 |
-
'if both `data_file` and `split` have been specified.')
|
315 |
-
|
316 |
-
elif (all(isinstance(split, str) or split is None
|
317 |
-
for split in self.hparams.train_val_test_split)
|
318 |
-
and not any([self.hparams.data_file, self.split])):
|
319 |
-
for dataset in ['train_data', 'val_data', 'test_data']:
|
320 |
-
if getattr(self, dataset) is not None:
|
321 |
-
data_path = Path(getattr(self, dataset))
|
322 |
-
if not data_path.is_absolute():
|
323 |
-
data_path = Path(self.hparams.data_dir, data_path)
|
324 |
-
setattr(self, dataset, self.dataset(data_path=data_path))
|
325 |
-
|
326 |
-
else:
|
327 |
-
raise ValueError('For training, you must specify either all of `data_file`, `split`, '
|
328 |
-
'and `train_val_test_split` as a sequence of numbers or '
|
329 |
-
'solely `train_val_test_split` as a sequence of data file paths.')
|
330 |
-
|
331 |
-
elif self.hparams.data_file and not any([self.split, self.hparams.train_val_test_split]):
|
332 |
-
data_path = Path(self.hparams.data_file)
|
333 |
-
if not data_path.is_absolute():
|
334 |
-
data_path = Path(self.hparams.data_dir, data_path)
|
335 |
-
self.test_data = self.predict_data = self.dataset(data_path=data_path)
|
336 |
-
|
337 |
-
else:
|
338 |
-
raise ValueError("For training, you must specify `train_val_test_split`. "
|
339 |
-
"For testing/predicting, you must specify only `data_file` without "
|
340 |
-
"`train_val_test_split` or `split`.")
|
341 |
-
|
342 |
-
def train_dataloader(self):
|
343 |
-
return DataLoader(
|
344 |
-
dataset=self.train_data,
|
345 |
-
batch_sampler=SafeBatchSampler(
|
346 |
-
data_source=self.train_data,
|
347 |
-
batch_size=self.hparams.batch_size,
|
348 |
-
# Dropping the last batch prevents problems caused by variable batch sizes in training, e.g.,
|
349 |
-
# batch_size=1 in BatchNorm, and shuffling ensures the model be trained on all samples over epochs.
|
350 |
-
drop_last=True,
|
351 |
-
shuffle=True,
|
352 |
-
),
|
353 |
-
# batch_size=self.hparams.batch_size,
|
354 |
-
# shuffle=True,
|
355 |
-
num_workers=self.hparams.num_workers,
|
356 |
-
pin_memory=self.hparams.pin_memory,
|
357 |
-
collate_fn=self.collator,
|
358 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
359 |
-
)
|
360 |
-
|
361 |
-
def val_dataloader(self):
|
362 |
-
return DataLoader(
|
363 |
-
dataset=self.val_data,
|
364 |
-
batch_sampler=SafeBatchSampler(
|
365 |
-
data_source=self.val_data,
|
366 |
-
batch_size=self.hparams.batch_size,
|
367 |
-
drop_last=False,
|
368 |
-
shuffle=False
|
369 |
-
),
|
370 |
-
# batch_size=self.hparams.batch_size,
|
371 |
-
# shuffle=False,
|
372 |
-
num_workers=self.hparams.num_workers,
|
373 |
-
pin_memory=self.hparams.pin_memory,
|
374 |
-
collate_fn=self.collator,
|
375 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
376 |
-
)
|
377 |
-
|
378 |
-
def test_dataloader(self):
|
379 |
-
return DataLoader(
|
380 |
-
dataset=self.test_data,
|
381 |
-
batch_sampler=SafeBatchSampler(
|
382 |
-
data_source=self.test_data,
|
383 |
-
batch_size=self.hparams.batch_size,
|
384 |
-
drop_last=False,
|
385 |
-
shuffle=False
|
386 |
-
),
|
387 |
-
# batch_size=self.hparams.batch_size,
|
388 |
-
# shuffle=False,
|
389 |
-
num_workers=self.hparams.num_workers,
|
390 |
-
pin_memory=self.hparams.pin_memory,
|
391 |
-
collate_fn=self.collator,
|
392 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
393 |
-
)
|
394 |
-
|
395 |
-
def predict_dataloader(self):
|
396 |
-
return DataLoader(
|
397 |
-
dataset=self.predict_data,
|
398 |
-
batch_sampler=SafeBatchSampler(
|
399 |
-
data_source=self.predict_data,
|
400 |
-
batch_size=self.hparams.batch_size,
|
401 |
-
drop_last=False,
|
402 |
-
shuffle=False
|
403 |
-
),
|
404 |
-
# batch_size=self.hparams.batch_size,
|
405 |
-
# shuffle=False,
|
406 |
-
num_workers=self.hparams.num_workers,
|
407 |
-
pin_memory=self.hparams.pin_memory,
|
408 |
-
collate_fn=self.collator,
|
409 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
410 |
-
)
|
411 |
-
|
412 |
-
def teardown(self, stage: Optional[str] = None):
|
413 |
-
"""Clean up after fit or test."""
|
414 |
-
pass
|
415 |
-
|
416 |
-
def state_dict(self):
|
417 |
-
"""Extra things to save to checkpoint."""
|
418 |
-
return {}
|
419 |
-
|
420 |
-
def load_state_dict(self, state_dict: Dict[str, Any]):
|
421 |
-
"""Things to do when loading checkpoint."""
|
422 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/dti.py.bak
DELETED
@@ -1,369 +0,0 @@
|
|
1 |
-
from functools import partial
|
2 |
-
from numbers import Number
|
3 |
-
from pathlib import Path
|
4 |
-
from typing import Any, Dict, Optional, Sequence, Union, Literal
|
5 |
-
|
6 |
-
from lightning import LightningDataModule
|
7 |
-
import pandas as pd
|
8 |
-
from sklearn.preprocessing import LabelEncoder
|
9 |
-
from torch.utils.data import Dataset, DataLoader
|
10 |
-
|
11 |
-
from deepscreen.data.utils import label_transform, collate_fn, SafeBatchSampler
|
12 |
-
from deepscreen.utils import get_logger
|
13 |
-
|
14 |
-
log = get_logger(__name__)
|
15 |
-
|
16 |
-
|
17 |
-
# TODO: save a list of corrupted records
|
18 |
-
|
19 |
-
|
20 |
-
class DTIDataset(Dataset):
|
21 |
-
def __init__(
|
22 |
-
self,
|
23 |
-
task: Literal['regression', 'binary', 'multiclass'],
|
24 |
-
n_class: Optional[int],
|
25 |
-
data_path: str | Path,
|
26 |
-
drug_featurizer: callable,
|
27 |
-
protein_featurizer: callable,
|
28 |
-
thresholds: Optional[Union[Number, Sequence[Number]]] = None,
|
29 |
-
discard_intermediate: Optional[bool] = False,
|
30 |
-
):
|
31 |
-
df = pd.read_csv(
|
32 |
-
data_path,
|
33 |
-
engine='python',
|
34 |
-
header=0,
|
35 |
-
usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'],
|
36 |
-
dtype={
|
37 |
-
'X1': 'str',
|
38 |
-
'ID1': 'str',
|
39 |
-
'X2': 'str',
|
40 |
-
'ID2': 'str',
|
41 |
-
'Y': 'float32',
|
42 |
-
'U': 'str',
|
43 |
-
},
|
44 |
-
)
|
45 |
-
# Read the whole data table
|
46 |
-
|
47 |
-
# if 'ID1' in df:
|
48 |
-
# self.x1_to_id1 = dict(zip(df['X1'], df['ID1']))
|
49 |
-
# if 'ID2' in df:
|
50 |
-
# self.x2_to_id2 = dict(zip(df['X2'], df['ID2']))
|
51 |
-
# self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2']))))
|
52 |
-
# self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2']))))
|
53 |
-
|
54 |
-
# # train and eval mode data processing (fully labelled)
|
55 |
-
# if 'Y' in df.columns and df['Y'].notnull().all():
|
56 |
-
log.info(f"Processing data file: {data_path}")
|
57 |
-
|
58 |
-
# Forward-fill all non-label columns
|
59 |
-
df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0)
|
60 |
-
|
61 |
-
if 'Y' in df:
|
62 |
-
log.info(f"Performing pre-transformation target validation.")
|
63 |
-
# TODO: check sklearn.utils.multiclass.check_classification_targets
|
64 |
-
match task:
|
65 |
-
case 'regression':
|
66 |
-
assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \
|
67 |
-
f"""`Y` must be numeric for `regression` task,
|
68 |
-
but it has {set(df['Y'].apply(type))}."""
|
69 |
-
|
70 |
-
case 'binary':
|
71 |
-
if all(df['Y'].isin([0, 1])):
|
72 |
-
assert not thresholds, \
|
73 |
-
f"""`Y` is already 0 or 1 for `binary` (classification) `task`,
|
74 |
-
but still got `thresholds` {thresholds}.
|
75 |
-
Double check your choices of `task` and `thresholds` and records in the `Y` column."""
|
76 |
-
else:
|
77 |
-
assert thresholds, \
|
78 |
-
f"""`Y` must be 0 or 1 for `binary` (classification) `task`,
|
79 |
-
but it has {pd.unique(df['Y'])}.
|
80 |
-
You must set `thresholds` to discretize continuous labels."""
|
81 |
-
|
82 |
-
case 'multiclass':
|
83 |
-
assert n_class >= 3, f'`n_class` for `multiclass` (classification) `task` must be at least 3.'
|
84 |
-
|
85 |
-
if all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)):
|
86 |
-
assert not thresholds, \
|
87 |
-
f"""`Y` is already non-negative integers for
|
88 |
-
`multiclass` (classification) `task`, but still got `thresholds` {thresholds}.
|
89 |
-
Double check your choice of `task`, `thresholds` and records in the `Y` column."""
|
90 |
-
else:
|
91 |
-
assert thresholds, \
|
92 |
-
f"""`Y` must be non-negative integers for
|
93 |
-
`multiclass` (classification) 'task',but it has {pd.unique(df['Y'])}.
|
94 |
-
You must set `thresholds` to discretize continuous labels."""
|
95 |
-
|
96 |
-
if 'U' in df.columns:
|
97 |
-
units = df['U']
|
98 |
-
else:
|
99 |
-
units = None
|
100 |
-
log.warning("Units ('U') not in the data table. "
|
101 |
-
"Assuming all labels to be discrete or in p-scale (-log10[M]).")
|
102 |
-
|
103 |
-
# Transform labels
|
104 |
-
df['Y'] = label_transform(labels=df['Y'], units=units, thresholds=thresholds,
|
105 |
-
discard_intermediate=discard_intermediate)
|
106 |
-
|
107 |
-
# Filter out rows with a NaN in Y (missing values)
|
108 |
-
df.dropna(subset=['Y'], inplace=True)
|
109 |
-
|
110 |
-
log.info(f"Performing post-transformation target validation.")
|
111 |
-
match task:
|
112 |
-
case 'regression':
|
113 |
-
df['Y'] = df['Y'].astype('float32')
|
114 |
-
assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \
|
115 |
-
f"""`Y` must be numeric for `regression` task,
|
116 |
-
but after transformation it still has {set(df['Y'].apply(type))}.
|
117 |
-
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
|
118 |
-
|
119 |
-
case 'binary':
|
120 |
-
df['Y'] = df['Y'].astype('int')
|
121 |
-
assert all(df['Y'].isin([0, 1])), \
|
122 |
-
f"""`Y` must be 0 or 1 for `binary` (classification) `task`, "
|
123 |
-
but after transformation it still has {pd.unique(df['Y'])}.
|
124 |
-
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
|
125 |
-
|
126 |
-
case 'multiclass':
|
127 |
-
df['Y'] = df['Y'].astype('int')
|
128 |
-
assert all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)), \
|
129 |
-
f"""Y must be non-negative integers for task `multiclass` (classification)
|
130 |
-
but after transformation it still has {pd.unique(df['Y'])}.
|
131 |
-
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
|
132 |
-
|
133 |
-
target_n_unique = df['Y'].nunique()
|
134 |
-
assert target_n_unique == n_class, \
|
135 |
-
f"""You have set `n_class` for `multiclass` (classification) `task` to {n_class},
|
136 |
-
but after transformation Y still has {target_n_unique} unique labels.
|
137 |
-
Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
|
138 |
-
|
139 |
-
# Indexed protein/FASTA for retrieval metrics
|
140 |
-
df['IDX'] = LabelEncoder().fit_transform(df['X2'])
|
141 |
-
|
142 |
-
self.df = df
|
143 |
-
self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x)
|
144 |
-
self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x)
|
145 |
-
|
146 |
-
def __len__(self):
|
147 |
-
return len(self.df.index)
|
148 |
-
|
149 |
-
def __getitem__(self, i):
|
150 |
-
sample = self.df.loc[i]
|
151 |
-
return {
|
152 |
-
'N': i,
|
153 |
-
'X1': self.drug_featurizer(sample['X1']),
|
154 |
-
'ID1': sample.get('ID1', sample['X1']),
|
155 |
-
'X2': self.protein_featurizer(sample['X2']),
|
156 |
-
'ID2': sample.get('ID2', sample['X2']),
|
157 |
-
'Y': sample.get('Y'),
|
158 |
-
'IDX': sample['IDX'],
|
159 |
-
}
|
160 |
-
|
161 |
-
|
162 |
-
class DTIDataModule(LightningDataModule):
|
163 |
-
"""
|
164 |
-
DTI DataModule
|
165 |
-
|
166 |
-
A DataModule implements 5 key methods:
|
167 |
-
|
168 |
-
def prepare_data(self):
|
169 |
-
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
|
170 |
-
# download data, pre-process, split, save to disk, etc.
|
171 |
-
def setup(self, stage):
|
172 |
-
# things to do on every process in DDP
|
173 |
-
# load data, set variables, etc.
|
174 |
-
def train_dataloader(self):
|
175 |
-
# return train dataloader
|
176 |
-
def val_dataloader(self):
|
177 |
-
# return validation dataloader
|
178 |
-
def test_dataloader(self):
|
179 |
-
# return test dataloader
|
180 |
-
def teardown(self):
|
181 |
-
# called on every process in DDP
|
182 |
-
# clean up after fit or test
|
183 |
-
|
184 |
-
This allows you to share a full dataset without explaining how to download,
|
185 |
-
split, transform and process the data.
|
186 |
-
|
187 |
-
Read the docs:
|
188 |
-
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
|
189 |
-
"""
|
190 |
-
|
191 |
-
def __init__(
|
192 |
-
self,
|
193 |
-
task: Literal['regression', 'binary', 'multiclass'],
|
194 |
-
n_class: Optional[int],
|
195 |
-
batch_size: int,
|
196 |
-
# train: bool,
|
197 |
-
drug_featurizer: callable,
|
198 |
-
protein_featurizer: callable,
|
199 |
-
collator: callable = collate_fn,
|
200 |
-
data_dir: str = "data/",
|
201 |
-
data_file: Optional[str] = None,
|
202 |
-
train_val_test_split: Optional[Union[Sequence[Number | str]]] = None,
|
203 |
-
split: Optional[callable] = None,
|
204 |
-
thresholds: Optional[Union[Number, Sequence[Number]]] = None,
|
205 |
-
discard_intermediate: Optional[bool] = False,
|
206 |
-
num_workers: int = 0,
|
207 |
-
pin_memory: bool = False,
|
208 |
-
):
|
209 |
-
super().__init__()
|
210 |
-
|
211 |
-
self.train_data: Optional[Dataset] = None
|
212 |
-
self.val_data: Optional[Dataset] = None
|
213 |
-
self.test_data: Optional[Dataset] = None
|
214 |
-
self.predict_data: Optional[Dataset] = None
|
215 |
-
self.split = split
|
216 |
-
self.collator = collator
|
217 |
-
self.dataset = partial(
|
218 |
-
DTIDataset,
|
219 |
-
task=task,
|
220 |
-
n_class=n_class,
|
221 |
-
drug_featurizer=drug_featurizer,
|
222 |
-
protein_featurizer=protein_featurizer,
|
223 |
-
thresholds=thresholds,
|
224 |
-
discard_intermediate=discard_intermediate
|
225 |
-
)
|
226 |
-
|
227 |
-
if train_val_test_split:
|
228 |
-
# TODO test behavior for trainer.test and predict when this is passed
|
229 |
-
if len(train_val_test_split) not in [2, 3]:
|
230 |
-
raise ValueError('Length of `train_val_test_split` must be 2 (for training without testing) or 3.')
|
231 |
-
if all([data_file, split]):
|
232 |
-
if all(isinstance(split, Number) for split in train_val_test_split):
|
233 |
-
pass
|
234 |
-
else:
|
235 |
-
raise ValueError('`train_val_test_split` must be a sequence numbers '
|
236 |
-
'(float for percentages and int for sample numbers) '
|
237 |
-
'if both `data_file` and `split` have been specified.')
|
238 |
-
elif all(isinstance(split, str) for split in train_val_test_split) and not any([data_file, split]):
|
239 |
-
split_paths = []
|
240 |
-
for split in train_val_test_split:
|
241 |
-
split = Path(split)
|
242 |
-
if not split.is_absolute():
|
243 |
-
split = Path(data_dir, split)
|
244 |
-
split_paths.append(split)
|
245 |
-
|
246 |
-
self.train_data = self.dataset(data_path=split_paths[0])
|
247 |
-
self.val_data = self.dataset(data_path=split_paths[1])
|
248 |
-
if len(train_val_test_split) == 3:
|
249 |
-
self.test_data = self.dataset(data_path=split_paths[2])
|
250 |
-
else:
|
251 |
-
raise ValueError('For training, you must specify either `data_file`, `split`, '
|
252 |
-
'and `train_val_test_split` as a sequence of numbers or '
|
253 |
-
'solely `train_val_test_split` as a sequence of data file paths.')
|
254 |
-
|
255 |
-
elif data_file and not any([split, train_val_test_split]):
|
256 |
-
data_file = Path(data_file)
|
257 |
-
if not data_file.is_absolute():
|
258 |
-
data_file = Path(data_dir, data_file)
|
259 |
-
self.test_data = self.predict_data = self.dataset(data_path=data_file)
|
260 |
-
else:
|
261 |
-
raise ValueError("For training, you must specify `train_val_test_split`. "
|
262 |
-
"For testing/predicting, you must specify only `data_file` without "
|
263 |
-
"`train_val_test_split` or `split`.")
|
264 |
-
|
265 |
-
# this line allows to access init params with 'self.hparams' attribute
|
266 |
-
# also ensures init params will be stored in ckpt
|
267 |
-
self.save_hyperparameters(logger=False) # ignore=['split']
|
268 |
-
|
269 |
-
def prepare_data(self):
|
270 |
-
"""
|
271 |
-
Download data if needed.
|
272 |
-
Do not use it to assign state (e.g., self.x = x).
|
273 |
-
"""
|
274 |
-
|
275 |
-
def setup(self, stage: Optional[str] = None, encoding: str = None):
|
276 |
-
"""
|
277 |
-
Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
|
278 |
-
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
|
279 |
-
careful not to execute data splitting twice.
|
280 |
-
"""
|
281 |
-
# TODO test SafeBatchSampler (which skips samples with any None without introducing variable batch size)
|
282 |
-
# load and split datasets only if not loaded in initialization
|
283 |
-
if not any([self.train_data, self.test_data, self.val_data, self.predict_data]):
|
284 |
-
self.train_data, self.val_data, self.test_data = self.split(
|
285 |
-
dataset=self.dataset(data_path=Path(self.hparams.data_dir, self.hparams.data_file)),
|
286 |
-
lengths=self.hparams.train_val_test_split
|
287 |
-
)
|
288 |
-
|
289 |
-
def train_dataloader(self):
|
290 |
-
return DataLoader(
|
291 |
-
dataset=self.train_data,
|
292 |
-
batch_sampler=SafeBatchSampler(
|
293 |
-
data_source=self.train_data,
|
294 |
-
batch_size=self.hparams.batch_size,
|
295 |
-
# Dropping the last batch prevents problems caused by variable batch sizes in training, e.g.,
|
296 |
-
# batch_size=1 in BatchNorm, and shuffling ensures the model be trained on all samples over epochs.
|
297 |
-
drop_last=True,
|
298 |
-
shuffle=True,
|
299 |
-
),
|
300 |
-
# batch_size=self.hparams.batch_size,
|
301 |
-
# shuffle=True,
|
302 |
-
num_workers=self.hparams.num_workers,
|
303 |
-
pin_memory=self.hparams.pin_memory,
|
304 |
-
collate_fn=self.collator,
|
305 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
306 |
-
)
|
307 |
-
|
308 |
-
def val_dataloader(self):
|
309 |
-
return DataLoader(
|
310 |
-
dataset=self.val_data,
|
311 |
-
batch_sampler=SafeBatchSampler(
|
312 |
-
data_source=self.val_data,
|
313 |
-
batch_size=self.hparams.batch_size,
|
314 |
-
drop_last=False,
|
315 |
-
shuffle=False
|
316 |
-
),
|
317 |
-
# batch_size=self.hparams.batch_size,
|
318 |
-
# shuffle=False,
|
319 |
-
num_workers=self.hparams.num_workers,
|
320 |
-
pin_memory=self.hparams.pin_memory,
|
321 |
-
collate_fn=self.collator,
|
322 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
323 |
-
)
|
324 |
-
|
325 |
-
def test_dataloader(self):
|
326 |
-
return DataLoader(
|
327 |
-
dataset=self.test_data,
|
328 |
-
batch_sampler=SafeBatchSampler(
|
329 |
-
data_source=self.test_data,
|
330 |
-
batch_size=self.hparams.batch_size,
|
331 |
-
drop_last=False,
|
332 |
-
shuffle=False
|
333 |
-
),
|
334 |
-
# batch_size=self.hparams.batch_size,
|
335 |
-
# shuffle=False,
|
336 |
-
num_workers=self.hparams.num_workers,
|
337 |
-
pin_memory=self.hparams.pin_memory,
|
338 |
-
collate_fn=self.collator,
|
339 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
340 |
-
)
|
341 |
-
|
342 |
-
def predict_dataloader(self):
|
343 |
-
return DataLoader(
|
344 |
-
dataset=self.predict_data,
|
345 |
-
batch_sampler=SafeBatchSampler(
|
346 |
-
data_source=self.predict_data,
|
347 |
-
batch_size=self.hparams.batch_size,
|
348 |
-
drop_last=False,
|
349 |
-
shuffle=False
|
350 |
-
),
|
351 |
-
# batch_size=self.hparams.batch_size,
|
352 |
-
# shuffle=False,
|
353 |
-
num_workers=self.hparams.num_workers,
|
354 |
-
pin_memory=self.hparams.pin_memory,
|
355 |
-
collate_fn=self.collator,
|
356 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
357 |
-
)
|
358 |
-
|
359 |
-
def teardown(self, stage: Optional[str] = None):
|
360 |
-
"""Clean up after fit or test."""
|
361 |
-
pass
|
362 |
-
|
363 |
-
def state_dict(self):
|
364 |
-
"""Extra things to save to checkpoint."""
|
365 |
-
return {}
|
366 |
-
|
367 |
-
def load_state_dict(self, state_dict: Dict[str, Any]):
|
368 |
-
"""Things to do when loading checkpoint."""
|
369 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/dti_datamodule.py
DELETED
@@ -1,314 +0,0 @@
|
|
1 |
-
# from itertools import product
|
2 |
-
from collections import namedtuple
|
3 |
-
from numbers import Number
|
4 |
-
from typing import Any, Dict, Optional, Sequence, Union, Literal
|
5 |
-
|
6 |
-
# import numpy as np
|
7 |
-
import pandas as pd
|
8 |
-
from lightning import LightningDataModule
|
9 |
-
from torch.utils.data import Dataset, DataLoader, random_split
|
10 |
-
|
11 |
-
from deepscreen.data.utils.label import label_transform
|
12 |
-
from deepscreen.data.utils.collator import collate_fn
|
13 |
-
from deepscreen.data.utils.sampler import SafeBatchSampler
|
14 |
-
|
15 |
-
|
16 |
-
class DTIDataset(Dataset):
|
17 |
-
def __init__(
|
18 |
-
self,
|
19 |
-
task: Literal['regression', 'binary', 'multiclass'],
|
20 |
-
n_classes: Optional[int],
|
21 |
-
data_dir: str,
|
22 |
-
dataset_name: str,
|
23 |
-
drug_featurizer: callable,
|
24 |
-
protein_featurizer: callable,
|
25 |
-
thresholds: Optional[Union[Number, Sequence[Number]]] = None,
|
26 |
-
discard_intermediate: Optional[bool] = False,
|
27 |
-
):
|
28 |
-
df = pd.read_csv(
|
29 |
-
f'{data_dir}{dataset_name}.csv',
|
30 |
-
header=0, sep=',',
|
31 |
-
usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'],
|
32 |
-
dtype={'X1': 'str', 'ID1': 'str',
|
33 |
-
'X2': 'str', 'ID2': 'str',
|
34 |
-
'Y': 'float32', 'U': 'str'}
|
35 |
-
)
|
36 |
-
# if 'ID1' in df:
|
37 |
-
# self.x1_to_id1 = dict(zip(df['X1'], df['ID1']))
|
38 |
-
# if 'ID2' in df:
|
39 |
-
# self.x2_to_id2 = dict(zip(df['X2'], df['ID2']))
|
40 |
-
# self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2']))))
|
41 |
-
# self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2']))))
|
42 |
-
|
43 |
-
# # train and eval mode data processing (fully labelled)
|
44 |
-
# if 'Y' in df.columns and df['Y'].notnull().all():
|
45 |
-
|
46 |
-
# Forward-fill all non-label columns
|
47 |
-
df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0)
|
48 |
-
|
49 |
-
if 'Y' in df:
|
50 |
-
# Transform labels
|
51 |
-
df['Y'] = df['Y'].apply(label_transform, units=df.get('U', None), thresholds=thresholds,
|
52 |
-
discard_intermediate=discard_intermediate).astype('float32')
|
53 |
-
|
54 |
-
# Filter out rows with a NaN in Y (missing values)
|
55 |
-
df.dropna(subset=['Y'], inplace=True)
|
56 |
-
|
57 |
-
# Validate target labels for training/testing
|
58 |
-
# TODO: check sklearn.utils.multiclass.check_classification_targets
|
59 |
-
match task:
|
60 |
-
case 'regression':
|
61 |
-
assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \
|
62 |
-
f"Y for task `regression` must be numeric; got {set(df['Y'].apply(type))}."
|
63 |
-
case 'binary':
|
64 |
-
assert all(df['Y'].isin([0, 1])), \
|
65 |
-
f"Y for task `binary` (classification) must be 0 or 1, but Y got {pd.unique(df['Y'])}." \
|
66 |
-
"\nYou may set `thresholds` to discretize continuous labels."
|
67 |
-
case 'multiclass':
|
68 |
-
assert n_classes >= 3, f'n_classes for task `multiclass` (classification) must be at least 3.'
|
69 |
-
assert all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)), \
|
70 |
-
f"Y for task `multiclass` (classification) must be non-negative integers, " \
|
71 |
-
f"but Y got {pd.unique(df['Y'])}." \
|
72 |
-
"\nYou may set `thresholds` to discretize continuous labels."
|
73 |
-
target_n_unique = df['Y'].nunique()
|
74 |
-
assert target_n_unique == n_classes, \
|
75 |
-
f"You have set n_classes for task `multiclass` (classification) task to {n_classes}, " \
|
76 |
-
f"but Y has {target_n_unique} unique labels."
|
77 |
-
|
78 |
-
# # Predict mode data processing
|
79 |
-
# else:
|
80 |
-
# df = pd.DataFrame(product(df['X1'].dropna(), df['X2'].dropna()), columns=['X1', 'X2'])
|
81 |
-
# if hasattr(self, "x1_to_id1"):
|
82 |
-
# df['ID1'] = df['X1'].map(self.x1_to_id1)
|
83 |
-
# if hasattr(self, "x1_to_id2"):
|
84 |
-
# df['ID2'] = df['X2'].map(self.x2_to_id2)
|
85 |
-
|
86 |
-
# self.smiles = df['X1']
|
87 |
-
# self.fasta = df['X2']
|
88 |
-
# self.smiles_ids = df.get('ID1', df['X1'])
|
89 |
-
# self.fasta_ids = df.get('ID2', df['X2'])
|
90 |
-
# self.labels = df.get('Y', None)
|
91 |
-
|
92 |
-
self.df = df
|
93 |
-
self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x)
|
94 |
-
self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x)
|
95 |
-
self.n_classes = df['Y'].nunique()
|
96 |
-
# self.train = train
|
97 |
-
|
98 |
-
self.Data = namedtuple('Data', ['FT1', 'ID1', 'FT2', 'ID2', 'Y'])
|
99 |
-
|
100 |
-
def __len__(self):
|
101 |
-
return len(self.df.index)
|
102 |
-
|
103 |
-
def __getitem__(self, idx):
|
104 |
-
sample = self.df.loc[idx]
|
105 |
-
return self.Data(
|
106 |
-
FT1=self.drug_featurizer(sample['X1']),
|
107 |
-
ID1=sample.get('ID1', sample['X1']),
|
108 |
-
FT2=self.protein_featurizer(sample['X2']),
|
109 |
-
ID2=sample.get('ID2', sample['X2']),
|
110 |
-
Y=sample.get('Y')
|
111 |
-
)
|
112 |
-
# {
|
113 |
-
# 'FT1': self.drug_featurizer(sample['X1']),
|
114 |
-
# 'ID1': sample.get('ID1', sample['X1']),
|
115 |
-
# 'FT2': self.protein_featurizer(sample['X2']),
|
116 |
-
# 'ID2': sample.get('ID2', sample['X2']),
|
117 |
-
# 'Y': sample.get('Y')
|
118 |
-
# }
|
119 |
-
# if self.train:
|
120 |
-
# sample = self.drug_featurizer(self.smiles[idx]), self.protein_featurizer(self.fasta[idx]), self.labels[idx]
|
121 |
-
# sample = {
|
122 |
-
# 'FT1': self.drug_featurizer(self.smiles[idx]),
|
123 |
-
# 'FT2': self.protein_featurizer(self.fasta[idx]),
|
124 |
-
# 'ID2': self.smiles_ids[idx],
|
125 |
-
# }
|
126 |
-
# else:
|
127 |
-
# # sample = self.drug_featurizer(self.smiles[idx]), self.protein_featurizer(self.fasta[idx])
|
128 |
-
# sample = {
|
129 |
-
# 'FT1': self.drug_featurizer(self.smiles[idx]),
|
130 |
-
# 'FT2': self.protein_featurizer(self.fasta[idx]),
|
131 |
-
# }
|
132 |
-
#
|
133 |
-
# if all([True if n is not None else False for n in sample.values()]):
|
134 |
-
# return sample # | {
|
135 |
-
# # 'ID1': self.smiles_ids[idx],
|
136 |
-
# # 'X1': self.drug_featurizer(self.smiles[idx]),
|
137 |
-
# # 'ID2': self.fasta_ids[idx],
|
138 |
-
# # 'X2': self.protein_featurizer(self.fasta[idx]),
|
139 |
-
# # }
|
140 |
-
# else:
|
141 |
-
# return self.__getitem__(np.random.randint(0, self.size))
|
142 |
-
|
143 |
-
|
144 |
-
class DTIdatamodule(LightningDataModule):
|
145 |
-
"""
|
146 |
-
DTI DataModule
|
147 |
-
|
148 |
-
A DataModule implements 5 key methods:
|
149 |
-
|
150 |
-
def prepare_data(self):
|
151 |
-
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
|
152 |
-
# download data, pre-process, split, save to disk, etc.
|
153 |
-
def setup(self, stage):
|
154 |
-
# things to do on every process in DDP
|
155 |
-
# load data, set variables, etc.
|
156 |
-
def train_dataloader(self):
|
157 |
-
# return train dataloader
|
158 |
-
def val_dataloader(self):
|
159 |
-
# return validation dataloader
|
160 |
-
def test_dataloader(self):
|
161 |
-
# return test dataloader
|
162 |
-
def teardown(self):
|
163 |
-
# called on every process in DDP
|
164 |
-
# clean up after fit or test
|
165 |
-
|
166 |
-
This allows you to share a full dataset without explaining how to download,
|
167 |
-
split, transform and process the data.
|
168 |
-
|
169 |
-
Read the docs:
|
170 |
-
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
|
171 |
-
"""
|
172 |
-
|
173 |
-
def __init__(
|
174 |
-
self,
|
175 |
-
task: Literal['regression', 'binary', 'multiclass'],
|
176 |
-
n_classes: Optional[int],
|
177 |
-
train: bool,
|
178 |
-
drug_featurizer: callable,
|
179 |
-
protein_featurizer: callable,
|
180 |
-
batch_size: int,
|
181 |
-
train_val_test_split: Optional[Sequence[Number]],
|
182 |
-
num_workers: int = 0,
|
183 |
-
thresholds: Optional[Union[Number, Sequence[Number]]] = None,
|
184 |
-
pin_memory: bool = False,
|
185 |
-
data_dir: str = "data/",
|
186 |
-
dataset_name: Optional[str] = None,
|
187 |
-
split: Optional[callable] = random_split,
|
188 |
-
):
|
189 |
-
super().__init__()
|
190 |
-
|
191 |
-
# this line allows to access init params with 'self.hparams' attribute
|
192 |
-
# also ensures init params will be stored in ckpt
|
193 |
-
self.save_hyperparameters(logger=False)
|
194 |
-
|
195 |
-
# data processing
|
196 |
-
self.data_split = split
|
197 |
-
|
198 |
-
self.data_train: Optional[Dataset] = None
|
199 |
-
self.data_val: Optional[Dataset] = None
|
200 |
-
self.data_test: Optional[Dataset] = None
|
201 |
-
self.data_predict: Optional[Dataset] = None
|
202 |
-
|
203 |
-
def prepare_data(self):
|
204 |
-
"""
|
205 |
-
Download data if needed.
|
206 |
-
Do not use it to assign state (e.g., self.x = x).
|
207 |
-
"""
|
208 |
-
|
209 |
-
def setup(self, stage: Optional[str] = None, encoding: str = None):
|
210 |
-
"""
|
211 |
-
Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
|
212 |
-
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
|
213 |
-
careful not to execute data splitting twice.
|
214 |
-
"""
|
215 |
-
# TODO test SafeBatchSampler (which skips samples with any None without introducing variable batch size)
|
216 |
-
# load and split datasets only if not loaded in initialization
|
217 |
-
if not any([self.data_train, self.data_val, self.data_test, self.data_predict]):
|
218 |
-
dataset = DTIDataset(
|
219 |
-
task=self.hparams.task,
|
220 |
-
n_classes=self.hparams.n_classes,
|
221 |
-
data_dir=self.hparams.data_dir,
|
222 |
-
drug_featurizer=self.hparams.drug_featurizer,
|
223 |
-
protein_featurizer=self.hparams.protein_featurizer,
|
224 |
-
dataset_name=self.hparams.dataset_name,
|
225 |
-
thresholds=self.hparams.thresholds,
|
226 |
-
)
|
227 |
-
|
228 |
-
if self.hparams.train:
|
229 |
-
self.data_train, self.data_val, self.data_test = self.data_split(
|
230 |
-
dataset=dataset,
|
231 |
-
lengths=self.hparams.train_val_test_split
|
232 |
-
)
|
233 |
-
else:
|
234 |
-
self.data_test = self.data_predict = dataset
|
235 |
-
|
236 |
-
def train_dataloader(self):
|
237 |
-
return DataLoader(
|
238 |
-
dataset=self.data_train,
|
239 |
-
batch_sampler=SafeBatchSampler(
|
240 |
-
data_source=self.data_train,
|
241 |
-
batch_size=self.hparams.batch_size,
|
242 |
-
drop_last=True,
|
243 |
-
shuffle=True,
|
244 |
-
),
|
245 |
-
# batch_size=self.hparams.batch_size,
|
246 |
-
# shuffle=True,
|
247 |
-
num_workers=self.hparams.num_workers,
|
248 |
-
pin_memory=self.hparams.pin_memory,
|
249 |
-
collate_fn=collate_fn,
|
250 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
251 |
-
)
|
252 |
-
|
253 |
-
def val_dataloader(self):
|
254 |
-
return DataLoader(
|
255 |
-
dataset=self.data_val,
|
256 |
-
batch_sampler=SafeBatchSampler(
|
257 |
-
data_source=self.data_val,
|
258 |
-
batch_size=self.hparams.batch_size,
|
259 |
-
drop_last=False,
|
260 |
-
shuffle=False,
|
261 |
-
),
|
262 |
-
# batch_size=self.hparams.batch_size,
|
263 |
-
# shuffle=False,
|
264 |
-
num_workers=self.hparams.num_workers,
|
265 |
-
pin_memory=self.hparams.pin_memory,
|
266 |
-
collate_fn=collate_fn,
|
267 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
268 |
-
)
|
269 |
-
|
270 |
-
def test_dataloader(self):
|
271 |
-
return DataLoader(
|
272 |
-
dataset=self.data_test,
|
273 |
-
batch_sampler=SafeBatchSampler(
|
274 |
-
data_source=self.data_test,
|
275 |
-
batch_size=self.hparams.batch_size,
|
276 |
-
drop_last=False,
|
277 |
-
shuffle=False,
|
278 |
-
),
|
279 |
-
# batch_size=self.hparams.batch_size,
|
280 |
-
# shuffle=False,
|
281 |
-
num_workers=self.hparams.num_workers,
|
282 |
-
pin_memory=self.hparams.pin_memory,
|
283 |
-
collate_fn=collate_fn,
|
284 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
285 |
-
)
|
286 |
-
|
287 |
-
def predict_dataloader(self):
|
288 |
-
return DataLoader(
|
289 |
-
dataset=self.data_predict,
|
290 |
-
batch_sampler=SafeBatchSampler(
|
291 |
-
data_source=self.data_predict,
|
292 |
-
batch_size=self.hparams.batch_size,
|
293 |
-
drop_last=False,
|
294 |
-
shuffle=False,
|
295 |
-
),
|
296 |
-
# batch_size=self.hparams.batch_size,
|
297 |
-
# shuffle=False,
|
298 |
-
num_workers=self.hparams.num_workers,
|
299 |
-
pin_memory=self.hparams.pin_memory,
|
300 |
-
collate_fn=collate_fn,
|
301 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
302 |
-
)
|
303 |
-
|
304 |
-
def teardown(self, stage: Optional[str] = None):
|
305 |
-
"""Clean up after fit or test."""
|
306 |
-
pass
|
307 |
-
|
308 |
-
def state_dict(self):
|
309 |
-
"""Extra things to save to checkpoint."""
|
310 |
-
return {}
|
311 |
-
|
312 |
-
def load_state_dict(self, state_dict: Dict[str, Any]):
|
313 |
-
"""Things to do when loading checkpoint."""
|
314 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/entity_datamodule.py
DELETED
@@ -1,167 +0,0 @@
|
|
1 |
-
from numbers import Number
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import Any, Dict, Optional, Sequence, Type
|
4 |
-
|
5 |
-
from lightning import LightningDataModule
|
6 |
-
from sklearn.base import TransformerMixin
|
7 |
-
from torch.utils.data import Dataset, DataLoader
|
8 |
-
|
9 |
-
from deepscreen.data.utils import collate_fn, SafeBatchSampler
|
10 |
-
from deepscreen.data.utils.dataset import BaseEntityDataset
|
11 |
-
|
12 |
-
|
13 |
-
class EntityDataModule(LightningDataModule):
|
14 |
-
"""
|
15 |
-
def prepare_data(self):
|
16 |
-
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
|
17 |
-
# download data, pre-process, split, save to disk, etc.
|
18 |
-
def setup(self, stage):
|
19 |
-
# things to do on every process in DDP
|
20 |
-
# load data, set variables, etc.
|
21 |
-
def train_dataloader(self):
|
22 |
-
# return train dataloader
|
23 |
-
def val_dataloader(self):
|
24 |
-
# return validation dataloader
|
25 |
-
def test_dataloader(self):
|
26 |
-
# return test dataloader
|
27 |
-
def teardown(self):
|
28 |
-
# called on every process in DDP
|
29 |
-
# clean up after fit or test
|
30 |
-
"""
|
31 |
-
def __init__(
|
32 |
-
self,
|
33 |
-
dataset: type[BaseEntityDataset],
|
34 |
-
transformer: type[TransformerMixin],
|
35 |
-
train: bool,
|
36 |
-
batch_size: int,
|
37 |
-
data_dir: str = "data/",
|
38 |
-
data_file: Optional[str] = None,
|
39 |
-
train_val_test_split: Optional[Sequence[Number], Sequence[str]] = None,
|
40 |
-
split: Optional[callable] = None,
|
41 |
-
num_workers: int = 0,
|
42 |
-
pin_memory: bool = False,
|
43 |
-
):
|
44 |
-
super().__init__()
|
45 |
-
|
46 |
-
# data processing
|
47 |
-
self.split = split
|
48 |
-
|
49 |
-
if train:
|
50 |
-
if all([data_file, split]):
|
51 |
-
if all(isinstance(split, Number) for split in train_val_test_split):
|
52 |
-
pass
|
53 |
-
else:
|
54 |
-
raise ValueError('`train_val_test_split` must be a sequence of 3 numbers '
|
55 |
-
'(float for percentages and int for sample numbers) if '
|
56 |
-
'`data_file` and `split` have been specified.')
|
57 |
-
elif all(isinstance(split, str) for split in train_val_test_split) and not any([data_file, split]):
|
58 |
-
self.train_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[0]))
|
59 |
-
self.val_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[1]))
|
60 |
-
self.test_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[2]))
|
61 |
-
else:
|
62 |
-
raise ValueError('For training (train=True), you must specify either '
|
63 |
-
'`dataset_name` and `split` with `train_val_test_split` of 3 numbers or '
|
64 |
-
'solely `train_val_test_split` of 3 data file names.')
|
65 |
-
else:
|
66 |
-
if data_file and not any([split, train_val_test_split]):
|
67 |
-
self.test_data = self.predict_data = dataset(dataset_path=str(Path(data_dir) / data_file))
|
68 |
-
else:
|
69 |
-
raise ValueError("For testing/predicting (train=False), you must specify only `data_file` without "
|
70 |
-
"`train_val_test_split` or `split`")
|
71 |
-
|
72 |
-
# this line allows to access init params with 'self.hparams' attribute
|
73 |
-
# also ensures init params will be stored in ckpt
|
74 |
-
self.save_hyperparameters(logger=False)
|
75 |
-
def prepare_data(self):
|
76 |
-
"""
|
77 |
-
Download data if needed.
|
78 |
-
Do not use it to assign state (e.g., self.x = x).
|
79 |
-
"""
|
80 |
-
|
81 |
-
def setup(self, stage: Optional[str] = None, encoding: str = None):
|
82 |
-
"""
|
83 |
-
Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
|
84 |
-
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
|
85 |
-
careful not to execute data splitting twice.
|
86 |
-
"""
|
87 |
-
# TODO test SafeBatchSampler (which skips samples with any None without introducing variable batch size)
|
88 |
-
# TODO: find a way to apply transformer.fit_transform only to train and transformer.transform only to val, test
|
89 |
-
# load and split datasets only if not loaded in initialization
|
90 |
-
if not any([self.train_data, self.test_data, self.val_data, self.predict_data]):
|
91 |
-
self.train_data, self.val_data, self.test_data = self.split(
|
92 |
-
dataset=self.hparams.dataset(data_dir=self.hparams.data_dir,
|
93 |
-
dataset_name=self.hparams.train_dataset_name),
|
94 |
-
lengths=self.hparams.train_val_test_split
|
95 |
-
)
|
96 |
-
|
97 |
-
def train_dataloader(self):
|
98 |
-
return DataLoader(
|
99 |
-
dataset=self.train_data,
|
100 |
-
batch_sampler=SafeBatchSampler(
|
101 |
-
data_source=self.train_data,
|
102 |
-
batch_size=self.hparams.batch_size,
|
103 |
-
shuffle=True),
|
104 |
-
# batch_size=self.hparams.batch_size,
|
105 |
-
# shuffle=True,
|
106 |
-
num_workers=self.hparams.num_workers,
|
107 |
-
pin_memory=self.hparams.pin_memory,
|
108 |
-
collate_fn=collate_fn,
|
109 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
110 |
-
)
|
111 |
-
|
112 |
-
def val_dataloader(self):
|
113 |
-
return DataLoader(
|
114 |
-
dataset=self.val_data,
|
115 |
-
batch_sampler=SafeBatchSampler(
|
116 |
-
data_source=self.val_data,
|
117 |
-
batch_size=self.hparams.batch_size,
|
118 |
-
shuffle=False),
|
119 |
-
# batch_size=self.hparams.batch_size,
|
120 |
-
# shuffle=False,
|
121 |
-
num_workers=self.hparams.num_workers,
|
122 |
-
pin_memory=self.hparams.pin_memory,
|
123 |
-
collate_fn=collate_fn,
|
124 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
125 |
-
)
|
126 |
-
|
127 |
-
def test_dataloader(self):
|
128 |
-
return DataLoader(
|
129 |
-
dataset=self.test_data,
|
130 |
-
batch_sampler=SafeBatchSampler(
|
131 |
-
data_source=self.test_data,
|
132 |
-
batch_size=self.hparams.batch_size,
|
133 |
-
shuffle=False),
|
134 |
-
# batch_size=self.hparams.batch_size,
|
135 |
-
# shuffle=False,
|
136 |
-
num_workers=self.hparams.num_workers,
|
137 |
-
pin_memory=self.hparams.pin_memory,
|
138 |
-
collate_fn=collate_fn,
|
139 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
140 |
-
)
|
141 |
-
|
142 |
-
def predict_dataloader(self):
|
143 |
-
return DataLoader(
|
144 |
-
dataset=self.predict_data,
|
145 |
-
batch_sampler=SafeBatchSampler(
|
146 |
-
data_source=self.predict_data,
|
147 |
-
batch_size=self.hparams.batch_size,
|
148 |
-
shuffle=False),
|
149 |
-
# batch_size=self.hparams.batch_size,
|
150 |
-
# shuffle=False,
|
151 |
-
num_workers=self.hparams.num_workers,
|
152 |
-
pin_memory=self.hparams.pin_memory,
|
153 |
-
collate_fn=collate_fn,
|
154 |
-
persistent_workers=True if self.hparams.num_workers > 0 else False
|
155 |
-
)
|
156 |
-
|
157 |
-
def teardown(self, stage: Optional[str] = None):
|
158 |
-
"""Clean up after fit or test."""
|
159 |
-
pass
|
160 |
-
|
161 |
-
def state_dict(self):
|
162 |
-
"""Extra things to save to checkpoint."""
|
163 |
-
return {}
|
164 |
-
|
165 |
-
def load_state_dict(self, state_dict: Dict[str, Any]):
|
166 |
-
"""Things to do when loading checkpoint."""
|
167 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/featurizers/__init__.py
DELETED
File without changes
|
deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc
DELETED
Binary file (191 Bytes)
|
|
deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc
DELETED
Binary file (5.6 kB)
|
|
deepscreen/data/featurizers/__pycache__/fcs.cpython-311.pyc
DELETED
Binary file (4.17 kB)
|
|
deepscreen/data/featurizers/__pycache__/graph.cpython-311.pyc
DELETED
Binary file (7.21 kB)
|
|
deepscreen/data/featurizers/__pycache__/token.cpython-311.pyc
DELETED
Binary file (14.7 kB)
|
|
deepscreen/data/featurizers/categorical.py
DELETED
@@ -1,86 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
|
3 |
-
# Sets of KNOWN characters in SMILES and FASTA sequences
|
4 |
-
# Use list instead of set to preserve character order
|
5 |
-
SMILES_VOCAB = ('#', '%', ')', '(', '+', '-', '.', '1', '0', '3', '2', '5', '4',
|
6 |
-
'7', '6', '9', '8', '=', 'A', 'C', 'B', 'E', 'D', 'G', 'F', 'I',
|
7 |
-
'H', 'K', 'M', 'L', 'O', 'N', 'P', 'S', 'R', 'U', 'T', 'W', 'V',
|
8 |
-
'Y', '[', 'Z', ']', '_', 'a', 'c', 'b', 'e', 'd', 'g', 'f', 'i',
|
9 |
-
'h', 'm', 'l', 'o', 'n', 's', 'r', 'u', 't', 'y')
|
10 |
-
FASTA_VOCAB = ('A', 'C', 'B', 'E', 'D', 'G', 'F', 'I', 'H', 'K', 'M', 'L', 'O',
|
11 |
-
'N', 'Q', 'P', 'S', 'R', 'U', 'T', 'W', 'V', 'Y', 'X', 'Z')
|
12 |
-
|
13 |
-
# Check uniqueness, create character-index dicts, and add '?' for unknown characters as index 0
|
14 |
-
assert len(SMILES_VOCAB) == len(set(SMILES_VOCAB)), 'SMILES_CHARSET has duplicate characters.'
|
15 |
-
SMILES_CHARSET_IDX = {character: index+1 for index, character in enumerate(SMILES_VOCAB)} | {'?': 0}
|
16 |
-
|
17 |
-
assert len(FASTA_VOCAB) == len(set(FASTA_VOCAB)), 'FASTA_CHARSET has duplicate characters.'
|
18 |
-
FASTA_CHARSET_IDX = {character: index+1 for index, character in enumerate(FASTA_VOCAB)} | {'?': 0}
|
19 |
-
|
20 |
-
|
21 |
-
def sequence_to_onehot(sequence: str, charset, max_sequence_length: int):
|
22 |
-
assert len(charset) == len(set(charset)), '`charset` contains duplicate characters.'
|
23 |
-
charset_idx = {character: index+1 for index, character in enumerate(charset)} | {'?': 0}
|
24 |
-
|
25 |
-
onehot = np.zeros((max_sequence_length, len(charset_idx)), dtype=int)
|
26 |
-
for index, character in enumerate(sequence[:max_sequence_length]):
|
27 |
-
onehot[index, charset_idx.get(character, 0)] = 1
|
28 |
-
|
29 |
-
return onehot.transpose()
|
30 |
-
|
31 |
-
|
32 |
-
def sequence_to_label(sequence: str, charset, max_sequence_length: int):
|
33 |
-
assert len(charset) == len(set(charset)), '`charset` contains duplicate characters.'
|
34 |
-
charset_idx = {character: index+1 for index, character in enumerate(charset)} | {'?': 0}
|
35 |
-
|
36 |
-
label = np.zeros(max_sequence_length, dtype=int)
|
37 |
-
for index, character in enumerate(sequence[:max_sequence_length]):
|
38 |
-
label[index] = charset_idx.get(character, 0)
|
39 |
-
|
40 |
-
return label
|
41 |
-
|
42 |
-
|
43 |
-
def smiles_to_onehot(smiles: str, smiles_charset=SMILES_VOCAB, max_sequence_length: int = 100): # , in_channels: int = len(SMILES_CHARSET)
|
44 |
-
# assert len(SMILES_CHARSET) == len(set(SMILES_CHARSET)), 'SMILES_CHARSET has duplicate characters.'
|
45 |
-
# onehot = np.zeros((max_sequence_length, len(SMILES_CHARSET_IDX)))
|
46 |
-
# for index, character in enumerate(smiles[:max_sequence_length]):
|
47 |
-
# onehot[index, SMILES_CHARSET_IDX.get(character, 0)] = 1
|
48 |
-
# return onehot.transpose()
|
49 |
-
return sequence_to_onehot(smiles, smiles_charset, max_sequence_length)
|
50 |
-
|
51 |
-
|
52 |
-
def smiles_to_label(smiles: str, smiles_charset=SMILES_VOCAB, max_sequence_length: int = 100): # , in_channels: int = len(SMILES_CHARSET)
|
53 |
-
# label = np.zeros(max_sequence_length)
|
54 |
-
# for index, character in enumerate(smiles[:max_sequence_length]):
|
55 |
-
# label[index] = SMILES_CHARSET_IDX.get(character, 0)
|
56 |
-
# return label
|
57 |
-
return sequence_to_label(smiles, smiles_charset, max_sequence_length)
|
58 |
-
|
59 |
-
|
60 |
-
def fasta_to_onehot(fasta: str, fasta_charset=FASTA_VOCAB, max_sequence_length: int = 1000): # in_channels: int = len(FASTA_CHARSET)
|
61 |
-
# onehot = np.zeros((max_sequence_length, len(FASTA_CHARSET_IDX)))
|
62 |
-
# for index, character in enumerate(fasta[:max_sequence_length]):
|
63 |
-
# onehot[index, FASTA_CHARSET_IDX.get(character, 0)] = 1
|
64 |
-
# return onehot.transpose()
|
65 |
-
return sequence_to_onehot(fasta, fasta_charset, max_sequence_length)
|
66 |
-
|
67 |
-
|
68 |
-
def fasta_to_label(fasta: str, fasta_charset=FASTA_VOCAB, max_sequence_length: int = 1000): # in_channels: int = len(FASTA_CHARSET)
|
69 |
-
# label = np.zeros(max_sequence_length)
|
70 |
-
# for index, character in enumerate(fasta[:max_sequence_length]):
|
71 |
-
# label[index] = FASTA_CHARSET_IDX.get(character, 0)
|
72 |
-
# return label
|
73 |
-
return sequence_to_label(fasta, fasta_charset, max_sequence_length)
|
74 |
-
|
75 |
-
|
76 |
-
def one_of_k_encoding(x, allowable_set):
|
77 |
-
if x not in allowable_set:
|
78 |
-
raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
|
79 |
-
return list(map(lambda s: x == s, allowable_set))
|
80 |
-
|
81 |
-
|
82 |
-
def one_of_k_encoding_unk(x, allowable_set):
|
83 |
-
"""Maps inputs not in the allowable set to the last element."""
|
84 |
-
if x not in allowable_set:
|
85 |
-
x = allowable_set[-1]
|
86 |
-
return list(map(lambda s: x == s, allowable_set))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/featurizers/chem.py
DELETED
@@ -1,48 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Mainly adapted from MolMap:
|
3 |
-
https://github.com/shenwanxiang/bidd-molmap/tree/master/molmap/feature/fingerprint
|
4 |
-
"""
|
5 |
-
import numpy as np
|
6 |
-
from rdkit import Chem, DataStructs
|
7 |
-
from rdkit.Chem import AllChem
|
8 |
-
from rdkit.Chem.Fingerprints import FingerprintMols
|
9 |
-
from rdkit.Chem.rdReducedGraphs import GetErGFingerprint
|
10 |
-
|
11 |
-
from deepscreen import get_logger
|
12 |
-
|
13 |
-
log = get_logger(__name__)
|
14 |
-
|
15 |
-
|
16 |
-
def smiles_to_erg(smiles):
|
17 |
-
try:
|
18 |
-
mol = Chem.MolFromSmiles(smiles)
|
19 |
-
features = np.array(GetErGFingerprint(mol), dtype=bool)
|
20 |
-
return features
|
21 |
-
except Exception as e:
|
22 |
-
log.warning(f"Failed to convert SMILES ({smiles}) to ErGFP due to {str(e)}")
|
23 |
-
return None
|
24 |
-
|
25 |
-
|
26 |
-
def smiles_to_morgan(smiles, radius=2, n_bits=1024):
|
27 |
-
try:
|
28 |
-
mol = Chem.MolFromSmiles(smiles)
|
29 |
-
features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=n_bits)
|
30 |
-
features = np.zeros((1,))
|
31 |
-
DataStructs.ConvertToNumpyArray(features_vec, features)
|
32 |
-
except Exception as e:
|
33 |
-
log.warning(f"Failed to convert SMILES ({smiles}) to ErGFP due to {str(e)}")
|
34 |
-
return None
|
35 |
-
|
36 |
-
|
37 |
-
def smiles_to_daylight(smiles):
|
38 |
-
try:
|
39 |
-
NumFinger = 2048
|
40 |
-
mol = Chem.MolFromSmiles(smiles)
|
41 |
-
bv = FingerprintMols.FingerprintMol(mol)
|
42 |
-
temp = tuple(bv.GetOnBits())
|
43 |
-
features = np.zeros((NumFinger,))
|
44 |
-
features[np.array(temp)] = 1
|
45 |
-
except:
|
46 |
-
print(f'RDKit could not find this SMILES: {smiles} convert to all 0 features')
|
47 |
-
features = np.zeros((2048,))
|
48 |
-
return features.astype(int)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/featurizers/fcs.py
DELETED
@@ -1,67 +0,0 @@
|
|
1 |
-
from importlib import resources
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import pandas as pd
|
5 |
-
from subword_nmt.apply_bpe import BPE
|
6 |
-
import codecs
|
7 |
-
|
8 |
-
vocab_path = resources.files('deepscreen').parent.joinpath('resources/vocabs/ESPF/protein_codes_uniprot.txt')
|
9 |
-
bpe_codes_protein = codecs.open(vocab_path)
|
10 |
-
protein_bpe = BPE(bpe_codes_protein, merges=-1, separator='')
|
11 |
-
|
12 |
-
sub_csv_path = resources.files('deepscreen').parent.joinpath('resources/vocabs/ESPF/subword_units_map_uniprot.csv')
|
13 |
-
sub_csv = pd.read_csv(sub_csv_path)
|
14 |
-
idx2word_protein = sub_csv['index'].values
|
15 |
-
words2idx_protein = dict(zip(idx2word_protein, range(0, len(idx2word_protein))))
|
16 |
-
|
17 |
-
vocab_path = resources.files('deepscreen').parent.joinpath('resources/vocabs/ESPF/drug_codes_chembl.txt')
|
18 |
-
bpe_codes_drug = codecs.open(vocab_path)
|
19 |
-
drug_bpe = BPE(bpe_codes_drug, merges=-1, separator='')
|
20 |
-
|
21 |
-
sub_csv_path = resources.files('deepscreen').parent.joinpath('resources/vocabs/ESPF/subword_units_map_chembl.csv')
|
22 |
-
sub_csv = pd.read_csv(sub_csv_path)
|
23 |
-
idx2word_drug = sub_csv['index'].values
|
24 |
-
words2idx_drug = dict(zip(idx2word_drug, range(0, len(idx2word_drug))))
|
25 |
-
|
26 |
-
|
27 |
-
def protein_to_embedding(x, max_sequence_length):
|
28 |
-
max_p = max_sequence_length
|
29 |
-
t1 = protein_bpe.process_line(x).split() # split
|
30 |
-
try:
|
31 |
-
i1 = np.asarray([words2idx_protein[i] for i in t1]) # index
|
32 |
-
except:
|
33 |
-
i1 = np.array([0])
|
34 |
-
# print(x)
|
35 |
-
|
36 |
-
l = len(i1)
|
37 |
-
|
38 |
-
if l < max_p:
|
39 |
-
i = np.pad(i1, (0, max_p - l), 'constant', constant_values=0)
|
40 |
-
input_mask = ([1] * l) + ([0] * (max_p - l))
|
41 |
-
else:
|
42 |
-
i = i1[:max_p]
|
43 |
-
input_mask = [1] * max_p
|
44 |
-
|
45 |
-
return i, np.asarray(input_mask)
|
46 |
-
|
47 |
-
|
48 |
-
def drug_to_embedding(x, max_sequence_length):
|
49 |
-
max_d = max_sequence_length
|
50 |
-
t1 = drug_bpe.process_line(x).split() # split
|
51 |
-
try:
|
52 |
-
i1 = np.asarray([words2idx_drug[i] for i in t1]) # index
|
53 |
-
except:
|
54 |
-
i1 = np.array([0])
|
55 |
-
# print(x)
|
56 |
-
|
57 |
-
l = len(i1)
|
58 |
-
|
59 |
-
if l < max_d:
|
60 |
-
i = np.pad(i1, (0, max_d - l), 'constant', constant_values=0)
|
61 |
-
input_mask = ([1] * l) + ([0] * (max_d - l))
|
62 |
-
|
63 |
-
else:
|
64 |
-
i = i1[:max_d]
|
65 |
-
input_mask = [1] * max_d
|
66 |
-
|
67 |
-
return i, np.asarray(input_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/featurizers/fingerprint/__init__.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
from typing import Literal
|
2 |
-
|
3 |
-
from .atompairs import GetAtomPairFPs
|
4 |
-
from .avalonfp import GetAvalonFPs
|
5 |
-
from .rdkitfp import GetRDkitFPs
|
6 |
-
from .morganfp import GetMorganFPs
|
7 |
-
from .estatefp import GetEstateFPs
|
8 |
-
from .maccskeys import GetMACCSFPs
|
9 |
-
from .pharmErGfp import GetPharmacoErGFPs
|
10 |
-
from .pharmPointfp import GetPharmacoPFPs
|
11 |
-
from .pubchemfp import GetPubChemFPs
|
12 |
-
from .torsions import GetTorsionFPs
|
13 |
-
from .mhfp6 import GetMHFP6
|
14 |
-
# from .map4 import GetMAP4
|
15 |
-
from rdkit import Chem
|
16 |
-
|
17 |
-
from deepscreen import get_logger
|
18 |
-
|
19 |
-
log = get_logger(__name__)
|
20 |
-
|
21 |
-
FP_MAP = {
|
22 |
-
'MorganFP': GetMorganFPs,
|
23 |
-
'RDkitFP': GetRDkitFPs,
|
24 |
-
'AtomPairFP': GetAtomPairFPs,
|
25 |
-
'TorsionFP': GetTorsionFPs,
|
26 |
-
'AvalonFP': GetAvalonFPs,
|
27 |
-
'EstateFP': GetEstateFPs,
|
28 |
-
'MACCSFP': GetMACCSFPs,
|
29 |
-
'PharmacoErGFP': GetPharmacoErGFPs,
|
30 |
-
'PharmacoPFP': GetPharmacoPFPs,
|
31 |
-
'PubChemFP': GetPubChemFPs,
|
32 |
-
'MHFP6': GetMHFP6,
|
33 |
-
# 'MAP4': GetMAP4,
|
34 |
-
}
|
35 |
-
|
36 |
-
|
37 |
-
def smiles_to_fingerprint(smiles, fingerprint: Literal[tuple(FP_MAP.keys())], **kwargs):
|
38 |
-
func = FP_MAP[fingerprint]
|
39 |
-
try:
|
40 |
-
mol = Chem.MolFromSmiles(smiles)
|
41 |
-
arr = func(mol, **kwargs)
|
42 |
-
return arr
|
43 |
-
except Exception as e:
|
44 |
-
log.warning(f"Failed to convert SMILES ({smiles}) to {fingerprint} due to {str(e)}")
|
45 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/featurizers/fingerprint/__pycache__/__init__.cpython-311.pyc
DELETED
Binary file (2.18 kB)
|
|
deepscreen/data/featurizers/fingerprint/__pycache__/atompairs.cpython-311.pyc
DELETED
Binary file (1.03 kB)
|
|
deepscreen/data/featurizers/fingerprint/__pycache__/avalonfp.cpython-311.pyc
DELETED
Binary file (928 Bytes)
|
|
deepscreen/data/featurizers/fingerprint/__pycache__/estatefp.cpython-311.pyc
DELETED
Binary file (685 Bytes)
|
|
deepscreen/data/featurizers/fingerprint/__pycache__/maccskeys.cpython-311.pyc
DELETED
Binary file (1.3 kB)
|
|
deepscreen/data/featurizers/fingerprint/__pycache__/map4.cpython-311.pyc
DELETED
Binary file (7.61 kB)
|
|
deepscreen/data/featurizers/fingerprint/__pycache__/mhfp6.cpython-311.pyc
DELETED
Binary file (1.07 kB)
|
|
deepscreen/data/featurizers/fingerprint/__pycache__/morganfp.cpython-311.pyc
DELETED
Binary file (962 Bytes)
|
|
deepscreen/data/featurizers/fingerprint/__pycache__/pharmErGfp.cpython-311.pyc
DELETED
Binary file (2.4 kB)
|
|
deepscreen/data/featurizers/fingerprint/__pycache__/pharmPointfp.cpython-311.pyc
DELETED
Binary file (3.23 kB)
|
|
deepscreen/data/featurizers/fingerprint/__pycache__/pubchemfp.cpython-311.pyc
DELETED
Binary file (77.7 kB)
|
|
deepscreen/data/featurizers/fingerprint/__pycache__/rdkitfp.cpython-311.pyc
DELETED
Binary file (1.65 kB)
|
|
deepscreen/data/featurizers/fingerprint/__pycache__/torsions.cpython-311.pyc
DELETED
Binary file (1.04 kB)
|
|
deepscreen/data/featurizers/fingerprint/atompairs.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
from rdkit.Chem.AtomPairs import Pairs
|
2 |
-
from rdkit.Chem import DataStructs
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
_type = 'topological-based'
|
6 |
-
|
7 |
-
|
8 |
-
def GetAtomPairFPs(mol, nBits=2048, binary=True):
|
9 |
-
'''
|
10 |
-
atompairs fingerprints
|
11 |
-
'''
|
12 |
-
fp = Pairs.GetHashedAtomPairFingerprint(mol, nBits=nBits)
|
13 |
-
if binary:
|
14 |
-
arr = np.zeros((0,), dtype=np.bool_)
|
15 |
-
else:
|
16 |
-
arr = np.zeros((0,), dtype=np.int8)
|
17 |
-
DataStructs.ConvertToNumpyArray(fp, arr)
|
18 |
-
return arr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/featurizers/fingerprint/avalonfp.py
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
from rdkit.Chem import DataStructs
|
2 |
-
from rdkit.Avalon.pyAvalonTools import GetAvalonFP as GAFP
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
_type = 'topological-based'
|
6 |
-
|
7 |
-
|
8 |
-
def GetAvalonFPs(mol, nBits=2048):
|
9 |
-
'''
|
10 |
-
Avalon_fingerprints: https://pubs.acs.org/doi/pdf/10.1021/ci050413p
|
11 |
-
'''
|
12 |
-
|
13 |
-
fp = GAFP(mol, nBits=nBits)
|
14 |
-
arr = np.zeros((0,), dtype=np.bool_)
|
15 |
-
DataStructs.ConvertToNumpyArray(fp, arr)
|
16 |
-
return arr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/featurizers/fingerprint/estatefp.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
from rdkit.Chem.EState import Fingerprinter
|
2 |
-
import numpy as np
|
3 |
-
|
4 |
-
_type = 'Estate-based'
|
5 |
-
|
6 |
-
|
7 |
-
def GetEstateFPs(mol):
|
8 |
-
'''
|
9 |
-
79 bits Estate fps
|
10 |
-
'''
|
11 |
-
x = Fingerprinter.FingerprintMol(mol)[0]
|
12 |
-
return x.astype(np.bool_)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/featurizers/fingerprint/maccskeys.py
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
from rdkit.Chem import AllChem
|
2 |
-
from rdkit.Chem import DataStructs
|
3 |
-
import numpy as np
|
4 |
-
import pandas as pd
|
5 |
-
import os
|
6 |
-
|
7 |
-
_type = 'SMARTS-based'
|
8 |
-
|
9 |
-
file_path = os.path.dirname(__file__)
|
10 |
-
|
11 |
-
|
12 |
-
def GetMACCSFPs(mol):
|
13 |
-
'''
|
14 |
-
166 bits
|
15 |
-
'''
|
16 |
-
|
17 |
-
fp = AllChem.GetMACCSKeysFingerprint(mol)
|
18 |
-
|
19 |
-
arr = np.zeros((0,), dtype=np.bool_)
|
20 |
-
DataStructs.ConvertToNumpyArray(fp, arr)
|
21 |
-
return arr
|
22 |
-
|
23 |
-
|
24 |
-
def GetMACCSFPInfos():
|
25 |
-
return pd.read_excel(os.path.join(file_path, 'maccskeys.xlsx'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/featurizers/fingerprint/maccskeys.xlsx
DELETED
Binary file (14 kB)
|
|
deepscreen/data/featurizers/fingerprint/map4.py
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
MinHashed Atom-pair Fingerprint, MAP
|
3 |
-
orignal paper: Capecchi, Alice, Daniel Probst, and Jean-Louis Reymond. "One molecular fingerprint to rule them all: drugs, biomolecules, and the metabolome." Journal of Cheminformatics 12.1 (2020): 1-15. orignal code: https://github.com/reymond-group/map4, thanks their orignal work
|
4 |
-
|
5 |
-
A small bug is fixed: https://github.com/reymond-group/map4/issues/6
|
6 |
-
"""
|
7 |
-
|
8 |
-
_type = 'topological-based'
|
9 |
-
|
10 |
-
import itertools
|
11 |
-
from collections import defaultdict
|
12 |
-
|
13 |
-
import tmap as tm
|
14 |
-
from mhfp.encoder import MHFPEncoder
|
15 |
-
from rdkit import Chem
|
16 |
-
from rdkit.Chem import rdmolops
|
17 |
-
from rdkit.Chem.rdmolops import GetDistanceMatrix
|
18 |
-
|
19 |
-
|
20 |
-
def to_smiles(mol):
|
21 |
-
return Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False)
|
22 |
-
|
23 |
-
|
24 |
-
class MAP4Calculator:
|
25 |
-
def __init__(self, dimensions=2048, radius=2, is_counted=False, is_folded=False, fold_dimensions=2048):
|
26 |
-
"""
|
27 |
-
MAP4 calculator class
|
28 |
-
"""
|
29 |
-
self.dimensions = dimensions
|
30 |
-
self.radius = radius
|
31 |
-
self.is_counted = is_counted
|
32 |
-
self.is_folded = is_folded
|
33 |
-
self.fold_dimensions = fold_dimensions
|
34 |
-
|
35 |
-
if self.is_folded:
|
36 |
-
self.encoder = MHFPEncoder(dimensions)
|
37 |
-
else:
|
38 |
-
self.encoder = tm.Minhash(dimensions)
|
39 |
-
|
40 |
-
def calculate(self, mol):
|
41 |
-
"""Calculates the atom pair minhashed fingerprint
|
42 |
-
Arguments:
|
43 |
-
mol -- rdkit mol object
|
44 |
-
Returns:
|
45 |
-
tmap VectorUint -- minhashed fingerprint
|
46 |
-
"""
|
47 |
-
|
48 |
-
atom_env_pairs = self._calculate(mol)
|
49 |
-
if self.is_folded:
|
50 |
-
return self._fold(atom_env_pairs)
|
51 |
-
return self.encoder.from_string_array(atom_env_pairs)
|
52 |
-
|
53 |
-
def calculate_many(self, mols):
|
54 |
-
""" Calculates the atom pair minhashed fingerprint
|
55 |
-
Arguments:
|
56 |
-
mols -- list of mols
|
57 |
-
Returns:
|
58 |
-
list of tmap VectorUint -- minhashed fingerprints list
|
59 |
-
"""
|
60 |
-
|
61 |
-
atom_env_pairs_list = [self._calculate(mol) for mol in mols]
|
62 |
-
if self.is_folded:
|
63 |
-
return [self._fold(pairs) for pairs in atom_env_pairs_list]
|
64 |
-
return self.encoder.batch_from_string_array(atom_env_pairs_list)
|
65 |
-
|
66 |
-
def _calculate(self, mol):
|
67 |
-
return self._all_pairs(mol, self._get_atom_envs(mol))
|
68 |
-
|
69 |
-
def _fold(self, pairs):
|
70 |
-
fp_hash = self.encoder.hash(set(pairs))
|
71 |
-
return self.encoder.fold(fp_hash, self.fold_dimensions)
|
72 |
-
|
73 |
-
def _get_atom_envs(self, mol):
|
74 |
-
atoms_env = {}
|
75 |
-
for atom in mol.GetAtoms():
|
76 |
-
idx = atom.GetIdx()
|
77 |
-
for radius in range(1, self.radius + 1):
|
78 |
-
if idx not in atoms_env:
|
79 |
-
atoms_env[idx] = []
|
80 |
-
atoms_env[idx].append(MAP4Calculator._find_env(mol, idx, radius))
|
81 |
-
return atoms_env
|
82 |
-
|
83 |
-
@classmethod
|
84 |
-
def _find_env(cls, mol, idx, radius):
|
85 |
-
env = rdmolops.FindAtomEnvironmentOfRadiusN(mol, radius, idx)
|
86 |
-
atom_map = {}
|
87 |
-
|
88 |
-
submol = Chem.PathToSubmol(mol, env, atomMap=atom_map)
|
89 |
-
if idx in atom_map:
|
90 |
-
smiles = Chem.MolToSmiles(submol, rootedAtAtom=atom_map[idx], canonical=True, isomericSmiles=False)
|
91 |
-
return smiles
|
92 |
-
return ''
|
93 |
-
|
94 |
-
def _all_pairs(self, mol, atoms_env):
|
95 |
-
atom_pairs = []
|
96 |
-
distance_matrix = GetDistanceMatrix(mol)
|
97 |
-
num_atoms = mol.GetNumAtoms()
|
98 |
-
shingle_dict = defaultdict(int)
|
99 |
-
for idx1, idx2 in itertools.combinations(range(num_atoms), 2):
|
100 |
-
dist = str(int(distance_matrix[idx1][idx2]))
|
101 |
-
|
102 |
-
for i in range(self.radius):
|
103 |
-
env_a = atoms_env[idx1][i]
|
104 |
-
env_b = atoms_env[idx2][i]
|
105 |
-
|
106 |
-
ordered = sorted([env_a, env_b])
|
107 |
-
|
108 |
-
shingle = '{}|{}|{}'.format(ordered[0], dist, ordered[1])
|
109 |
-
|
110 |
-
if self.is_counted:
|
111 |
-
shingle_dict[shingle] += 1
|
112 |
-
shingle += '|' + str(shingle_dict[shingle])
|
113 |
-
|
114 |
-
atom_pairs.append(shingle.encode('utf-8'))
|
115 |
-
return list(set(atom_pairs))
|
116 |
-
|
117 |
-
|
118 |
-
def GetMAP4(mol, nBits=2048, radius=2, fold_dimensions=None):
|
119 |
-
"""
|
120 |
-
MAP4: radius=2
|
121 |
-
"""
|
122 |
-
if fold_dimensions == None:
|
123 |
-
fold_dimensions = nBits
|
124 |
-
|
125 |
-
calc = MAP4Calculator(dimensions=nBits, radius=radius, is_counted=False, is_folded=True,
|
126 |
-
fold_dimensions=fold_dimensions)
|
127 |
-
|
128 |
-
arr = calc.calculate(mol)
|
129 |
-
|
130 |
-
return arr.astype(bool)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/featurizers/fingerprint/mhfp6.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Probst, Daniel, and Jean-Louis Reymond. "A probabilistic molecular fingerprint for big data settings." Journal of cheminformatics 10.1 (2018): 66.'
|
3 |
-
|
4 |
-
orignal code: https://github.com/reymond-group/mhfp
|
5 |
-
|
6 |
-
"""
|
7 |
-
|
8 |
-
from mhfp.encoder import MHFPEncoder
|
9 |
-
|
10 |
-
|
11 |
-
def GetMHFP6(mol, nBits=2048, radius=3):
|
12 |
-
"""
|
13 |
-
MHFP6: radius=3
|
14 |
-
"""
|
15 |
-
encoder = MHFPEncoder(n_permutations=nBits)
|
16 |
-
hash_values = encoder.encode_mol(mol, radius=radius, rings=True, kekulize=True, min_radius=1)
|
17 |
-
arr = encoder.fold(hash_values, nBits)
|
18 |
-
return arr.astype(bool)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/featurizers/fingerprint/mnimalfatures.fdef
DELETED
@@ -1,53 +0,0 @@
|
|
1 |
-
AtomType NDonor [N&!H0&v3,N&!H0&+1&v4,n&H1&+0]
|
2 |
-
AtomType ChalcDonor [O,S;H1;+0]
|
3 |
-
DefineFeature SingleAtomDonor [{NDonor},{ChalcDonor},!$([D1]-[C;D3]=[O,S,N])]
|
4 |
-
Family Donor
|
5 |
-
Weights 1
|
6 |
-
EndFeature
|
7 |
-
|
8 |
-
AtomType NAcceptor [$([N&v3;H1,H2]-[!$(*=[O,N,P,S])])]
|
9 |
-
Atomtype NAcceptor [$([N;v3;H0])]
|
10 |
-
AtomType NAcceptor [$([n;+0])]
|
11 |
-
AtomType ChalcAcceptor [$([O,S;H1;v2]-[!$(*=[O,N,P,S])])]
|
12 |
-
AtomType ChalcAcceptor [O,S;H0;v2]
|
13 |
-
Atomtype ChalcAcceptor [O,S;-]
|
14 |
-
Atomtype ChalcAcceptor [o,s;+0]
|
15 |
-
AtomType HalogenAcceptor [F]
|
16 |
-
DefineFeature SingleAtomAcceptor [{NAcceptor},{ChalcAcceptor},{HalogenAcceptor}]
|
17 |
-
Family Acceptor
|
18 |
-
Weights 1
|
19 |
-
EndFeature
|
20 |
-
|
21 |
-
# this one is delightfully easy:
|
22 |
-
DefineFeature AcidicGroup [C,S](=[O,S,P])-[O;H1,H0&-1]
|
23 |
-
Family NegIonizable
|
24 |
-
Weights 1.0,1.0,1.0
|
25 |
-
EndFeature
|
26 |
-
|
27 |
-
AtomType CarbonOrArom_NonCarbonyl [$([C,a]);!$([C,a](=O))]
|
28 |
-
AtomType BasicNH2 [$([N;H2&+0][{CarbonOrArom_NonCarbonyl}])]
|
29 |
-
AtomType BasicNH1 [$([N;H1&+0]([{CarbonOrArom_NonCarbonyl}])[{CarbonOrArom_NonCarbonyl}])]
|
30 |
-
AtomType BasicNH0 [$([N;H0&+0]([{CarbonOrArom_NonCarbonyl}])([{CarbonOrArom_NonCarbonyl}])[{CarbonOrArom_NonCarbonyl}])]
|
31 |
-
AtomType BasicNakedN [N,n;X2;+0]
|
32 |
-
DefineFeature BasicGroup [{BasicNH2},{BasicNH1},{BasicNH0},{BasicNakedN}]
|
33 |
-
Family PosIonizable
|
34 |
-
Weights 1.0
|
35 |
-
EndFeature
|
36 |
-
|
37 |
-
# aromatic rings of various sizes:
|
38 |
-
DefineFeature Arom5 a1aaaa1
|
39 |
-
Family Aromatic
|
40 |
-
Weights 1.0,1.0,1.0,1.0,1.0
|
41 |
-
EndFeature
|
42 |
-
DefineFeature Arom6 a1aaaaa1
|
43 |
-
Family Aromatic
|
44 |
-
Weights 1.0,1.0,1.0,1.0,1.0,1.0
|
45 |
-
EndFeature
|
46 |
-
DefineFeature Arom7 a1aaaaaa1
|
47 |
-
Family Aromatic
|
48 |
-
Weights 1.0,1.0,1.0,1.0,1.0,1.0,1.0
|
49 |
-
EndFeature
|
50 |
-
DefineFeature Arom8 a1aaaaaaa1
|
51 |
-
Family Aromatic
|
52 |
-
Weights 1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
|
53 |
-
EndFeature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/featurizers/fingerprint/morganfp.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
from rdkit.Chem import AllChem
|
2 |
-
from rdkit.Chem import DataStructs
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
|
6 |
-
def GetMorganFPs(mol, nBits=2048, radius=2, return_bitInfo=False):
|
7 |
-
"""
|
8 |
-
ECFP4: radius=2
|
9 |
-
"""
|
10 |
-
bitInfo = {}
|
11 |
-
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius,
|
12 |
-
bitInfo=bitInfo, nBits=nBits)
|
13 |
-
arr = np.zeros((0,), dtype=np.bool_)
|
14 |
-
DataStructs.ConvertToNumpyArray(fp, arr)
|
15 |
-
|
16 |
-
if return_bitInfo:
|
17 |
-
return arr, bitInfo
|
18 |
-
return arr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/featurizers/fingerprint/pharmErGfp.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
"""
|
4 |
-
Created on Sat Aug 17 16:54:12 2019
|
5 |
-
|
6 |
-
@author: [email protected]
|
7 |
-
|
8 |
-
@calculate ErG fps, more info: https://pubs.acs.org/doi/full/10.1021/ci050457y#
|
9 |
-
"""
|
10 |
-
|
11 |
-
_type = 'Pharmacophore-based'
|
12 |
-
|
13 |
-
import numpy as np
|
14 |
-
from rdkit.Chem import AllChem
|
15 |
-
|
16 |
-
## get info from : https://github.com/rdkit/rdkit/blob/d41752d558bf7200ab67b98cdd9e37f1bdd378de/Code/GraphMol/ReducedGraphs/ReducedGraphs.cpp
|
17 |
-
Donor = ["[N;!H0;v3,v4&+1]", "[O,S;H1;+0]", "[n&H1&+0]"]
|
18 |
-
|
19 |
-
Acceptor = ["[O,S;H1;v2;!$(*-*=[O,N,P,S])]", "[O;H0;v2]", "[O,S;v1;-]",
|
20 |
-
"[N;v3;!$(N-*=[O,N,P,S])]", "[n&H0&+0]", "[o;+0;!$([o]:n);!$([o]:c:n)]"]
|
21 |
-
|
22 |
-
Positive = ["[#7;+]", "[N;H2&+0][$([C,a]);!$([C,a](=O))]",
|
23 |
-
"[N;H1&+0]([$([C,a]);!$([C,a](=O))])[$([C,a]);!$([C,a](=O))]",
|
24 |
-
"[N;H0&+0]([C;!$(C(=O))])([C;!$(C(=O))])[C;!$(C(=O))]"]
|
25 |
-
|
26 |
-
Negative = ["[C,S](=[O,S,P])-[O;H1,-1]"]
|
27 |
-
|
28 |
-
Hydrophobic = ["[C;D3,D4](-[CH3])-[CH3]", "[S;D2](-C)-C"]
|
29 |
-
|
30 |
-
Aromatic = ["a"]
|
31 |
-
|
32 |
-
PROPERTY_KEY = ["Donor", "Acceptor", "Positive", "Negative", "Hydrophobic", "Aromatic"]
|
33 |
-
|
34 |
-
|
35 |
-
def GetPharmacoErGFPs(mol, fuzzIncrement=0.3, maxPath=21, binary=True, return_bitInfo=False):
|
36 |
-
'''
|
37 |
-
https://pubs.acs.org/doi/full/10.1021/ci050457y#
|
38 |
-
return maxPath*21 bits
|
39 |
-
|
40 |
-
size(v) = (n(n + 1)/2) * (maxDist - minDist + 1)
|
41 |
-
|
42 |
-
'''
|
43 |
-
minPath = 1
|
44 |
-
|
45 |
-
arr = AllChem.GetErGFingerprint(mol, fuzzIncrement=fuzzIncrement, maxPath=maxPath, minPath=minPath)
|
46 |
-
arr = arr.astype(np.float32)
|
47 |
-
|
48 |
-
if binary:
|
49 |
-
arr = arr.astype(np.bool_)
|
50 |
-
|
51 |
-
if return_bitInfo:
|
52 |
-
bitInfo = []
|
53 |
-
for i in range(len(PROPERTY_KEY)):
|
54 |
-
for j in range(i, len(PROPERTY_KEY)):
|
55 |
-
for path in range(minPath, maxPath + 1):
|
56 |
-
triplet = (PROPERTY_KEY[i], PROPERTY_KEY[j], path)
|
57 |
-
bitInfo.append(triplet)
|
58 |
-
return arr, bitInfo
|
59 |
-
|
60 |
-
return arr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepscreen/data/featurizers/fingerprint/pharmPointfp.py
DELETED
@@ -1,59 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
"""
|
4 |
-
Created on Sat Aug 17 16:54:12 2019
|
5 |
-
|
6 |
-
@author: [email protected]
|
7 |
-
|
8 |
-
Combining a set of chemical features with the 2D (topological) distances between them gives a 2D pharmacophore. When the distances are binned, unique integer ids can be assigned to each of these pharmacophores and they can be stored in a fingerprint. Details of the encoding are in: https://www.rdkit.org/docs/RDKit_Book.html#ph4-figure
|
9 |
-
"""
|
10 |
-
|
11 |
-
_type = 'Pharmacophore-based'
|
12 |
-
|
13 |
-
from rdkit.Chem.Pharm2D.SigFactory import SigFactory
|
14 |
-
from rdkit.Chem.Pharm2D import Generate
|
15 |
-
from rdkit.Chem import DataStructs
|
16 |
-
from rdkit.Chem import ChemicalFeatures
|
17 |
-
|
18 |
-
import numpy as np
|
19 |
-
import os
|
20 |
-
|
21 |
-
fdef = os.path.join(os.path.dirname(__file__), 'mnimalfatures.fdef')
|
22 |
-
featFactory = ChemicalFeatures.BuildFeatureFactory(fdef)
|
23 |
-
|
24 |
-
|
25 |
-
def GetPharmacoPFPs(mol,
|
26 |
-
bins=[(i, i + 1) for i in range(20)],
|
27 |
-
minPointCount=2,
|
28 |
-
maxPointCount=2,
|
29 |
-
return_bitInfo=False):
|
30 |
-
'''
|
31 |
-
Note: maxPointCont with 3 is slowly
|
32 |
-
|
33 |
-
bins = [(i,i+1) for i in range(20)],
|
34 |
-
maxPonitCount=2 for large-scale computation
|
35 |
-
|
36 |
-
'''
|
37 |
-
MysigFactory = SigFactory(featFactory,
|
38 |
-
trianglePruneBins=False,
|
39 |
-
minPointCount=minPointCount,
|
40 |
-
maxPointCount=maxPointCount)
|
41 |
-
MysigFactory.SetBins(bins)
|
42 |
-
MysigFactory.Init()
|
43 |
-
|
44 |
-
res = Generate.Gen2DFingerprint(mol, MysigFactory)
|
45 |
-
arr = np.array(list(res)).astype(np.bool_)
|
46 |
-
if return_bitInfo:
|
47 |
-
description = []
|
48 |
-
for i in range(len(res)):
|
49 |
-
description.append(MysigFactory.GetBitDescription(i))
|
50 |
-
return arr, description
|
51 |
-
|
52 |
-
return arr
|
53 |
-
|
54 |
-
|
55 |
-
if __name__ == '__main__':
|
56 |
-
from rdkit import Chem
|
57 |
-
|
58 |
-
mol = Chem.MolFromSmiles('CC#CC(=O)NC1=NC=C2C(=C1)C(=NC=N2)NC3=CC(=C(C=C3)F)Cl')
|
59 |
-
a = GetPharmacoPFPs(mol, bins=[(i, i + 1) for i in range(20)], minPointCount=2, maxPointCount=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|