diff --git a/deepscreen/__init__.py b/deepscreen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9379f1b400eb74e7407963697105ab2dca542b3 --- /dev/null +++ b/deepscreen/__init__.py @@ -0,0 +1,101 @@ +""" +DeepScreen package initialization, registering custom objects and monkey patching for some libraries. +""" +import sys +from builtins import eval + +import lightning.fabric.strategies.launchers.subprocess_script as subprocess_script +import torch +from omegaconf import OmegaConf + +from deepscreen.utils import get_logger + +log = get_logger(__name__) + +# Allow basic Python operations in hydra interpolation; examples: +# `in_channels: ${eval:${model.drug_encoder.out_channels}+${model.protein_encoder.out_channels}}` +# `subdir: ${eval:${hydra.job.override_dirname}.replace('/', '.')}` +OmegaConf.register_new_resolver("eval", eval) + + +def sanitize_path(path_str: str): + """ + Sanitize a string for path creation by replacing unsafe characters and cutting length to 255 (OS limitation). + """ + return path_str.replace("/", ".").replace("\\", ".").replace(":", "-")[:255] + + +OmegaConf.register_new_resolver("sanitize_path", sanitize_path) + + +def _hydra_subprocess_cmd(local_rank: int): + """ + Monkey patching for lightning.fabric.strategies.launchers.subprocess_script._hydra_subprocess_cmd + Temporarily fixes the problem of unnecessarily creating log folders for DDP subprocesses in Hydra multirun/sweep. + """ + import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 + from hydra.core.hydra_config import HydraConfig + from hydra.utils import get_original_cwd, to_absolute_path + + # when user is using hydra find the absolute path + if __main__.__spec__ is None: # pragma: no-cover + command = [sys.executable, to_absolute_path(sys.argv[0])] + else: + command = [sys.executable, "-m", __main__.__spec__.name] + + command += sys.argv[1:] + + cwd = get_original_cwd() + rundir = f'"{HydraConfig.get().runtime.output_dir}"' + # Set output_subdir null since we don't want different subprocesses trying to write to config.yaml + command += [f"hydra.job.name=train_ddp_process_{local_rank}", + "hydra.output_subdir=null," + f"hydra.runtime.output_dir={rundir}"] + return command, cwd + + +subprocess_script._hydra_subprocess_cmd = _hydra_subprocess_cmd + +# from torch import Tensor +# from lightning.fabric.utilities.distributed import _distributed_available +# from lightning.pytorch.utilities.rank_zero import WarningCache +# from lightning.pytorch.utilities.warnings import PossibleUserWarning +# from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection + +# warning_cache = WarningCache() +# +# @staticmethod +# def _get_cache(result_metric, on_step: bool): +# cache = None +# if on_step and result_metric.meta.on_step: +# cache = result_metric._forward_cache +# elif not on_step and result_metric.meta.on_epoch: +# if result_metric._computed is None: +# should = result_metric.meta.sync.should +# if not should and _distributed_available() and result_metric.is_tensor: +# warning_cache.warn( +# f"It is recommended to use `self.log({result_metric.meta.name!r}, ..., sync_dist=True)`" +# " when logging on epoch level in distributed setting to accumulate the metric across" +# " devices.", +# category=PossibleUserWarning, +# ) +# result_metric.compute() +# result_metric.meta.sync.should = should +# +# cache = result_metric._computed +# +# if cache is not None: +# if isinstance(cache, Tensor): +# if not result_metric.meta.enable_graph: +# return cache.detach() +# +# return cache +# +# +# _ResultCollection._get_cache = _get_cache + +if torch.cuda.is_available(): + if torch.cuda.get_device_capability() >= (8, 0): + torch.set_float32_matmul_precision("high") + log.info("Your GPU supports tensor cores, " + "we will enable it automatically by setting `torch.set_float32_matmul_precision('high')`") diff --git a/deepscreen/__pycache__/__init__.cpython-311.pyc b/deepscreen/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f716a0fe682e10bbd19b657e449841053938c805 Binary files /dev/null and b/deepscreen/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/__pycache__/predict.cpython-311.pyc b/deepscreen/__pycache__/predict.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36ecc3f738005c62fa82d2d50947c96604afb506 Binary files /dev/null and b/deepscreen/__pycache__/predict.cpython-311.pyc differ diff --git a/deepscreen/data/__init__.py b/deepscreen/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deepscreen/data/__pycache__/__init__.cpython-311.pyc b/deepscreen/data/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0c09d97afa3c2a4034a852cc250d2fa20bdc490 Binary files /dev/null and b/deepscreen/data/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/data/__pycache__/dti.cpython-311.pyc b/deepscreen/data/__pycache__/dti.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a91110329d8403a13aa5195248050b47c3e6bba1 Binary files /dev/null and b/deepscreen/data/__pycache__/dti.cpython-311.pyc differ diff --git a/deepscreen/data/dti.py b/deepscreen/data/dti.py new file mode 100644 index 0000000000000000000000000000000000000000..8a94d2b6efb179bdee32e9362874889e300e0597 --- /dev/null +++ b/deepscreen/data/dti.py @@ -0,0 +1,422 @@ +import re +from functools import partial +from numbers import Number +from pathlib import Path +from typing import Any, Dict, Optional, Sequence, Union, Literal + +from lightning import LightningDataModule +import pandas as pd +import swifter +from sklearn.preprocessing import LabelEncoder +from torch.utils.data import Dataset, DataLoader + +from deepscreen.data.utils import label_transform, collate_fn, SafeBatchSampler +from deepscreen.utils import get_logger + +log = get_logger(__name__) + +SMILES_PAT = r"[^A-Za-z0-9=#:+\-\[\]<>()/\\@%,.*]" +FASTA_PAT = r"[^A-Z*\-]" + + +def validate_seq_str(seq, regex): + if seq: + err_charset = set(re.findall(regex, seq)) + if not err_charset: + return None + else: + return ', '.join(err_charset) + else: + return 'Empty string' + + +# TODO: save a list of corrupted records + +def rdkit_canonicalize(smiles): + from rdkit import Chem + try: + mol = Chem.MolFromSmiles(smiles) + cano_smiles = Chem.MolToSmiles(mol) + return cano_smiles + except Exception as e: + log.warning(f'Failed to canonicalize SMILES using RDKIT due to {str(e)}. Returning original SMILES: {smiles}') + return smiles + + +class DTIDataset(Dataset): + def __init__( + self, + task: Literal['regression', 'binary', 'multiclass'], + num_classes: Optional[int], + data_path: str | Path, + drug_featurizer: callable, + protein_featurizer: callable, + thresholds: Optional[Union[Number, Sequence[Number]]] = None, + discard_intermediate: Optional[bool] = False, + query: Optional[str] = 'X2' + ): + df = pd.read_csv( + data_path, + engine='python', + header=0, + usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'], + dtype={ + 'X1': 'str', + 'ID1': 'str', + 'X2': 'str', + 'ID2': 'str', + 'Y': 'float32', + 'U': 'str', + }, + ) + # Read the whole data table + + # if 'ID1' in df: + # self.x1_to_id1 = dict(zip(df['X1'], df['ID1'])) + # if 'ID2' in df: + # self.x2_to_id2 = dict(zip(df['X2'], df['ID2'])) + # self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2'])))) + # self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2'])))) + + # # train and eval mode data processing (fully labelled) + # if 'Y' in df.columns and df['Y'].notnull().all(): + log.info(f"Processing data file: {data_path}") + + # Forward-fill all non-label columns + df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0) + + # TODO potentially allow running through the whole data validation process + # error = False + + if 'Y' in df: + log.info(f"Validating labels (`Y`)...") + # TODO: check sklearn.utils.multiclass.check_classification_targets + match task: + case 'regression': + assert all(df['Y'].swifter.apply(lambda x: isinstance(x, Number))), \ + f"""`Y` must be numeric for `regression` task, + but it has {set(df['Y'].swifter.apply(type))}.""" + + case 'binary': + if all(df['Y'].isin([0, 1])): + assert not thresholds, \ + f"""`Y` is already 0 or 1 for `binary` (classification) `task`, + but still got `thresholds` ({thresholds}). + Double check your choices of `task` and `thresholds`, and records in the `Y` column.""" + else: + assert thresholds, \ + f"""`Y` must be 0 or 1 for `binary` (classification) `task`, + but it has {pd.unique(df['Y'])}. + You may set `thresholds` to discretize continuous labels.""" # TODO print err idx instead + + case 'multiclass': + assert num_classes >= 3, f'`num_classes` for `task=multiclass` must be at least 3.' + + if all(df['Y'].swifter.apply(lambda x: x.is_integer() and x >= 0)): + assert not thresholds, \ + f"""`Y` is already non-negative integers for + `multiclass` (classification) `task`, but still got `thresholds` ({thresholds}). + Double check your choice of `task`, `thresholds` and records in the `Y` column.""" + else: + assert thresholds, \ + f"""`Y` must be non-negative integers for + `multiclass` (classification) 'task',but it has {pd.unique(df['Y'])}. + You must set `thresholds` to discretize continuous labels.""" # TODO print err idx instead + + if 'U' in df.columns: + units = df['U'] + else: + units = None + log.warning("Units ('U') not in the data table. " + "Assuming all labels to be discrete or in p-scale (-log10[M]).") + + # Transform labels + df['Y'] = label_transform(labels=df['Y'], units=units, thresholds=thresholds, + discard_intermediate=discard_intermediate) + + # Filter out rows with a NaN in Y (missing values) + df.dropna(subset=['Y'], inplace=True) + + match task: + case 'regression': + df['Y'] = df['Y'].astype('float32') + assert all(df['Y'].swifter.apply(lambda x: isinstance(x, Number))), \ + f"""`Y` must be numeric for `regression` task, + but after transformation it still has {set(df['Y'].swifter.apply(type))}. + Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" + # TODO print err idx instead + case 'binary': + df['Y'] = df['Y'].astype('int') + assert all(df['Y'].isin([0, 1])), \ + f"""`Y` must be 0 or 1 for `task=binary`, " + but after transformation it still has {pd.unique(df['Y'])}. + Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" + # TODO print err idx instead + case 'multiclass': + df['Y'] = df['Y'].astype('int') + assert all(df['Y'].swifter.apply(lambda x: x.is_integer() and x >= 0)), \ + f"""Y must be non-negative integers for `task=multiclass` + but after transformation it still has {pd.unique(df['Y'])}. + Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" + # TODO print err idx instead + target_n_unique = df['Y'].nunique() + assert target_n_unique == num_classes, \ + f"""You have set `num_classes` for `task=multiclass` to {num_classes}, + but after transformation Y still has {target_n_unique} unique labels. + Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" + + log.info("Validating SMILES (`X1`)...") + df['X1_ERR'] = df['X1'].swifter.progress_bar( + desc="Validating SMILES...").apply(validate_seq_str, regex=SMILES_PAT) + if not df['X1_ERR'].isna().all(): + raise Exception(f"Encountered invalid SMILES:\n{df[~df['X1_ERR'].isna()][['X1', 'X1_ERR']]}") + df['X1^'] = df['X1'].apply(rdkit_canonicalize) # swifter + + log.info("Validating FASTA (`X2`)...") + df['X2'] = df['X2'].str.upper() + df['X2_ERR'] = df['X2'].swifter.progress_bar( + desc="Validating FASTA...").apply(validate_seq_str, regex=FASTA_PAT) + if not df['X2_ERR'].isna().all(): + raise Exception(f"Encountered invalid FASTA:\n{df[~df['X2_ERR'].isna()][['X2', 'X2_ERR']]}") + + # FASTA/SMILES indices as query for retrieval metrics like enrichment factor and hit rate + if query: + df['ID^'] = LabelEncoder().fit_transform(df[query]) + + self.df = df + self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x) + self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x) + + def __len__(self): + return len(self.df.index) + + def __getitem__(self, i): + sample = self.df.loc[i] + return { + 'N': i, + 'X1': sample['X1'], + 'X1^': self.drug_featurizer(sample['X1^']), + 'ID1': sample.get('ID1'), + 'X2': sample['X2'], + 'X2^': self.protein_featurizer(sample['X2']), + 'ID2': sample.get('ID2'), + 'Y': sample.get('Y'), + 'ID^': sample.get('ID^'), + } + + +class DTIDataModule(LightningDataModule): + """ + DTI DataModule + + A DataModule implements 5 key methods: + + def prepare_data(self): + # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) + # download data, pre-process, split, save to disk, etc. + def setup(self, stage): + # things to do on every process in DDP + # load data, set variables, etc. + def train_dataloader(self): + # return train dataloader + def val_dataloader(self): + # return validation dataloader + def test_dataloader(self): + # return test dataloader + def teardown(self): + # called on every process in DDP + # clean up after fit or test + + This allows you to share a full dataset without explaining how to download, + split, transform and process the data. + + Read the docs: + https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html + """ + + def __init__( + self, + task: Literal['regression', 'binary', 'multiclass'], + num_classes: Optional[int], + batch_size: int, + # train: bool, + drug_featurizer: callable, + protein_featurizer: callable, + collator: callable = collate_fn, + data_dir: str = "data/", + data_file: Optional[str] = None, + train_val_test_split: Optional[Union[Sequence[Number | str]]] = None, + split: Optional[callable] = None, + thresholds: Optional[Union[Number, Sequence[Number]]] = None, + discard_intermediate: Optional[bool] = False, + num_workers: int = 0, + pin_memory: bool = False, + ): + super().__init__() + + self.train_data: Optional[Dataset] = None + self.val_data: Optional[Dataset] = None + self.test_data: Optional[Dataset] = None + self.predict_data: Optional[Dataset] = None + self.split = split + self.collator = collator + self.dataset = partial( + DTIDataset, + task=task, + num_classes=num_classes, + drug_featurizer=drug_featurizer, + protein_featurizer=protein_featurizer, + thresholds=thresholds, + discard_intermediate=discard_intermediate + ) + + # this line allows to access init params with 'self.hparams' ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) # ignore=['split'] + + def prepare_data(self): + """ + Download data if needed. + Do not use it to assign state (e.g., self.x = x). + """ + + def setup(self, stage: Optional[str] = None, encoding: str = None): + """ + Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be + careful not to execute data splitting twice. + """ + # load and split datasets only if not loaded in initialization + if not any([self.train_data, self.test_data, self.val_data, self.predict_data]): + if self.hparams.train_val_test_split: + if len(self.hparams.train_val_test_split) != 3: + raise ValueError('Length of `train_val_test_split` must be 3. ' + 'Set the second element to None for training without validation. ' + 'Set the third element to None for training without testing.') + + self.train_data = self.hparams.train_val_test_split[0] + self.val_data = self.hparams.train_val_test_split[1] + self.test_data = self.hparams.train_val_test_split[2] + + if all([self.hparams.data_file, self.split]): + if all(isinstance(split, Number) or split is None + for split in self.hparams.train_val_test_split): + split_data = self.split( + dataset=self.dataset(data_path=Path(self.hparams.data_dir, self.hparams.data_file)), + lengths=[split for split in self.hparams.train_val_test_split if split is not None] + ) + for dataset in ['train_data', 'val_data', 'test_data']: + if getattr(self, dataset) is not None: + setattr(self, dataset, split_data.pop(0)) + + else: + raise ValueError('`train_val_test_split` must be a sequence numbers or None' + '(float for percentages and int for sample numbers) ' + 'if both `data_file` and `split` have been specified.') + + elif (all(isinstance(split, str) or split is None + for split in self.hparams.train_val_test_split) + and not any([self.hparams.data_file, self.split])): + for dataset in ['train_data', 'val_data', 'test_data']: + if getattr(self, dataset) is not None: + data_path = Path(getattr(self, dataset)) + if not data_path.is_absolute(): + data_path = Path(self.hparams.data_dir, data_path) + setattr(self, dataset, self.dataset(data_path=data_path)) + + else: + raise ValueError('For training, you must specify either all of `data_file`, `split`, ' + 'and `train_val_test_split` as a sequence of numbers or ' + 'solely `train_val_test_split` as a sequence of data file paths.') + + elif self.hparams.data_file and not any([self.split, self.hparams.train_val_test_split]): + data_path = Path(self.hparams.data_file) + if not data_path.is_absolute(): + data_path = Path(self.hparams.data_dir, data_path) + self.test_data = self.predict_data = self.dataset(data_path=data_path) + + else: + raise ValueError("For training, you must specify `train_val_test_split`. " + "For testing/predicting, you must specify only `data_file` without " + "`train_val_test_split` or `split`.") + + def train_dataloader(self): + return DataLoader( + dataset=self.train_data, + batch_sampler=SafeBatchSampler( + data_source=self.train_data, + batch_size=self.hparams.batch_size, + # Dropping the last batch prevents problems caused by variable batch sizes in training, e.g., + # batch_size=1 in BatchNorm, and shuffling ensures the model be trained on all samples over epochs. + drop_last=True, + shuffle=True, + ), + # batch_size=self.hparams.batch_size, + # shuffle=True, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=self.collator, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.val_data, + batch_sampler=SafeBatchSampler( + data_source=self.val_data, + batch_size=self.hparams.batch_size, + drop_last=False, + shuffle=False + ), + # batch_size=self.hparams.batch_size, + # shuffle=False, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=self.collator, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def test_dataloader(self): + return DataLoader( + dataset=self.test_data, + batch_sampler=SafeBatchSampler( + data_source=self.test_data, + batch_size=self.hparams.batch_size, + drop_last=False, + shuffle=False + ), + # batch_size=self.hparams.batch_size, + # shuffle=False, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=self.collator, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def predict_dataloader(self): + return DataLoader( + dataset=self.predict_data, + batch_sampler=SafeBatchSampler( + data_source=self.predict_data, + batch_size=self.hparams.batch_size, + drop_last=False, + shuffle=False + ), + # batch_size=self.hparams.batch_size, + # shuffle=False, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=self.collator, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def teardown(self, stage: Optional[str] = None): + """Clean up after fit or test.""" + pass + + def state_dict(self): + """Extra things to save to checkpoint.""" + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Things to do when loading checkpoint.""" + pass diff --git a/deepscreen/data/dti.py.bak b/deepscreen/data/dti.py.bak new file mode 100644 index 0000000000000000000000000000000000000000..a531948e5acf64b3e1f2404ce9a82cbafbcd84d4 --- /dev/null +++ b/deepscreen/data/dti.py.bak @@ -0,0 +1,369 @@ +from functools import partial +from numbers import Number +from pathlib import Path +from typing import Any, Dict, Optional, Sequence, Union, Literal + +from lightning import LightningDataModule +import pandas as pd +from sklearn.preprocessing import LabelEncoder +from torch.utils.data import Dataset, DataLoader + +from deepscreen.data.utils import label_transform, collate_fn, SafeBatchSampler +from deepscreen.utils import get_logger + +log = get_logger(__name__) + + +# TODO: save a list of corrupted records + + +class DTIDataset(Dataset): + def __init__( + self, + task: Literal['regression', 'binary', 'multiclass'], + n_class: Optional[int], + data_path: str | Path, + drug_featurizer: callable, + protein_featurizer: callable, + thresholds: Optional[Union[Number, Sequence[Number]]] = None, + discard_intermediate: Optional[bool] = False, + ): + df = pd.read_csv( + data_path, + engine='python', + header=0, + usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'], + dtype={ + 'X1': 'str', + 'ID1': 'str', + 'X2': 'str', + 'ID2': 'str', + 'Y': 'float32', + 'U': 'str', + }, + ) + # Read the whole data table + + # if 'ID1' in df: + # self.x1_to_id1 = dict(zip(df['X1'], df['ID1'])) + # if 'ID2' in df: + # self.x2_to_id2 = dict(zip(df['X2'], df['ID2'])) + # self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2'])))) + # self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2'])))) + + # # train and eval mode data processing (fully labelled) + # if 'Y' in df.columns and df['Y'].notnull().all(): + log.info(f"Processing data file: {data_path}") + + # Forward-fill all non-label columns + df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0) + + if 'Y' in df: + log.info(f"Performing pre-transformation target validation.") + # TODO: check sklearn.utils.multiclass.check_classification_targets + match task: + case 'regression': + assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \ + f"""`Y` must be numeric for `regression` task, + but it has {set(df['Y'].apply(type))}.""" + + case 'binary': + if all(df['Y'].isin([0, 1])): + assert not thresholds, \ + f"""`Y` is already 0 or 1 for `binary` (classification) `task`, + but still got `thresholds` {thresholds}. + Double check your choices of `task` and `thresholds` and records in the `Y` column.""" + else: + assert thresholds, \ + f"""`Y` must be 0 or 1 for `binary` (classification) `task`, + but it has {pd.unique(df['Y'])}. + You must set `thresholds` to discretize continuous labels.""" + + case 'multiclass': + assert n_class >= 3, f'`n_class` for `multiclass` (classification) `task` must be at least 3.' + + if all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)): + assert not thresholds, \ + f"""`Y` is already non-negative integers for + `multiclass` (classification) `task`, but still got `thresholds` {thresholds}. + Double check your choice of `task`, `thresholds` and records in the `Y` column.""" + else: + assert thresholds, \ + f"""`Y` must be non-negative integers for + `multiclass` (classification) 'task',but it has {pd.unique(df['Y'])}. + You must set `thresholds` to discretize continuous labels.""" + + if 'U' in df.columns: + units = df['U'] + else: + units = None + log.warning("Units ('U') not in the data table. " + "Assuming all labels to be discrete or in p-scale (-log10[M]).") + + # Transform labels + df['Y'] = label_transform(labels=df['Y'], units=units, thresholds=thresholds, + discard_intermediate=discard_intermediate) + + # Filter out rows with a NaN in Y (missing values) + df.dropna(subset=['Y'], inplace=True) + + log.info(f"Performing post-transformation target validation.") + match task: + case 'regression': + df['Y'] = df['Y'].astype('float32') + assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \ + f"""`Y` must be numeric for `regression` task, + but after transformation it still has {set(df['Y'].apply(type))}. + Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" + + case 'binary': + df['Y'] = df['Y'].astype('int') + assert all(df['Y'].isin([0, 1])), \ + f"""`Y` must be 0 or 1 for `binary` (classification) `task`, " + but after transformation it still has {pd.unique(df['Y'])}. + Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" + + case 'multiclass': + df['Y'] = df['Y'].astype('int') + assert all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)), \ + f"""Y must be non-negative integers for task `multiclass` (classification) + but after transformation it still has {pd.unique(df['Y'])}. + Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" + + target_n_unique = df['Y'].nunique() + assert target_n_unique == n_class, \ + f"""You have set `n_class` for `multiclass` (classification) `task` to {n_class}, + but after transformation Y still has {target_n_unique} unique labels. + Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" + + # Indexed protein/FASTA for retrieval metrics + df['IDX'] = LabelEncoder().fit_transform(df['X2']) + + self.df = df + self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x) + self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x) + + def __len__(self): + return len(self.df.index) + + def __getitem__(self, i): + sample = self.df.loc[i] + return { + 'N': i, + 'X1': self.drug_featurizer(sample['X1']), + 'ID1': sample.get('ID1', sample['X1']), + 'X2': self.protein_featurizer(sample['X2']), + 'ID2': sample.get('ID2', sample['X2']), + 'Y': sample.get('Y'), + 'IDX': sample['IDX'], + } + + +class DTIDataModule(LightningDataModule): + """ + DTI DataModule + + A DataModule implements 5 key methods: + + def prepare_data(self): + # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) + # download data, pre-process, split, save to disk, etc. + def setup(self, stage): + # things to do on every process in DDP + # load data, set variables, etc. + def train_dataloader(self): + # return train dataloader + def val_dataloader(self): + # return validation dataloader + def test_dataloader(self): + # return test dataloader + def teardown(self): + # called on every process in DDP + # clean up after fit or test + + This allows you to share a full dataset without explaining how to download, + split, transform and process the data. + + Read the docs: + https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html + """ + + def __init__( + self, + task: Literal['regression', 'binary', 'multiclass'], + n_class: Optional[int], + batch_size: int, + # train: bool, + drug_featurizer: callable, + protein_featurizer: callable, + collator: callable = collate_fn, + data_dir: str = "data/", + data_file: Optional[str] = None, + train_val_test_split: Optional[Union[Sequence[Number | str]]] = None, + split: Optional[callable] = None, + thresholds: Optional[Union[Number, Sequence[Number]]] = None, + discard_intermediate: Optional[bool] = False, + num_workers: int = 0, + pin_memory: bool = False, + ): + super().__init__() + + self.train_data: Optional[Dataset] = None + self.val_data: Optional[Dataset] = None + self.test_data: Optional[Dataset] = None + self.predict_data: Optional[Dataset] = None + self.split = split + self.collator = collator + self.dataset = partial( + DTIDataset, + task=task, + n_class=n_class, + drug_featurizer=drug_featurizer, + protein_featurizer=protein_featurizer, + thresholds=thresholds, + discard_intermediate=discard_intermediate + ) + + if train_val_test_split: + # TODO test behavior for trainer.test and predict when this is passed + if len(train_val_test_split) not in [2, 3]: + raise ValueError('Length of `train_val_test_split` must be 2 (for training without testing) or 3.') + if all([data_file, split]): + if all(isinstance(split, Number) for split in train_val_test_split): + pass + else: + raise ValueError('`train_val_test_split` must be a sequence numbers ' + '(float for percentages and int for sample numbers) ' + 'if both `data_file` and `split` have been specified.') + elif all(isinstance(split, str) for split in train_val_test_split) and not any([data_file, split]): + split_paths = [] + for split in train_val_test_split: + split = Path(split) + if not split.is_absolute(): + split = Path(data_dir, split) + split_paths.append(split) + + self.train_data = self.dataset(data_path=split_paths[0]) + self.val_data = self.dataset(data_path=split_paths[1]) + if len(train_val_test_split) == 3: + self.test_data = self.dataset(data_path=split_paths[2]) + else: + raise ValueError('For training, you must specify either `data_file`, `split`, ' + 'and `train_val_test_split` as a sequence of numbers or ' + 'solely `train_val_test_split` as a sequence of data file paths.') + + elif data_file and not any([split, train_val_test_split]): + data_file = Path(data_file) + if not data_file.is_absolute(): + data_file = Path(data_dir, data_file) + self.test_data = self.predict_data = self.dataset(data_path=data_file) + else: + raise ValueError("For training, you must specify `train_val_test_split`. " + "For testing/predicting, you must specify only `data_file` without " + "`train_val_test_split` or `split`.") + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) # ignore=['split'] + + def prepare_data(self): + """ + Download data if needed. + Do not use it to assign state (e.g., self.x = x). + """ + + def setup(self, stage: Optional[str] = None, encoding: str = None): + """ + Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be + careful not to execute data splitting twice. + """ + # TODO test SafeBatchSampler (which skips samples with any None without introducing variable batch size) + # load and split datasets only if not loaded in initialization + if not any([self.train_data, self.test_data, self.val_data, self.predict_data]): + self.train_data, self.val_data, self.test_data = self.split( + dataset=self.dataset(data_path=Path(self.hparams.data_dir, self.hparams.data_file)), + lengths=self.hparams.train_val_test_split + ) + + def train_dataloader(self): + return DataLoader( + dataset=self.train_data, + batch_sampler=SafeBatchSampler( + data_source=self.train_data, + batch_size=self.hparams.batch_size, + # Dropping the last batch prevents problems caused by variable batch sizes in training, e.g., + # batch_size=1 in BatchNorm, and shuffling ensures the model be trained on all samples over epochs. + drop_last=True, + shuffle=True, + ), + # batch_size=self.hparams.batch_size, + # shuffle=True, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=self.collator, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.val_data, + batch_sampler=SafeBatchSampler( + data_source=self.val_data, + batch_size=self.hparams.batch_size, + drop_last=False, + shuffle=False + ), + # batch_size=self.hparams.batch_size, + # shuffle=False, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=self.collator, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def test_dataloader(self): + return DataLoader( + dataset=self.test_data, + batch_sampler=SafeBatchSampler( + data_source=self.test_data, + batch_size=self.hparams.batch_size, + drop_last=False, + shuffle=False + ), + # batch_size=self.hparams.batch_size, + # shuffle=False, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=self.collator, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def predict_dataloader(self): + return DataLoader( + dataset=self.predict_data, + batch_sampler=SafeBatchSampler( + data_source=self.predict_data, + batch_size=self.hparams.batch_size, + drop_last=False, + shuffle=False + ), + # batch_size=self.hparams.batch_size, + # shuffle=False, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=self.collator, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def teardown(self, stage: Optional[str] = None): + """Clean up after fit or test.""" + pass + + def state_dict(self): + """Extra things to save to checkpoint.""" + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Things to do when loading checkpoint.""" + pass diff --git a/deepscreen/data/dti_datamodule.py b/deepscreen/data/dti_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..35dc78ed63aeaec390b06cf3dcb318e6e835414d --- /dev/null +++ b/deepscreen/data/dti_datamodule.py @@ -0,0 +1,314 @@ +# from itertools import product +from collections import namedtuple +from numbers import Number +from typing import Any, Dict, Optional, Sequence, Union, Literal + +# import numpy as np +import pandas as pd +from lightning import LightningDataModule +from torch.utils.data import Dataset, DataLoader, random_split + +from deepscreen.data.utils.label import label_transform +from deepscreen.data.utils.collator import collate_fn +from deepscreen.data.utils.sampler import SafeBatchSampler + + +class DTIDataset(Dataset): + def __init__( + self, + task: Literal['regression', 'binary', 'multiclass'], + n_classes: Optional[int], + data_dir: str, + dataset_name: str, + drug_featurizer: callable, + protein_featurizer: callable, + thresholds: Optional[Union[Number, Sequence[Number]]] = None, + discard_intermediate: Optional[bool] = False, + ): + df = pd.read_csv( + f'{data_dir}{dataset_name}.csv', + header=0, sep=',', + usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'], + dtype={'X1': 'str', 'ID1': 'str', + 'X2': 'str', 'ID2': 'str', + 'Y': 'float32', 'U': 'str'} + ) + # if 'ID1' in df: + # self.x1_to_id1 = dict(zip(df['X1'], df['ID1'])) + # if 'ID2' in df: + # self.x2_to_id2 = dict(zip(df['X2'], df['ID2'])) + # self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2'])))) + # self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2'])))) + + # # train and eval mode data processing (fully labelled) + # if 'Y' in df.columns and df['Y'].notnull().all(): + + # Forward-fill all non-label columns + df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0) + + if 'Y' in df: + # Transform labels + df['Y'] = df['Y'].apply(label_transform, units=df.get('U', None), thresholds=thresholds, + discard_intermediate=discard_intermediate).astype('float32') + + # Filter out rows with a NaN in Y (missing values) + df.dropna(subset=['Y'], inplace=True) + + # Validate target labels for training/testing + # TODO: check sklearn.utils.multiclass.check_classification_targets + match task: + case 'regression': + assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \ + f"Y for task `regression` must be numeric; got {set(df['Y'].apply(type))}." + case 'binary': + assert all(df['Y'].isin([0, 1])), \ + f"Y for task `binary` (classification) must be 0 or 1, but Y got {pd.unique(df['Y'])}." \ + "\nYou may set `thresholds` to discretize continuous labels." + case 'multiclass': + assert n_classes >= 3, f'n_classes for task `multiclass` (classification) must be at least 3.' + assert all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)), \ + f"Y for task `multiclass` (classification) must be non-negative integers, " \ + f"but Y got {pd.unique(df['Y'])}." \ + "\nYou may set `thresholds` to discretize continuous labels." + target_n_unique = df['Y'].nunique() + assert target_n_unique == n_classes, \ + f"You have set n_classes for task `multiclass` (classification) task to {n_classes}, " \ + f"but Y has {target_n_unique} unique labels." + + # # Predict mode data processing + # else: + # df = pd.DataFrame(product(df['X1'].dropna(), df['X2'].dropna()), columns=['X1', 'X2']) + # if hasattr(self, "x1_to_id1"): + # df['ID1'] = df['X1'].map(self.x1_to_id1) + # if hasattr(self, "x1_to_id2"): + # df['ID2'] = df['X2'].map(self.x2_to_id2) + + # self.smiles = df['X1'] + # self.fasta = df['X2'] + # self.smiles_ids = df.get('ID1', df['X1']) + # self.fasta_ids = df.get('ID2', df['X2']) + # self.labels = df.get('Y', None) + + self.df = df + self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x) + self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x) + self.n_classes = df['Y'].nunique() + # self.train = train + + self.Data = namedtuple('Data', ['FT1', 'ID1', 'FT2', 'ID2', 'Y']) + + def __len__(self): + return len(self.df.index) + + def __getitem__(self, idx): + sample = self.df.loc[idx] + return self.Data( + FT1=self.drug_featurizer(sample['X1']), + ID1=sample.get('ID1', sample['X1']), + FT2=self.protein_featurizer(sample['X2']), + ID2=sample.get('ID2', sample['X2']), + Y=sample.get('Y') + ) + # { + # 'FT1': self.drug_featurizer(sample['X1']), + # 'ID1': sample.get('ID1', sample['X1']), + # 'FT2': self.protein_featurizer(sample['X2']), + # 'ID2': sample.get('ID2', sample['X2']), + # 'Y': sample.get('Y') + # } + # if self.train: + # sample = self.drug_featurizer(self.smiles[idx]), self.protein_featurizer(self.fasta[idx]), self.labels[idx] + # sample = { + # 'FT1': self.drug_featurizer(self.smiles[idx]), + # 'FT2': self.protein_featurizer(self.fasta[idx]), + # 'ID2': self.smiles_ids[idx], + # } + # else: + # # sample = self.drug_featurizer(self.smiles[idx]), self.protein_featurizer(self.fasta[idx]) + # sample = { + # 'FT1': self.drug_featurizer(self.smiles[idx]), + # 'FT2': self.protein_featurizer(self.fasta[idx]), + # } + # + # if all([True if n is not None else False for n in sample.values()]): + # return sample # | { + # # 'ID1': self.smiles_ids[idx], + # # 'X1': self.drug_featurizer(self.smiles[idx]), + # # 'ID2': self.fasta_ids[idx], + # # 'X2': self.protein_featurizer(self.fasta[idx]), + # # } + # else: + # return self.__getitem__(np.random.randint(0, self.size)) + + +class DTIdatamodule(LightningDataModule): + """ + DTI DataModule + + A DataModule implements 5 key methods: + + def prepare_data(self): + # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) + # download data, pre-process, split, save to disk, etc. + def setup(self, stage): + # things to do on every process in DDP + # load data, set variables, etc. + def train_dataloader(self): + # return train dataloader + def val_dataloader(self): + # return validation dataloader + def test_dataloader(self): + # return test dataloader + def teardown(self): + # called on every process in DDP + # clean up after fit or test + + This allows you to share a full dataset without explaining how to download, + split, transform and process the data. + + Read the docs: + https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html + """ + + def __init__( + self, + task: Literal['regression', 'binary', 'multiclass'], + n_classes: Optional[int], + train: bool, + drug_featurizer: callable, + protein_featurizer: callable, + batch_size: int, + train_val_test_split: Optional[Sequence[Number]], + num_workers: int = 0, + thresholds: Optional[Union[Number, Sequence[Number]]] = None, + pin_memory: bool = False, + data_dir: str = "data/", + dataset_name: Optional[str] = None, + split: Optional[callable] = random_split, + ): + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + # data processing + self.data_split = split + + self.data_train: Optional[Dataset] = None + self.data_val: Optional[Dataset] = None + self.data_test: Optional[Dataset] = None + self.data_predict: Optional[Dataset] = None + + def prepare_data(self): + """ + Download data if needed. + Do not use it to assign state (e.g., self.x = x). + """ + + def setup(self, stage: Optional[str] = None, encoding: str = None): + """ + Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be + careful not to execute data splitting twice. + """ + # TODO test SafeBatchSampler (which skips samples with any None without introducing variable batch size) + # load and split datasets only if not loaded in initialization + if not any([self.data_train, self.data_val, self.data_test, self.data_predict]): + dataset = DTIDataset( + task=self.hparams.task, + n_classes=self.hparams.n_classes, + data_dir=self.hparams.data_dir, + drug_featurizer=self.hparams.drug_featurizer, + protein_featurizer=self.hparams.protein_featurizer, + dataset_name=self.hparams.dataset_name, + thresholds=self.hparams.thresholds, + ) + + if self.hparams.train: + self.data_train, self.data_val, self.data_test = self.data_split( + dataset=dataset, + lengths=self.hparams.train_val_test_split + ) + else: + self.data_test = self.data_predict = dataset + + def train_dataloader(self): + return DataLoader( + dataset=self.data_train, + batch_sampler=SafeBatchSampler( + data_source=self.data_train, + batch_size=self.hparams.batch_size, + drop_last=True, + shuffle=True, + ), + # batch_size=self.hparams.batch_size, + # shuffle=True, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=collate_fn, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.data_val, + batch_sampler=SafeBatchSampler( + data_source=self.data_val, + batch_size=self.hparams.batch_size, + drop_last=False, + shuffle=False, + ), + # batch_size=self.hparams.batch_size, + # shuffle=False, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=collate_fn, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def test_dataloader(self): + return DataLoader( + dataset=self.data_test, + batch_sampler=SafeBatchSampler( + data_source=self.data_test, + batch_size=self.hparams.batch_size, + drop_last=False, + shuffle=False, + ), + # batch_size=self.hparams.batch_size, + # shuffle=False, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=collate_fn, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def predict_dataloader(self): + return DataLoader( + dataset=self.data_predict, + batch_sampler=SafeBatchSampler( + data_source=self.data_predict, + batch_size=self.hparams.batch_size, + drop_last=False, + shuffle=False, + ), + # batch_size=self.hparams.batch_size, + # shuffle=False, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=collate_fn, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def teardown(self, stage: Optional[str] = None): + """Clean up after fit or test.""" + pass + + def state_dict(self): + """Extra things to save to checkpoint.""" + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Things to do when loading checkpoint.""" + pass diff --git a/deepscreen/data/entity_datamodule.py b/deepscreen/data/entity_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..8b6b2d46d251c5f07bbf68a5639d538b37f6432b --- /dev/null +++ b/deepscreen/data/entity_datamodule.py @@ -0,0 +1,167 @@ +from numbers import Number +from pathlib import Path +from typing import Any, Dict, Optional, Sequence, Type + +from lightning import LightningDataModule +from sklearn.base import TransformerMixin +from torch.utils.data import Dataset, DataLoader + +from deepscreen.data.utils import collate_fn, SafeBatchSampler +from deepscreen.data.utils.dataset import BaseEntityDataset + + +class EntityDataModule(LightningDataModule): + """ + def prepare_data(self): + # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) + # download data, pre-process, split, save to disk, etc. + def setup(self, stage): + # things to do on every process in DDP + # load data, set variables, etc. + def train_dataloader(self): + # return train dataloader + def val_dataloader(self): + # return validation dataloader + def test_dataloader(self): + # return test dataloader + def teardown(self): + # called on every process in DDP + # clean up after fit or test + """ + def __init__( + self, + dataset: type[BaseEntityDataset], + transformer: type[TransformerMixin], + train: bool, + batch_size: int, + data_dir: str = "data/", + data_file: Optional[str] = None, + train_val_test_split: Optional[Sequence[Number], Sequence[str]] = None, + split: Optional[callable] = None, + num_workers: int = 0, + pin_memory: bool = False, + ): + super().__init__() + + # data processing + self.split = split + + if train: + if all([data_file, split]): + if all(isinstance(split, Number) for split in train_val_test_split): + pass + else: + raise ValueError('`train_val_test_split` must be a sequence of 3 numbers ' + '(float for percentages and int for sample numbers) if ' + '`data_file` and `split` have been specified.') + elif all(isinstance(split, str) for split in train_val_test_split) and not any([data_file, split]): + self.train_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[0])) + self.val_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[1])) + self.test_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[2])) + else: + raise ValueError('For training (train=True), you must specify either ' + '`dataset_name` and `split` with `train_val_test_split` of 3 numbers or ' + 'solely `train_val_test_split` of 3 data file names.') + else: + if data_file and not any([split, train_val_test_split]): + self.test_data = self.predict_data = dataset(dataset_path=str(Path(data_dir) / data_file)) + else: + raise ValueError("For testing/predicting (train=False), you must specify only `data_file` without " + "`train_val_test_split` or `split`") + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + def prepare_data(self): + """ + Download data if needed. + Do not use it to assign state (e.g., self.x = x). + """ + + def setup(self, stage: Optional[str] = None, encoding: str = None): + """ + Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be + careful not to execute data splitting twice. + """ + # TODO test SafeBatchSampler (which skips samples with any None without introducing variable batch size) + # TODO: find a way to apply transformer.fit_transform only to train and transformer.transform only to val, test + # load and split datasets only if not loaded in initialization + if not any([self.train_data, self.test_data, self.val_data, self.predict_data]): + self.train_data, self.val_data, self.test_data = self.split( + dataset=self.hparams.dataset(data_dir=self.hparams.data_dir, + dataset_name=self.hparams.train_dataset_name), + lengths=self.hparams.train_val_test_split + ) + + def train_dataloader(self): + return DataLoader( + dataset=self.train_data, + batch_sampler=SafeBatchSampler( + data_source=self.train_data, + batch_size=self.hparams.batch_size, + shuffle=True), + # batch_size=self.hparams.batch_size, + # shuffle=True, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=collate_fn, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.val_data, + batch_sampler=SafeBatchSampler( + data_source=self.val_data, + batch_size=self.hparams.batch_size, + shuffle=False), + # batch_size=self.hparams.batch_size, + # shuffle=False, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=collate_fn, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def test_dataloader(self): + return DataLoader( + dataset=self.test_data, + batch_sampler=SafeBatchSampler( + data_source=self.test_data, + batch_size=self.hparams.batch_size, + shuffle=False), + # batch_size=self.hparams.batch_size, + # shuffle=False, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=collate_fn, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def predict_dataloader(self): + return DataLoader( + dataset=self.predict_data, + batch_sampler=SafeBatchSampler( + data_source=self.predict_data, + batch_size=self.hparams.batch_size, + shuffle=False), + # batch_size=self.hparams.batch_size, + # shuffle=False, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=collate_fn, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def teardown(self, stage: Optional[str] = None): + """Clean up after fit or test.""" + pass + + def state_dict(self): + """Extra things to save to checkpoint.""" + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Things to do when loading checkpoint.""" + pass diff --git a/deepscreen/data/featurizers/__init__.py b/deepscreen/data/featurizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc b/deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02522f17ba285baea39eddce0b53111f20d16a8e Binary files /dev/null and b/deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc b/deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f5e8e6f2243f6bfd75e95d81ed3d5547a1662ba Binary files /dev/null and b/deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/__pycache__/token.cpython-311.pyc b/deepscreen/data/featurizers/__pycache__/token.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77f1a92dead63d52aa40a7e2a6add8f6ed5dcf66 Binary files /dev/null and b/deepscreen/data/featurizers/__pycache__/token.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/categorical.py b/deepscreen/data/featurizers/categorical.py new file mode 100644 index 0000000000000000000000000000000000000000..85c7d18f91ec1fbf53338c871bafb284ad5ac263 --- /dev/null +++ b/deepscreen/data/featurizers/categorical.py @@ -0,0 +1,86 @@ +import numpy as np + +# Sets of KNOWN characters in SMILES and FASTA sequences +# Use list instead of set to preserve character order +SMILES_VOCAB = ('#', '%', ')', '(', '+', '-', '.', '1', '0', '3', '2', '5', '4', + '7', '6', '9', '8', '=', 'A', 'C', 'B', 'E', 'D', 'G', 'F', 'I', + 'H', 'K', 'M', 'L', 'O', 'N', 'P', 'S', 'R', 'U', 'T', 'W', 'V', + 'Y', '[', 'Z', ']', '_', 'a', 'c', 'b', 'e', 'd', 'g', 'f', 'i', + 'h', 'm', 'l', 'o', 'n', 's', 'r', 'u', 't', 'y') +FASTA_VOCAB = ('A', 'C', 'B', 'E', 'D', 'G', 'F', 'I', 'H', 'K', 'M', 'L', 'O', + 'N', 'Q', 'P', 'S', 'R', 'U', 'T', 'W', 'V', 'Y', 'X', 'Z') + +# Check uniqueness, create character-index dicts, and add '?' for unknown characters as index 0 +assert len(SMILES_VOCAB) == len(set(SMILES_VOCAB)), 'SMILES_CHARSET has duplicate characters.' +SMILES_CHARSET_IDX = {character: index+1 for index, character in enumerate(SMILES_VOCAB)} | {'?': 0} + +assert len(FASTA_VOCAB) == len(set(FASTA_VOCAB)), 'FASTA_CHARSET has duplicate characters.' +FASTA_CHARSET_IDX = {character: index+1 for index, character in enumerate(FASTA_VOCAB)} | {'?': 0} + + +def sequence_to_onehot(sequence: str, charset, max_sequence_length: int): + assert len(charset) == len(set(charset)), '`charset` contains duplicate characters.' + charset_idx = {character: index+1 for index, character in enumerate(charset)} | {'?': 0} + + onehot = np.zeros((max_sequence_length, len(charset_idx)), dtype=int) + for index, character in enumerate(sequence[:max_sequence_length]): + onehot[index, charset_idx.get(character, 0)] = 1 + + return onehot.transpose() + + +def sequence_to_label(sequence: str, charset, max_sequence_length: int): + assert len(charset) == len(set(charset)), '`charset` contains duplicate characters.' + charset_idx = {character: index+1 for index, character in enumerate(charset)} | {'?': 0} + + label = np.zeros(max_sequence_length, dtype=int) + for index, character in enumerate(sequence[:max_sequence_length]): + label[index] = charset_idx.get(character, 0) + + return label + + +def smiles_to_onehot(smiles: str, smiles_charset=SMILES_VOCAB, max_sequence_length: int = 100): # , in_channels: int = len(SMILES_CHARSET) + # assert len(SMILES_CHARSET) == len(set(SMILES_CHARSET)), 'SMILES_CHARSET has duplicate characters.' + # onehot = np.zeros((max_sequence_length, len(SMILES_CHARSET_IDX))) + # for index, character in enumerate(smiles[:max_sequence_length]): + # onehot[index, SMILES_CHARSET_IDX.get(character, 0)] = 1 + # return onehot.transpose() + return sequence_to_onehot(smiles, smiles_charset, max_sequence_length) + + +def smiles_to_label(smiles: str, smiles_charset=SMILES_VOCAB, max_sequence_length: int = 100): # , in_channels: int = len(SMILES_CHARSET) + # label = np.zeros(max_sequence_length) + # for index, character in enumerate(smiles[:max_sequence_length]): + # label[index] = SMILES_CHARSET_IDX.get(character, 0) + # return label + return sequence_to_label(smiles, smiles_charset, max_sequence_length) + + +def fasta_to_onehot(fasta: str, fasta_charset=FASTA_VOCAB, max_sequence_length: int = 1000): # in_channels: int = len(FASTA_CHARSET) + # onehot = np.zeros((max_sequence_length, len(FASTA_CHARSET_IDX))) + # for index, character in enumerate(fasta[:max_sequence_length]): + # onehot[index, FASTA_CHARSET_IDX.get(character, 0)] = 1 + # return onehot.transpose() + return sequence_to_onehot(fasta, fasta_charset, max_sequence_length) + + +def fasta_to_label(fasta: str, fasta_charset=FASTA_VOCAB, max_sequence_length: int = 1000): # in_channels: int = len(FASTA_CHARSET) + # label = np.zeros(max_sequence_length) + # for index, character in enumerate(fasta[:max_sequence_length]): + # label[index] = FASTA_CHARSET_IDX.get(character, 0) + # return label + return sequence_to_label(fasta, fasta_charset, max_sequence_length) + + +def one_of_k_encoding(x, allowable_set): + if x not in allowable_set: + raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set)) + return list(map(lambda s: x == s, allowable_set)) + + +def one_of_k_encoding_unk(x, allowable_set): + """Maps inputs not in the allowable set to the last element.""" + if x not in allowable_set: + x = allowable_set[-1] + return list(map(lambda s: x == s, allowable_set)) diff --git a/deepscreen/data/featurizers/chem.py b/deepscreen/data/featurizers/chem.py new file mode 100644 index 0000000000000000000000000000000000000000..6f08d11e1e81681157001e9eb13598064d5ea222 --- /dev/null +++ b/deepscreen/data/featurizers/chem.py @@ -0,0 +1,48 @@ +""" +Mainly adapted from MolMap: +https://github.com/shenwanxiang/bidd-molmap/tree/master/molmap/feature/fingerprint +""" +import numpy as np +from rdkit import Chem, DataStructs +from rdkit.Chem import AllChem +from rdkit.Chem.Fingerprints import FingerprintMols +from rdkit.Chem.rdReducedGraphs import GetErGFingerprint + +from deepscreen import get_logger + +log = get_logger(__name__) + + +def smiles_to_erg(smiles): + try: + mol = Chem.MolFromSmiles(smiles) + features = np.array(GetErGFingerprint(mol), dtype=bool) + return features + except Exception as e: + log.warning(f"Failed to convert SMILES ({smiles}) to ErGFP due to {str(e)}") + return None + + +def smiles_to_morgan(smiles, radius=2, n_bits=1024): + try: + mol = Chem.MolFromSmiles(smiles) + features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=n_bits) + features = np.zeros((1,)) + DataStructs.ConvertToNumpyArray(features_vec, features) + except Exception as e: + log.warning(f"Failed to convert SMILES ({smiles}) to ErGFP due to {str(e)}") + return None + + +def smiles_to_daylight(smiles): + try: + NumFinger = 2048 + mol = Chem.MolFromSmiles(smiles) + bv = FingerprintMols.FingerprintMol(mol) + temp = tuple(bv.GetOnBits()) + features = np.zeros((NumFinger,)) + features[np.array(temp)] = 1 + except: + print(f'RDKit could not find this SMILES: {smiles} convert to all 0 features') + features = np.zeros((2048,)) + return features.astype(int) diff --git a/deepscreen/data/featurizers/fcs.py b/deepscreen/data/featurizers/fcs.py new file mode 100644 index 0000000000000000000000000000000000000000..7f34b2cfb8f9b6079de8a4c51bbefc00cd9cd51c --- /dev/null +++ b/deepscreen/data/featurizers/fcs.py @@ -0,0 +1,67 @@ +from importlib import resources + +import numpy as np +import pandas as pd +from subword_nmt.apply_bpe import BPE +import codecs + +vocab_path = resources.files('deepscreen').parent.joinpath('resources/vocabs/ESPF/protein_codes_uniprot.txt') +bpe_codes_protein = codecs.open(vocab_path) +protein_bpe = BPE(bpe_codes_protein, merges=-1, separator='') + +sub_csv_path = resources.files('deepscreen').parent.joinpath('resources/vocabs/ESPF/subword_units_map_uniprot.csv') +sub_csv = pd.read_csv(sub_csv_path) +idx2word_protein = sub_csv['index'].values +words2idx_protein = dict(zip(idx2word_protein, range(0, len(idx2word_protein)))) + +vocab_path = resources.files('deepscreen').parent.joinpath('resources/vocabs/ESPF/drug_codes_chembl.txt') +bpe_codes_drug = codecs.open(vocab_path) +drug_bpe = BPE(bpe_codes_drug, merges=-1, separator='') + +sub_csv_path = resources.files('deepscreen').parent.joinpath('resources/vocabs/ESPF/subword_units_map_chembl.csv') +sub_csv = pd.read_csv(sub_csv_path) +idx2word_drug = sub_csv['index'].values +words2idx_drug = dict(zip(idx2word_drug, range(0, len(idx2word_drug)))) + + +def protein_to_embedding(x, max_sequence_length): + max_p = max_sequence_length + t1 = protein_bpe.process_line(x).split() # split + try: + i1 = np.asarray([words2idx_protein[i] for i in t1]) # index + except: + i1 = np.array([0]) + # print(x) + + l = len(i1) + + if l < max_p: + i = np.pad(i1, (0, max_p - l), 'constant', constant_values=0) + input_mask = ([1] * l) + ([0] * (max_p - l)) + else: + i = i1[:max_p] + input_mask = [1] * max_p + + return i, np.asarray(input_mask) + + +def drug_to_embedding(x, max_sequence_length): + max_d = max_sequence_length + t1 = drug_bpe.process_line(x).split() # split + try: + i1 = np.asarray([words2idx_drug[i] for i in t1]) # index + except: + i1 = np.array([0]) + # print(x) + + l = len(i1) + + if l < max_d: + i = np.pad(i1, (0, max_d - l), 'constant', constant_values=0) + input_mask = ([1] * l) + ([0] * (max_d - l)) + + else: + i = i1[:max_d] + input_mask = [1] * max_d + + return i, np.asarray(input_mask) diff --git a/deepscreen/data/featurizers/fingerprint/__init__.py b/deepscreen/data/featurizers/fingerprint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6315b75d3f65a6d20ff3b440d72db558c038856a --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/__init__.py @@ -0,0 +1,45 @@ +from typing import Literal + +from .atompairs import GetAtomPairFPs +from .avalonfp import GetAvalonFPs +from .rdkitfp import GetRDkitFPs +from .morganfp import GetMorganFPs +from .estatefp import GetEstateFPs +from .maccskeys import GetMACCSFPs +from .pharmErGfp import GetPharmacoErGFPs +from .pharmPointfp import GetPharmacoPFPs +from .pubchemfp import GetPubChemFPs +from .torsions import GetTorsionFPs +from .mhfp6 import GetMHFP6 +# from .map4 import GetMAP4 +from rdkit import Chem + +from deepscreen import get_logger + +log = get_logger(__name__) + +FP_MAP = { + 'MorganFP': GetMorganFPs, + 'RDkitFP': GetRDkitFPs, + 'AtomPairFP': GetAtomPairFPs, + 'TorsionFP': GetTorsionFPs, + 'AvalonFP': GetAvalonFPs, + 'EstateFP': GetEstateFPs, + 'MACCSFP': GetMACCSFPs, + 'PharmacoErGFP': GetPharmacoErGFPs, + 'PharmacoPFP': GetPharmacoPFPs, + 'PubChemFP': GetPubChemFPs, + 'MHFP6': GetMHFP6, + # 'MAP4': GetMAP4, +} + + +def smiles_to_fingerprint(smiles, fingerprint: Literal[tuple(FP_MAP.keys())], **kwargs): + func = FP_MAP[fingerprint] + try: + mol = Chem.MolFromSmiles(smiles) + arr = func(mol, **kwargs) + return arr + except Exception as e: + log.warning(f"Failed to convert SMILES ({smiles}) to {fingerprint} due to {str(e)}") + return None diff --git a/deepscreen/data/featurizers/fingerprint/atompairs.py b/deepscreen/data/featurizers/fingerprint/atompairs.py new file mode 100644 index 0000000000000000000000000000000000000000..336a90777ccec7664740b88da14be8cb07855dd1 --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/atompairs.py @@ -0,0 +1,18 @@ +from rdkit.Chem.AtomPairs import Pairs +from rdkit.Chem import DataStructs +import numpy as np + +_type = 'topological-based' + + +def GetAtomPairFPs(mol, nBits=2048, binary=True): + ''' + atompairs fingerprints + ''' + fp = Pairs.GetHashedAtomPairFingerprint(mol, nBits=nBits) + if binary: + arr = np.zeros((0,), dtype=np.bool_) + else: + arr = np.zeros((0,), dtype=np.int8) + DataStructs.ConvertToNumpyArray(fp, arr) + return arr diff --git a/deepscreen/data/featurizers/fingerprint/avalonfp.py b/deepscreen/data/featurizers/fingerprint/avalonfp.py new file mode 100644 index 0000000000000000000000000000000000000000..0e70f05f265993fadbed9075574a24034ae9b9ad --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/avalonfp.py @@ -0,0 +1,16 @@ +from rdkit.Chem import DataStructs +from rdkit.Avalon.pyAvalonTools import GetAvalonFP as GAFP +import numpy as np + +_type = 'topological-based' + + +def GetAvalonFPs(mol, nBits=2048): + ''' + Avalon_fingerprints: https://pubs.acs.org/doi/pdf/10.1021/ci050413p + ''' + + fp = GAFP(mol, nBits=nBits) + arr = np.zeros((0,), dtype=np.bool_) + DataStructs.ConvertToNumpyArray(fp, arr) + return arr diff --git a/deepscreen/data/featurizers/fingerprint/estatefp.py b/deepscreen/data/featurizers/fingerprint/estatefp.py new file mode 100644 index 0000000000000000000000000000000000000000..369b20756b947c66496ecb29719915e9c806253f --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/estatefp.py @@ -0,0 +1,12 @@ +from rdkit.Chem.EState import Fingerprinter +import numpy as np + +_type = 'Estate-based' + + +def GetEstateFPs(mol): + ''' + 79 bits Estate fps + ''' + x = Fingerprinter.FingerprintMol(mol)[0] + return x.astype(np.bool_) diff --git a/deepscreen/data/featurizers/fingerprint/maccskeys.py b/deepscreen/data/featurizers/fingerprint/maccskeys.py new file mode 100644 index 0000000000000000000000000000000000000000..79ba3030c49ca405b1e91031b92742a9a49753b3 --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/maccskeys.py @@ -0,0 +1,25 @@ +from rdkit.Chem import AllChem +from rdkit.Chem import DataStructs +import numpy as np +import pandas as pd +import os + +_type = 'SMARTS-based' + +file_path = os.path.dirname(__file__) + + +def GetMACCSFPs(mol): + ''' + 166 bits + ''' + + fp = AllChem.GetMACCSKeysFingerprint(mol) + + arr = np.zeros((0,), dtype=np.bool_) + DataStructs.ConvertToNumpyArray(fp, arr) + return arr + + +def GetMACCSFPInfos(): + return pd.read_excel(os.path.join(file_path, 'maccskeys.xlsx')) diff --git a/deepscreen/data/featurizers/fingerprint/maccskeys.xlsx b/deepscreen/data/featurizers/fingerprint/maccskeys.xlsx new file mode 100644 index 0000000000000000000000000000000000000000..3037fa603e73843000a97e6ed886b44cfcd5ebd9 Binary files /dev/null and b/deepscreen/data/featurizers/fingerprint/maccskeys.xlsx differ diff --git a/deepscreen/data/featurizers/fingerprint/map4.py b/deepscreen/data/featurizers/fingerprint/map4.py new file mode 100644 index 0000000000000000000000000000000000000000..e60c9267d5ad273766657b7b6ac5263f167bc1ec --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/map4.py @@ -0,0 +1,130 @@ +""" +MinHashed Atom-pair Fingerprint, MAP +orignal paper: Capecchi, Alice, Daniel Probst, and Jean-Louis Reymond. "One molecular fingerprint to rule them all: drugs, biomolecules, and the metabolome." Journal of Cheminformatics 12.1 (2020): 1-15. orignal code: https://github.com/reymond-group/map4, thanks their orignal work + +A small bug is fixed: https://github.com/reymond-group/map4/issues/6 +""" + +_type = 'topological-based' + +import itertools +from collections import defaultdict + +import tmap as tm +from mhfp.encoder import MHFPEncoder +from rdkit import Chem +from rdkit.Chem import rdmolops +from rdkit.Chem.rdmolops import GetDistanceMatrix + + +def to_smiles(mol): + return Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False) + + +class MAP4Calculator: + def __init__(self, dimensions=2048, radius=2, is_counted=False, is_folded=False, fold_dimensions=2048): + """ + MAP4 calculator class + """ + self.dimensions = dimensions + self.radius = radius + self.is_counted = is_counted + self.is_folded = is_folded + self.fold_dimensions = fold_dimensions + + if self.is_folded: + self.encoder = MHFPEncoder(dimensions) + else: + self.encoder = tm.Minhash(dimensions) + + def calculate(self, mol): + """Calculates the atom pair minhashed fingerprint + Arguments: + mol -- rdkit mol object + Returns: + tmap VectorUint -- minhashed fingerprint + """ + + atom_env_pairs = self._calculate(mol) + if self.is_folded: + return self._fold(atom_env_pairs) + return self.encoder.from_string_array(atom_env_pairs) + + def calculate_many(self, mols): + """ Calculates the atom pair minhashed fingerprint + Arguments: + mols -- list of mols + Returns: + list of tmap VectorUint -- minhashed fingerprints list + """ + + atom_env_pairs_list = [self._calculate(mol) for mol in mols] + if self.is_folded: + return [self._fold(pairs) for pairs in atom_env_pairs_list] + return self.encoder.batch_from_string_array(atom_env_pairs_list) + + def _calculate(self, mol): + return self._all_pairs(mol, self._get_atom_envs(mol)) + + def _fold(self, pairs): + fp_hash = self.encoder.hash(set(pairs)) + return self.encoder.fold(fp_hash, self.fold_dimensions) + + def _get_atom_envs(self, mol): + atoms_env = {} + for atom in mol.GetAtoms(): + idx = atom.GetIdx() + for radius in range(1, self.radius + 1): + if idx not in atoms_env: + atoms_env[idx] = [] + atoms_env[idx].append(MAP4Calculator._find_env(mol, idx, radius)) + return atoms_env + + @classmethod + def _find_env(cls, mol, idx, radius): + env = rdmolops.FindAtomEnvironmentOfRadiusN(mol, radius, idx) + atom_map = {} + + submol = Chem.PathToSubmol(mol, env, atomMap=atom_map) + if idx in atom_map: + smiles = Chem.MolToSmiles(submol, rootedAtAtom=atom_map[idx], canonical=True, isomericSmiles=False) + return smiles + return '' + + def _all_pairs(self, mol, atoms_env): + atom_pairs = [] + distance_matrix = GetDistanceMatrix(mol) + num_atoms = mol.GetNumAtoms() + shingle_dict = defaultdict(int) + for idx1, idx2 in itertools.combinations(range(num_atoms), 2): + dist = str(int(distance_matrix[idx1][idx2])) + + for i in range(self.radius): + env_a = atoms_env[idx1][i] + env_b = atoms_env[idx2][i] + + ordered = sorted([env_a, env_b]) + + shingle = '{}|{}|{}'.format(ordered[0], dist, ordered[1]) + + if self.is_counted: + shingle_dict[shingle] += 1 + shingle += '|' + str(shingle_dict[shingle]) + + atom_pairs.append(shingle.encode('utf-8')) + return list(set(atom_pairs)) + + +def GetMAP4(mol, nBits=2048, radius=2, fold_dimensions=None): + """ + MAP4: radius=2 + """ + if fold_dimensions == None: + fold_dimensions = nBits + + calc = MAP4Calculator(dimensions=nBits, radius=radius, is_counted=False, is_folded=True, + fold_dimensions=fold_dimensions) + + arr = calc.calculate(mol) + + return arr.astype(bool) diff --git a/deepscreen/data/featurizers/fingerprint/mhfp6.py b/deepscreen/data/featurizers/fingerprint/mhfp6.py new file mode 100644 index 0000000000000000000000000000000000000000..2befc6bd970a8cb6ac5738020cddec4e459802ff --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/mhfp6.py @@ -0,0 +1,18 @@ +""" +Probst, Daniel, and Jean-Louis Reymond. "A probabilistic molecular fingerprint for big data settings." Journal of cheminformatics 10.1 (2018): 66.' + +orignal code: https://github.com/reymond-group/mhfp + +""" + +from mhfp.encoder import MHFPEncoder + + +def GetMHFP6(mol, nBits=2048, radius=3): + """ + MHFP6: radius=3 + """ + encoder = MHFPEncoder(n_permutations=nBits) + hash_values = encoder.encode_mol(mol, radius=radius, rings=True, kekulize=True, min_radius=1) + arr = encoder.fold(hash_values, nBits) + return arr.astype(bool) diff --git a/deepscreen/data/featurizers/fingerprint/mnimalfatures.fdef b/deepscreen/data/featurizers/fingerprint/mnimalfatures.fdef new file mode 100644 index 0000000000000000000000000000000000000000..ed695f788d5c9b4911de46a52417ba0fa94f2afd --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/mnimalfatures.fdef @@ -0,0 +1,53 @@ +AtomType NDonor [N&!H0&v3,N&!H0&+1&v4,n&H1&+0] +AtomType ChalcDonor [O,S;H1;+0] +DefineFeature SingleAtomDonor [{NDonor},{ChalcDonor},!$([D1]-[C;D3]=[O,S,N])] + Family Donor + Weights 1 +EndFeature + +AtomType NAcceptor [$([N&v3;H1,H2]-[!$(*=[O,N,P,S])])] +Atomtype NAcceptor [$([N;v3;H0])] +AtomType NAcceptor [$([n;+0])] +AtomType ChalcAcceptor [$([O,S;H1;v2]-[!$(*=[O,N,P,S])])] +AtomType ChalcAcceptor [O,S;H0;v2] +Atomtype ChalcAcceptor [O,S;-] +Atomtype ChalcAcceptor [o,s;+0] +AtomType HalogenAcceptor [F] +DefineFeature SingleAtomAcceptor [{NAcceptor},{ChalcAcceptor},{HalogenAcceptor}] + Family Acceptor + Weights 1 +EndFeature + +# this one is delightfully easy: +DefineFeature AcidicGroup [C,S](=[O,S,P])-[O;H1,H0&-1] + Family NegIonizable + Weights 1.0,1.0,1.0 +EndFeature + +AtomType CarbonOrArom_NonCarbonyl [$([C,a]);!$([C,a](=O))] +AtomType BasicNH2 [$([N;H2&+0][{CarbonOrArom_NonCarbonyl}])] +AtomType BasicNH1 [$([N;H1&+0]([{CarbonOrArom_NonCarbonyl}])[{CarbonOrArom_NonCarbonyl}])] +AtomType BasicNH0 [$([N;H0&+0]([{CarbonOrArom_NonCarbonyl}])([{CarbonOrArom_NonCarbonyl}])[{CarbonOrArom_NonCarbonyl}])] +AtomType BasicNakedN [N,n;X2;+0] +DefineFeature BasicGroup [{BasicNH2},{BasicNH1},{BasicNH0},{BasicNakedN}] + Family PosIonizable + Weights 1.0 +EndFeature + +# aromatic rings of various sizes: +DefineFeature Arom5 a1aaaa1 + Family Aromatic + Weights 1.0,1.0,1.0,1.0,1.0 +EndFeature +DefineFeature Arom6 a1aaaaa1 + Family Aromatic + Weights 1.0,1.0,1.0,1.0,1.0,1.0 +EndFeature +DefineFeature Arom7 a1aaaaaa1 + Family Aromatic + Weights 1.0,1.0,1.0,1.0,1.0,1.0,1.0 +EndFeature +DefineFeature Arom8 a1aaaaaaa1 + Family Aromatic + Weights 1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0 +EndFeature diff --git a/deepscreen/data/featurizers/fingerprint/morganfp.py b/deepscreen/data/featurizers/fingerprint/morganfp.py new file mode 100644 index 0000000000000000000000000000000000000000..43bef977bd69572fdbd491a7047e5f60d209805a --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/morganfp.py @@ -0,0 +1,18 @@ +from rdkit.Chem import AllChem +from rdkit.Chem import DataStructs +import numpy as np + + +def GetMorganFPs(mol, nBits=2048, radius=2, return_bitInfo=False): + """ + ECFP4: radius=2 + """ + bitInfo = {} + fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, + bitInfo=bitInfo, nBits=nBits) + arr = np.zeros((0,), dtype=np.bool_) + DataStructs.ConvertToNumpyArray(fp, arr) + + if return_bitInfo: + return arr, bitInfo + return arr diff --git a/deepscreen/data/featurizers/fingerprint/pharmErGfp.py b/deepscreen/data/featurizers/fingerprint/pharmErGfp.py new file mode 100644 index 0000000000000000000000000000000000000000..092dc850b1b354c8edfeffc8e83591a61d66a88e --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/pharmErGfp.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Sat Aug 17 16:54:12 2019 + +@author: wanxiang.shen@u.nus.edu + +@calculate ErG fps, more info: https://pubs.acs.org/doi/full/10.1021/ci050457y# +""" + +_type = 'Pharmacophore-based' + +import numpy as np +from rdkit.Chem import AllChem + +## get info from : https://github.com/rdkit/rdkit/blob/d41752d558bf7200ab67b98cdd9e37f1bdd378de/Code/GraphMol/ReducedGraphs/ReducedGraphs.cpp +Donor = ["[N;!H0;v3,v4&+1]", "[O,S;H1;+0]", "[n&H1&+0]"] + +Acceptor = ["[O,S;H1;v2;!$(*-*=[O,N,P,S])]", "[O;H0;v2]", "[O,S;v1;-]", + "[N;v3;!$(N-*=[O,N,P,S])]", "[n&H0&+0]", "[o;+0;!$([o]:n);!$([o]:c:n)]"] + +Positive = ["[#7;+]", "[N;H2&+0][$([C,a]);!$([C,a](=O))]", + "[N;H1&+0]([$([C,a]);!$([C,a](=O))])[$([C,a]);!$([C,a](=O))]", + "[N;H0&+0]([C;!$(C(=O))])([C;!$(C(=O))])[C;!$(C(=O))]"] + +Negative = ["[C,S](=[O,S,P])-[O;H1,-1]"] + +Hydrophobic = ["[C;D3,D4](-[CH3])-[CH3]", "[S;D2](-C)-C"] + +Aromatic = ["a"] + +PROPERTY_KEY = ["Donor", "Acceptor", "Positive", "Negative", "Hydrophobic", "Aromatic"] + + +def GetPharmacoErGFPs(mol, fuzzIncrement=0.3, maxPath=21, binary=True, return_bitInfo=False): + ''' + https://pubs.acs.org/doi/full/10.1021/ci050457y# + return maxPath*21 bits + + size(v) = (n(n + 1)/2) * (maxDist - minDist + 1) + + ''' + minPath = 1 + + arr = AllChem.GetErGFingerprint(mol, fuzzIncrement=fuzzIncrement, maxPath=maxPath, minPath=minPath) + arr = arr.astype(np.float32) + + if binary: + arr = arr.astype(np.bool_) + + if return_bitInfo: + bitInfo = [] + for i in range(len(PROPERTY_KEY)): + for j in range(i, len(PROPERTY_KEY)): + for path in range(minPath, maxPath + 1): + triplet = (PROPERTY_KEY[i], PROPERTY_KEY[j], path) + bitInfo.append(triplet) + return arr, bitInfo + + return arr diff --git a/deepscreen/data/featurizers/fingerprint/pharmPointfp.py b/deepscreen/data/featurizers/fingerprint/pharmPointfp.py new file mode 100644 index 0000000000000000000000000000000000000000..543a518dc686ab33e3e88bfd6d526cb9f481f9db --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/pharmPointfp.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Sat Aug 17 16:54:12 2019 + +@author: wanxiang.shen@u.nus.edu + +Combining a set of chemical features with the 2D (topological) distances between them gives a 2D pharmacophore. When the distances are binned, unique integer ids can be assigned to each of these pharmacophores and they can be stored in a fingerprint. Details of the encoding are in: https://www.rdkit.org/docs/RDKit_Book.html#ph4-figure +""" + +_type = 'Pharmacophore-based' + +from rdkit.Chem.Pharm2D.SigFactory import SigFactory +from rdkit.Chem.Pharm2D import Generate +from rdkit.Chem import DataStructs +from rdkit.Chem import ChemicalFeatures + +import numpy as np +import os + +fdef = os.path.join(os.path.dirname(__file__), 'mnimalfatures.fdef') +featFactory = ChemicalFeatures.BuildFeatureFactory(fdef) + + +def GetPharmacoPFPs(mol, + bins=[(i, i + 1) for i in range(20)], + minPointCount=2, + maxPointCount=2, + return_bitInfo=False): + ''' + Note: maxPointCont with 3 is slowly + + bins = [(i,i+1) for i in range(20)], + maxPonitCount=2 for large-scale computation + + ''' + MysigFactory = SigFactory(featFactory, + trianglePruneBins=False, + minPointCount=minPointCount, + maxPointCount=maxPointCount) + MysigFactory.SetBins(bins) + MysigFactory.Init() + + res = Generate.Gen2DFingerprint(mol, MysigFactory) + arr = np.array(list(res)).astype(np.bool_) + if return_bitInfo: + description = [] + for i in range(len(res)): + description.append(MysigFactory.GetBitDescription(i)) + return arr, description + + return arr + + +if __name__ == '__main__': + from rdkit import Chem + + mol = Chem.MolFromSmiles('CC#CC(=O)NC1=NC=C2C(=C1)C(=NC=N2)NC3=CC(=C(C=C3)F)Cl') + a = GetPharmacoPFPs(mol, bins=[(i, i + 1) for i in range(20)], minPointCount=2, maxPointCount=2) diff --git a/deepscreen/data/featurizers/fingerprint/pubchemfp.py b/deepscreen/data/featurizers/fingerprint/pubchemfp.py new file mode 100644 index 0000000000000000000000000000000000000000..1865a3772d19c6d2c4d12d343fba3f69662d6498 --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/pubchemfp.py @@ -0,0 +1,1731 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Sun Aug 25 20:29:36 2019 + +@author: charleshen + +@Note: The code are copyed from PyBioMed, with a minor repair + +https://www.ncbi.nlm.nih.gov/pubmed/29556758 + +these are SMARTS patterns corresponding to the PubChem fingerprints +https://astro.temple.edu/~tua87106/list_fingerprints.pdf +ftp://ftp.ncbi.nlm.nih.gov/pubchem/specifications/pubchem_fingerprints.txt + +""" + +_type = 'SMARTS-based' + +import numpy as np +from rdkit import Chem +from rdkit import DataStructs +import os +import pandas as pd + +smartsPatts = { + 1: ('[H]', 3), # 1-115 + 2: ('[H]', 7), + 3: ('[H]', 15), + 4: ('[H]', 31), + 5: ('[Li]', 0), + 6: ('[Li]', 1), + 7: ('[B]', 0), + 8: ('[B]', 1), + 9: ('[B]', 3), + 10: ('[C]', 1), + 11: ('[C]', 3), + 12: ('[C]', 7), + 13: ('[C]', 15), + 14: ('[C]', 31), + 15: ('[N]', 0), + 16: ('[N]', 1), + 17: ('[N]', 3), + 18: ('[N]', 7), + 19: ('[O]', 0), + 20: ('[O]', 1), + 21: ('[O]', 3), + 22: ('[O]', 7), + 23: ('[O]', 15), + 24: ('[F]', 0), + 25: ('[F]', 1), + 26: ('[F]', 3), + 27: ('[Na]', 0), + 28: ('[Na]', 1), + 29: ('[Si]', 0), + 30: ('[Si]', 1), + 31: ('[P]', 0), + 32: ('[P]', 1), + 33: ('[P]', 3), + 34: ('[S]', 0), + 35: ('[S]', 1), + 36: ('[S]', 3), + 37: ('[S]', 7), + 38: ('[Cl]', 0), + 39: ('[Cl]', 1), + 40: ('[Cl]', 3), + 41: ('[Cl]', 7), + 42: ('[K]', 0), + 43: ('[K]', 1), + 44: ('[Br]', 0), + 45: ('[Br]', 1), + 46: ('[Br]', 3), + 47: ('[I]', 0), + 48: ('[I]', 1), + 49: ('[I]', 3), + 50: ('[Be]', 0), + 51: ('[Mg]', 0), + 52: ('[Al]', 0), + 53: ('[Ca]', 0), + 54: ('[Sc]', 0), + 55: ('[Ti]', 0), + 56: ('[V]', 0), + 57: ('[Cr]', 0), + 58: ('[Mn]', 0), + 59: ('[Fe]', 0), + 60: ('[CO]', 0), + 61: ('[Ni]', 0), + 62: ('[Cu]', 0), + 63: ('[Zn]', 0), + 64: ('[Ga]', 0), + 65: ('[Ge]', 0), + 66: ('[As]', 0), + 67: ('[Se]', 0), + 68: ('[Kr]', 0), + 69: ('[Rb]', 0), + 70: ('[Sr]', 0), + 71: ('[Y]', 0), + 72: ('[Zr]', 0), + 73: ('[Nb]', 0), + 74: ('[Mo]', 0), + 75: ('[Ru]', 0), + 76: ('[Rh]', 0), + 77: ('[Pd]', 0), + 78: ('[Ag]', 0), + 79: ('[Cd]', 0), + 80: ('[In]', 0), + 81: ('[Sn]', 0), + 82: ('[Sb]', 0), + 83: ('[Te]', 0), + 84: ('[Xe]', 0), + 85: ('[Cs]', 0), + 86: ('[Ba]', 0), + 87: ('[Lu]', 0), + 88: ('[Hf]', 0), + 89: ('[Ta]', 0), + 90: ('[W]', 0), + 91: ('[Re]', 0), + 92: ('[Os]', 0), + 93: ('[Ir]', 0), + 94: ('[Pt]', 0), + 95: ('[Au]', 0), + 96: ('[Hg]', 0), + 97: ('[Tl]', 0), + 98: ('[Pb]', 0), + 99: ('[Bi]', 0), + 100: ('[La]', 0), + 101: ('[Ce]', 0), + 102: ('[Pr]', 0), + 103: ('[Nd]', 0), + 104: ('[Pm]', 0), + 105: ('[Sm]', 0), + 106: ('[Eu]', 0), + 107: ('[Gd]', 0), + 108: ('[Tb]', 0), + 109: ('[Dy]', 0), + 110: ('[Ho]', 0), + 111: ('[Er]', 0), + 112: ('[Tm]', 0), + 113: ('[Yb]', 0), + 114: ('[Tc]', 0), + 115: ('[U]', 0), + 116: ('[Li&!H0]', 0), # 264-881 + 117: ('[Li]~[Li]', 0), + 118: ('[Li]~[#5]', 0), + 119: ('[Li]~[#6]', 0), + 120: ('[Li]~[#8]', 0), + 121: ('[Li]~[F]', 0), + 122: ('[Li]~[#15]', 0), + 123: ('[Li]~[#16]', 0), + 124: ('[Li]~[Cl]', 0), + 125: ('[#5&!H0]', 0), + 126: ('[#5]~[#5]', 0), + 127: ('[#5]~[#6]', 0), + 128: ('[#5]~[#7]', 0), + 129: ('[#5]~[#8]', 0), + 130: ('[#5]~[F]', 0), + 131: ('[#5]~[#14]', 0), + 132: ('[#5]~[#15]', 0), + 133: ('[#5]~[#16]', 0), + 134: ('[#5]~[Cl]', 0), + 135: ('[#5]~[Br]', 0), + 136: ('[#6&!H0]', 0), + 137: ('[#6]~[#6]', 0), + 138: ('[#6]~[#7]', 0), + 139: ('[#6]~[#8]', 0), + 140: ('[#6]~[F]', 0), + 141: ('[#6]~[Na]', 0), + 142: ('[#6]~[Mg]', 0), + 143: ('[#6]~[Al]', 0), + 144: ('[#6]~[#14]', 0), + 145: ('[#6]~[#15]', 0), + 146: ('[#6]~[#16]', 0), + 147: ('[#6]~[Cl]', 0), + 148: ('[#6]~[#33]', 0), + 149: ('[#6]~[#34]', 0), + 150: ('[#6]~[Br]', 0), + 151: ('[#6]~[I]', 0), + 152: ('[#7&!H0]', 0), + 153: ('[#7]~[#7]', 0), + 154: ('[#7]~[#8]', 0), + 155: ('[#7]~[F]', 0), + 156: ('[#7]~[#14]', 0), + 157: ('[#7]~[#15]', 0), + 158: ('[#7]~[#16]', 0), + 159: ('[#7]~[Cl]', 0), + 160: ('[#7]~[Br]', 0), + 161: ('[#8&!H0]', 0), + 162: ('[#8]~[#8]', 0), + 163: ('[#8]~[Mg]', 0), + 164: ('[#8]~[Na]', 0), + 165: ('[#8]~[Al]', 0), + 166: ('[#8]~[#14]', 0), + 167: ('[#8]~[#15]', 0), + 168: ('[#8]~[K]', 0), + 169: ('[F]~[#15]', 0), + 170: ('[F]~[#16]', 0), + 171: ('[Al&!H0]', 0), + 172: ('[Al]~[Cl]', 0), + 173: ('[#14&!H0]', 0), + 174: ('[#14]~[#14]', 0), + 175: ('[#14]~[Cl]', 0), + 176: ('[#15&!H0]', 0), + 177: ('[#15]~[#15]', 0), + 178: ('[#33&!H0]', 0), + 179: ('[#33]~[#33]', 0), + 180: ('[#6](~Br)(~[#6])', 0), + 181: ('[#6](~Br)(~[#6])(~[#6])', 0), + 182: ('[#6&!H0]~[Br]', 0), + 183: ('[#6](~[Br])(:[c])', 0), + 184: ('[#6](~[Br])(:[n])', 0), + 185: ('[#6](~[#6])(~[#6])', 0), + 186: ('[#6](~[#6])(~[#6])(~[#6])', 0), + 187: ('[#6](~[#6])(~[#6])(~[#6])(~[#6])', 0), + 188: ('[#6H1](~[#6])(~[#6])(~[#6])', 0), + 189: ('[#6](~[#6])(~[#6])(~[#6])(~[#7])', 0), + 190: ('[#6](~[#6])(~[#6])(~[#6])(~[#8])', 0), + 191: ('[#6H1](~[#6])(~[#6])(~[#7])', 0), + 192: ('[#6H1](~[#6])(~[#6])(~[#8])', 0), + 193: ('[#6](~[#6])(~[#6])(~[#7])', 0), + 194: ('[#6](~[#6])(~[#6])(~[#8])', 0), + 195: ('[#6](~[#6])(~[Cl])', 0), + 196: ('[#6&!H0](~[#6])(~[Cl])', 0), + 197: ('[#6H,#6H2,#6H3,#6H4]~[#6]', 0), + 198: ('[#6&!H0](~[#6])(~[#7])', 0), + 199: ('[#6&!H0](~[#6])(~[#8])', 0), + 200: ('[#6H1](~[#6])(~[#8])(~[#8])', 0), + 201: ('[#6&!H0](~[#6])(~[#15])', 0), + 202: ('[#6&!H0](~[#6])(~[#16])', 0), + 203: ('[#6](~[#6])(~[I])', 0), + 204: ('[#6](~[#6])(~[#7])', 0), + 205: ('[#6](~[#6])(~[#8])', 0), + 206: ('[#6](~[#6])(~[#16])', 0), + 207: ('[#6](~[#6])(~[#14])', 0), + 208: ('[#6](~[#6])(:c)', 0), + 209: ('[#6](~[#6])(:c)(:c)', 0), + 210: ('[#6](~[#6])(:c)(:n)', 0), + 211: ('[#6](~[#6])(:n)', 0), + 212: ('[#6](~[#6])(:n)(:n)', 0), + 213: ('[#6](~[Cl])(~[Cl])', 0), + 214: ('[#6&!H0](~[Cl])', 0), + 215: ('[#6](~[Cl])(:c)', 0), + 216: ('[#6](~[F])(~[F])', 0), + 217: ('[#6](~[F])(:c)', 0), + 218: ('[#6&!H0](~[#7])', 0), + 219: ('[#6&!H0](~[#8])', 0), + 220: ('[#6&!H0](~[#8])(~[#8])', 0), + 221: ('[#6&!H0](~[#16])', 0), + 222: ('[#6&!H0](~[#14])', 0), + 223: ('[#6&!H0]:c', 0), + 224: ('[#6&!H0](:c)(:c)', 0), + 225: ('[#6&!H0](:c)(:n)', 0), + 226: ('[#6&!H0](:n)', 0), + 227: ('[#6H3]', 0), + 228: ('[#6](~[#7])(~[#7])', 0), + 229: ('[#6](~[#7])(:c)', 0), + 230: ('[#6](~[#7])(:c)(:c)', 0), + 231: ('[#6](~[#7])(:c)(:n)', 0), + 232: ('[#6](~[#7])(:n)', 0), + 233: ('[#6](~[#8])(~[#8])', 0), + 234: ('[#6](~[#8])(:c)', 0), + 235: ('[#6](~[#8])(:c)(:c)', 0), + 236: ('[#6](~[#16])(:c)', 0), + 237: ('[#6](:c)(:c)', 0), + 238: ('[#6](:c)(:c)(:c)', 0), + 239: ('[#6](:c)(:c)(:n)', 0), + 240: ('[#6](:c)(:n)', 0), + 241: ('[#6](:c)(:n)(:n)', 0), + 242: ('[#6](:n)(:n)', 0), + 243: ('[#7](~[#6])(~[#6])', 0), + 244: ('[#7](~[#6])(~[#6])(~[#6])', 0), + 245: ('[#7&!H0](~[#6])(~[#6])', 0), + 246: ('[#7&!H0](~[#6])', 0), + 247: ('[#7&!H0](~[#6])(~[#7])', 0), + 248: ('[#7](~[#6])(~[#8])', 0), + 249: ('[#7](~[#6])(:c)', 0), + 250: ('[#7](~[#6])(:c)(:c)', 0), + 251: ('[#7&!H0](~[#7])', 0), + 252: ('[#7&!H0](:c)', 0), + 253: ('[#7&!H0](:c)(:c)', 0), + 254: ('[#7](~[#8])(~[#8])', 0), + 255: ('[#7](~[#8])(:o)', 0), + 256: ('[#7](:c)(:c)', 0), + 257: ('[#7](:c)(:c)(:c)', 0), + 258: ('[#8](~[#6])(~[#6])', 0), + 259: ('[#8&!H0](~[#6])', 0), + 260: ('[#8](~[#6])(~[#15])', 0), + 261: ('[#8&!H0](~[#16])', 0), + 262: ('[#8](:c)(:c)', 0), + 263: ('[#15](~[#6])(~[#6])', 0), + 264: ('[#15](~[#8])(~[#8])', 0), + 265: ('[#16](~[#6])(~[#6])', 0), + 266: ('[#16&!H0](~[#6])', 0), + 267: ('[#16](~[#6])(~[#8])', 0), + 268: ('[#14](~[#6])(~[#6])', 0), + 269: ('[#6]=,:[#6]', 0), + 270: ('[#6]#[#6]', 0), + 271: ('[#6]=,:[#7]', 0), + 272: ('[#6]#[#7]', 0), + 273: ('[#6]=,:[#8]', 0), + 274: ('[#6]=,:[#16]', 0), + 275: ('[#7]=,:[#7]', 0), + 276: ('[#7]=,:[#8]', 0), + 277: ('[#7]=,:[#15]', 0), + 278: ('[#15]=,:[#8]', 0), + 279: ('[#15]=,:[#15]', 0), + 280: ('[#6](#[#6])(-,:[#6])', 0), + 281: ('[#6&!H0](#[#6])', 0), + 282: ('[#6](#[#7])(-,:[#6])', 0), + 283: ('[#6](-,:[#6])(-,:[#6])(=,:[#6])', 0), + 284: ('[#6](-,:[#6])(-,:[#6])(=,:[#7])', 0), + 285: ('[#6](-,:[#6])(-,:[#6])(=,:[#8])', 0), + 286: ('[#6](-,:[#6])([Cl])(=,:[#8])', 0), + 287: ('[#6&!H0](-,:[#6])(=,:[#6])', 0), + 288: ('[#6&!H0](-,:[#6])(=,:[#7])', 0), + 289: ('[#6&!H0](-,:[#6])(=,:[#8])', 0), + 290: ('[#6](-,:[#6])(-,:[#7])(=,:[#6])', 0), + 291: ('[#6](-,:[#6])(-,:[#7])(=,:[#7])', 0), + 292: ('[#6](-,:[#6])(-,:[#7])(=,:[#8])', 0), + 293: ('[#6](-,:[#6])(-,:[#8])(=,:[#8])', 0), + 294: ('[#6](-,:[#6])(=,:[#6])', 0), + 295: ('[#6](-,:[#6])(=,:[#7])', 0), + 296: ('[#6](-,:[#6])(=,:[#8])', 0), + 297: ('[#6]([Cl])(=,:[#8])', 0), + 298: ('[#6&!H0](-,:[#7])(=,:[#6])', 0), + 299: ('[#6&!H0](=,:[#6])', 0), + 300: ('[#6&!H0](=,:[#7])', 0), + 301: ('[#6&!H0](=,:[#8])', 0), + 302: ('[#6](-,:[#7])(=,:[#6])', 0), + 303: ('[#6](-,:[#7])(=,:[#7])', 0), + 304: ('[#6](-,:[#7])(=,:[#8])', 0), + 305: ('[#6](-,:[#8])(=,:[#8])', 0), + 306: ('[#7](-,:[#6])(=,:[#6])', 0), + 307: ('[#7](-,:[#6])(=,:[#8])', 0), + 308: ('[#7](-,:[#8])(=,:[#8])', 0), + 309: ('[#15](-,:[#8])(=,:[#8])', 0), + 310: ('[#16](-,:[#6])(=,:[#8])', 0), + 311: ('[#16](-,:[#8])(=,:[#8])', 0), + 312: ('[#16](=,:[#8])(=,:[#8])', 0), + 313: ('[#6]-,:[#6]-,:[#6]#[#6]', 0), + 314: ('[#8]-,:[#6]-,:[#6]=,:[#7]', 0), + 315: ('[#8]-,:[#6]-,:[#6]=,:[#8]', 0), + 316: ('[#7]:[#6]-,:[#16&!H0]', 0), + 317: ('[#7]-,:[#6]-,:[#6]=,:[#6]', 0), + 318: ('[#8]=,:[#16]-,:[#6]-,:[#6]', 0), + 319: ('[#7]#[#6]-,:[#6]=,:[#6]', 0), + 320: ('[#6]=,:[#7]-,:[#7]-,:[#6]', 0), + 321: ('[#8]=,:[#16]-,:[#6]-,:[#7]', 0), + 322: ('[#16]-,:[#16]-,:[#6]:[#6]', 0), + 323: ('[#6]:[#6]-,:[#6]=,:[#6]', 0), + 324: ('[#16]:[#6]:[#6]:[#6]', 0), + 325: ('[#6]:[#7]:[#6]-,:[#6]', 0), + 326: ('[#16]-,:[#6]:[#7]:[#6]', 0), + 327: ('[#16]:[#6]:[#6]:[#7]', 0), + 328: ('[#16]-,:[#6]=,:[#7]-,:[#6]', 0), + 329: ('[#6]-,:[#8]-,:[#6]=,:[#6]', 0), + 330: ('[#7]-,:[#7]-,:[#6]:[#6]', 0), + 331: ('[#16]-,:[#6]=,:[#7&!H0]', 0), + 332: ('[#16]-,:[#6]-,:[#16]-,:[#6]', 0), + 333: ('[#6]:[#16]:[#6]-,:[#6]', 0), + 334: ('[#8]-,:[#16]-,:[#6]:[#6]', 0), + 335: ('[#6]:[#7]-,:[#6]:[#6]', 0), + 336: ('[#7]-,:[#16]-,:[#6]:[#6]', 0), + 337: ('[#7]-,:[#6]:[#7]:[#6]', 0), + 338: ('[#7]:[#6]:[#6]:[#7]', 0), + 339: ('[#7]-,:[#6]:[#7]:[#7]', 0), + 340: ('[#7]-,:[#6]=,:[#7]-,:[#6]', 0), + 341: ('[#7]-,:[#6]=,:[#7&!H0]', 0), + 342: ('[#7]-,:[#6]-,:[#16]-,:[#6]', 0), + 343: ('[#6]-,:[#6]-,:[#6]=,:[#6]', 0), + 344: ('[#6]-,:[#7]:[#6&!H0]', 0), + 345: ('[#7]-,:[#6]:[#8]:[#6]', 0), + 346: ('[#8]=,:[#6]-,:[#6]:[#6]', 0), + 347: ('[#8]=,:[#6]-,:[#6]:[#7]', 0), + 348: ('[#6]-,:[#7]-,:[#6]:[#6]', 0), + 349: ('[#7]:[#7]-,:[#6&!H0]', 0), + 350: ('[#8]-,:[#6]:[#6]:[#7]', 0), + 351: ('[#8]-,:[#6]=,:[#6]-,:[#6]', 0), + 352: ('[#7]-,:[#6]:[#6]:[#7]', 0), + 353: ('[#6]-,:[#16]-,:[#6]:[#6]', 0), + 354: ('[Cl]-,:[#6]:[#6]-,:[#6]', 0), + 355: ('[#7]-,:[#6]=,:[#6&!H0]', 0), + 356: ('[Cl]-,:[#6]:[#6&!H0]', 0), + 357: ('[#7]:[#6]:[#7]-,:[#6]', 0), + 358: ('[Cl]-,:[#6]:[#6]-,:[#8]', 0), + 359: ('[#6]-,:[#6]:[#7]:[#6]', 0), + 360: ('[#6]-,:[#6]-,:[#16]-,:[#6]', 0), + 361: ('[#16]=,:[#6]-,:[#7]-,:[#6]', 0), + 362: ('[Br]-,:[#6]:[#6]-,:[#6]', 0), + 363: ('[#7&!H0]-,:[#7&!H0]', 0), + 364: ('[#16]=,:[#6]-,:[#7&!H0]', 0), + 365: ('[#6]-,:[#33]-[#8&!H0]', 0), + 366: ('[#16]:[#6]:[#6&!H0]', 0), + 367: ('[#8]-,:[#7]-,:[#6]-,:[#6]', 0), + 368: ('[#7]-,:[#7]-,:[#6]-,:[#6]', 0), + 369: ('[#6H,#6H2,#6H3]=,:[#6H,#6H2,#6H3]', 0), + 370: ('[#7]-,:[#7]-,:[#6]-,:[#7]', 0), + 371: ('[#8]=,:[#6]-,:[#7]-,:[#7]', 0), + 372: ('[#7]=,:[#6]-,:[#7]-,:[#6]', 0), + 373: ('[#6]=,:[#6]-,:[#6]:[#6]', 0), + 374: ('[#6]:[#7]-,:[#6&!H0]', 0), + 375: ('[#6]-,:[#7]-,:[#7&!H0]', 0), + 376: ('[#7]:[#6]:[#6]-,:[#6]', 0), + 377: ('[#6]-,:[#6]=,:[#6]-,:[#6]', 0), + 378: ('[#33]-,:[#6]:[#6&!H0]', 0), + 379: ('[Cl]-,:[#6]:[#6]-,:[Cl]', 0), + 380: ('[#6]:[#6]:[#7&!H0]', 0), + 381: ('[#7&!H0]-,:[#6&!H0]', 0), + 382: ('[Cl]-,:[#6]-,:[#6]-,:[Cl]', 0), + 383: ('[#7]:[#6]-,:[#6]:[#6]', 0), + 384: ('[#16]-,:[#6]:[#6]-,:[#6]', 0), + 385: ('[#16]-,:[#6]:[#6&!H0]', 0), + 386: ('[#16]-,:[#6]:[#6]-,:[#7]', 0), + 387: ('[#16]-,:[#6]:[#6]-,:[#8]', 0), + 388: ('[#8]=,:[#6]-,:[#6]-,:[#6]', 0), + 389: ('[#8]=,:[#6]-,:[#6]-,:[#7]', 0), + 390: ('[#8]=,:[#6]-,:[#6]-,:[#8]', 0), + 391: ('[#7]=,:[#6]-,:[#6]-,:[#6]', 0), + 392: ('[#7]=,:[#6]-,:[#6&!H0]', 0), + 393: ('[#6]-,:[#7]-,:[#6&!H0]', 0), + 394: ('[#8]-,:[#6]:[#6]-,:[#6]', 0), + 395: ('[#8]-,:[#6]:[#6&!H0]', 0), + 396: ('[#8]-,:[#6]:[#6]-,:[#7]', 0), + 397: ('[#8]-,:[#6]:[#6]-,:[#8]', 0), + 398: ('[#7]-,:[#6]:[#6]-,:[#6]', 0), + 399: ('[#7]-,:[#6]:[#6&!H0]', 0), + 400: ('[#7]-,:[#6]:[#6]-,:[#7]', 0), + 401: ('[#8]-,:[#6]-,:[#6]:[#6]', 0), + 402: ('[#7]-,:[#6]-,:[#6]:[#6]', 0), + 403: ('[Cl]-,:[#6]-,:[#6]-,:[#6]', 0), + 404: ('[Cl]-,:[#6]-,:[#6]-,:[#8]', 0), + 405: ('[#6]:[#6]-,:[#6]:[#6]', 0), + 406: ('[#8]=,:[#6]-,:[#6]=,:[#6]', 0), + 407: ('[Br]-,:[#6]-,:[#6]-,:[#6]', 0), + 408: ('[#7]=,:[#6]-,:[#6]=,:[#6]', 0), + 409: ('[#6]=,:[#6]-,:[#6]-,:[#6]', 0), + 410: ('[#7]:[#6]-,:[#8&!H0]', 0), + 411: ('[#8]=,:[#7]-,:c:c', 0), + 412: ('[#8]-,:[#6]-,:[#7&!H0]', 0), + 413: ('[#7]-,:[#6]-,:[#7]-,:[#6]', 0), + 414: ('[Cl]-,:[#6]-,:[#6]=,:[#8]', 0), + 415: ('[Br]-,:[#6]-,:[#6]=,:[#8]', 0), + 416: ('[#8]-,:[#6]-,:[#8]-,:[#6]', 0), + 417: ('[#6]=,:[#6]-,:[#6]=,:[#6]', 0), + 418: ('[#6]:[#6]-,:[#8]-,:[#6]', 0), + 419: ('[#8]-,:[#6]-,:[#6]-,:[#7]', 0), + 420: ('[#8]-,:[#6]-,:[#6]-,:[#8]', 0), + 421: ('N#[#6]-,:[#6]-,:[#6]', 0), + 422: ('[#7]-,:[#6]-,:[#6]-,:[#7]', 0), + 423: ('[#6]:[#6]-,:[#6]-,:[#6]', 0), + 424: ('[#6&!H0]-,:[#8&!H0]', 0), + 425: ('n:c:n:c', 0), + 426: ('[#8]-,:[#6]-,:[#6]=,:[#6]', 0), + 427: ('[#8]-,:[#6]-,:[#6]:[#6]-,:[#6]', 0), + 428: ('[#8]-,:[#6]-,:[#6]:[#6]-,:[#8]', 0), + 429: ('[#7]=,:[#6]-,:[#6]:[#6&!H0]', 0), + 430: ('c:c-,:[#7]-,:c:c', 0), + 431: ('[#6]-,:[#6]:[#6]-,:c:c', 0), + 432: ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 433: ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#7]', 0), + 434: ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#8]', 0), + 435: ('[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 436: ('[Cl]-,:[#6]:[#6]-,:[#8]-,:[#6]', 0), + 437: ('c:c-,:[#6]=,:[#6]-,:[#6]', 0), + 438: ('[#6]-,:[#6]:[#6]-,:[#7]-,:[#6]', 0), + 439: ('[#6]-,:[#16]-,:[#6]-,:[#6]-,:[#6]', 0), + 440: ('[#7]-,:[#6]:[#6]-,:[#8&!H0]', 0), + 441: ('[#8]=,:[#6]-,:[#6]-,:[#6]=,:[#8]', 0), + 442: ('[#6]-,:[#6]:[#6]-,:[#8]-,:[#6]', 0), + 443: ('[#6]-,:[#6]:[#6]-,:[#8&!H0]', 0), + 444: ('[Cl]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 445: ('[#7]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 446: ('[#7]-,:[#6]-,:[#6]-,:[#6]-,:[#7]', 0), + 447: ('[#6]-,:[#8]-,:[#6]-,:[#6]=,:[#6]', 0), + 448: ('c:c-,:[#6]-,:[#6]-,:[#6]', 0), + 449: ('[#7]=,:[#6]-,:[#7]-,:[#6]-,:[#6]', 0), + 450: ('[#8]=,:[#6]-,:[#6]-,:c:c', 0), + 451: ('[Cl]-,:[#6]:[#6]:[#6]-,:[#6]', 0), + 452: ('[#6H,#6H2,#6H3]-,:[#6]=,:[#6H,#6H2,#6H3]', 0), + 453: ('[#7]-,:[#6]:[#6]:[#6]-,:[#6]', 0), + 454: ('[#7]-,:[#6]:[#6]:[#6]-,:[#7]', 0), + 455: ('[#8]=,:[#6]-,:[#6]-,:[#7]-,:[#6]', 0), + 456: ('[#6]-,:c:c:[#6]-,:[#6]', 0), + 457: ('[#6]-,:[#8]-,:[#6]-,:[#6]:c', 0), + 458: ('[#8]=,:[#6]-,:[#6]-,:[#8]-,:[#6]', 0), + 459: ('[#8]-,:[#6]:[#6]-,:[#6]-,:[#6]', 0), + 460: ('[#7]-,:[#6]-,:[#6]-,:[#6]:c', 0), + 461: ('[#6]-,:[#6]-,:[#6]-,:[#6]:c', 0), + 462: ('[Cl]-,:[#6]-,:[#6]-,:[#7]-,:[#6]', 0), + 463: ('[#6]-,:[#8]-,:[#6]-,:[#8]-,:[#6]', 0), + 464: ('[#7]-,:[#6]-,:[#6]-,:[#7]-,:[#6]', 0), + 465: ('[#7]-,:[#6]-,:[#8]-,:[#6]-,:[#6]', 0), + 466: ('[#6]-,:[#7]-,:[#6]-,:[#6]-,:[#6]', 0), + 467: ('[#6]-,:[#6]-,:[#8]-,:[#6]-,:[#6]', 0), + 468: ('[#7]-,:[#6]-,:[#6]-,:[#8]-,:[#6]', 0), + 469: ('c:c:n:n:c', 0), + 470: ('[#6]-,:[#6]-,:[#6]-,:[#8&!H0]', 0), + 471: ('c:[#6]-,:[#6]-,:[#6]:c', 0), + 472: ('[#8]-,:[#6]-,:[#6]=,:[#6]-,:[#6]', 0), + 473: ('c:c-,:[#8]-,:[#6]-,:[#6]', 0), + 474: ('[#7]-,:[#6]:c:c:n', 0), + 475: ('[#8]=,:[#6]-,:[#8]-,:[#6]:c', 0), + 476: ('[#8]=,:[#6]-,:[#6]:[#6]-,:[#6]', 0), + 477: ('[#8]=,:[#6]-,:[#6]:[#6]-,:[#7]', 0), + 478: ('[#8]=,:[#6]-,:[#6]:[#6]-,:[#8]', 0), + 479: ('[#6]-,:[#8]-,:[#6]:[#6]-,:[#6]', 0), + 480: ('[#8]=,:[#33]-,:[#6]:c:c', 0), + 481: ('[#6]-,:[#7]-,:[#6]-,:[#6]:c', 0), + 482: ('[#16]-,:[#6]:c:c-,:[#7]', 0), + 483: ('[#8]-,:[#6]:[#6]-,:[#8]-,:[#6]', 0), + 484: ('[#8]-,:[#6]:[#6]-,:[#8&!H0]', 0), + 485: ('[#6]-,:[#6]-,:[#8]-,:[#6]:c', 0), + 486: ('[#7]-,:[#6]-,:[#6]:[#6]-,:[#6]', 0), + 487: ('[#6]-,:[#6]-,:[#6]:[#6]-,:[#6]', 0), + 488: ('[#7]-,:[#7]-,:[#6]-,:[#7&!H0]', 0), + 489: ('[#6]-,:[#7]-,:[#6]-,:[#7]-,:[#6]', 0), + 490: ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 491: ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#7]', 0), + 492: ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#8]', 0), + 493: ('[#6]=,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 494: ('[#8]-,:[#6]-,:[#6]-,:[#6]=,:[#6]', 0), + 495: ('[#8]-,:[#6]-,:[#6]-,:[#6]=,:[#8]', 0), + 496: ('[#6&!H0]-,:[#6]-,:[#7&!H0]', 0), + 497: ('[#6]-,:[#6]=,:[#7]-,:[#7]-,:[#6]', 0), + 498: ('[#8]=,:[#6]-,:[#7]-,:[#6]-,:[#6]', 0), + 499: ('[#8]=,:[#6]-,:[#7]-,:[#6&!H0]', 0), + 500: ('[#8]=,:[#6]-,:[#7]-,:[#6]-,:[#7]', 0), + 501: ('[#8]=,:[#7]-,:[#6]:[#6]-,:[#7]', 0), + 502: ('[#8]=,:[#7]-,:c:c-,:[#8]', 0), + 503: ('[#8]=,:[#6]-,:[#7]-,:[#6]=,:[#8]', 0), + 504: ('[#8]-,:[#6]:[#6]:[#6]-,:[#6]', 0), + 505: ('[#8]-,:[#6]:[#6]:[#6]-,:[#7]', 0), + 506: ('[#8]-,:[#6]:[#6]:[#6]-,:[#8]', 0), + 507: ('[#7]-,:[#6]-,:[#7]-,:[#6]-,:[#6]', 0), + 508: ('[#8]-,:[#6]-,:[#6]-,:[#6]:c', 0), + 509: ('[#6]-,:[#6]-,:[#7]-,:[#6]-,:[#6]', 0), + 510: ('[#6]-,:[#7]-,:[#6]:[#6]-,:[#6]', 0), + 511: ('[#6]-,:[#6]-,:[#16]-,:[#6]-,:[#6]', 0), + 512: ('[#8]-,:[#6]-,:[#6]-,:[#7]-,:[#6]', 0), + 513: ('[#6]-,:[#6]=,:[#6]-,:[#6]-,:[#6]', 0), + 514: ('[#8]-,:[#6]-,:[#8]-,:[#6]-,:[#6]', 0), + 515: ('[#8]-,:[#6]-,:[#6]-,:[#8]-,:[#6]', 0), + 516: ('[#8]-,:[#6]-,:[#6]-,:[#8&!H0]', 0), + 517: ('[#6]-,:[#6]=,:[#6]-,:[#6]=,:[#6]', 0), + 518: ('[#7]-,:[#6]:[#6]-,:[#6]-,:[#6]', 0), + 519: ('[#6]=,:[#6]-,:[#6]-,:[#8]-,:[#6]', 0), + 520: ('[#6]=,:[#6]-,:[#6]-,:[#8&!H0]', 0), + 521: ('[#6]-,:[#6]:[#6]-,:[#6]-,:[#6]', 0), + 522: ('[Cl]-,:[#6]:[#6]-,:[#6]=,:[#8]', 0), + 523: ('[Br]-,:[#6]:c:c-,:[#6]', 0), + 524: ('[#8]=,:[#6]-,:[#6]=,:[#6]-,:[#6]', 0), + 525: ('[#8]=,:[#6]-,:[#6]=,:[#6&!H0]', 0), + 526: ('[#8]=,:[#6]-,:[#6]=,:[#6]-,:[#7]', 0), + 527: ('[#7]-,:[#6]-,:[#7]-,:[#6]:c', 0), + 528: ('[Br]-,:[#6]-,:[#6]-,:[#6]:c', 0), + 529: ('[#7]#[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 530: ('[#6]-,:[#6]=,:[#6]-,:[#6]:c', 0), + 531: ('[#6]-,:[#6]-,:[#6]=,:[#6]-,:[#6]', 0), + 532: ('[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 533: ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 534: ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#8]', 0), + 535: ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#7]', 0), + 536: ('[#7]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 537: ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 538: ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#7]', 0), + 539: ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#8]', 0), + 540: ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]=,:[#8]', 0), + 541: ('[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 542: ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 543: ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#8]', 0), + 544: ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#7]', 0), + 545: ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 546: ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#8]', 0), + 547: ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]=,:[#8]', 0), + 548: ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#7]', 0), + 549: ('[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 550: ('[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6](-,:[#6])-,:[#6]', 0), + 551: ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 552: ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6](-,:[#6])-,:[#6]', 0), + 553: ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#8]-,:[#6]', 0), + 554: ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6](-,:[#8])-,:[#6]', 0), + 555: ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#7]-,:[#6]', 0), + 556: ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6](-,:[#7])-,:[#6]', 0), + 557: ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 558: ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6](-,:[#8])-,:[#6]', 0), + 559: ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6](=,:[#8])-,:[#6]', 0), + 560: ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6](-,:[#7])-,:[#6]', 0), + 561: ('[#6]-,:[#6](-,:[#6])-,:[#6]-,:[#6]', 0), + 562: ('[#6]-,:[#6](-,:[#6])-,:[#6]-,:[#6]-,:[#6]', 0), + 563: ('[#6]-,:[#6]-,:[#6](-,:[#6])-,:[#6]-,:[#6]', 0), + 564: ('[#6]-,:[#6](-,:[#6])(-,:[#6])-,:[#6]-,:[#6]', 0), + 565: ('[#6]-,:[#6](-,:[#6])-,:[#6](-,:[#6])-,:[#6]', 0), + 566: ('[#6]c1ccc([#6])cc1', 0), + 567: ('[#6]c1ccc([#8])cc1', 0), + 568: ('[#6]c1ccc([#16])cc1', 0), + 569: ('[#6]c1ccc([#7])cc1', 0), + 570: ('[#6]c1ccc(Cl)cc1', 0), + 571: ('[#6]c1ccc(Br)cc1', 0), + 572: ('[#8]c1ccc([#8])cc1', 0), + 573: ('[#8]c1ccc([#16])cc1', 0), + 574: ('[#8]c1ccc([#7])cc1', 0), + 575: ('[#8]c1ccc(Cl)cc1', 0), + 576: ('[#8]c1ccc(Br)cc1', 0), + 577: ('[#16]c1ccc([#16])cc1', 0), + 578: ('[#16]c1ccc([#7])cc1', 0), + 579: ('[#16]c1ccc(Cl)cc1', 0), + 580: ('[#16]c1ccc(Br)cc1', 0), + 581: ('[#7]c1ccc([#7])cc1', 0), + 582: ('[#7]c1ccc(Cl)cc1', 0), + 583: ('[#7]c1ccc(Br)cc1', 0), + 584: ('Clc1ccc(Cl)cc1', 0), + 585: ('Clc1ccc(Br)cc1', 0), + 586: ('Brc1ccc(Br)cc1', 0), + 587: ('[#6]c1cc([#6])ccc1', 0), + 588: ('[#6]c1cc([#8])ccc1', 0), + 589: ('[#6]c1cc([#16])ccc1', 0), + 590: ('[#6]c1cc([#7])ccc1', 0), + 591: ('[#6]c1cc(Cl)ccc1', 0), + 592: ('[#6]c1cc(Br)ccc1', 0), + 593: ('[#8]c1cc([#8])ccc1', 0), + 594: ('[#8]c1cc([#16])ccc1', 0), + 595: ('[#8]c1cc([#7])ccc1', 0), + 596: ('[#8]c1cc(Cl)ccc1', 0), + 597: ('[#8]c1cc(Br)ccc1', 0), + 598: ('[#16]c1cc([#16])ccc1', 0), + 599: ('[#16]c1cc([#7])ccc1', 0), + 600: ('[#16]c1cc(Cl)ccc1', 0), + 601: ('[#16]c1cc(Br)ccc1', 0), + 602: ('[#7]c1cc([#7])ccc1', 0), + 603: ('[#7]c1cc(Cl)ccc1', 0), + 604: ('[#7]c1cc(Br)ccc1', 0), + 605: ('Clc1cc(Cl)ccc1', 0), + 606: ('Clc1cc(Br)ccc1', 0), + 607: ('Brc1cc(Br)ccc1', 0), + 608: ('[#6]c1c([#6])cccc1', 0), + 609: ('[#6]c1c([#8])cccc1', 0), + 610: ('[#6]c1c([#16])cccc1', 0), + 611: ('[#6]c1c([#7])cccc1', 0), + 612: ('[#6]c1c(Cl)cccc1', 0), + 613: ('[#6]c1c(Br)cccc1', 0), + 614: ('[#8]c1c([#8])cccc1', 0), + 615: ('[#8]c1c([#16])cccc1', 0), + 616: ('[#8]c1c([#7])cccc1', 0), + 617: ('[#8]c1c(Cl)cccc1', 0), + 618: ('[#8]c1c(Br)cccc1', 0), + 619: ('[#16]c1c([#16])cccc1', 0), + 620: ('[#16]c1c([#7])cccc1', 0), + 621: ('[#16]c1c(Cl)cccc1', 0), + 622: ('[#16]c1c(Br)cccc1', 0), + 623: ('[#7]c1c([#7])cccc1', 0), + 624: ('[#7]c1c(Cl)cccc1', 0), + 625: ('[#7]c1c(Br)cccc1', 0), + 626: ('Clc1c(Cl)cccc1', 0), + 627: ('Clc1c(Br)cccc1', 0), + 628: ('Brc1c(Br)cccc1', 0), + 629: ('[#6][#6]1[#6][#6][#6]([#6])[#6][#6]1', 0), + 630: ('[#6][#6]1[#6][#6][#6]([#8])[#6][#6]1', 0), + 631: ('[#6][#6]1[#6][#6][#6]([#16])[#6][#6]1', 0), + 632: ('[#6][#6]1[#6][#6][#6]([#7])[#6][#6]1', 0), + 633: ('[#6][#6]1[#6][#6][#6](Cl)[#6][#6]1', 0), + 634: ('[#6][#6]1[#6][#6][#6](Br)[#6][#6]1', 0), + 635: ('[#8][#6]1[#6][#6][#6]([#8])[#6][#6]1', 0), + 636: ('[#8][#6]1[#6][#6][#6]([#16])[#6][#6]1', 0), + 637: ('[#8][#6]1[#6][#6][#6]([#7])[#6][#6]1', 0), + 638: ('[#8][#6]1[#6][#6][#6](Cl)[#6][#6]1', 0), + 639: ('[#8][#6]1[#6][#6][#6](Br)[#6][#6]1', 0), + 640: ('[#16][#6]1[#6][#6][#6]([#16])[#6][#6]1', 0), + 641: ('[#16][#6]1[#6][#6][#6]([#7])[#6][#6]1', 0), + 642: ('[#16][#6]1[#6][#6][#6](Cl)[#6][#6]1', 0), + 643: ('[#16][#6]1[#6][#6][#6](Br)[#6][#6]1', 0), + 644: ('[#7][#6]1[#6][#6][#6]([#7])[#6][#6]1', 0), + 645: ('[#7][#6]1[#6][#6][#6](Cl)[#6][#6]1', 0), + 646: ('[#7][#6]1[#6][#6][#6](Br)[#6][#6]1', 0), + 647: ('Cl[#6]1[#6][#6][#6](Cl)[#6][#6]1', 0), + 648: ('Cl[#6]1[#6][#6][#6](Br)[#6][#6]1', 0), + 649: ('Br[#6]1[#6][#6][#6](Br)[#6][#6]1', 0), + 650: ('[#6][#6]1[#6][#6]([#6])[#6][#6][#6]1', 0), + 651: ('[#6][#6]1[#6][#6]([#8])[#6][#6][#6]1', 0), + 652: ('[#6][#6]1[#6][#6]([#16])[#6][#6][#6]1', 0), + 653: ('[#6][#6]1[#6][#6]([#7])[#6][#6][#6]1', 0), + 654: ('[#6][#6]1[#6][#6](Cl)[#6][#6][#6]1', 0), + 655: ('[#6][#6]1[#6][#6](Br)[#6][#6][#6]1', 0), + 656: ('[#8][#6]1[#6][#6]([#8])[#6][#6][#6]1', 0), + 657: ('[#8][#6]1[#6][#6]([#16])[#6][#6][#6]1', 0), + 658: ('[#8][#6]1[#6][#6]([#7])[#6][#6][#6]1', 0), + 659: ('[#8][#6]1[#6][#6](Cl)[#6][#6][#6]1', 0), + 660: ('[#8][#6]1[#6][#6](Br)[#6][#6][#6]1', 0), + 661: ('[#16][#6]1[#6][#6]([#16])[#6][#6][#6]1', 0), + 662: ('[#16][#6]1[#6][#6]([#7])[#6][#6][#6]1', 0), + 663: ('[#16][#6]1[#6][#6](Cl)[#6][#6][#6]1', 0), + 664: ('[#16][#6]1[#6][#6](Br)[#6][#6][#6]1', 0), + 665: ('[#7][#6]1[#6][#6]([#7])[#6][#6][#6]1', 0), + 666: ('[#7][#6]1[#6][#6](Cl)[#6][#6][#6]1', 0), + 667: ('[#7][#6]1[#6][#6](Br)[#6][#6][#6]1', 0), + 668: ('Cl[#6]1[#6][#6](Cl)[#6][#6][#6]1', 0), + 669: ('Cl[#6]1[#6][#6](Br)[#6][#6][#6]1', 0), + 670: ('Br[#6]1[#6][#6](Br)[#6][#6][#6]1', 0), + 671: ('[#6][#6]1[#6]([#6])[#6][#6][#6][#6]1', 0), + 672: ('[#6][#6]1[#6]([#8])[#6][#6][#6][#6]1', 0), + 673: ('[#6][#6]1[#6]([#16])[#6][#6][#6][#6]1', 0), + 674: ('[#6][#6]1[#6]([#7])[#6][#6][#6][#6]1', 0), + 675: ('[#6][#6]1[#6](Cl)[#6][#6][#6][#6]1', 0), + 676: ('[#6][#6]1[#6](Br)[#6][#6][#6][#6]1', 0), + 677: ('[#8][#6]1[#6]([#8])[#6][#6][#6][#6]1', 0), + 678: ('[#8][#6]1[#6]([#16])[#6][#6][#6][#6]1', 0), + 679: ('[#8][#6]1[#6]([#7])[#6][#6][#6][#6]1', 0), + 680: ('[#8][#6]1[#6](Cl)[#6][#6][#6][#6]1', 0), + 681: ('[#8][#6]1[#6](Br)[#6][#6][#6][#6]1', 0), + 682: ('[#16][#6]1[#6]([#16])[#6][#6][#6][#6]1', 0), + 683: ('[#16][#6]1[#6]([#7])[#6][#6][#6][#6]1', 0), + 684: ('[#16][#6]1[#6](Cl)[#6][#6][#6][#6]1', 0), + 685: ('[#16][#6]1[#6](Br)[#6][#6][#6][#6]1', 0), + 686: ('[#7][#6]1[#6]([#7])[#6][#6][#6][#6]1', 0), + 687: ('[#7][#6]1[#6](Cl)[#6][#6][#6][#6]1', 0), + 688: ('[#7][#6]1[#6](Br)[#6][#6][#6][#6]1', 0), + 689: ('Cl[#6]1[#6](Cl)[#6][#6][#6][#6]1', 0), + 690: ('Cl[#6]1[#6](Br)[#6][#6][#6][#6]1', 0), + 691: ('Br[#6]1[#6](Br)[#6][#6][#6][#6]1', 0), + 692: ('[#6][#6]1[#6][#6]([#6])[#6][#6]1', 0), + 693: ('[#6][#6]1[#6][#6]([#8])[#6][#6]1', 0), + 694: ('[#6][#6]1[#6][#6]([#16])[#6][#6]1', 0), + 695: ('[#6][#6]1[#6][#6]([#7])[#6][#6]1', 0), + 696: ('[#6][#6]1[#6][#6](Cl)[#6][#6]1', 0), + 697: ('[#6][#6]1[#6][#6](Br)[#6][#6]1', 0), + 698: ('[#8][#6]1[#6][#6]([#8])[#6][#6]1', 0), + 699: ('[#8][#6]1[#6][#6]([#16])[#6][#6]1', 0), + 700: ('[#8][#6]1[#6][#6]([#7])[#6][#6]1', 0), + 701: ('[#8][#6]1[#6][#6](Cl)[#6][#6]1', 0), + 702: ('[#8][#6]1[#6][#6](Br)[#6][#6]1', 0), + 703: ('[#16][#6]1[#6][#6]([#16])[#6][#6]1', 0), + 704: ('[#16][#6]1[#6][#6]([#7])[#6][#6]1', 0), + 705: ('[#16][#6]1[#6][#6](Cl)[#6][#6]1', 0), + 706: ('[#16][#6]1[#6][#6](Br)[#6][#6]1', 0), + 707: ('[#7][#6]1[#6][#6]([#7])[#6][#6]1', 0), + 708: ('[#7][#6]1[#6][#6](Cl)[#6][#6]1', 0), + 709: ('[#7][#6]1[#6][#6](Br)[#6][#6]1', 0), + 710: ('Cl[#6]1[#6][#6](Cl)[#6][#6]1', 0), + 711: ('Cl[#6]1[#6][#6](Br)[#6][#6]1', 0), + 712: ('Br[#6]1[#6][#6](Br)[#6][#6]1', 0), + 713: ('[#6][#6]1[#6]([#6])[#6][#6][#6]1', 0), + 714: ('[#6][#6]1[#6]([#8])[#6][#6][#6]1', 0), + 715: ('[#6][#6]1[#6]([#16])[#6][#6][#6]1', 0), + 716: ('[#6][#6]1[#6]([#7])[#6][#6][#6]1', 0), + 717: ('[#6][#6]1[#6](Cl)[#6][#6][#6]1', 0), + 718: ('[#6][#6]1[#6](Br)[#6][#6][#6]1', 0), + 719: ('[#8][#6]1[#6]([#8])[#6][#6][#6]1', 0), + 720: ('[#8][#6]1[#6]([#16])[#6][#6][#6]1', 0), + 721: ('[#8][#6]1[#6]([#7])[#6][#6][#6]1', 0), + 722: ('[#8][#6]1[#6](Cl)[#6][#6][#6]1', 0), + 723: ('[#8][#6]1[#6](Br)[#6][#6][#6]1', 0), + 724: ('[#16][#6]1[#6]([#16])[#6][#6][#6]1', 0), + 725: ('[#16][#6]1[#6]([#7])[#6][#6][#6]1', 0), + 726: ('[#16][#6]1[#6](Cl)[#6][#6][#6]1', 0), + 727: ('[#16][#6]1[#6](Br)[#6][#6][#6]1', 0), + 728: ('[#7][#6]1[#6]([#7])[#6][#6][#6]1', 0), + 729: ('[#7][#6]1[#6](Cl)[#6][#6]1', 0), + 730: ('[#7][#6]1[#6](Br)[#6][#6][#6]1', 0), + 731: ('Cl[#6]1[#6](Cl)[#6][#6][#6]1', 0), + 732: ('Cl[#6]1[#6](Br)[#6][#6][#6]1', 0), + 733: ('Br[#6]1[#6](Br)[#6][#6][#6]1', 0)} + +PubchemKeys = None + + +def InitKeys(keyList, keyDict): + """ *Internal Use Only* + generates SMARTS patterns for the keys, run once + """ + assert len(keyList) == len(keyDict.keys()), 'length mismatch' + for key in keyDict.keys(): + patt, count = keyDict[key] + if patt != '?': + sma = Chem.MolFromSmarts(patt) + if not sma: + print('SMARTS parser error for key #%d: %s' % (key, patt)) + else: + keyList[key - 1] = sma, count + + +def calcPubChemFingerPart1(mol, **kwargs): + """ Calculate PubChem Fingerprints (1-115; 263-881) + **Arguments** + - mol: the molecule to be fingerprinted + - any extra keyword arguments are ignored + **Returns** + a _DataStructs.SparseBitVect_ containing the fingerprint. + >>> m = Chem.MolFromSmiles('CNO') + >>> bv = PubChemFingerPart1(m) + >>> tuple(bv.GetOnBits()) + (24, 68, 69, 71, 93, 94, 102, 124, 131, 139, 151, 158, 160, 161, 164) + >>> bv = PubChemFingerPart1(Chem.MolFromSmiles('CCC')) + >>> tuple(bv.GetOnBits()) + (74, 114, 149, 155, 160) + """ + global PubchemKeys + if PubchemKeys is None: + PubchemKeys = [(None, 0)] * len(smartsPatts.keys()) + InitKeys(PubchemKeys, smartsPatts) + ctor = kwargs.get('ctor', DataStructs.SparseBitVect) + res = ctor(len(PubchemKeys) + 1) + for i, (patt, count) in enumerate(PubchemKeys): + if patt is not None: + if count == 0: + res[i + 1] = mol.HasSubstructMatch(patt) + else: + matches = mol.GetSubstructMatches(patt) + if len(matches) > count: + res[i + 1] = 1 + return res + + +def func_1(mol, bits): + """ *Internal Use Only* + Calculate PubChem Fingerprints (116-263) + """ + ringSize = [] + temp = {3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0} + AllRingsAtom = mol.GetRingInfo().AtomRings() + for ring in AllRingsAtom: + ringSize.append(len(ring)) + for k, v in temp.items(): + if len(ring) == k: + temp[k] += 1 + if temp[3] >= 2: + bits[0] = 1 + bits[7] = 1 + elif temp[3] == 1: + bits[0] = 1 + else: + pass + if temp[4] >= 2: + bits[14] = 1 + bits[21] = 1 + elif temp[4] == 1: + bits[14] = 1 + else: + pass + if temp[5] >= 5: + bits[28] = 1 + bits[35] = 1 + bits[42] = 1 + bits[49] = 1 + bits[56] = 1 + elif temp[5] == 4: + bits[28] = 1 + bits[35] = 1 + bits[42] = 1 + bits[49] = 1 + elif temp[5] == 3: + bits[28] = 1 + bits[35] = 1 + bits[42] = 1 + elif temp[5] == 2: + bits[28] = 1 + bits[35] = 1 + elif temp[5] == 1: + bits[28] = 1 + else: + pass + if temp[6] >= 5: + bits[63] = 1 + bits[70] = 1 + bits[77] = 1 + bits[84] = 1 + bits[91] = 1 + elif temp[6] == 4: + bits[63] = 1 + bits[70] = 1 + bits[77] = 1 + bits[84] = 1 + elif temp[6] == 3: + bits[63] = 1 + bits[70] = 1 + bits[77] = 1 + elif temp[6] == 2: + bits[63] = 1 + bits[70] = 1 + elif temp[6] == 1: + bits[63] = 1 + else: + pass + if temp[7] >= 2: + bits[98] = 1 + bits[105] = 1 + elif temp[7] == 1: + bits[98] = 1 + else: + pass + if temp[8] >= 2: + bits[112] = 1 + bits[119] = 1 + elif temp[8] == 1: + bits[112] = 1 + else: + pass + if temp[9] >= 1: + bits[126] = 1 + else: + pass + if temp[10] >= 1: + bits[133] = 1 + else: + pass + + return ringSize, bits + + +def func_2(mol, bits): + """ *Internal Use Only* + saturated or aromatic carbon-only ring + """ + AllRingsBond = mol.GetRingInfo().BondRings() + ringSize = [] + temp = {3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0} + for ring in AllRingsBond: + ######### saturated + nonsingle = False + for bondIdx in ring: + if mol.GetBondWithIdx(bondIdx).GetBondType().name != 'SINGLE': + nonsingle = True + break + if nonsingle == False: + ringSize.append(len(ring)) + for k, v in temp.items(): + if len(ring) == k: + temp[k] += 1 + ######## aromatic carbon-only + aromatic = True + AllCarb = True + for bondIdx in ring: + if mol.GetBondWithIdx(bondIdx).GetBondType().name != 'AROMATIC': + aromatic = False + break + for bondIdx in ring: + BeginAtom = mol.GetBondWithIdx(bondIdx).GetBeginAtom() + EndAtom = mol.GetBondWithIdx(bondIdx).GetEndAtom() + if BeginAtom.GetAtomicNum() != 6 or EndAtom.GetAtomicNum() != 6: + AllCarb = False + break + if aromatic == True and AllCarb == True: + ringSize.append(len(ring)) + for k, v in temp.items(): + if len(ring) == k: + temp[k] += 1 + if temp[3] >= 2: + bits[1] = 1 + bits[8] = 1 + elif temp[3] == 1: + bits[1] = 1 + else: + pass + if temp[4] >= 2: + bits[15] = 1 + bits[22] = 1 + elif temp[4] == 1: + bits[15] = 1 + else: + pass + if temp[5] >= 5: + bits[29] = 1 + bits[36] = 1 + bits[43] = 1 + bits[50] = 1 + bits[57] = 1 + elif temp[5] == 4: + bits[29] = 1 + bits[36] = 1 + bits[43] = 1 + bits[50] = 1 + elif temp[5] == 3: + bits[29] = 1 + bits[36] = 1 + bits[43] = 1 + elif temp[5] == 2: + bits[29] = 1 + bits[36] = 1 + elif temp[5] == 1: + bits[29] = 1 + else: + pass + if temp[6] >= 5: + bits[64] = 1 + bits[71] = 1 + bits[78] = 1 + bits[85] = 1 + bits[92] = 1 + elif temp[6] == 4: + bits[64] = 1 + bits[71] = 1 + bits[78] = 1 + bits[85] = 1 + elif temp[6] == 3: + bits[64] = 1 + bits[71] = 1 + bits[78] = 1 + elif temp[6] == 2: + bits[64] = 1 + bits[71] = 1 + elif temp[6] == 1: + bits[64] = 1 + else: + pass + if temp[7] >= 2: + bits[99] = 1 + bits[106] = 1 + elif temp[7] == 1: + bits[99] = 1 + else: + pass + if temp[8] >= 2: + bits[113] = 1 + bits[120] = 1 + elif temp[8] == 1: + bits[113] = 1 + else: + pass + if temp[9] >= 1: + bits[127] = 1 + else: + pass + if temp[10] >= 1: + bits[134] = 1 + else: + pass + return ringSize, bits + + +def func_3(mol, bits): + """ *Internal Use Only* + saturated or aromatic nitrogen-containing + """ + AllRingsBond = mol.GetRingInfo().BondRings() + ringSize = [] + temp = {3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0} + for ring in AllRingsBond: + ######### saturated + nonsingle = False + for bondIdx in ring: + if mol.GetBondWithIdx(bondIdx).GetBondType().name != 'SINGLE': + nonsingle = True + break + if nonsingle == False: + ringSize.append(len(ring)) + for k, v in temp.items(): + if len(ring) == k: + temp[k] += 1 + ######## aromatic nitrogen-containing + aromatic = True + ContainNitro = False + for bondIdx in ring: + if mol.GetBondWithIdx(bondIdx).GetBondType().name != 'AROMATIC': + aromatic = False + break + for bondIdx in ring: + BeginAtom = mol.GetBondWithIdx(bondIdx).GetBeginAtom() + EndAtom = mol.GetBondWithIdx(bondIdx).GetEndAtom() + if BeginAtom.GetAtomicNum() == 7 or EndAtom.GetAtomicNum() == 7: + ContainNitro = True + break + if aromatic == True and ContainNitro == True: + ringSize.append(len(ring)) + for k, v in temp.items(): + if len(ring) == k: + temp[k] += 1 + if temp[3] >= 2: + bits[2] = 1 + bits[9] = 1 + elif temp[3] == 1: + bits[2] = 1 + else: + pass + if temp[4] >= 2: + bits[16] = 1 + bits[23] = 1 + elif temp[4] == 1: + bits[16] = 1 + else: + pass + if temp[5] >= 5: + bits[30] = 1 + bits[37] = 1 + bits[44] = 1 + bits[51] = 1 + bits[58] = 1 + elif temp[5] == 4: + bits[30] = 1 + bits[37] = 1 + bits[44] = 1 + bits[51] = 1 + elif temp[5] == 3: + bits[30] = 1 + bits[37] = 1 + bits[44] = 1 + elif temp[5] == 2: + bits[30] = 1 + bits[37] = 1 + elif temp[5] == 1: + bits[30] = 1 + else: + pass + if temp[6] >= 5: + bits[65] = 1 + bits[72] = 1 + bits[79] = 1 + bits[86] = 1 + bits[93] = 1 + elif temp[6] == 4: + bits[65] = 1 + bits[72] = 1 + bits[79] = 1 + bits[86] = 1 + elif temp[6] == 3: + bits[65] = 1 + bits[72] = 1 + bits[79] = 1 + elif temp[6] == 2: + bits[65] = 1 + bits[72] = 1 + elif temp[6] == 1: + bits[65] = 1 + else: + pass + if temp[7] >= 2: + bits[100] = 1 + bits[107] = 1 + elif temp[7] == 1: + bits[100] = 1 + else: + pass + if temp[8] >= 2: + bits[114] = 1 + bits[121] = 1 + elif temp[8] == 1: + bits[114] = 1 + else: + pass + if temp[9] >= 1: + bits[128] = 1 + else: + pass + if temp[10] >= 1: + bits[135] = 1 + else: + pass + return ringSize, bits + + +def func_4(mol, bits): + """ *Internal Use Only* + saturated or aromatic heteroatom-containing + """ + AllRingsBond = mol.GetRingInfo().BondRings() + ringSize = [] + temp = {3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0} + for ring in AllRingsBond: + ######### saturated + nonsingle = False + for bondIdx in ring: + if mol.GetBondWithIdx(bondIdx).GetBondType().name != 'SINGLE': + nonsingle = True + break + if nonsingle == False: + ringSize.append(len(ring)) + for k, v in temp.items(): + if len(ring) == k: + temp[k] += 1 + ######## aromatic heteroatom-containing + aromatic = True + heteroatom = False + for bondIdx in ring: + if mol.GetBondWithIdx(bondIdx).GetBondType().name != 'AROMATIC': + aromatic = False + break + for bondIdx in ring: + BeginAtom = mol.GetBondWithIdx(bondIdx).GetBeginAtom() + EndAtom = mol.GetBondWithIdx(bondIdx).GetEndAtom() + if BeginAtom.GetAtomicNum() not in [1, 6] or EndAtom.GetAtomicNum() not in [1, 6]: + heteroatom = True + break + if aromatic == True and heteroatom == True: + ringSize.append(len(ring)) + for k, v in temp.items(): + if len(ring) == k: + temp[k] += 1 + if temp[3] >= 2: + bits[3] = 1 + bits[10] = 1 + elif temp[3] == 1: + bits[3] = 1 + else: + pass + if temp[4] >= 2: + bits[17] = 1 + bits[24] = 1 + elif temp[4] == 1: + bits[17] = 1 + else: + pass + if temp[5] >= 5: + bits[31] = 1 + bits[38] = 1 + bits[45] = 1 + bits[52] = 1 + bits[59] = 1 + elif temp[5] == 4: + bits[31] = 1 + bits[38] = 1 + bits[45] = 1 + bits[52] = 1 + elif temp[5] == 3: + bits[31] = 1 + bits[38] = 1 + bits[45] = 1 + elif temp[5] == 2: + bits[31] = 1 + bits[38] = 1 + elif temp[5] == 1: + bits[31] = 1 + else: + pass + if temp[6] >= 5: + bits[66] = 1 + bits[73] = 1 + bits[80] = 1 + bits[87] = 1 + bits[94] = 1 + elif temp[6] == 4: + bits[66] = 1 + bits[73] = 1 + bits[80] = 1 + bits[87] = 1 + elif temp[6] == 3: + bits[66] = 1 + bits[73] = 1 + bits[80] = 1 + elif temp[6] == 2: + bits[66] = 1 + bits[73] = 1 + elif temp[6] == 1: + bits[66] = 1 + else: + pass + if temp[7] >= 2: + bits[101] = 1 + bits[108] = 1 + elif temp[7] == 1: + bits[101] = 1 + else: + pass + if temp[8] >= 2: + bits[115] = 1 + bits[122] = 1 + elif temp[8] == 1: + bits[115] = 1 + else: + pass + if temp[9] >= 1: + bits[129] = 1 + else: + pass + if temp[10] >= 1: + bits[136] = 1 + else: + pass + return ringSize, bits + + +def func_5(mol, bits): + """ *Internal Use Only* + unsaturated non-aromatic carbon-only + """ + ringSize = [] + AllRingsBond = mol.GetRingInfo().BondRings() + temp = {3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0} + for ring in AllRingsBond: + unsaturated = False + nonaromatic = True + Allcarb = True + ######### unsaturated + for bondIdx in ring: + if mol.GetBondWithIdx(bondIdx).GetBondType().name != 'SINGLE': + unsaturated = True + break + ######## non-aromatic + for bondIdx in ring: + if mol.GetBondWithIdx(bondIdx).GetBondType().name == 'AROMATIC': + nonaromatic = False + break + ######## allcarb + for bondIdx in ring: + BeginAtom = mol.GetBondWithIdx(bondIdx).GetBeginAtom() + EndAtom = mol.GetBondWithIdx(bondIdx).GetEndAtom() + if BeginAtom.GetAtomicNum() != 6 or EndAtom.GetAtomicNum() != 6: + Allcarb = False + break + if unsaturated == True and nonaromatic == True and Allcarb == True: + ringSize.append(len(ring)) + for k, v in temp.items(): + if len(ring) == k: + temp[k] += 1 + if temp[3] >= 2: + bits[4] = 1 + bits[11] = 1 + elif temp[3] == 1: + bits[4] = 1 + else: + pass + if temp[4] >= 2: + bits[18] = 1 + bits[25] = 1 + elif temp[4] == 1: + bits[18] = 1 + else: + pass + if temp[5] >= 5: + bits[32] = 1 + bits[39] = 1 + bits[46] = 1 + bits[53] = 1 + bits[60] = 1 + elif temp[5] == 4: + bits[32] = 1 + bits[39] = 1 + bits[46] = 1 + bits[53] = 1 + elif temp[5] == 3: + bits[32] = 1 + bits[39] = 1 + bits[46] = 1 + elif temp[5] == 2: + bits[32] = 1 + bits[39] = 1 + elif temp[5] == 1: + bits[32] = 1 + else: + pass + if temp[6] >= 5: + bits[67] = 1 + bits[74] = 1 + bits[81] = 1 + bits[88] = 1 + bits[95] = 1 + elif temp[6] == 4: + bits[67] = 1 + bits[74] = 1 + bits[81] = 1 + bits[88] = 1 + elif temp[6] == 3: + bits[67] = 1 + bits[74] = 1 + bits[81] = 1 + elif temp[6] == 2: + bits[67] = 1 + bits[74] = 1 + elif temp[6] == 1: + bits[67] = 1 + else: + pass + if temp[7] >= 2: + bits[102] = 1 + bits[109] = 1 + elif temp[7] == 1: + bits[102] = 1 + else: + pass + if temp[8] >= 2: + bits[116] = 1 + bits[123] = 1 + elif temp[8] == 1: + bits[116] = 1 + else: + pass + if temp[9] >= 1: + bits[130] = 1 + else: + pass + if temp[10] >= 1: + bits[137] = 1 + else: + pass + return ringSize, bits + + +def func_6(mol, bits): + """ *Internal Use Only* + unsaturated non-aromatic nitrogen-containing + """ + ringSize = [] + AllRingsBond = mol.GetRingInfo().BondRings() + temp = {3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0} + for ring in AllRingsBond: + unsaturated = False + nonaromatic = True + ContainNitro = False + ######### unsaturated + for bondIdx in ring: + if mol.GetBondWithIdx(bondIdx).GetBondType().name != 'SINGLE': + unsaturated = True + break + ######## non-aromatic + for bondIdx in ring: + if mol.GetBondWithIdx(bondIdx).GetBondType().name == 'AROMATIC': + nonaromatic = False + break + ######## nitrogen-containing + for bondIdx in ring: + BeginAtom = mol.GetBondWithIdx(bondIdx).GetBeginAtom() + EndAtom = mol.GetBondWithIdx(bondIdx).GetEndAtom() + if BeginAtom.GetAtomicNum() == 7 or EndAtom.GetAtomicNum() == 7: + ContainNitro = True + break + if unsaturated == True and nonaromatic == True and ContainNitro == True: + ringSize.append(len(ring)) + for k, v in temp.items(): + if len(ring) == k: + temp[k] += 1 + if temp[3] >= 2: + bits[5] = 1 + bits[12] = 1 + elif temp[3] == 1: + bits[5] = 1 + else: + pass + if temp[4] >= 2: + bits[19] = 1 + bits[26] = 1 + elif temp[4] == 1: + bits[19] = 1 + else: + pass + if temp[5] >= 5: + bits[33] = 1 + bits[40] = 1 + bits[47] = 1 + bits[54] = 1 + bits[61] = 1 + elif temp[5] == 4: + bits[33] = 1 + bits[40] = 1 + bits[47] = 1 + bits[54] = 1 + elif temp[5] == 3: + bits[33] = 1 + bits[40] = 1 + bits[47] = 1 + elif temp[5] == 2: + bits[33] = 1 + bits[40] = 1 + elif temp[5] == 1: + bits[33] = 1 + else: + pass + if temp[6] >= 5: + bits[68] = 1 + bits[75] = 1 + bits[82] = 1 + bits[89] = 1 + bits[96] = 1 + elif temp[6] == 4: + bits[68] = 1 + bits[75] = 1 + bits[82] = 1 + bits[89] = 1 + elif temp[6] == 3: + bits[68] = 1 + bits[75] = 1 + bits[82] = 1 + elif temp[6] == 2: + bits[68] = 1 + bits[75] = 1 + elif temp[6] == 1: + bits[68] = 1 + else: + pass + if temp[7] >= 2: + bits[103] = 1 + bits[110] = 1 + elif temp[7] == 1: + bits[103] = 1 + else: + pass + if temp[8] >= 2: + bits[117] = 1 + bits[124] = 1 + elif temp[8] == 1: + bits[117] = 1 + else: + pass + if temp[9] >= 1: + bits[131] = 1 + else: + pass + if temp[10] >= 1: + bits[138] = 1 + else: + pass + return ringSize, bits + + +def func_7(mol, bits): + """ *Internal Use Only* + unsaturated non-aromatic heteroatom-containing + """ + ringSize = [] + AllRingsBond = mol.GetRingInfo().BondRings() + temp = {3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0} + for ring in AllRingsBond: + unsaturated = False + nonaromatic = True + heteroatom = False + ######### unsaturated + for bondIdx in ring: + if mol.GetBondWithIdx(bondIdx).GetBondType().name != 'SINGLE': + unsaturated = True + break + ######## non-aromatic + for bondIdx in ring: + if mol.GetBondWithIdx(bondIdx).GetBondType().name == 'AROMATIC': + nonaromatic = False + break + ######## heteroatom-containing + for bondIdx in ring: + BeginAtom = mol.GetBondWithIdx(bondIdx).GetBeginAtom() + EndAtom = mol.GetBondWithIdx(bondIdx).GetEndAtom() + if BeginAtom.GetAtomicNum() not in [1, 6] or EndAtom.GetAtomicNum() not in [1, 6]: + heteroatom = True + break + if unsaturated == True and nonaromatic == True and heteroatom == True: + ringSize.append(len(ring)) + for k, v in temp.items(): + if len(ring) == k: + temp[k] += 1 + if temp[3] >= 2: + bits[6] = 1 + bits[13] = 1 + elif temp[3] == 1: + bits[6] = 1 + else: + pass + if temp[4] >= 2: + bits[20] = 1 + bits[27] = 1 + elif temp[4] == 1: + bits[20] = 1 + else: + pass + if temp[5] >= 5: + bits[34] = 1 + bits[41] = 1 + bits[48] = 1 + bits[55] = 1 + bits[62] = 1 + elif temp[5] == 4: + bits[34] = 1 + bits[41] = 1 + bits[48] = 1 + bits[55] = 1 + elif temp[5] == 3: + bits[34] = 1 + bits[41] = 1 + bits[48] = 1 + elif temp[5] == 2: + bits[34] = 1 + bits[41] = 1 + elif temp[5] == 1: + bits[34] = 1 + else: + pass + if temp[6] >= 5: + bits[69] = 1 + bits[76] = 1 + bits[83] = 1 + bits[90] = 1 + bits[97] = 1 + elif temp[6] == 4: + bits[69] = 1 + bits[76] = 1 + bits[83] = 1 + bits[90] = 1 + elif temp[6] == 3: + bits[69] = 1 + bits[76] = 1 + bits[83] = 1 + elif temp[6] == 2: + bits[69] = 1 + bits[76] = 1 + elif temp[6] == 1: + bits[69] = 1 + else: + pass + if temp[7] >= 2: + bits[104] = 1 + bits[111] = 1 + elif temp[7] == 1: + bits[104] = 1 + else: + pass + if temp[8] >= 2: + bits[118] = 1 + bits[125] = 1 + elif temp[8] == 1: + bits[118] = 1 + else: + pass + if temp[9] >= 1: + bits[132] = 1 + else: + pass + if temp[10] >= 1: + bits[139] = 1 + else: + pass + return ringSize, bits + + +def func_8(mol, bits): + """ *Internal Use Only* + aromatic rings or hetero-aromatic rings + """ + AllRingsBond = mol.GetRingInfo().BondRings() + temp = {'aromatic': 0, 'heteroatom': 0} + for ring in AllRingsBond: + aromatic = True + heteroatom = False + for bondIdx in ring: + if mol.GetBondWithIdx(bondIdx).GetBondType().name != 'AROMATIC': + aromatic = False + break + if aromatic == True: + temp['aromatic'] += 1 + for bondIdx in ring: + BeginAtom = mol.GetBondWithIdx(bondIdx).GetBeginAtom() + EndAtom = mol.GetBondWithIdx(bondIdx).GetEndAtom() + if BeginAtom.GetAtomicNum() not in [1, 6] or EndAtom.GetAtomicNum() not in [1, 6]: + heteroatom = True + break + if heteroatom == True: + temp['heteroatom'] += 1 + if temp['aromatic'] >= 4: + bits[140] = 1 + bits[142] = 1 + bits[144] = 1 + bits[146] = 1 + elif temp['aromatic'] == 3: + bits[140] = 1 + bits[142] = 1 + bits[144] = 1 + elif temp['aromatic'] == 2: + bits[140] = 1 + bits[142] = 1 + elif temp['aromatic'] == 1: + bits[140] = 1 + else: + pass + if temp['aromatic'] >= 4 and temp['heteroatom'] >= 4: + bits[141] = 1 + bits[143] = 1 + bits[145] = 1 + bits[147] = 1 + elif temp['aromatic'] == 3 and temp['heteroatom'] == 3: + bits[141] = 1 + bits[143] = 1 + bits[145] = 1 + elif temp['aromatic'] == 2 and temp['heteroatom'] == 2: + bits[141] = 1 + bits[143] = 1 + elif temp['aromatic'] == 1 and temp['heteroatom'] == 1: + bits[141] = 1 + else: + pass + return bits + + +def calcPubChemFingerPart2(mol): # 116-263 + """ *Internal Use Only* + Calculate PubChem Fingerprints (116-263) + """ + bits = [0] * 148 + bits = func_1(mol, bits)[1] + bits = func_2(mol, bits)[1] + bits = func_3(mol, bits)[1] + bits = func_4(mol, bits)[1] + bits = func_5(mol, bits)[1] + bits = func_6(mol, bits)[1] + bits = func_7(mol, bits)[1] + bits = func_8(mol, bits) + + return bits + + +def GetPubChemFPs(mol): + """*Internal Use Only* + Calculate PubChem Fingerprints + """ + mol = Chem.AddHs(mol) + AllBits = [0] * 881 + res1 = list(calcPubChemFingerPart1(mol).ToBitString()) + for index, item in enumerate(res1[1:116]): + if item == '1': + AllBits[index] = 1 + for index2, item2 in enumerate(res1[116:734]): + if item2 == '1': + AllBits[index2 + 115 + 148] = 1 + res2 = calcPubChemFingerPart2(mol) + for index3, item3 in enumerate(res2): + if item3 == 1: + AllBits[index3 + 115] = 1 + AllBits = np.array(AllBits, dtype=np.bool_) + + return AllBits + + +# ------------------------------------ + + +file_path = os.path.dirname(__file__) + + +def GetPubChemFPInfos(): + return pd.read_excel(os.path.join(file_path, 'pubchemfp.xlsx')) + + +if __name__ == '__main__': + print('-' * 10 + 'START' + '-' * 10) + SMILES = 'C1=NC2NC3=CNCC3=CC2CC1' + mol = Chem.MolFromSmiles(SMILES) + mol2 = Chem.AddHs(mol) + result = GetPubChemFPs(mol2) + print('Molecule: %s' % SMILES) + print('-' * 25) + print('Results: %s' % result) + print('-' * 10 + 'END' + '-' * 10) diff --git a/deepscreen/data/featurizers/fingerprint/pubchemfp.xlsx b/deepscreen/data/featurizers/fingerprint/pubchemfp.xlsx new file mode 100644 index 0000000000000000000000000000000000000000..464009090fdbbd5cc156076a3a9abc973f0320e9 Binary files /dev/null and b/deepscreen/data/featurizers/fingerprint/pubchemfp.xlsx differ diff --git a/deepscreen/data/featurizers/fingerprint/rdkitfp.py b/deepscreen/data/featurizers/fingerprint/rdkitfp.py new file mode 100644 index 0000000000000000000000000000000000000000..a56fd8e267ba57cdc911edb9bf65f1ff030ff574 --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/rdkitfp.py @@ -0,0 +1,42 @@ +""" +topological fingerprint + +""" + +import numpy as np +from rdkit.Chem.rdmolops import RDKFingerprint +from rdkit.Chem import DataStructs + +_type = 'topological-based' + + +def GetRDkitFPs(mol, nBits=2048, return_bitInfo=False): + """ + ################################################################# + Calculate Daylight-like fingerprint or topological fingerprint + + (1024 bits). + + Usage: + + result=CalculateDaylightFingerprint(mol) + + Input: mol is a molecule object. + + Output: result is a tuple form. The first is the number of + + fingerprints. The second is a dict form whose keys are the + + position which this molecule has some substructure. The third + + is the DataStructs which is used for calculating the similarity. + ################################################################# + """ + + bitInfo = {} + fp = RDKFingerprint(mol, fpSize=nBits, bitInfo=bitInfo) + arr = np.zeros((0,), dtype=np.bool_) + DataStructs.ConvertToNumpyArray(fp, arr) + if return_bitInfo: + return arr, return_bitInfo + return arr diff --git a/deepscreen/data/featurizers/fingerprint/smarts_maccskey.py b/deepscreen/data/featurizers/fingerprint/smarts_maccskey.py new file mode 100644 index 0000000000000000000000000000000000000000..4546568d20e31efbe47f04b51d20e895c014327f --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/smarts_maccskey.py @@ -0,0 +1,178 @@ +smartsPatts = { + 'MACCSFP0': (None, 0), + # ignore, Bit 0 is a placeholder and should be ignored: https://github.com/rdkit/rdkit/issues/1726 + 'MACCSFP1': ('?', 0), + 'MACCSFP2': ('[#104]', 0), + 'MACCSFP3': ('[#32,#33,#34,#50,#51,#52,#82,#83,#84]', 0), + 'MACCSFP4': ('[Ac,Th,Pa,U,Np,Pu,Am,Cm,Bk,Cf,Es,Fm,Md,No,Lr]', 0), + 'MACCSFP5': ('[Sc,Ti,Y,Zr,Hf]', 0), + 'MACCSFP6': ('[La,Ce,Pr,Nd,Pm,Sm,Eu,Gd,Tb,Dy,Ho,Er,Tm,Yb,Lu]', 0), + 'MACCSFP7': ('[V,Cr,Mn,Nb,Mo,Tc,Ta,W,Re]', 0), + 'MACCSFP8': ('[!#6;!#1]1~*~*~*~1', 0), + 'MACCSFP9': ('[Fe,Co,Ni,Ru,Rh,Pd,Os,Ir,Pt]', 0), + 'MACCSFP10': ('[Be,Mg,Ca,Sr,Ba,Ra]', 0), + 'MACCSFP11': ('*1~*~*~*~1', 0), + 'MACCSFP12': ('[Cu,Zn,Ag,Cd,Au,Hg]', 0), + 'MACCSFP13': ('[#8]~[#7](~[#6])~[#6]', 0), + 'MACCSFP14': ('[#16]-[#16]', 0), + 'MACCSFP15': ('[#8]~[#6](~[#8])~[#8]', 0), + 'MACCSFP16': ('[!#6;!#1]1~*~*~1', 0), + 'MACCSFP17': ('[#6]#[#6]', 0), + 'MACCSFP18': ('[#5,#13,#31,#49,#81]', 0), + 'MACCSFP19': ('*1~*~*~*~*~*~*~1', 0), + 'MACCSFP20': ('[#14]', 0), + 'MACCSFP21': ('[#6]=[#6](~[!#6;!#1])~[!#6;!#1]', 0), + 'MACCSFP22': ('*1~*~*~1', 0), + 'MACCSFP23': ('[#7]~[#6](~[#8])~[#8]', 0), + 'MACCSFP24': ('[#7]-[#8]', 0), + 'MACCSFP25': ('[#7]~[#6](~[#7])~[#7]', 0), + 'MACCSFP26': ('[#6]=;@[#6](@*)@*', 0), + 'MACCSFP27': ('[I]', 0), + 'MACCSFP28': ('[!#6;!#1]~[CH2]~[!#6;!#1]', 0), + 'MACCSFP29': ('[#15]', 0), + 'MACCSFP30': ('[#6]~[!#6;!#1](~[#6])(~[#6])~*', 0), + 'MACCSFP31': ('[!#6;!#1]~[F,Cl,Br,I]', 0), + 'MACCSFP32': ('[#6]~[#16]~[#7]', 0), + 'MACCSFP33': ('[#7]~[#16]', 0), + 'MACCSFP34': ('[CH2]=*', 0), + 'MACCSFP35': ('[Li,Na,K,Rb,Cs,Fr]', 0), + 'MACCSFP36': ('[#16R]', 0), + 'MACCSFP37': ('[#7]~[#6](~[#8])~[#7]', 0), + 'MACCSFP38': ('[#7]~[#6](~[#6])~[#7]', 0), + 'MACCSFP39': ('[#8]~[#16](~[#8])~[#8]', 0), + 'MACCSFP40': ('[#16]-[#8]', 0), + 'MACCSFP41': ('[#6]#[#7]', 0), + 'MACCSFP42': ('F', 0), + 'MACCSFP43': ('[!#6;!#1;!H0]~*~[!#6;!#1;!H0]', 0), + 'MACCSFP44': ('?', 0), + 'MACCSFP45': ('[#6]=[#6]~[#7]', 0), + 'MACCSFP46': ('Br', 0), + 'MACCSFP47': ('[#16]~*~[#7]', 0), + 'MACCSFP48': ('[#8]~[!#6;!#1](~[#8])(~[#8])', 0), + 'MACCSFP49': ('[!+0]', 0), + 'MACCSFP50': ('[#6]=[#6](~[#6])~[#6]', 0), + 'MACCSFP51': ('[#6]~[#16]~[#8]', 0), + 'MACCSFP52': ('[#7]~[#7]', 0), + 'MACCSFP53': ('[!#6;!#1;!H0]~*~*~*~[!#6;!#1;!H0]', 0), + 'MACCSFP54': ('[!#6;!#1;!H0]~*~*~[!#6;!#1;!H0]', 0), + 'MACCSFP55': ('[#8]~[#16]~[#8]', 0), + 'MACCSFP56': ('[#8]~[#7](~[#8])~[#6]', 0), + 'MACCSFP57': ('[#8R]', 0), + 'MACCSFP58': ('[!#6;!#1]~[#16]~[!#6;!#1]', 0), + 'MACCSFP59': ('[#16]!:*:*', 0), + 'MACCSFP60': ('[#16]=[#8]', 0), + 'MACCSFP61': ('*~[#16](~*)~*', 0), + 'MACCSFP62': ('*@*!@*@*', 0), + 'MACCSFP63': ('[#7]=[#8]', 0), + 'MACCSFP64': ('*@*!@[#16]', 0), + 'MACCSFP65': ('c:n', 0), + 'MACCSFP66': ('[#6]~[#6](~[#6])(~[#6])~*', 0), + 'MACCSFP67': ('[!#6;!#1]~[#16]', 0), + 'MACCSFP68': ('[!#6;!#1;!H0]~[!#6;!#1;!H0]', 0), + 'MACCSFP69': ('[!#6;!#1]~[!#6;!#1;!H0]', 0), + 'MACCSFP70': ('[!#6;!#1]~[#7]~[!#6;!#1]', 0), + 'MACCSFP71': ('[#7]~[#8]', 0), + 'MACCSFP72': ('[#8]~*~*~[#8]', 0), + 'MACCSFP73': ('[#16]=*', 0), + 'MACCSFP74': ('[CH3]~*~[CH3]', 0), + 'MACCSFP75': ('*!@[#7]@*', 0), + 'MACCSFP76': ('[#6]=[#6](~*)~*', 0), + 'MACCSFP77': ('[#7]~*~[#7]', 0), + 'MACCSFP78': ('[#6]=[#7]', 0), + 'MACCSFP79': ('[#7]~*~*~[#7]', 0), + 'MACCSFP80': ('[#7]~*~*~*~[#7]', 0), + 'MACCSFP81': ('[#16]~*(~*)~*', 0), + 'MACCSFP82': ('*~[CH2]~[!#6;!#1;!H0]', 0), + 'MACCSFP83': ('[!#6;!#1]1~*~*~*~*~1', 0), + 'MACCSFP84': ('[NH2]', 0), + 'MACCSFP85': ('[#6]~[#7](~[#6])~[#6]', 0), + 'MACCSFP86': ('[C;H2,H3][!#6;!#1][C;H2,H3]', 0), + 'MACCSFP87': ('[F,Cl,Br,I]!@*@*', 0), + 'MACCSFP88': ('[#16]', 0), + 'MACCSFP89': ('[#8]~*~*~*~[#8]', 0), + 'MACCSFP90': ( + '[$([!#6;!#1;!H0]~*~*~[CH2]~*),$([!#6;!#1;!H0;R]1@[R]@[R]@[CH2;R]1),$([!#6;!#1;!H0]~[R]1@[R]@[CH2;R]1)]', + 0), + 'MACCSFP91': ( + '[$([!#6;!#1;!H0]~*~*~*~[CH2]~*),$([!#6;!#1;!H0;R]1@[R]@[R]@[R]@[CH2;R]1),$([!#6;!#1;!H0]~[R]1@[R]@[R]@[CH2;R]1),$([!#6;!#1;!H0]~*~[R]1@[R]@[CH2;R]1)]', + 0), + 'MACCSFP92': ('[#8]~[#6](~[#7])~[#6]', 0), + 'MACCSFP93': ('[!#6;!#1]~[CH3]', 0), + 'MACCSFP94': ('[!#6;!#1]~[#7]', 0), + 'MACCSFP95': ('[#7]~*~*~[#8]', 0), + 'MACCSFP96': ('*1~*~*~*~*~1', 0), + 'MACCSFP97': ('[#7]~*~*~*~[#8]', 0), + 'MACCSFP98': ('[!#6;!#1]1~*~*~*~*~*~1', 0), + 'MACCSFP99': ('[#6]=[#6]', 0), + 'MACCSFP100': ('*~[CH2]~[#7]', 0), + 'MACCSFP101': ( + '[$([R]@1@[R]@[R]@[R]@[R]@[R]@[R]@[R]1),$([R]@1@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]1),$([R]@1@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]1),$([R]@1@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]1),$([R]@1@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]1),$([R]@1@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]1),$([R]@1@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]@[R]1)]', + 0), + 'MACCSFP102': ('[!#6;!#1]~[#8]', 0), + 'MACCSFP103': ('Cl', 0), + 'MACCSFP104': ('[!#6;!#1;!H0]~*~[CH2]~*', 0), + 'MACCSFP105': ('*@*(@*)@*', 0), + 'MACCSFP106': ('[!#6;!#1]~*(~[!#6;!#1])~[!#6;!#1]', 0), + 'MACCSFP107': ('[F,Cl,Br,I]~*(~*)~*', 0), + 'MACCSFP108': ('[CH3]~*~*~*~[CH2]~*', 0), + 'MACCSFP109': ('*~[CH2]~[#8]', 0), + 'MACCSFP110': ('[#7]~[#6]~[#8]', 0), + 'MACCSFP111': ('[#7]~*~[CH2]~*', 0), + 'MACCSFP112': ('*~*(~*)(~*)~*', 0), + 'MACCSFP113': ('[#8]!:*:*', 0), + 'MACCSFP114': ('[CH3]~[CH2]~*', 0), + 'MACCSFP115': ('[CH3]~*~[CH2]~*', 0), + 'MACCSFP116': ('[$([CH3]~*~*~[CH2]~*),$([CH3]~*1~*~[CH2]1)]', 0), + 'MACCSFP117': ('[#7]~*~[#8]', 0), + 'MACCSFP118': ('[$(*~[CH2]~[CH2]~*),$(*1~[CH2]~[CH2]1)]', 1), + 'MACCSFP119': ('[#7]=*', 0), + 'MACCSFP120': ('[!#6;R]', 1), + 'MACCSFP121': ('[#7;R]', 0), + 'MACCSFP122': ('*~[#7](~*)~*', 0), + 'MACCSFP123': ('[#8]~[#6]~[#8]', 0), + 'MACCSFP124': ('[!#6;!#1]~[!#6;!#1]', 0), + 'MACCSFP125': ('?', 0), + 'MACCSFP126': ('*!@[#8]!@*', 0), + 'MACCSFP127': ('*@*!@[#8]', 1), + 'MACCSFP128': ( + '[$(*~[CH2]~*~*~*~[CH2]~*),$([R]1@[CH2;R]@[R]@[R]@[R]@[CH2;R]1),$(*~[CH2]~[R]1@[R]@[R]@[CH2;R]1),$(*~[CH2]~*~[R]1@[R]@[CH2;R]1)]', + 0), + 'MACCSFP129': ('[$(*~[CH2]~*~*~[CH2]~*),$([R]1@[CH2]@[R]@[R]@[CH2;R]1),$(*~[CH2]~[R]1@[R]@[CH2;R]1)]', + 0), + 'MACCSFP130': ('[!#6;!#1]~[!#6;!#1]', 1), + 'MACCSFP131': ('[!#6;!#1;!H0]', 1), + 'MACCSFP132': ('[#8]~*~[CH2]~*', 0), + 'MACCSFP133': ('*@*!@[#7]', 0), + 'MACCSFP134': ('[F,Cl,Br,I]', 0), + 'MACCSFP135': ('[#7]!:*:*', 0), + 'MACCSFP136': ('[#8]=*', 1), + 'MACCSFP137': ('[!C;!c;R]', 0), + 'MACCSFP138': ('[!#6;!#1]~[CH2]~*', 1), + 'MACCSFP139': ('[O;!H0]', 0), + 'MACCSFP140': ('[#8]', 3), + 'MACCSFP141': ('[CH3]', 2), + 'MACCSFP142': ('[#7]', 1), + 'MACCSFP143': ('*@*!@[#8]', 0), + 'MACCSFP144': ('*!:*:*!:*', 0), + 'MACCSFP145': ('*1~*~*~*~*~*~1', 1), + 'MACCSFP146': ('[#8]', 2), + 'MACCSFP147': ('[$(*~[CH2]~[CH2]~*),$([R]1@[CH2;R]@[CH2;R]1)]', 0), + 'MACCSFP148': ('*~[!#6;!#1](~*)~*', 0), + 'MACCSFP149': ('[C;H3,H4]', 1), + 'MACCSFP150': ('*!@*@*!@*', 0), + 'MACCSFP151': ('[#7;!H0]', 0), + 'MACCSFP152': ('[#8]~[#6](~[#6])~[#6]', 0), + 'MACCSFP153': ('[!#6;!#1]~[CH2]~*', 0), + 'MACCSFP154': ('[#6]=[#8]', 0), + 'MACCSFP155': ('*!@[CH2]!@*', 0), + 'MACCSFP156': ('[#7]~*(~*)~*', 0), + 'MACCSFP157': ('[#6]-[#8]', 0), + 'MACCSFP158': ('[#6]-[#7]', 0), + 'MACCSFP159': ('[#8]', 1), + 'MACCSFP160': ('[C;H3,H4]', 0), + 'MACCSFP161': ('[#7]', 0), + 'MACCSFP162': ('a', 0), + 'MACCSFP163': ('*1~*~*~*~*~*~1', 0), + 'MACCSFP164': ('[#8]', 0), + 'MACCSFP165': ('[R]', 0), + 'MACCSFP166': ('?', 0)} diff --git a/deepscreen/data/featurizers/fingerprint/smarts_pharmacophore.py b/deepscreen/data/featurizers/fingerprint/smarts_pharmacophore.py new file mode 100644 index 0000000000000000000000000000000000000000..51a0a687890ca008811d7c1181e0b1f11735c582 --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/smarts_pharmacophore.py @@ -0,0 +1,21 @@ +Donor = ["[N;!H0;v3,v4&+1]", "[O,S;H1;+0]", "[n&H1&+0]"] + +Acceptor = ["[O,S;H1;v2;!$(*-*=[O,N,P,S])]", "[O;H0;v2]", "[O,S;v1;-]", + "[N;v3;!$(N-*=[O,N,P,S])]", "[n&H0&+0]", "[o;+0;!$([o]:n);!$([o]:c:n)]"] + +Positive = ["[#7;+]", "[N;H2&+0][$([C,a]);!$([C,a](=O))]", + "[N;H1&+0]([$([C,a]);!$([C,a](=O))])[$([C,a]);!$([C,a](=O))]", + "[N;H0&+0]([C;!$(C(=O))])([C;!$(C(=O))])[C;!$(C(=O))]"] + +Negative = ["[C,S](=[O,S,P])-[O;H1,-1]"] + +Hydrophobic = ["[C;D3,D4](-[CH3])-[CH3]", "[S;D2](-C)-C"] + +Aromatic = ["a"] + +pharmacophore_smarts = {"Donor": Donor, + "Acceptor": Acceptor, + "Positive": Positive, + "Negative": Negative, + "Hydrophobic": Hydrophobic, + "Aromatic": Aromatic} diff --git a/deepscreen/data/featurizers/fingerprint/smarts_pubchem.py b/deepscreen/data/featurizers/fingerprint/smarts_pubchem.py new file mode 100644 index 0000000000000000000000000000000000000000..64393ffc9f633d50f363108022099b33981a301d --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/smarts_pubchem.py @@ -0,0 +1,734 @@ +smartsPatts = { + 'PubChemFP0': ('[H]', 3), + 'PubChemFP1': ('[H]', 7), + 'PubChemFP2': ('[H]', 15), + 'PubChemFP3': ('[H]', 31), + 'PubChemFP4': ('[Li]', 0), + 'PubChemFP5': ('[Li]', 1), + 'PubChemFP6': ('[B]', 0), + 'PubChemFP7': ('[B]', 1), + 'PubChemFP8': ('[B]', 3), + 'PubChemFP9': ('[C]', 1), + 'PubChemFP10': ('[C]', 3), + 'PubChemFP11': ('[C]', 7), + 'PubChemFP12': ('[C]', 15), + 'PubChemFP13': ('[C]', 31), + 'PubChemFP14': ('[N]', 0), + 'PubChemFP15': ('[N]', 1), + 'PubChemFP16': ('[N]', 3), + 'PubChemFP17': ('[N]', 7), + 'PubChemFP18': ('[O]', 0), + 'PubChemFP19': ('[O]', 1), + 'PubChemFP20': ('[O]', 3), + 'PubChemFP21': ('[O]', 7), + 'PubChemFP22': ('[O]', 15), + 'PubChemFP23': ('[F]', 0), + 'PubChemFP24': ('[F]', 1), + 'PubChemFP25': ('[F]', 3), + 'PubChemFP26': ('[Na]', 0), + 'PubChemFP27': ('[Na]', 1), + 'PubChemFP28': ('[Si]', 0), + 'PubChemFP29': ('[Si]', 1), + 'PubChemFP30': ('[P]', 0), + 'PubChemFP31': ('[P]', 1), + 'PubChemFP32': ('[P]', 3), + 'PubChemFP33': ('[S]', 0), + 'PubChemFP34': ('[S]', 1), + 'PubChemFP35': ('[S]', 3), + 'PubChemFP36': ('[S]', 7), + 'PubChemFP37': ('[Cl]', 0), + 'PubChemFP38': ('[Cl]', 1), + 'PubChemFP39': ('[Cl]', 3), + 'PubChemFP40': ('[Cl]', 7), + 'PubChemFP41': ('[K]', 0), + 'PubChemFP42': ('[K]', 1), + 'PubChemFP43': ('[Br]', 0), + 'PubChemFP44': ('[Br]', 1), + 'PubChemFP45': ('[Br]', 3), + 'PubChemFP46': ('[I]', 0), + 'PubChemFP47': ('[I]', 1), + 'PubChemFP48': ('[I]', 3), + 'PubChemFP49': ('[Be]', 0), + 'PubChemFP50': ('[Mg]', 0), + 'PubChemFP51': ('[Al]', 0), + 'PubChemFP52': ('[Ca]', 0), + 'PubChemFP53': ('[Sc]', 0), + 'PubChemFP54': ('[Ti]', 0), + 'PubChemFP55': ('[V]', 0), + 'PubChemFP56': ('[Cr]', 0), + 'PubChemFP57': ('[Mn]', 0), + 'PubChemFP58': ('[Fe]', 0), + 'PubChemFP59': ('[CO]', 0), + 'PubChemFP60': ('[Ni]', 0), + 'PubChemFP61': ('[Cu]', 0), + 'PubChemFP62': ('[Zn]', 0), + 'PubChemFP63': ('[Ga]', 0), + 'PubChemFP64': ('[Ge]', 0), + 'PubChemFP65': ('[As]', 0), + 'PubChemFP66': ('[Se]', 0), + 'PubChemFP67': ('[Kr]', 0), + 'PubChemFP68': ('[Rb]', 0), + 'PubChemFP69': ('[Sr]', 0), + 'PubChemFP70': ('[Y]', 0), + 'PubChemFP71': ('[Zr]', 0), + 'PubChemFP72': ('[Nb]', 0), + 'PubChemFP73': ('[Mo]', 0), + 'PubChemFP74': ('[Ru]', 0), + 'PubChemFP75': ('[Rh]', 0), + 'PubChemFP76': ('[Pd]', 0), + 'PubChemFP77': ('[Ag]', 0), + 'PubChemFP78': ('[Cd]', 0), + 'PubChemFP79': ('[In]', 0), + 'PubChemFP80': ('[Sn]', 0), + 'PubChemFP81': ('[Sb]', 0), + 'PubChemFP82': ('[Te]', 0), + 'PubChemFP83': ('[Xe]', 0), + 'PubChemFP84': ('[Cs]', 0), + 'PubChemFP85': ('[Ba]', 0), + 'PubChemFP86': ('[Lu]', 0), + 'PubChemFP87': ('[Hf]', 0), + 'PubChemFP88': ('[Ta]', 0), + 'PubChemFP89': ('[W]', 0), + 'PubChemFP90': ('[Re]', 0), + 'PubChemFP91': ('[Os]', 0), + 'PubChemFP92': ('[Ir]', 0), + 'PubChemFP93': ('[Pt]', 0), + 'PubChemFP94': ('[Au]', 0), + 'PubChemFP95': ('[Hg]', 0), + 'PubChemFP96': ('[Tl]', 0), + 'PubChemFP97': ('[Pb]', 0), + 'PubChemFP98': ('[Bi]', 0), + 'PubChemFP99': ('[La]', 0), + 'PubChemFP100': ('[Ce]', 0), + 'PubChemFP101': ('[Pr]', 0), + 'PubChemFP102': ('[Nd]', 0), + 'PubChemFP103': ('[Pm]', 0), + 'PubChemFP104': ('[Sm]', 0), + 'PubChemFP105': ('[Eu]', 0), + 'PubChemFP106': ('[Gd]', 0), + 'PubChemFP107': ('[Tb]', 0), + 'PubChemFP108': ('[Dy]', 0), + 'PubChemFP109': ('[Ho]', 0), + 'PubChemFP110': ('[Er]', 0), + 'PubChemFP111': ('[Tm]', 0), + 'PubChemFP112': ('[Yb]', 0), + 'PubChemFP113': ('[Tc]', 0), + 'PubChemFP114': ('[U]', 0), + 'PubChemFP263': ('[Li&!H0]', 0), + 'PubChemFP264': ('[Li]~[Li]', 0), + 'PubChemFP265': ('[Li]~[#5]', 0), + 'PubChemFP266': ('[Li]~[#6]', 0), + 'PubChemFP267': ('[Li]~[#8]', 0), + 'PubChemFP268': ('[Li]~[F]', 0), + 'PubChemFP269': ('[Li]~[#15]', 0), + 'PubChemFP270': ('[Li]~[#16]', 0), + 'PubChemFP271': ('[Li]~[Cl]', 0), + 'PubChemFP272': ('[#5&!H0]', 0), + 'PubChemFP273': ('[#5]~[#5]', 0), + 'PubChemFP274': ('[#5]~[#6]', 0), + 'PubChemFP275': ('[#5]~[#7]', 0), + 'PubChemFP276': ('[#5]~[#8]', 0), + 'PubChemFP277': ('[#5]~[F]', 0), + 'PubChemFP278': ('[#5]~[#14]', 0), + 'PubChemFP279': ('[#5]~[#15]', 0), + 'PubChemFP280': ('[#5]~[#16]', 0), + 'PubChemFP281': ('[#5]~[Cl]', 0), + 'PubChemFP282': ('[#5]~[Br]', 0), + 'PubChemFP283': ('[#6&!H0]', 0), + 'PubChemFP284': ('[#6]~[#6]', 0), + 'PubChemFP285': ('[#6]~[#7]', 0), + 'PubChemFP286': ('[#6]~[#8]', 0), + 'PubChemFP287': ('[#6]~[F]', 0), + 'PubChemFP288': ('[#6]~[Na]', 0), + 'PubChemFP289': ('[#6]~[Mg]', 0), + 'PubChemFP290': ('[#6]~[Al]', 0), + 'PubChemFP291': ('[#6]~[#14]', 0), + 'PubChemFP292': ('[#6]~[#15]', 0), + 'PubChemFP293': ('[#6]~[#16]', 0), + 'PubChemFP294': ('[#6]~[Cl]', 0), + 'PubChemFP295': ('[#6]~[#33]', 0), + 'PubChemFP296': ('[#6]~[#34]', 0), + 'PubChemFP297': ('[#6]~[Br]', 0), + 'PubChemFP298': ('[#6]~[I]', 0), + 'PubChemFP299': ('[#7&!H0]', 0), + 'PubChemFP300': ('[#7]~[#7]', 0), + 'PubChemFP301': ('[#7]~[#8]', 0), + 'PubChemFP302': ('[#7]~[F]', 0), + 'PubChemFP303': ('[#7]~[#14]', 0), + 'PubChemFP304': ('[#7]~[#15]', 0), + 'PubChemFP305': ('[#7]~[#16]', 0), + 'PubChemFP306': ('[#7]~[Cl]', 0), + 'PubChemFP307': ('[#7]~[Br]', 0), + 'PubChemFP308': ('[#8&!H0]', 0), + 'PubChemFP309': ('[#8]~[#8]', 0), + 'PubChemFP310': ('[#8]~[Mg]', 0), + 'PubChemFP311': ('[#8]~[Na]', 0), + 'PubChemFP312': ('[#8]~[Al]', 0), + 'PubChemFP313': ('[#8]~[#14]', 0), + 'PubChemFP314': ('[#8]~[#15]', 0), + 'PubChemFP315': ('[#8]~[K]', 0), + 'PubChemFP316': ('[F]~[#15]', 0), + 'PubChemFP317': ('[F]~[#16]', 0), + 'PubChemFP318': ('[Al&!H0]', 0), + 'PubChemFP319': ('[Al]~[Cl]', 0), + 'PubChemFP320': ('[#14&!H0]', 0), + 'PubChemFP321': ('[#14]~[#14]', 0), + 'PubChemFP322': ('[#14]~[Cl]', 0), + 'PubChemFP323': ('[#15&!H0]', 0), + 'PubChemFP324': ('[#15]~[#15]', 0), + 'PubChemFP325': ('[#33&!H0]', 0), + 'PubChemFP326': ('[#33]~[#33]', 0), + 'PubChemFP327': ('[#6](~Br)(~[#6])', 0), + 'PubChemFP328': ('[#6](~Br)(~[#6])(~[#6])', 0), + 'PubChemFP329': ('[#6&!H0]~[Br]', 0), + 'PubChemFP330': ('[#6](~[Br])(:[c])', 0), + 'PubChemFP331': ('[#6](~[Br])(:[n])', 0), + 'PubChemFP332': ('[#6](~[#6])(~[#6])', 0), + 'PubChemFP333': ('[#6](~[#6])(~[#6])(~[#6])', 0), + 'PubChemFP334': ('[#6](~[#6])(~[#6])(~[#6])(~[#6])', 0), + 'PubChemFP335': ('[#6H1](~[#6])(~[#6])(~[#6])', 0), + 'PubChemFP336': ('[#6](~[#6])(~[#6])(~[#6])(~[#7])', 0), + 'PubChemFP337': ('[#6](~[#6])(~[#6])(~[#6])(~[#8])', 0), + 'PubChemFP338': ('[#6H1](~[#6])(~[#6])(~[#7])', 0), + 'PubChemFP339': ('[#6H1](~[#6])(~[#6])(~[#8])', 0), + 'PubChemFP340': ('[#6](~[#6])(~[#6])(~[#7])', 0), + 'PubChemFP341': ('[#6](~[#6])(~[#6])(~[#8])', 0), + 'PubChemFP342': ('[#6](~[#6])(~[Cl])', 0), + 'PubChemFP343': ('[#6&!H0](~[#6])(~[Cl])', 0), + 'PubChemFP344': ('[#6H,#6H2,#6H3,#6H4]~[#6]', 0), + 'PubChemFP345': ('[#6&!H0](~[#6])(~[#7])', 0), + 'PubChemFP346': ('[#6&!H0](~[#6])(~[#8])', 0), + 'PubChemFP347': ('[#6H1](~[#6])(~[#8])(~[#8])', 0), + 'PubChemFP348': ('[#6&!H0](~[#6])(~[#15])', 0), + 'PubChemFP349': ('[#6&!H0](~[#6])(~[#16])', 0), + 'PubChemFP350': ('[#6](~[#6])(~[I])', 0), + 'PubChemFP351': ('[#6](~[#6])(~[#7])', 0), + 'PubChemFP352': ('[#6](~[#6])(~[#8])', 0), + 'PubChemFP353': ('[#6](~[#6])(~[#16])', 0), + 'PubChemFP354': ('[#6](~[#6])(~[#14])', 0), + 'PubChemFP355': ('[#6](~[#6])(:c)', 0), + 'PubChemFP356': ('[#6](~[#6])(:c)(:c)', 0), + 'PubChemFP357': ('[#6](~[#6])(:c)(:n)', 0), + 'PubChemFP358': ('[#6](~[#6])(:n)', 0), + 'PubChemFP359': ('[#6](~[#6])(:n)(:n)', 0), + 'PubChemFP360': ('[#6](~[Cl])(~[Cl])', 0), + 'PubChemFP361': ('[#6&!H0](~[Cl])', 0), + 'PubChemFP362': ('[#6](~[Cl])(:c)', 0), + 'PubChemFP363': ('[#6](~[F])(~[F])', 0), + 'PubChemFP364': ('[#6](~[F])(:c)', 0), + 'PubChemFP365': ('[#6&!H0](~[#7])', 0), + 'PubChemFP366': ('[#6&!H0](~[#8])', 0), + 'PubChemFP367': ('[#6&!H0](~[#8])(~[#8])', 0), + 'PubChemFP368': ('[#6&!H0](~[#16])', 0), + 'PubChemFP369': ('[#6&!H0](~[#14])', 0), + 'PubChemFP370': ('[#6&!H0]:c', 0), + 'PubChemFP371': ('[#6&!H0](:c)(:c)', 0), + 'PubChemFP372': ('[#6&!H0](:c)(:n)', 0), + 'PubChemFP373': ('[#6&!H0](:n)', 0), + 'PubChemFP374': ('[#6H3]', 0), + 'PubChemFP375': ('[#6](~[#7])(~[#7])', 0), + 'PubChemFP376': ('[#6](~[#7])(:c)', 0), + 'PubChemFP377': ('[#6](~[#7])(:c)(:c)', 0), + 'PubChemFP378': ('[#6](~[#7])(:c)(:n)', 0), + 'PubChemFP379': ('[#6](~[#7])(:n)', 0), + 'PubChemFP380': ('[#6](~[#8])(~[#8])', 0), + 'PubChemFP381': ('[#6](~[#8])(:c)', 0), + 'PubChemFP382': ('[#6](~[#8])(:c)(:c)', 0), + 'PubChemFP383': ('[#6](~[#16])(:c)', 0), + 'PubChemFP384': ('[#6](:c)(:c)', 0), + 'PubChemFP385': ('[#6](:c)(:c)(:c)', 0), + 'PubChemFP386': ('[#6](:c)(:c)(:n)', 0), + 'PubChemFP387': ('[#6](:c)(:n)', 0), + 'PubChemFP388': ('[#6](:c)(:n)(:n)', 0), + 'PubChemFP389': ('[#6](:n)(:n)', 0), + 'PubChemFP390': ('[#7](~[#6])(~[#6])', 0), + 'PubChemFP391': ('[#7](~[#6])(~[#6])(~[#6])', 0), + 'PubChemFP392': ('[#7&!H0](~[#6])(~[#6])', 0), + 'PubChemFP393': ('[#7&!H0](~[#6])', 0), + 'PubChemFP394': ('[#7&!H0](~[#6])(~[#7])', 0), + 'PubChemFP395': ('[#7](~[#6])(~[#8])', 0), + 'PubChemFP396': ('[#7](~[#6])(:c)', 0), + 'PubChemFP397': ('[#7](~[#6])(:c)(:c)', 0), + 'PubChemFP398': ('[#7&!H0](~[#7])', 0), + 'PubChemFP399': ('[#7&!H0](:c)', 0), + 'PubChemFP400': ('[#7&!H0](:c)(:c)', 0), + 'PubChemFP401': ('[#7](~[#8])(~[#8])', 0), + 'PubChemFP402': ('[#7](~[#8])(:o)', 0), + 'PubChemFP403': ('[#7](:c)(:c)', 0), + 'PubChemFP404': ('[#7](:c)(:c)(:c)', 0), + 'PubChemFP405': ('[#8](~[#6])(~[#6])', 0), + 'PubChemFP406': ('[#8&!H0](~[#6])', 0), + 'PubChemFP407': ('[#8](~[#6])(~[#15])', 0), + 'PubChemFP408': ('[#8&!H0](~[#16])', 0), + 'PubChemFP409': ('[#8](:c)(:c)', 0), + 'PubChemFP410': ('[#15](~[#6])(~[#6])', 0), + 'PubChemFP411': ('[#15](~[#8])(~[#8])', 0), + 'PubChemFP412': ('[#16](~[#6])(~[#6])', 0), + 'PubChemFP413': ('[#16&!H0](~[#6])', 0), + 'PubChemFP414': ('[#16](~[#6])(~[#8])', 0), + 'PubChemFP415': ('[#14](~[#6])(~[#6])', 0), + 'PubChemFP416': ('[#6]=,:[#6]', 0), + 'PubChemFP417': ('[#6]#[#6]', 0), + 'PubChemFP418': ('[#6]=,:[#7]', 0), + 'PubChemFP419': ('[#6]#[#7]', 0), + 'PubChemFP420': ('[#6]=,:[#8]', 0), + 'PubChemFP421': ('[#6]=,:[#16]', 0), + 'PubChemFP422': ('[#7]=,:[#7]', 0), + 'PubChemFP423': ('[#7]=,:[#8]', 0), + 'PubChemFP424': ('[#7]=,:[#15]', 0), + 'PubChemFP425': ('[#15]=,:[#8]', 0), + 'PubChemFP426': ('[#15]=,:[#15]', 0), + 'PubChemFP427': ('[#6](#[#6])(-,:[#6])', 0), + 'PubChemFP428': ('[#6&!H0](#[#6])', 0), + 'PubChemFP429': ('[#6](#[#7])(-,:[#6])', 0), + 'PubChemFP430': ('[#6](-,:[#6])(-,:[#6])(=,:[#6])', 0), + 'PubChemFP431': ('[#6](-,:[#6])(-,:[#6])(=,:[#7])', 0), + 'PubChemFP432': ('[#6](-,:[#6])(-,:[#6])(=,:[#8])', 0), + 'PubChemFP433': ('[#6](-,:[#6])([Cl])(=,:[#8])', 0), + 'PubChemFP434': ('[#6&!H0](-,:[#6])(=,:[#6])', 0), + 'PubChemFP435': ('[#6&!H0](-,:[#6])(=,:[#7])', 0), + 'PubChemFP436': ('[#6&!H0](-,:[#6])(=,:[#8])', 0), + 'PubChemFP437': ('[#6](-,:[#6])(-,:[#7])(=,:[#6])', 0), + 'PubChemFP438': ('[#6](-,:[#6])(-,:[#7])(=,:[#7])', 0), + 'PubChemFP439': ('[#6](-,:[#6])(-,:[#7])(=,:[#8])', 0), + 'PubChemFP440': ('[#6](-,:[#6])(-,:[#8])(=,:[#8])', 0), + 'PubChemFP441': ('[#6](-,:[#6])(=,:[#6])', 0), + 'PubChemFP442': ('[#6](-,:[#6])(=,:[#7])', 0), + 'PubChemFP443': ('[#6](-,:[#6])(=,:[#8])', 0), + 'PubChemFP444': ('[#6]([Cl])(=,:[#8])', 0), + 'PubChemFP445': ('[#6&!H0](-,:[#7])(=,:[#6])', 0), + 'PubChemFP446': ('[#6&!H0](=,:[#6])', 0), + 'PubChemFP447': ('[#6&!H0](=,:[#7])', 0), + 'PubChemFP448': ('[#6&!H0](=,:[#8])', 0), + 'PubChemFP449': ('[#6](-,:[#7])(=,:[#6])', 0), + 'PubChemFP450': ('[#6](-,:[#7])(=,:[#7])', 0), + 'PubChemFP451': ('[#6](-,:[#7])(=,:[#8])', 0), + 'PubChemFP452': ('[#6](-,:[#8])(=,:[#8])', 0), + 'PubChemFP453': ('[#7](-,:[#6])(=,:[#6])', 0), + 'PubChemFP454': ('[#7](-,:[#6])(=,:[#8])', 0), + 'PubChemFP455': ('[#7](-,:[#8])(=,:[#8])', 0), + 'PubChemFP456': ('[#15](-,:[#8])(=,:[#8])', 0), + 'PubChemFP457': ('[#16](-,:[#6])(=,:[#8])', 0), + 'PubChemFP458': ('[#16](-,:[#8])(=,:[#8])', 0), + 'PubChemFP459': ('[#16](=,:[#8])(=,:[#8])', 0), + 'PubChemFP460': ('[#6]-,:[#6]-,:[#6]#[#6]', 0), + 'PubChemFP461': ('[#8]-,:[#6]-,:[#6]=,:[#7]', 0), + 'PubChemFP462': ('[#8]-,:[#6]-,:[#6]=,:[#8]', 0), + 'PubChemFP463': ('[#7]:[#6]-,:[#16&!H0]', 0), + 'PubChemFP464': ('[#7]-,:[#6]-,:[#6]=,:[#6]', 0), + 'PubChemFP465': ('[#8]=,:[#16]-,:[#6]-,:[#6]', 0), + 'PubChemFP466': ('[#7]#[#6]-,:[#6]=,:[#6]', 0), + 'PubChemFP467': ('[#6]=,:[#7]-,:[#7]-,:[#6]', 0), + 'PubChemFP468': ('[#8]=,:[#16]-,:[#6]-,:[#7]', 0), + 'PubChemFP469': ('[#16]-,:[#16]-,:[#6]:[#6]', 0), + 'PubChemFP470': ('[#6]:[#6]-,:[#6]=,:[#6]', 0), + 'PubChemFP471': ('[#16]:[#6]:[#6]:[#6]', 0), + 'PubChemFP472': ('[#6]:[#7]:[#6]-,:[#6]', 0), + 'PubChemFP473': ('[#16]-,:[#6]:[#7]:[#6]', 0), + 'PubChemFP474': ('[#16]:[#6]:[#6]:[#7]', 0), + 'PubChemFP475': ('[#16]-,:[#6]=,:[#7]-,:[#6]', 0), + 'PubChemFP476': ('[#6]-,:[#8]-,:[#6]=,:[#6]', 0), + 'PubChemFP477': ('[#7]-,:[#7]-,:[#6]:[#6]', 0), + 'PubChemFP478': ('[#16]-,:[#6]=,:[#7&!H0]', 0), + 'PubChemFP479': ('[#16]-,:[#6]-,:[#16]-,:[#6]', 0), + 'PubChemFP480': ('[#6]:[#16]:[#6]-,:[#6]', 0), + 'PubChemFP481': ('[#8]-,:[#16]-,:[#6]:[#6]', 0), + 'PubChemFP482': ('[#6]:[#7]-,:[#6]:[#6]', 0), + 'PubChemFP483': ('[#7]-,:[#16]-,:[#6]:[#6]', 0), + 'PubChemFP484': ('[#7]-,:[#6]:[#7]:[#6]', 0), + 'PubChemFP485': ('[#7]:[#6]:[#6]:[#7]', 0), + 'PubChemFP486': ('[#7]-,:[#6]:[#7]:[#7]', 0), + 'PubChemFP487': ('[#7]-,:[#6]=,:[#7]-,:[#6]', 0), + 'PubChemFP488': ('[#7]-,:[#6]=,:[#7&!H0]', 0), + 'PubChemFP489': ('[#7]-,:[#6]-,:[#16]-,:[#6]', 0), + 'PubChemFP490': ('[#6]-,:[#6]-,:[#6]=,:[#6]', 0), + 'PubChemFP491': ('[#6]-,:[#7]:[#6&!H0]', 0), + 'PubChemFP492': ('[#7]-,:[#6]:[#8]:[#6]', 0), + 'PubChemFP493': ('[#8]=,:[#6]-,:[#6]:[#6]', 0), + 'PubChemFP494': ('[#8]=,:[#6]-,:[#6]:[#7]', 0), + 'PubChemFP495': ('[#6]-,:[#7]-,:[#6]:[#6]', 0), + 'PubChemFP496': ('[#7]:[#7]-,:[#6&!H0]', 0), + 'PubChemFP497': ('[#8]-,:[#6]:[#6]:[#7]', 0), + 'PubChemFP498': ('[#8]-,:[#6]=,:[#6]-,:[#6]', 0), + 'PubChemFP499': ('[#7]-,:[#6]:[#6]:[#7]', 0), + 'PubChemFP500': ('[#6]-,:[#16]-,:[#6]:[#6]', 0), + 'PubChemFP501': ('[Cl]-,:[#6]:[#6]-,:[#6]', 0), + 'PubChemFP502': ('[#7]-,:[#6]=,:[#6&!H0]', 0), + 'PubChemFP503': ('[Cl]-,:[#6]:[#6&!H0]', 0), + 'PubChemFP504': ('[#7]:[#6]:[#7]-,:[#6]', 0), + 'PubChemFP505': ('[Cl]-,:[#6]:[#6]-,:[#8]', 0), + 'PubChemFP506': ('[#6]-,:[#6]:[#7]:[#6]', 0), + 'PubChemFP507': ('[#6]-,:[#6]-,:[#16]-,:[#6]', 0), + 'PubChemFP508': ('[#16]=,:[#6]-,:[#7]-,:[#6]', 0), + 'PubChemFP509': ('[Br]-,:[#6]:[#6]-,:[#6]', 0), + 'PubChemFP510': ('[#7&!H0]-,:[#7&!H0]', 0), + 'PubChemFP511': ('[#16]=,:[#6]-,:[#7&!H0]', 0), + 'PubChemFP512': ('[#6]-,:[#33]-[#8&!H0]', 0), + 'PubChemFP513': ('[#16]:[#6]:[#6&!H0]', 0), + 'PubChemFP514': ('[#8]-,:[#7]-,:[#6]-,:[#6]', 0), + 'PubChemFP515': ('[#7]-,:[#7]-,:[#6]-,:[#6]', 0), + 'PubChemFP516': ('[#6H,#6H2,#6H3]=,:[#6H,#6H2,#6H3]', 0), + 'PubChemFP517': ('[#7]-,:[#7]-,:[#6]-,:[#7]', 0), + 'PubChemFP518': ('[#8]=,:[#6]-,:[#7]-,:[#7]', 0), + 'PubChemFP519': ('[#7]=,:[#6]-,:[#7]-,:[#6]', 0), + 'PubChemFP520': ('[#6]=,:[#6]-,:[#6]:[#6]', 0), + 'PubChemFP521': ('[#6]:[#7]-,:[#6&!H0]', 0), + 'PubChemFP522': ('[#6]-,:[#7]-,:[#7&!H0]', 0), + 'PubChemFP523': ('[#7]:[#6]:[#6]-,:[#6]', 0), + 'PubChemFP524': ('[#6]-,:[#6]=,:[#6]-,:[#6]', 0), + 'PubChemFP525': ('[#33]-,:[#6]:[#6&!H0]', 0), + 'PubChemFP526': ('[Cl]-,:[#6]:[#6]-,:[Cl]', 0), + 'PubChemFP527': ('[#6]:[#6]:[#7&!H0]', 0), + 'PubChemFP528': ('[#7&!H0]-,:[#6&!H0]', 0), + 'PubChemFP529': ('[Cl]-,:[#6]-,:[#6]-,:[Cl]', 0), + 'PubChemFP530': ('[#7]:[#6]-,:[#6]:[#6]', 0), + 'PubChemFP531': ('[#16]-,:[#6]:[#6]-,:[#6]', 0), + 'PubChemFP532': ('[#16]-,:[#6]:[#6&!H0]', 0), + 'PubChemFP533': ('[#16]-,:[#6]:[#6]-,:[#7]', 0), + 'PubChemFP534': ('[#16]-,:[#6]:[#6]-,:[#8]', 0), + 'PubChemFP535': ('[#8]=,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP536': ('[#8]=,:[#6]-,:[#6]-,:[#7]', 0), + 'PubChemFP537': ('[#8]=,:[#6]-,:[#6]-,:[#8]', 0), + 'PubChemFP538': ('[#7]=,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP539': ('[#7]=,:[#6]-,:[#6&!H0]', 0), + 'PubChemFP540': ('[#6]-,:[#7]-,:[#6&!H0]', 0), + 'PubChemFP541': ('[#8]-,:[#6]:[#6]-,:[#6]', 0), + 'PubChemFP542': ('[#8]-,:[#6]:[#6&!H0]', 0), + 'PubChemFP543': ('[#8]-,:[#6]:[#6]-,:[#7]', 0), + 'PubChemFP544': ('[#8]-,:[#6]:[#6]-,:[#8]', 0), + 'PubChemFP545': ('[#7]-,:[#6]:[#6]-,:[#6]', 0), + 'PubChemFP546': ('[#7]-,:[#6]:[#6&!H0]', 0), + 'PubChemFP547': ('[#7]-,:[#6]:[#6]-,:[#7]', 0), + 'PubChemFP548': ('[#8]-,:[#6]-,:[#6]:[#6]', 0), + 'PubChemFP549': ('[#7]-,:[#6]-,:[#6]:[#6]', 0), + 'PubChemFP550': ('[Cl]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP551': ('[Cl]-,:[#6]-,:[#6]-,:[#8]', 0), + 'PubChemFP552': ('[#6]:[#6]-,:[#6]:[#6]', 0), + 'PubChemFP553': ('[#8]=,:[#6]-,:[#6]=,:[#6]', 0), + 'PubChemFP554': ('[Br]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP555': ('[#7]=,:[#6]-,:[#6]=,:[#6]', 0), + 'PubChemFP556': ('[#6]=,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP557': ('[#7]:[#6]-,:[#8&!H0]', 0), + 'PubChemFP558': ('[#8]=,:[#7]-,:c:c', 0), + 'PubChemFP559': ('[#8]-,:[#6]-,:[#7&!H0]', 0), + 'PubChemFP560': ('[#7]-,:[#6]-,:[#7]-,:[#6]', 0), + 'PubChemFP561': ('[Cl]-,:[#6]-,:[#6]=,:[#8]', 0), + 'PubChemFP562': ('[Br]-,:[#6]-,:[#6]=,:[#8]', 0), + 'PubChemFP563': ('[#8]-,:[#6]-,:[#8]-,:[#6]', 0), + 'PubChemFP564': ('[#6]=,:[#6]-,:[#6]=,:[#6]', 0), + 'PubChemFP565': ('[#6]:[#6]-,:[#8]-,:[#6]', 0), + 'PubChemFP566': ('[#8]-,:[#6]-,:[#6]-,:[#7]', 0), + 'PubChemFP567': ('[#8]-,:[#6]-,:[#6]-,:[#8]', 0), + 'PubChemFP568': ('N#[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP569': ('[#7]-,:[#6]-,:[#6]-,:[#7]', 0), + 'PubChemFP570': ('[#6]:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP571': ('[#6&!H0]-,:[#8&!H0]', 0), + 'PubChemFP572': ('n:c:n:c', 0), + 'PubChemFP573': ('[#8]-,:[#6]-,:[#6]=,:[#6]', 0), + 'PubChemFP574': ('[#8]-,:[#6]-,:[#6]:[#6]-,:[#6]', 0), + 'PubChemFP575': ('[#8]-,:[#6]-,:[#6]:[#6]-,:[#8]', 0), + 'PubChemFP576': ('[#7]=,:[#6]-,:[#6]:[#6&!H0]', 0), + 'PubChemFP577': ('c:c-,:[#7]-,:c:c', 0), + 'PubChemFP578': ('[#6]-,:[#6]:[#6]-,:c:c', 0), + 'PubChemFP579': ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP580': ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#7]', 0), + 'PubChemFP581': ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#8]', 0), + 'PubChemFP582': ('[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP583': ('[Cl]-,:[#6]:[#6]-,:[#8]-,:[#6]', 0), + 'PubChemFP584': ('c:c-,:[#6]=,:[#6]-,:[#6]', 0), + 'PubChemFP585': ('[#6]-,:[#6]:[#6]-,:[#7]-,:[#6]', 0), + 'PubChemFP586': ('[#6]-,:[#16]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP587': ('[#7]-,:[#6]:[#6]-,:[#8&!H0]', 0), + 'PubChemFP588': ('[#8]=,:[#6]-,:[#6]-,:[#6]=,:[#8]', 0), + 'PubChemFP589': ('[#6]-,:[#6]:[#6]-,:[#8]-,:[#6]', 0), + 'PubChemFP590': ('[#6]-,:[#6]:[#6]-,:[#8&!H0]', 0), + 'PubChemFP591': ('[Cl]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP592': ('[#7]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP593': ('[#7]-,:[#6]-,:[#6]-,:[#6]-,:[#7]', 0), + 'PubChemFP594': ('[#6]-,:[#8]-,:[#6]-,:[#6]=,:[#6]', 0), + 'PubChemFP595': ('c:c-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP596': ('[#7]=,:[#6]-,:[#7]-,:[#6]-,:[#6]', 0), + 'PubChemFP597': ('[#8]=,:[#6]-,:[#6]-,:c:c', 0), + 'PubChemFP598': ('[Cl]-,:[#6]:[#6]:[#6]-,:[#6]', 0), + 'PubChemFP599': ('[#6H,#6H2,#6H3]-,:[#6]=,:[#6H,#6H2,#6H3]', 0), + 'PubChemFP600': ('[#7]-,:[#6]:[#6]:[#6]-,:[#6]', 0), + 'PubChemFP601': ('[#7]-,:[#6]:[#6]:[#6]-,:[#7]', 0), + 'PubChemFP602': ('[#8]=,:[#6]-,:[#6]-,:[#7]-,:[#6]', 0), + 'PubChemFP603': ('[#6]-,:c:c:[#6]-,:[#6]', 0), + 'PubChemFP604': ('[#6]-,:[#8]-,:[#6]-,:[#6]:c', 0), + 'PubChemFP605': ('[#8]=,:[#6]-,:[#6]-,:[#8]-,:[#6]', 0), + 'PubChemFP606': ('[#8]-,:[#6]:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP607': ('[#7]-,:[#6]-,:[#6]-,:[#6]:c', 0), + 'PubChemFP608': ('[#6]-,:[#6]-,:[#6]-,:[#6]:c', 0), + 'PubChemFP609': ('[Cl]-,:[#6]-,:[#6]-,:[#7]-,:[#6]', 0), + 'PubChemFP610': ('[#6]-,:[#8]-,:[#6]-,:[#8]-,:[#6]', 0), + 'PubChemFP611': ('[#7]-,:[#6]-,:[#6]-,:[#7]-,:[#6]', 0), + 'PubChemFP612': ('[#7]-,:[#6]-,:[#8]-,:[#6]-,:[#6]', 0), + 'PubChemFP613': ('[#6]-,:[#7]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP614': ('[#6]-,:[#6]-,:[#8]-,:[#6]-,:[#6]', 0), + 'PubChemFP615': ('[#7]-,:[#6]-,:[#6]-,:[#8]-,:[#6]', 0), + 'PubChemFP616': ('c:c:n:n:c', 0), + 'PubChemFP617': ('[#6]-,:[#6]-,:[#6]-,:[#8&!H0]', 0), + 'PubChemFP618': ('c:[#6]-,:[#6]-,:[#6]:c', 0), + 'PubChemFP619': ('[#8]-,:[#6]-,:[#6]=,:[#6]-,:[#6]', 0), + 'PubChemFP620': ('c:c-,:[#8]-,:[#6]-,:[#6]', 0), + 'PubChemFP621': ('[#7]-,:[#6]:c:c:n', 0), + 'PubChemFP622': ('[#8]=,:[#6]-,:[#8]-,:[#6]:c', 0), + 'PubChemFP623': ('[#8]=,:[#6]-,:[#6]:[#6]-,:[#6]', 0), + 'PubChemFP624': ('[#8]=,:[#6]-,:[#6]:[#6]-,:[#7]', 0), + 'PubChemFP625': ('[#8]=,:[#6]-,:[#6]:[#6]-,:[#8]', 0), + 'PubChemFP626': ('[#6]-,:[#8]-,:[#6]:[#6]-,:[#6]', 0), + 'PubChemFP627': ('[#8]=,:[#33]-,:[#6]:c:c', 0), + 'PubChemFP628': ('[#6]-,:[#7]-,:[#6]-,:[#6]:c', 0), + 'PubChemFP629': ('[#16]-,:[#6]:c:c-,:[#7]', 0), + 'PubChemFP630': ('[#8]-,:[#6]:[#6]-,:[#8]-,:[#6]', 0), + 'PubChemFP631': ('[#8]-,:[#6]:[#6]-,:[#8&!H0]', 0), + 'PubChemFP632': ('[#6]-,:[#6]-,:[#8]-,:[#6]:c', 0), + 'PubChemFP633': ('[#7]-,:[#6]-,:[#6]:[#6]-,:[#6]', 0), + 'PubChemFP634': ('[#6]-,:[#6]-,:[#6]:[#6]-,:[#6]', 0), + 'PubChemFP635': ('[#7]-,:[#7]-,:[#6]-,:[#7&!H0]', 0), + 'PubChemFP636': ('[#6]-,:[#7]-,:[#6]-,:[#7]-,:[#6]', 0), + 'PubChemFP637': ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP638': ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#7]', 0), + 'PubChemFP639': ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#8]', 0), + 'PubChemFP640': ('[#6]=,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP641': ('[#8]-,:[#6]-,:[#6]-,:[#6]=,:[#6]', 0), + 'PubChemFP642': ('[#8]-,:[#6]-,:[#6]-,:[#6]=,:[#8]', 0), + 'PubChemFP643': ('[#6&!H0]-,:[#6]-,:[#7&!H0]', 0), + 'PubChemFP644': ('[#6]-,:[#6]=,:[#7]-,:[#7]-,:[#6]', 0), + 'PubChemFP645': ('[#8]=,:[#6]-,:[#7]-,:[#6]-,:[#6]', 0), + 'PubChemFP646': ('[#8]=,:[#6]-,:[#7]-,:[#6&!H0]', 0), + 'PubChemFP647': ('[#8]=,:[#6]-,:[#7]-,:[#6]-,:[#7]', 0), + 'PubChemFP648': ('[#8]=,:[#7]-,:[#6]:[#6]-,:[#7]', 0), + 'PubChemFP649': ('[#8]=,:[#7]-,:c:c-,:[#8]', 0), + 'PubChemFP650': ('[#8]=,:[#6]-,:[#7]-,:[#6]=,:[#8]', 0), + 'PubChemFP651': ('[#8]-,:[#6]:[#6]:[#6]-,:[#6]', 0), + 'PubChemFP652': ('[#8]-,:[#6]:[#6]:[#6]-,:[#7]', 0), + 'PubChemFP653': ('[#8]-,:[#6]:[#6]:[#6]-,:[#8]', 0), + 'PubChemFP654': ('[#7]-,:[#6]-,:[#7]-,:[#6]-,:[#6]', 0), + 'PubChemFP655': ('[#8]-,:[#6]-,:[#6]-,:[#6]:c', 0), + 'PubChemFP656': ('[#6]-,:[#6]-,:[#7]-,:[#6]-,:[#6]', 0), + 'PubChemFP657': ('[#6]-,:[#7]-,:[#6]:[#6]-,:[#6]', 0), + 'PubChemFP658': ('[#6]-,:[#6]-,:[#16]-,:[#6]-,:[#6]', 0), + 'PubChemFP659': ('[#8]-,:[#6]-,:[#6]-,:[#7]-,:[#6]', 0), + 'PubChemFP660': ('[#6]-,:[#6]=,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP661': ('[#8]-,:[#6]-,:[#8]-,:[#6]-,:[#6]', 0), + 'PubChemFP662': ('[#8]-,:[#6]-,:[#6]-,:[#8]-,:[#6]', 0), + 'PubChemFP663': ('[#8]-,:[#6]-,:[#6]-,:[#8&!H0]', 0), + 'PubChemFP664': ('[#6]-,:[#6]=,:[#6]-,:[#6]=,:[#6]', 0), + 'PubChemFP665': ('[#7]-,:[#6]:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP666': ('[#6]=,:[#6]-,:[#6]-,:[#8]-,:[#6]', 0), + 'PubChemFP667': ('[#6]=,:[#6]-,:[#6]-,:[#8&!H0]', 0), + 'PubChemFP668': ('[#6]-,:[#6]:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP669': ('[Cl]-,:[#6]:[#6]-,:[#6]=,:[#8]', 0), + 'PubChemFP670': ('[Br]-,:[#6]:c:c-,:[#6]', 0), + 'PubChemFP671': ('[#8]=,:[#6]-,:[#6]=,:[#6]-,:[#6]', 0), + 'PubChemFP672': ('[#8]=,:[#6]-,:[#6]=,:[#6&!H0]', 0), + 'PubChemFP673': ('[#8]=,:[#6]-,:[#6]=,:[#6]-,:[#7]', 0), + 'PubChemFP674': ('[#7]-,:[#6]-,:[#7]-,:[#6]:c', 0), + 'PubChemFP675': ('[Br]-,:[#6]-,:[#6]-,:[#6]:c', 0), + 'PubChemFP676': ('[#7]#[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP677': ('[#6]-,:[#6]=,:[#6]-,:[#6]:c', 0), + 'PubChemFP678': ('[#6]-,:[#6]-,:[#6]=,:[#6]-,:[#6]', 0), + 'PubChemFP679': ('[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP680': ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP681': ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#8]', 0), + 'PubChemFP682': ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#7]', 0), + 'PubChemFP683': ('[#7]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP684': ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP685': ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#7]', 0), + 'PubChemFP686': ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#8]', 0), + 'PubChemFP687': ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]=,:[#8]', 0), + 'PubChemFP688': ('[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP689': ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP690': ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#8]', 0), + 'PubChemFP691': ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#7]', 0), + 'PubChemFP692': ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP693': ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#8]', 0), + 'PubChemFP694': ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]=,:[#8]', 0), + 'PubChemFP695': ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#7]', 0), + 'PubChemFP696': ('[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP697': ('[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6](-,:[#6])-,:[#6]', 0), + 'PubChemFP698': ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP699': ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6](-,:[#6])-,:[#6]', 0), + 'PubChemFP700': ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#8]-,:[#6]', 0), + 'PubChemFP701': ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6](-,:[#8])-,:[#6]', 0), + 'PubChemFP702': ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#7]-,:[#6]', 0), + 'PubChemFP703': ('[#8]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6](-,:[#7])-,:[#6]', 0), + 'PubChemFP704': ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP705': ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6](-,:[#8])-,:[#6]', 0), + 'PubChemFP706': ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6](=,:[#8])-,:[#6]', 0), + 'PubChemFP707': ('[#8]=,:[#6]-,:[#6]-,:[#6]-,:[#6]-,:[#6](-,:[#7])-,:[#6]', 0), + 'PubChemFP708': ('[#6]-,:[#6](-,:[#6])-,:[#6]-,:[#6]', 0), + 'PubChemFP709': ('[#6]-,:[#6](-,:[#6])-,:[#6]-,:[#6]-,:[#6]', 0), + 'PubChemFP710': ('[#6]-,:[#6]-,:[#6](-,:[#6])-,:[#6]-,:[#6]', 0), + 'PubChemFP711': ('[#6]-,:[#6](-,:[#6])(-,:[#6])-,:[#6]-,:[#6]', 0), + 'PubChemFP712': ('[#6]-,:[#6](-,:[#6])-,:[#6](-,:[#6])-,:[#6]', 0), + 'PubChemFP713': ('[#6]c1ccc([#6])cc1', 0), + 'PubChemFP714': ('[#6]c1ccc([#8])cc1', 0), + 'PubChemFP715': ('[#6]c1ccc([#16])cc1', 0), + 'PubChemFP716': ('[#6]c1ccc([#7])cc1', 0), + 'PubChemFP717': ('[#6]c1ccc(Cl)cc1', 0), + 'PubChemFP718': ('[#6]c1ccc(Br)cc1', 0), + 'PubChemFP719': ('[#8]c1ccc([#8])cc1', 0), + 'PubChemFP720': ('[#8]c1ccc([#16])cc1', 0), + 'PubChemFP721': ('[#8]c1ccc([#7])cc1', 0), + 'PubChemFP722': ('[#8]c1ccc(Cl)cc1', 0), + 'PubChemFP723': ('[#8]c1ccc(Br)cc1', 0), + 'PubChemFP724': ('[#16]c1ccc([#16])cc1', 0), + 'PubChemFP725': ('[#16]c1ccc([#7])cc1', 0), + 'PubChemFP726': ('[#16]c1ccc(Cl)cc1', 0), + 'PubChemFP727': ('[#16]c1ccc(Br)cc1', 0), + 'PubChemFP728': ('[#7]c1ccc([#7])cc1', 0), + 'PubChemFP729': ('[#7]c1ccc(Cl)cc1', 0), + 'PubChemFP730': ('[#7]c1ccc(Br)cc1', 0), + 'PubChemFP731': ('Clc1ccc(Cl)cc1', 0), + 'PubChemFP732': ('Clc1ccc(Br)cc1', 0), + 'PubChemFP733': ('Brc1ccc(Br)cc1', 0), + 'PubChemFP734': ('[#6]c1cc([#6])ccc1', 0), + 'PubChemFP735': ('[#6]c1cc([#8])ccc1', 0), + 'PubChemFP736': ('[#6]c1cc([#16])ccc1', 0), + 'PubChemFP737': ('[#6]c1cc([#7])ccc1', 0), + 'PubChemFP738': ('[#6]c1cc(Cl)ccc1', 0), + 'PubChemFP739': ('[#6]c1cc(Br)ccc1', 0), + 'PubChemFP740': ('[#8]c1cc([#8])ccc1', 0), + 'PubChemFP741': ('[#8]c1cc([#16])ccc1', 0), + 'PubChemFP742': ('[#8]c1cc([#7])ccc1', 0), + 'PubChemFP743': ('[#8]c1cc(Cl)ccc1', 0), + 'PubChemFP744': ('[#8]c1cc(Br)ccc1', 0), + 'PubChemFP745': ('[#16]c1cc([#16])ccc1', 0), + 'PubChemFP746': ('[#16]c1cc([#7])ccc1', 0), + 'PubChemFP747': ('[#16]c1cc(Cl)ccc1', 0), + 'PubChemFP748': ('[#16]c1cc(Br)ccc1', 0), + 'PubChemFP749': ('[#7]c1cc([#7])ccc1', 0), + 'PubChemFP750': ('[#7]c1cc(Cl)ccc1', 0), + 'PubChemFP751': ('[#7]c1cc(Br)ccc1', 0), + 'PubChemFP752': ('Clc1cc(Cl)ccc1', 0), + 'PubChemFP753': ('Clc1cc(Br)ccc1', 0), + 'PubChemFP754': ('Brc1cc(Br)ccc1', 0), + 'PubChemFP755': ('[#6]c1c([#6])cccc1', 0), + 'PubChemFP756': ('[#6]c1c([#8])cccc1', 0), + 'PubChemFP757': ('[#6]c1c([#16])cccc1', 0), + 'PubChemFP758': ('[#6]c1c([#7])cccc1', 0), + 'PubChemFP759': ('[#6]c1c(Cl)cccc1', 0), + 'PubChemFP760': ('[#6]c1c(Br)cccc1', 0), + 'PubChemFP761': ('[#8]c1c([#8])cccc1', 0), + 'PubChemFP762': ('[#8]c1c([#16])cccc1', 0), + 'PubChemFP763': ('[#8]c1c([#7])cccc1', 0), + 'PubChemFP764': ('[#8]c1c(Cl)cccc1', 0), + 'PubChemFP765': ('[#8]c1c(Br)cccc1', 0), + 'PubChemFP766': ('[#16]c1c([#16])cccc1', 0), + 'PubChemFP767': ('[#16]c1c([#7])cccc1', 0), + 'PubChemFP768': ('[#16]c1c(Cl)cccc1', 0), + 'PubChemFP769': ('[#16]c1c(Br)cccc1', 0), + 'PubChemFP770': ('[#7]c1c([#7])cccc1', 0), + 'PubChemFP771': ('[#7]c1c(Cl)cccc1', 0), + 'PubChemFP772': ('[#7]c1c(Br)cccc1', 0), + 'PubChemFP773': ('Clc1c(Cl)cccc1', 0), + 'PubChemFP774': ('Clc1c(Br)cccc1', 0), + 'PubChemFP775': ('Brc1c(Br)cccc1', 0), + 'PubChemFP776': ('[#6][#6]1[#6][#6][#6]([#6])[#6][#6]1', 0), + 'PubChemFP777': ('[#6][#6]1[#6][#6][#6]([#8])[#6][#6]1', 0), + 'PubChemFP778': ('[#6][#6]1[#6][#6][#6]([#16])[#6][#6]1', 0), + 'PubChemFP779': ('[#6][#6]1[#6][#6][#6]([#7])[#6][#6]1', 0), + 'PubChemFP780': ('[#6][#6]1[#6][#6][#6](Cl)[#6][#6]1', 0), + 'PubChemFP781': ('[#6][#6]1[#6][#6][#6](Br)[#6][#6]1', 0), + 'PubChemFP782': ('[#8][#6]1[#6][#6][#6]([#8])[#6][#6]1', 0), + 'PubChemFP783': ('[#8][#6]1[#6][#6][#6]([#16])[#6][#6]1', 0), + 'PubChemFP784': ('[#8][#6]1[#6][#6][#6]([#7])[#6][#6]1', 0), + 'PubChemFP785': ('[#8][#6]1[#6][#6][#6](Cl)[#6][#6]1', 0), + 'PubChemFP786': ('[#8][#6]1[#6][#6][#6](Br)[#6][#6]1', 0), + 'PubChemFP787': ('[#16][#6]1[#6][#6][#6]([#16])[#6][#6]1', 0), + 'PubChemFP788': ('[#16][#6]1[#6][#6][#6]([#7])[#6][#6]1', 0), + 'PubChemFP789': ('[#16][#6]1[#6][#6][#6](Cl)[#6][#6]1', 0), + 'PubChemFP790': ('[#16][#6]1[#6][#6][#6](Br)[#6][#6]1', 0), + 'PubChemFP791': ('[#7][#6]1[#6][#6][#6]([#7])[#6][#6]1', 0), + 'PubChemFP792': ('[#7][#6]1[#6][#6][#6](Cl)[#6][#6]1', 0), + 'PubChemFP793': ('[#7][#6]1[#6][#6][#6](Br)[#6][#6]1', 0), + 'PubChemFP794': ('Cl[#6]1[#6][#6][#6](Cl)[#6][#6]1', 0), + 'PubChemFP795': ('Cl[#6]1[#6][#6][#6](Br)[#6][#6]1', 0), + 'PubChemFP796': ('Br[#6]1[#6][#6][#6](Br)[#6][#6]1', 0), + 'PubChemFP797': ('[#6][#6]1[#6][#6]([#6])[#6][#6][#6]1', 0), + 'PubChemFP798': ('[#6][#6]1[#6][#6]([#8])[#6][#6][#6]1', 0), + 'PubChemFP799': ('[#6][#6]1[#6][#6]([#16])[#6][#6][#6]1', 0), + 'PubChemFP800': ('[#6][#6]1[#6][#6]([#7])[#6][#6][#6]1', 0), + 'PubChemFP801': ('[#6][#6]1[#6][#6](Cl)[#6][#6][#6]1', 0), + 'PubChemFP802': ('[#6][#6]1[#6][#6](Br)[#6][#6][#6]1', 0), + 'PubChemFP803': ('[#8][#6]1[#6][#6]([#8])[#6][#6][#6]1', 0), + 'PubChemFP804': ('[#8][#6]1[#6][#6]([#16])[#6][#6][#6]1', 0), + 'PubChemFP805': ('[#8][#6]1[#6][#6]([#7])[#6][#6][#6]1', 0), + 'PubChemFP806': ('[#8][#6]1[#6][#6](Cl)[#6][#6][#6]1', 0), + 'PubChemFP807': ('[#8][#6]1[#6][#6](Br)[#6][#6][#6]1', 0), + 'PubChemFP808': ('[#16][#6]1[#6][#6]([#16])[#6][#6][#6]1', 0), + 'PubChemFP809': ('[#16][#6]1[#6][#6]([#7])[#6][#6][#6]1', 0), + 'PubChemFP810': ('[#16][#6]1[#6][#6](Cl)[#6][#6][#6]1', 0), + 'PubChemFP811': ('[#16][#6]1[#6][#6](Br)[#6][#6][#6]1', 0), + 'PubChemFP812': ('[#7][#6]1[#6][#6]([#7])[#6][#6][#6]1', 0), + 'PubChemFP813': ('[#7][#6]1[#6][#6](Cl)[#6][#6][#6]1', 0), + 'PubChemFP814': ('[#7][#6]1[#6][#6](Br)[#6][#6][#6]1', 0), + 'PubChemFP815': ('Cl[#6]1[#6][#6](Cl)[#6][#6][#6]1', 0), + 'PubChemFP816': ('Cl[#6]1[#6][#6](Br)[#6][#6][#6]1', 0), + 'PubChemFP817': ('Br[#6]1[#6][#6](Br)[#6][#6][#6]1', 0), + 'PubChemFP818': ('[#6][#6]1[#6]([#6])[#6][#6][#6][#6]1', 0), + 'PubChemFP819': ('[#6][#6]1[#6]([#8])[#6][#6][#6][#6]1', 0), + 'PubChemFP820': ('[#6][#6]1[#6]([#16])[#6][#6][#6][#6]1', 0), + 'PubChemFP821': ('[#6][#6]1[#6]([#7])[#6][#6][#6][#6]1', 0), + 'PubChemFP822': ('[#6][#6]1[#6](Cl)[#6][#6][#6][#6]1', 0), + 'PubChemFP823': ('[#6][#6]1[#6](Br)[#6][#6][#6][#6]1', 0), + 'PubChemFP824': ('[#8][#6]1[#6]([#8])[#6][#6][#6][#6]1', 0), + 'PubChemFP825': ('[#8][#6]1[#6]([#16])[#6][#6][#6][#6]1', 0), + 'PubChemFP826': ('[#8][#6]1[#6]([#7])[#6][#6][#6][#6]1', 0), + 'PubChemFP827': ('[#8][#6]1[#6](Cl)[#6][#6][#6][#6]1', 0), + 'PubChemFP828': ('[#8][#6]1[#6](Br)[#6][#6][#6][#6]1', 0), + 'PubChemFP829': ('[#16][#6]1[#6]([#16])[#6][#6][#6][#6]1', 0), + 'PubChemFP830': ('[#16][#6]1[#6]([#7])[#6][#6][#6][#6]1', 0), + 'PubChemFP831': ('[#16][#6]1[#6](Cl)[#6][#6][#6][#6]1', 0), + 'PubChemFP832': ('[#16][#6]1[#6](Br)[#6][#6][#6][#6]1', 0), + 'PubChemFP833': ('[#7][#6]1[#6]([#7])[#6][#6][#6][#6]1', 0), + 'PubChemFP834': ('[#7][#6]1[#6](Cl)[#6][#6][#6][#6]1', 0), + 'PubChemFP835': ('[#7][#6]1[#6](Br)[#6][#6][#6][#6]1', 0), + 'PubChemFP836': ('Cl[#6]1[#6](Cl)[#6][#6][#6][#6]1', 0), + 'PubChemFP837': ('Cl[#6]1[#6](Br)[#6][#6][#6][#6]1', 0), + 'PubChemFP838': ('Br[#6]1[#6](Br)[#6][#6][#6][#6]1', 0), + 'PubChemFP839': ('[#6][#6]1[#6][#6]([#6])[#6][#6]1', 0), + 'PubChemFP840': ('[#6][#6]1[#6][#6]([#8])[#6][#6]1', 0), + 'PubChemFP841': ('[#6][#6]1[#6][#6]([#16])[#6][#6]1', 0), + 'PubChemFP842': ('[#6][#6]1[#6][#6]([#7])[#6][#6]1', 0), + 'PubChemFP843': ('[#6][#6]1[#6][#6](Cl)[#6][#6]1', 0), + 'PubChemFP844': ('[#6][#6]1[#6][#6](Br)[#6][#6]1', 0), + 'PubChemFP845': ('[#8][#6]1[#6][#6]([#8])[#6][#6]1', 0), + 'PubChemFP846': ('[#8][#6]1[#6][#6]([#16])[#6][#6]1', 0), + 'PubChemFP847': ('[#8][#6]1[#6][#6]([#7])[#6][#6]1', 0), + 'PubChemFP848': ('[#8][#6]1[#6][#6](Cl)[#6][#6]1', 0), + 'PubChemFP849': ('[#8][#6]1[#6][#6](Br)[#6][#6]1', 0), + 'PubChemFP850': ('[#16][#6]1[#6][#6]([#16])[#6][#6]1', 0), + 'PubChemFP851': ('[#16][#6]1[#6][#6]([#7])[#6][#6]1', 0), + 'PubChemFP852': ('[#16][#6]1[#6][#6](Cl)[#6][#6]1', 0), + 'PubChemFP853': ('[#16][#6]1[#6][#6](Br)[#6][#6]1', 0), + 'PubChemFP854': ('[#7][#6]1[#6][#6]([#7])[#6][#6]1', 0), + 'PubChemFP855': ('[#7][#6]1[#6][#6](Cl)[#6][#6]1', 0), + 'PubChemFP856': ('[#7][#6]1[#6][#6](Br)[#6][#6]1', 0), + 'PubChemFP857': ('Cl[#6]1[#6][#6](Cl)[#6][#6]1', 0), + 'PubChemFP858': ('Cl[#6]1[#6][#6](Br)[#6][#6]1', 0), + 'PubChemFP859': ('Br[#6]1[#6][#6](Br)[#6][#6]1', 0), + 'PubChemFP860': ('[#6][#6]1[#6]([#6])[#6][#6][#6]1', 0), + 'PubChemFP861': ('[#6][#6]1[#6]([#8])[#6][#6][#6]1', 0), + 'PubChemFP862': ('[#6][#6]1[#6]([#16])[#6][#6][#6]1', 0), + 'PubChemFP863': ('[#6][#6]1[#6]([#7])[#6][#6][#6]1', 0), + 'PubChemFP864': ('[#6][#6]1[#6](Cl)[#6][#6][#6]1', 0), + 'PubChemFP865': ('[#6][#6]1[#6](Br)[#6][#6][#6]1', 0), + 'PubChemFP866': ('[#8][#6]1[#6]([#8])[#6][#6][#6]1', 0), + 'PubChemFP867': ('[#8][#6]1[#6]([#16])[#6][#6][#6]1', 0), + 'PubChemFP868': ('[#8][#6]1[#6]([#7])[#6][#6][#6]1', 0), + 'PubChemFP869': ('[#8][#6]1[#6](Cl)[#6][#6][#6]1', 0), + 'PubChemFP870': ('[#8][#6]1[#6](Br)[#6][#6][#6]1', 0), + 'PubChemFP871': ('[#16][#6]1[#6]([#16])[#6][#6][#6]1', 0), + 'PubChemFP872': ('[#16][#6]1[#6]([#7])[#6][#6][#6]1', 0), + 'PubChemFP873': ('[#16][#6]1[#6](Cl)[#6][#6][#6]1', 0), + 'PubChemFP874': ('[#16][#6]1[#6](Br)[#6][#6][#6]1', 0), + 'PubChemFP875': ('[#7][#6]1[#6]([#7])[#6][#6][#6]1', 0), + 'PubChemFP876': ('[#7][#6]1[#6](Cl)[#6][#6]1', 0), + 'PubChemFP877': ('[#7][#6]1[#6](Br)[#6][#6][#6]1', 0), + 'PubChemFP878': ('Cl[#6]1[#6](Cl)[#6][#6][#6]1', 0), + 'PubChemFP879': ('Cl[#6]1[#6](Br)[#6][#6][#6]1', 0), + 'PubChemFP880': ('Br[#6]1[#6](Br)[#6][#6][#6]1', 0)} diff --git a/deepscreen/data/featurizers/fingerprint/torsions.py b/deepscreen/data/featurizers/fingerprint/torsions.py new file mode 100644 index 0000000000000000000000000000000000000000..d594832f48b3cc172b7b48a348f4a4e680db8eeb --- /dev/null +++ b/deepscreen/data/featurizers/fingerprint/torsions.py @@ -0,0 +1,18 @@ +from rdkit.Chem.AtomPairs import Torsions +from rdkit.Chem import DataStructs +import numpy as np + +_type = 'topological-based' + + +def GetTorsionFPs(mol, nBits=2048, binary=True): + ''' + atompairs fingerprints + ''' + fp = Torsions.GetHashedTopologicalTorsionFingerprint(mol, nBits=nBits) + if binary: + arr = np.zeros((0,), dtype=np.bool_) + else: + arr = np.zeros((0,), dtype=np.int8) + DataStructs.ConvertToNumpyArray(fp, arr) + return arr diff --git a/deepscreen/data/featurizers/graph.py b/deepscreen/data/featurizers/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..97ac6ffb9fdf06c276e0297d6bebeb2d7ae47260 --- /dev/null +++ b/deepscreen/data/featurizers/graph.py @@ -0,0 +1,133 @@ +import networkx as nx +import numpy as np +import torch +from rdkit import Chem +from torch_geometric.utils import from_smiles +from torch_geometric.data import Data + +from deepscreen.data.featurizers.categorical import one_of_k_encoding_unk, one_of_k_encoding +from deepscreen.utils import get_logger + +log = get_logger(__name__) + + +def atom_features(atom, explicit_H=False, use_chirality=True): + """ + Adapted from TransformerCPI 2.0 + """ + symbol = ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I', 'other'] # 10-dim + degree = [0, 1, 2, 3, 4, 5, 6] # 7-dim + hybridization_type = [Chem.rdchem.HybridizationType.SP, + Chem.rdchem.HybridizationType.SP2, + Chem.rdchem.HybridizationType.SP3, + Chem.rdchem.HybridizationType.SP3D, + Chem.rdchem.HybridizationType.SP3D2, + 'other'] # 6-dim + + # 10+7+2+6+1=26 + results = one_of_k_encoding_unk(atom.GetSymbol(), symbol) + \ + one_of_k_encoding(atom.GetDegree(), degree) + \ + [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \ + one_of_k_encoding_unk(atom.GetHybridization(), hybridization_type) + [atom.GetIsAromatic()] + + # In case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs` + # 26+5=31 + if not explicit_H: + results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(), + [0, 1, 2, 3, 4]) + # 31+3=34 + if use_chirality: + try: + results = results + one_of_k_encoding_unk( + atom.GetProp('_CIPCode'), + ['R', 'S']) + [atom.HasProp('_ChiralityPossible')] + except: + results = results + [False, False] + [atom.HasProp('_ChiralityPossible')] + + return np.array(results) + + +def bond_features(bond): + bt = bond.GetBondType() + return np.array( + [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, + bt == Chem.rdchem.BondType.AROMATIC, bond.GetIsConjugated(), bond.IsInRing()]) + + +def smiles_to_graph_pyg(smiles): + """ + Convert SMILES to graph with the default method defined by PyTorch Geometric + """ + try: + return from_smiles(smiles) + except Exception as e: + log.warning(f"Failed to featurize the following SMILES to graph: {smiles} due to {str(e)}") + return None + + +def smiles_to_graph(smiles, atom_features: callable = atom_features): + """ + Convert SMILES to graph with custom atom_features + """ + try: + mol = Chem.MolFromSmiles(smiles) + + features = [] + for atom in mol.GetAtoms(): + feature = atom_features(atom) + features.append(feature / sum(feature)) + features = np.array(features) + + edges = [] + for bond in mol.GetBonds(): + edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) + g = nx.Graph(edges).to_directed() + + if len(edges) == 0: + edge_index = [[0, 0]] + else: + edge_index = [] + for e1, e2 in g.edges: + edge_index.append([e1, e2]) + + return Data(x=torch.Tensor(features), + edge_index=torch.LongTensor(edge_index).transpose(0, 1)) + + except Exception as e: + log.warning(f"Failed to convert SMILES ({smiles}) to graph due to {str(e)}") + return None + # features = [] + # for atom in mol.GetAtoms(): + # feature = atom_features(atom) + # features.append(feature / sum(feature)) + # + # edge_indices = [] + # for bond in mol.GetBonds(): + # i = bond.GetBeginAtomIdx() + # j = bond.GetEndAtomIdx() + # edge_indices += [[i, j], [j, i]] + # + # edge_index = torch.tensor(edge_indices) + # edge_index = edge_index.t().to(torch.long).view(2, -1) + # + # if edge_index.numel() > 0: # Sort indices. + # perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort() + # edge_index = edge_index[:, perm] + # + + +def smiles_to_mol_features(smiles, num_atom_feat: callable): + try: + mol = Chem.MolFromSmiles(smiles) + num_atom_feat = len(atom_features(mol.GetAtoms()[0])) + atom_feat = np.zeros((mol.GetNumAtoms(), num_atom_feat)) + for atom in mol.GetAtoms(): + atom_feat[atom.GetIdx(), :] = atom_features(atom) + adj = Chem.GetAdjacencyMatrix(mol) + adj_mat = np.array(adj) + + return atom_feat, adj_mat + + except Exception as e: + log.warning(f"Failed to featurize the following SMILES to molecular features: {smiles} due to {str(e)}") + return None \ No newline at end of file diff --git a/deepscreen/data/featurizers/monn.py b/deepscreen/data/featurizers/monn.py new file mode 100644 index 0000000000000000000000000000000000000000..5160de2650ec04f4cf5e874ad4856b96045a71d9 --- /dev/null +++ b/deepscreen/data/featurizers/monn.py @@ -0,0 +1,106 @@ +import numpy as np +from rdkit.Chem import MolFromSmiles + +from deepscreen.data.featurizers.categorical import FASTA_VOCAB, fasta_to_label +from deepscreen.data.featurizers.graph import atom_features, bond_features + + +def get_mask(arr): + a = np.zeros(1, len(arr)) + a[1, :arr.shape[0]] = 1 + return a + + +def add_index(input_array, ebd_size): + batch_size, n_vertex, n_nbs = np.shape(input_array) + add_idx = np.array(range(0, ebd_size * batch_size, ebd_size) * (n_nbs * n_vertex)) + add_idx = np.transpose(add_idx.reshape(-1, batch_size)) + add_idx = add_idx.reshape(-1) + new_array = input_array.reshape(-1) + add_idx + return new_array + + +# TODO fix padding and masking +def drug_featurizer(smiles, max_neighbors=6): + mol = MolFromSmiles(smiles) + + # convert molecule to GNN input + n_atoms = mol.GetNumAtoms() + assert mol.GetNumBonds() >= 0 + + n_bonds = max(mol.GetNumBonds(), 1) + feat_atoms = np.zeros((n_atoms,)) # atom feature ID + feat_bonds = np.zeros((n_bonds,)) # bond feature ID + atom_adj = np.zeros((n_atoms, max_neighbors)) + bond_adj = np.zeros((n_atoms, max_neighbors)) + n_neighbors = np.zeros((n_atoms,)) + neighbor_mask = np.zeros((n_atoms, max_neighbors)) + + for atom in mol.GetAtoms(): + idx = atom.GetIdx() + feat_atoms[idx] = atom_features(atom) + + for bond in mol.GetBonds(): + a1 = bond.GetBeginAtom().GetIdx() + a2 = bond.GetEndAtom().GetIdx() + idx = bond.GetIdx() + feat_bonds[idx] = bond_features(bond) + try: + atom_adj[a1, n_neighbors[a1]] = a2 + atom_adj[a2, n_neighbors[a2]] = a1 + except: + return [], [], [], [], [] + bond_adj[a1, n_neighbors[a1]] = idx + bond_adj[a2, n_neighbors[a2]] = idx + n_neighbors[a1] += 1 + n_neighbors[a2] += 1 + + for i in range(len(n_neighbors)): + neighbor_mask[i, :n_neighbors[i]] = 1 + + vertex_mask = get_mask(feat_atoms) + # vertex = pack_1d(feat_atoms) + # edge = pack_1d(feat_bonds) + # atom_adj = pack_2d(atom_adj) + # bond_adj = pack_2d(bond_adj) + # nbs_mask = pack_2d(n_neighbors_mat) + + atom_adj = add_index(atom_adj, np.shape(atom_adj)[1]) + bond_adj = add_index(bond_adj, np.shape(feat_bonds)[1]) + + return vertex_mask, feat_atoms, feat_bonds, atom_adj, bond_adj, neighbor_mask + + +# TODO WIP the pairwise_label matrix probably should be generated beforehand and stored as an extra label in the dataset +def get_pairwise_label(pdbid, interaction_dict, mol): + if pdbid in interaction_dict: + sdf_element = np.array([atom.GetSymbol().upper() for atom in mol.GetAtoms()]) + atom_element = np.array(interaction_dict[pdbid]['atom_element'], dtype=str) + atom_name_list = np.array(interaction_dict[pdbid]['atom_name'], dtype=str) + atom_interact = np.array(interaction_dict[pdbid]['atom_interact'], dtype=int) + nonH_position = np.where(atom_element != 'H')[0] + assert sum(atom_element[nonH_position] != sdf_element) == 0 + + atom_name_list = atom_name_list[nonH_position].tolist() + pairwise_mat = np.zeros((len(nonH_position), len(interaction_dict[pdbid]['uniprot_seq'])), dtype=np.int32) + for atom_name, bond_type in interaction_dict[pdbid]['atom_bond_type']: + atom_idx = atom_name_list.index(str(atom_name)) + assert atom_idx < len(nonH_position) + + seq_idx_list = [] + for seq_idx, bond_type_seq in interaction_dict[pdbid]['residue_bond_type']: + if bond_type == bond_type_seq: + seq_idx_list.append(seq_idx) + pairwise_mat[atom_idx, seq_idx] = 1 + if len(np.where(pairwise_mat != 0)[0]) != 0: + pairwise_mask = True + return True, pairwise_mat + return False, np.zeros((1, 1)) + + +def protein_featurizer(fasta): + sequence = fasta_to_label(fasta) + # pad proteins and make masks + seq_mask = get_mask(sequence) + + return seq_mask, sequence diff --git a/deepscreen/data/featurizers/token.py b/deepscreen/data/featurizers/token.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc049ab2594fabbf8ab25de6fa042b58011cb12 --- /dev/null +++ b/deepscreen/data/featurizers/token.py @@ -0,0 +1,299 @@ +import collections +from importlib import resources +import os +import re +from typing import Optional, List + +import numpy as np +from transformers import BertTokenizer + +SMI_REGEX_PATTERN = r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])""" +# \[[^\]]+\] # match anything inside square brackets +# |Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p # match elements +# |\(|\) # match parentheses +# |\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2} # match various symbols +# |[0-9] # match digits + + +def sequence_to_kmers(sequence, k=3): + """ Divide a string into a list of kmers strings. + + Parameters: + sequence (string) + k (int), default 3 + Returns: + List containing a list of kmers. + """ + return [sequence[i:i + k] for i in range(len(sequence) - k + 1)] + + +def sequence_to_word_embedding(sequence, model): + """Get protein embedding, infer a list of 3-mers to (num_word, 100) matrix""" + kmers = sequence_to_kmers(sequence) + vec = np.zeros((len(kmers), 100)) + i = 0 + for word in kmers: + try: + vec[i,] = model.wv[word] + except KeyError: + pass + i += 1 + return vec + + +def sequence_to_token_ids(sequence, tokenizer): + token_ids = tokenizer.encode(sequence) + return np.array(token_ids) + + +# def sequence_to_token_ids(sequence, tokenizer, max_length: int): +# token_ids = tokenizer.encode(sequence) +# length = min(max_length, len(token_ids)) +# +# token_ids_padded = np.zeros(max_length, dtype='int') +# token_ids_padded[:length] = token_ids[:length] +# +# return token_ids_padded + + +class SmilesTokenizer(BertTokenizer): + """ + Adapted from https://github.com/deepchem/deepchem/. + + Creates the SmilesTokenizer class. The tokenizer heavily inherits from the BertTokenizer + implementation found in Huggingface's transformers library. It runs a WordPiece tokenization + algorithm over SMILES strings using the tokenization SMILES regex developed by Schwaller et al. + + Please see https://github.com/huggingface/transformers + and https://github.com/rxn4chemistry/rxnfp for more details. + + Examples + -------- + >>> tokenizer = SmilesTokenizer(vocab_path, regex_pattern) + >>> print(tokenizer.encode("CC(=O)OC1=CC=CC=C1C(=O)O")) + [12, 16, 16, 17, 22, 19, 18, 19, 16, 20, 22, 16, 16, 22, 16, 16, 22, 16, 20, 16, 17, 22, 19, 18, 19, 13] + + + References + ---------- + .. [1] Schwaller, Philippe; Probst, Daniel; Vaucher, Alain C.; Nair, Vishnu H; Kreutter, David; + Laino, Teodoro; et al. (2019): Mapping the Space of Chemical Reactions using Attention-Based Neural + Networks. ChemRxiv. Preprint. https://doi.org/10.26434/chemrxiv.9897365.v3 + + Note + ---- + This class requires huggingface's transformers and tokenizers libraries to be installed. + """ + + def __init__( + self, + vocab_file: str = 'resources/vocabs/smiles.txt', + regex_pattern: str = SMI_REGEX_PATTERN, + # unk_token="[UNK]", + # sep_token="[SEP]", + # pad_token="[PAD]", + # cls_token="[CLS]", + # mask_token="[MASK]", + **kwargs): + """Constructs a SmilesTokenizer. + + Parameters + ---------- + vocab_file: str + Path to a SMILES character per line vocabulary file. + Default vocab file is found in deepchem/feat/tests/data/vocab.txt + """ + + super().__init__(vocab_file, **kwargs) + + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocab file at path '{}'.".format(vocab_file)) + self.vocab = load_vocab(vocab_file) + unused_indexes = [i for i, v in enumerate(self.vocab.keys()) if v.startswith("[unused")] + self.highest_unused_index = 0 if len(unused_indexes) == 0 else max(unused_indexes) + self.ids_to_tokens = collections.OrderedDict([ + (ids, tok) for tok, ids in self.vocab.items() + ]) + self.basic_tokenizer = BasicSmilesTokenizer(regex_pattern=regex_pattern) + + @property + def vocab_size(self): + return len(self.vocab) + + @property + def vocab_list(self): + return list(self.vocab.keys()) + + def _tokenize(self, text: str, max_seq_length: int = 512, **kwargs): + """Tokenize a string into a list of tokens. + + Parameters + ---------- + text: str + Input string sequence to be tokenized. + """ + + max_len_single_sentence = max_seq_length - 2 + split_tokens = [ + token for token in self.basic_tokenizer.tokenize(text) + [:max_len_single_sentence] + ] + return split_tokens + + def _convert_token_to_id(self, token: str): + """Converts a token (str/unicode) in an id using the vocab. + + Parameters + ---------- + token: str + String token from a larger sequence to be converted to a numerical id. + """ + + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index: int): + """Converts an index (integer) in a token (string/unicode) using the vocab. + + Parameters + ---------- + index: int + Integer index to be converted back to a string-based token as part of a larger sequence. + """ + + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens: List[str]): + """Converts a sequence of tokens (string) in a single string. + + Parameters + ---------- + tokens: List[str] + List of tokens for a given string sequence. + + Returns + ------- + out_string: str + Single string from combined tokens. + """ + + out_string: str = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def add_special_tokens_ids_single_sequence(self, + token_ids: List[Optional[int]]): + """Adds special tokens to a sequence for sequence classification tasks. + + A BERT sequence has the following format: [CLS] X [SEP] + + Parameters + ---------- + token_ids: list[int] + list of tokenized input ids. Can be obtained using the encode or encode_plus methods. + """ + + return [self.cls_token_id] + token_ids + [self.sep_token_id] + + def add_special_tokens_single_sequence(self, tokens: List[str]): + """Adds special tokens to the a sequence for sequence classification tasks. + A BERT sequence has the following format: [CLS] X [SEP] + + Parameters + ---------- + tokens: List[str] + List of tokens for a given string sequence. + """ + return [self.cls_token] + tokens + [self.sep_token] + + def add_special_tokens_ids_sequence_pair( + self, token_ids_0: List[Optional[int]], + token_ids_1: List[Optional[int]]) -> List[Optional[int]]: + """Adds special tokens to a sequence pair for sequence classification tasks. + A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP] + + Parameters + ---------- + token_ids_0: List[int] + List of ids for the first string sequence in the sequence pair (A). + token_ids_1: List[int] + List of tokens for the second string sequence in the sequence pair (B). + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + return cls + token_ids_0 + sep + token_ids_1 + sep + + def add_padding_tokens(self, + token_ids: List[Optional[int]], + length: int, + right: bool = True) -> List[Optional[int]]: + """Adds padding tokens to return a sequence of length max_length. + By default padding tokens are added to the right of the sequence. + + Parameters + ---------- + token_ids: list[optional[int]] + list of tokenized input ids. Can be obtained using the encode or encode_plus methods. + length: int + right: bool, default True + + Returns + ------- + List[int] + """ + padding = [self.pad_token_id] * (length - len(token_ids)) + + if right: + return token_ids + padding + else: + return padding + token_ids + + +class BasicSmilesTokenizer(object): + """ + Adapted from https://github.com/deepchem/deepchem/. + Run basic SMILES tokenization using a regex pattern developed by Schwaller et. al. + This tokenizer is to be used when a tokenizer that does not require the transformers library by HuggingFace is required. + + Examples + -------- + >>> tokenizer = BasicSmilesTokenizer() + >>> print(tokenizer.tokenize("CC(=O)OC1=CC=CC=C1C(=O)O")) + ['C', 'C', '(', '=', 'O', ')', 'O', 'C', '1', '=', 'C', 'C', '=', 'C', 'C', '=', 'C', '1', 'C', '(', '=', 'O', ')', 'O'] + + + References + ---------- + .. [1] Philippe Schwaller, Teodoro Laino, Théophile Gaudin, Peter Bolgar, Christopher A. Hunter, Costas Bekas, and Alpha A. Lee + ACS Central Science 2019 5 (9): Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction + 1572-1583 DOI: 10.1021/acscentsci.9b00576 + """ + + def __init__(self, regex_pattern: str = SMI_REGEX_PATTERN): + """Constructs a BasicSMILESTokenizer. + + Parameters + ---------- + regex: string + SMILES token regex + """ + self.regex_pattern = regex_pattern + self.regex = re.compile(self.regex_pattern) + + def tokenize(self, text): + """Basic Tokenization of a SMILES. + """ + tokens = [token for token in self.regex.findall(text)] + return tokens + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab diff --git a/deepscreen/data/single_entity.py b/deepscreen/data/single_entity.py new file mode 100644 index 0000000000000000000000000000000000000000..ef4195fdd988b44f9712b3a426b2bb4b2857a6d0 --- /dev/null +++ b/deepscreen/data/single_entity.py @@ -0,0 +1,195 @@ +# from itertools import product +from numbers import Number +from pathlib import Path +from typing import Any, Dict, Optional, Sequence, Union, Literal + +# import numpy as np +import pandas as pd +from lightning import LightningDataModule +from sklearn.base import TransformerMixin +from torch.utils.data import Dataset, DataLoader, random_split + +from deepscreen.data.utils.dataset import SingleEntitySingleTargetDataset, BaseEntityDataset +from deepscreen.data.utils.label import label_transform +from deepscreen.data.utils.collator import collate_fn +from deepscreen.data.utils.sampler import SafeBatchSampler + + +class EntityDataModule(LightningDataModule): + """ + DTI DataModule + + A DataModule implements 5 key methods: + + def prepare_data(self): + # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) + # download data, pre-process, split, save to disk, etc. + def setup(self, stage): + # things to do on every process in DDP + # load data, set variables, etc. + def train_dataloader(self): + # return train dataloader + def val_dataloader(self): + # return validation dataloader + def test_dataloader(self): + # return test dataloader + def teardown(self): + # called on every process in DDP + # clean up after fit or test + + This allows you to share a full dataset without explaining how to download, + split, transform and process the data. + + Read the docs: + https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html + """ + + def __init__( + self, + dataset: type[BaseEntityDataset], + task: Literal['regression', 'binary', 'multiclass'], + n_classes: Optional[int], + train: bool, + batch_size: int, + num_workers: int = 0, + thresholds: Optional[Union[Number, Sequence[Number]]] = None, + pin_memory: bool = False, + data_dir: str = "data/", + data_file: Optional[str] = None, + train_val_test_split: Optional[Sequence[Number], Sequence[str]] = None, + split: Optional[callable] = random_split, + ): + super().__init__() + data_path = Path(data_dir) / data_file + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + # data processing + self.split = split + + if train: + if all([data_file, split]): + if all(isinstance(split, Number) for split in train_val_test_split): + pass + else: + raise ValueError('`train_val_test_split` must be a sequence of 3 numbers ' + '(float for percentages and int for sample numbers) if ' + '`data_file` and `split` have been specified.') + elif all(isinstance(split, str) for split in train_val_test_split) and not any([data_file, split]): + self.train_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[0])) + self.val_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[1])) + self.test_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[2])) + else: + raise ValueError('For training (train=True), you must specify either ' + '`dataset_name` and `split` with `train_val_test_split` of 3 numbers or ' + 'solely `train_val_test_split` of 3 data file names.') + else: + if data_file and not any([split, train_val_test_split]): + self.test_data = self.predict_data = dataset(dataset_path=str(Path(data_dir) / data_file)) + else: + raise ValueError("For testing/predicting (train=False), you must specify only `data_file` without " + "`train_val_test_split` or `split`") + + def prepare_data(self): + """ + Download data if needed. + Do not use it to assign state (e.g., self.x = x). + """ + + def setup(self, stage: Optional[str] = None, encoding: str = None): + """ + Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be + careful not to execute data splitting twice. + """ + # load and split datasets only if not loaded in initialization + if not any([self.data_train, self.data_val, self.data_test, self.data_predict]): + dataset = SingleEntitySingleTargetDataset( + task=self.hparams.task, + n_classes=self.hparams.n_classes, + dataset_path=Path(self.hparams.data_dir) / self.hparams.dataset_name, + transformer=self.hparams.transformer, + featurizer=self.hparams.featurizer, + thresholds=self.hparams.thresholds, + ) + + if self.hparams.train: + self.data_train, self.data_val, self.data_test = self.split( + dataset=dataset, + lengths=self.hparams.train_val_test_split + ) + else: + self.data_test = self.data_predict = dataset + + def train_dataloader(self): + return DataLoader( + dataset=self.data_train, + batch_sampler=SafeBatchSampler( + data_source=self.data_train, + batch_size=self.hparams.batch_size, + shuffle=True), + # batch_size=self.hparams.batch_size, + # shuffle=True, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=collate_fn, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.data_val, + batch_sampler=SafeBatchSampler( + data_source=self.data_val, + batch_size=self.hparams.batch_size, + shuffle=False), + # batch_size=self.hparams.batch_size, + # shuffle=False, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=collate_fn, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def test_dataloader(self): + return DataLoader( + dataset=self.data_test, + batch_sampler=SafeBatchSampler( + data_source=self.data_test, + batch_size=self.hparams.batch_size, + shuffle=False), + # batch_size=self.hparams.batch_size, + # shuffle=False, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=collate_fn, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def predict_dataloader(self): + return DataLoader( + dataset=self.data_predict, + batch_sampler=SafeBatchSampler( + data_source=self.data_predict, + batch_size=self.hparams.batch_size, + shuffle=False), + # batch_size=self.hparams.batch_size, + # shuffle=False, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=collate_fn, + persistent_workers=True if self.hparams.num_workers > 0 else False + ) + + def teardown(self, stage: Optional[str] = None): + """Clean up after fit or test.""" + pass + + def state_dict(self): + """Extra things to save to checkpoint.""" + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Things to do when loading checkpoint.""" + pass diff --git a/deepscreen/data/utils/__init__.py b/deepscreen/data/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..80edbecbbd91c4b36749186b2f72069c4918d949 --- /dev/null +++ b/deepscreen/data/utils/__init__.py @@ -0,0 +1,8 @@ +from typing import Dict, Sequence, TypeVar, Union + +from deepscreen.data.utils.collator import collate_fn +from deepscreen.data.utils.label import label_transform +from deepscreen.data.utils.sampler import SafeBatchSampler + +T = TypeVar('T') +FlexibleIterable = Union[T, Sequence[T], Dict[str, T]] diff --git a/deepscreen/data/utils/__pycache__/__init__.cpython-311.pyc b/deepscreen/data/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a06039ab6cbf8c6d7949afa9668d9c62bc356e16 Binary files /dev/null and b/deepscreen/data/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/data/utils/__pycache__/collator.cpython-311.pyc b/deepscreen/data/utils/__pycache__/collator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb17a8a4426c690b4a6c56a3abbe993b459bbcd4 Binary files /dev/null and b/deepscreen/data/utils/__pycache__/collator.cpython-311.pyc differ diff --git a/deepscreen/data/utils/__pycache__/label.cpython-311.pyc b/deepscreen/data/utils/__pycache__/label.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f866faaa1d09e2a3173e514924ecfcee4751ea28 Binary files /dev/null and b/deepscreen/data/utils/__pycache__/label.cpython-311.pyc differ diff --git a/deepscreen/data/utils/__pycache__/sampler.cpython-311.pyc b/deepscreen/data/utils/__pycache__/sampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..724502a7093c85b8d03eb232cff1682881b5a39a Binary files /dev/null and b/deepscreen/data/utils/__pycache__/sampler.cpython-311.pyc differ diff --git a/deepscreen/data/utils/__pycache__/split.cpython-311.pyc b/deepscreen/data/utils/__pycache__/split.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc2484c2d6fc079bb396f19ac239768da834d54b Binary files /dev/null and b/deepscreen/data/utils/__pycache__/split.cpython-311.pyc differ diff --git a/deepscreen/data/utils/collator.py b/deepscreen/data/utils/collator.py new file mode 100644 index 0000000000000000000000000000000000000000..c4c0875eaa0a7a9b49fef10a07cb3af824a24fe1 --- /dev/null +++ b/deepscreen/data/utils/collator.py @@ -0,0 +1,168 @@ +""" +Define collate functions for new data types here +""" +from functools import partial +from itertools import chain + +import dgl +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data._utils.collate import default_collate_fn_map, collate_tensor_fn, collate +import torch_geometric + + +def collate_pyg_fn(batch, collate_fn_map=None): + """ + PyG graph collation + """ + return torch_geometric.data.Batch.from_data_list(batch) + + +def collate_dgl_fn(batch, collate_fn_map=None): + """ + DGL graph collation + """ + return dgl.batch(batch) + + +def pad_collate_tensor_fn(batch, padding_value=0.0, collate_fn_map=None): + """ + Similar to pad_packed_sequence(pack_sequence(batch, enforce_sorted=False), batch_first=True), + but additionally supports padding a list of square Tensors of size ``(L x L x ...)``. + :param batch: + :param padding_value: + :param collate_fn_map: + :return: padded_batch, lengths + """ + lengths = [tensor.size(0) for tensor in batch] + if any(element != lengths[0] for element in lengths[1:]): + try: + # Tensors share at least one common dimension size, use pad_sequence + batch = pad_sequence(batch, batch_first=True, padding_value=padding_value) + except RuntimeError: + # Tensors do not share any common dimension size, find the max size of each dimension in the batch + max_sizes = [max([tensor.size(dim) for tensor in batch]) for dim in range(batch[0].dim())] + # Pad every dimension of all tensors in the batch to be the respective max size with the value + batch = collate_tensor_fn([ + torch.nn.functional.pad( + tensor, tuple(chain.from_iterable( + [(0, max_sizes[dim] - tensor.size(dim)) for dim in range(tensor.dim())][::-1]) + ), mode='constant', value=padding_value) for tensor in batch + ]) + else: + batch = collate_tensor_fn(batch) + + lengths = torch.as_tensor(lengths) + # Return the padded batch tensor and the lengths + return batch, lengths + + +# Join custom collate functions with the default collation map of PyTorch +COLLATE_FN_MAP = default_collate_fn_map | { + torch_geometric.data.data.BaseData: collate_pyg_fn, + dgl.DGLGraph: collate_dgl_fn, +} + + +def collate_fn(batch, automatic_padding=False, padding_value=0): + if automatic_padding: + COLLATE_FN_MAP.update({ + torch.Tensor: partial(pad_collate_tensor_fn, padding_value=padding_value), + }) + return collate(batch, collate_fn_map=COLLATE_FN_MAP) + + +# class VariableLengthSequence(torch.Tensor): +# """ +# A custom PyTorch Tensor class that is similar to PackedSequence, except it can be directly used as a batch tensor, +# and it has an attribute called lengths, which signifies the length of each original sequence in the batch. +# """ +# +# def __new__(cls, data, lengths): +# """ +# Creates a new VariableLengthSequence object from the given data and lengths. +# Args: +# data (torch.Tensor): The batch collated tensor of shape (batch_size, max_length, *). +# lengths (torch.Tensor): The lengths of each original sequence in the batch of shape (batch_size,). +# Returns: +# VariableLengthSequence: A new VariableLengthSequence object. +# """ +# # Check the validity of the inputs +# assert isinstance(data, torch.Tensor), "data must be a torch.Tensor" +# assert isinstance(lengths, torch.Tensor), "lengths must be a torch.Tensor" +# assert data.dim() >= 2, "data must have at least two dimensions" +# assert lengths.dim() == 1, "lengths must have one dimension" +# assert data.size(0) == lengths.size(0), "data and lengths must have the same batch size" +# assert lengths.min() > 0, "lengths must be positive" +# assert lengths.max() <= data.size(1), "lengths must not exceed the max length of data" +# +# # Create a new tensor object from data +# obj = super().__new__(cls, data) +# +# # Set the lengths attribute +# obj.lengths = lengths +# +# return obj + + +# class VariableLengthSequence(torch.Tensor): +# _lengths = torch.Tensor() +# +# def __new__(cls, data, lengths, *args, **kwargs): +# self = super().__new__(cls, data, *args, **kwargs) +# self.lengths = lengths +# return self +# +# def clone(self, *args, **kwargs): +# return VariableLengthSequence(super().clone(*args, **kwargs), self.lengths.clone()) +# +# def new_empty(self, *size): +# return VariableLengthSequence(super().new_empty(*size), self.lengths) +# +# def to(self, *args, **kwargs): +# return VariableLengthSequence(super().to(*args, **kwargs), self.lengths.to(*args, **kwargs)) +# +# def __format__(self, format_spec): +# # Convert self to a string or a number here, depending on what you need +# return self.item().__format__(format_spec) +# +# @property +# def lengths(self): +# return self._lengths +# +# @lengths.setter +# def lengths(self, lengths): +# self._lengths = lengths +# +# def cpu(self, *args, **kwargs): +# return VariableLengthSequence(super().cpu(*args, **kwargs), self.lengths.cpu(*args, **kwargs)) +# +# def cuda(self, *args, **kwargs): +# return VariableLengthSequence(super().cuda(*args, **kwargs), self.lengths.cuda(*args, **kwargs)) +# +# def pin_memory(self): +# return VariableLengthSequence(super().pin_memory(), self.lengths.pin_memory()) +# +# def share_memory_(self): +# super().share_memory_() +# self.lengths.share_memory_() +# return self +# +# def detach_(self, *args, **kwargs): +# super().detach_(*args, **kwargs) +# self.lengths.detach_(*args, **kwargs) +# return self +# +# def detach(self, *args, **kwargs): +# return VariableLengthSequence(super().detach(*args, **kwargs), self.lengths.detach(*args, **kwargs)) +# +# def record_stream(self, *args, **kwargs): +# super().record_stream(*args, **kwargs) +# self.lengths.record_stream(*args, **kwargs) +# return self + + + # @classmethod + # def __torch_function__(cls, func, types, args=(), kwargs=None): + # return super().__torch_function__(func, types, args, kwargs) \ + # if cls.lengths is not None else torch.Tensor.__torch_function__(func, types, args, kwargs) diff --git a/deepscreen/data/utils/dataset.py b/deepscreen/data/utils/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..12af30f77dfe9d9d21432d7f424895cd27bcc446 --- /dev/null +++ b/deepscreen/data/utils/dataset.py @@ -0,0 +1,216 @@ +from numbers import Number +from typing import Literal, Union, Sequence + +import pandas as pd +from sklearn.base import TransformerMixin +from sklearn.exceptions import NotFittedError +from sklearn.utils.validation import check_is_fitted +from torch.utils.data import Dataset + +from deepscreen.data.utils import label_transform, FlexibleIterable + + +class BaseEntityDataset(Dataset): + def __init__( + self, + dataset_path: str, + use_col_prefixes=('X', 'Y', 'ID', 'U') + ): + + # Read the data table header row first to filter columns and create column dtype dict + df = pd.read_csv( + dataset_path, + header=0, nrows=0, + usecols=lambda col: col.startswith(use_col_prefixes) + ) + # Read the whole data table + df = pd.read_csv( + dataset_path, + header=0, + usecols=df.columns, + dtype={col: 'float32' if col.startswith('Y') else 'string' for col in df.columns} + ) + + self.df = df + self.label_cols = [col for col in df.columns if col.startswith('Y')] + self.label_unit_cols = [col for col in df.columns if col.startswith('U')] + self.entity_id_cols = [col for col in df.columns if col.startswith('ID')] + self.entity_cols = [col for col in df.columns if col.startswith('X')] + + def __len__(self): + return len(self.df.index) + + def __getitem__(self, idx): + raise NotImplementedError + + +# TODO test transform +class SingleEntitySingleTargetDataset(BaseEntityDataset): + def __init__( + self, + dataset_path: str, + task: Literal['regression', 'binary', 'multiclass'], + n_classes: int, + featurizer: callable, + transformer: TransformerMixin = None, + thresholds: Union[Number, Sequence[Number]] = None, + discard_intermediate: bool = None, + forward_fill: bool = True + ): + super().__init__(dataset_path) + + assert len(self.entity_cols) == 1, 'The dataset contains more than 1 entity column (starting with `X`).' + if len(self.label_cols) >= 0: + assert len(self.label_cols) == 1, 'The dataset contains more than 1 label column (starting with `Y`).' + # Remove trailing `1`s in column names for flexibility + self.df.columns = self.df.columns.str.rstrip('1') + + # Forward-fill non-label columns + nonlabel_cols = self.label_unit_cols + self.entity_id_cols + self.entity_cols + if forward_fill: + self.df[nonlabel_cols] = self.df[nonlabel_cols].ffill(axis=0) + + # Process target labels for training/testing if exist + if self.label_cols: + # Transform target labels + self.df[self.label_cols] = self.df[self.label_cols].apply( + label_transform, + units=self.df.get('U', None), + thresholds=thresholds, + discard_intermediate=discard_intermediate).astype('float32') + + # Filter out rows with a NaN in Y (missing values); use inplace to save memory + self.df.dropna(subset=self.label_cols, inplace=True) + + # Validate target labels + # TODO: check sklearn.utils.multiclass.check_classification_targets + match task: + case 'regression': + assert all(self.df['Y'].apply(lambda x: isinstance(x, Number))), \ + f"Y for task `regression` must be numeric; got {set(self.df['Y'].apply(type))}." + case 'binary': + assert all(self.df['Y'].isin([0, 1])), \ + f"Y for task `binary` (classification) must be 0 or 1, but Y got {pd.unique(self.df['Y'])}." \ + "\nYou may set `thresholds` to discretize continuous labels." + case 'multiclass': + assert n_classes >= 3, f'n_classes for task `multiclass` (classification) must be at least 3.' + assert all(self.df['Y'].apply(lambda x: x.is_integer() and x >= 0)), \ + f"``Y` for task `multiclass` (classification) must be non-negative integers, " \ + f"but `Y` got {pd.unique(self.df['Y'])}." \ + "\nYou may set `thresholds` to discretize continuous labels." + target_n_unique = self.df['Y'].nunique() + assert target_n_unique == n_classes, \ + f"You have set n_classes for task `multiclass` (classification) task to {n_classes}, " \ + f"but `Y` has {target_n_unique} unique labels." + + if transformer: + self.df['X'] = self.df['X'].apply(featurizer) + try: + check_is_fitted(transformer) + self.df['X'] = list(transformer.transform(self.df['X'])) + except NotFittedError: + self.df['X'] = list(transformer.fit_transform(self.df['X'])) + + # Skip sample-wise feature extraction because it has already been done dataset-wise + self.featurizer = lambda x: x + + self.featurizer = featurizer + self.n_classes = n_classes + self.df['ID'] = self.df.get('ID', self.df['X']) + + def __getitem__(self, idx): + sample = self.df.loc[idx] + return { + 'X': self.featurizer(sample['X']), + 'ID': sample['ID'], + 'Y': sample.get('Y') + } + + +# TODO WIP +class MultiEntityMultiTargetDataset(BaseEntityDataset): + def __init__( + self, + dataset_path: str, + task: FlexibleIterable[Literal['regression', 'binary', 'multiclass']], + n_class: FlexibleIterable[int], + featurizers: FlexibleIterable[callable], + thresholds: FlexibleIterable[Union[Number, Sequence[Number]]] = None, + discard_intermediate: FlexibleIterable[bool] = None, + ): + super().__init__(dataset_path) + label_col_prefix = tuple('Y') + nonlabel_col_prefixes = tuple(('X', 'ID', 'U')) + allowed_col_prefixes = label_col_prefix + nonlabel_col_prefixes + + # Read the headers first to filter columns and create column dtype dict + df = pd.read_csv( + dataset_path, + header=0, nrows=0, + usecols=lambda col: col.startswith(allowed_col_prefixes) + ) + + # Read the whole table + df = pd.read_csv( + dataset_path, + header=0, + usecols=df.columns, + dtype={col: 'float32' if col.startswith('Y') else 'string' for col in df.columns} + ) + label_cols = [col for col in df.columns if col.startswith(label_col_prefix)] + nonlabel_cols = [col for col in df.columns if col.startswith(nonlabel_col_prefixes)] + self.entity_cols = [col for col in nonlabel_cols if col.startswith('X')] + + # Forward-fill all non-label columns + df[nonlabel_cols] = df[nonlabel_cols].ffill(axis=0) + + # Process target labels for training/testing + if label_cols: + # Transform target labels + df[label_cols] = df[label_cols].apply(label_transform, units=df.get('U', None), thresholds=thresholds, + discard_intermediate=discard_intermediate).astype('float32') + + # Filter out rows with a NaN in Y (missing values) + df.dropna(subset=label_cols, inplace=True) + + # Validate target labels + # TODO: check sklearn.utils.multiclass.check_classification_targets + # WIP + match task: + case 'regression': + assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \ + f"Y for task `regression` must be numeric; got {set(df['Y'].apply(type))}." + case 'binary': + assert all(df['Y'].isin([0, 1])), \ + f"Y for task `binary` must be 0 or 1, but Y got {pd.unique(df['Y'])}." \ + "\nYou may set `thresholds` to discretize continuous labels." + case 'multiclass': + assert len(label_cols) == len(n_class), \ + (f'Data table has {len(label_cols)} label columns (`Y*`) but you have specified ' + f'n_class of length {len(n_class)} for task `multiclass`.') + for label, n in zip(df[label_cols], n_class): + assert n >= 3, f'n_class for task `multiclass` must be at least 3.' + assert all(label.apply(lambda x: x.is_integer() and x >= 0)), \ + f"Y for task `multiclass` must be non-negative integers, " \ + f"but Y got {pd.unique(label)}." \ + "\nYou may set `thresholds` to discretize continuous labels." + target_n_unique = label.nunique() + assert target_n_unique == n, \ + f"You have set n_classes for task `multiclass` task to {n}, " \ + f"but Y has {target_n_unique} unique labels." + + self.df = df + self.featurizers = featurizers + self.n_class = n_class + + def __len__(self): + return len(self.df.index) + + # WIP + def __getitem__(self, idx): + sample = self.df.loc[idx] + return { + 'X': [featurizer(x) for featurizer, x in zip(self.featurizers, sample[self.entity_cols])], + 'ID': sample.get('ID', sample['X']), + 'Y': sample.get('Y') + } diff --git a/deepscreen/data/utils/label.py b/deepscreen/data/utils/label.py new file mode 100644 index 0000000000000000000000000000000000000000..439b799b1026800f52817d9c184ab431c5a601d3 --- /dev/null +++ b/deepscreen/data/utils/label.py @@ -0,0 +1,93 @@ +from numbers import Number +from typing import Optional, Union + +import numpy as np + +from deepscreen.utils import get_logger + +log = get_logger(__name__) + +MOLARITY_TO_POTENCY = { + 'p': lambda x: x, + 'M': lambda x: -np.log10(x), + 'mM': lambda x: -np.log10(x) + 3, + 'μM': lambda x: -np.log10(x) + 6, + 'uM': lambda x: -np.log10(x) + 6, # in case someone doesn't know how to type micromolar lol + 'nM': lambda x: -np.log10(x) + 9, + 'pM': lambda x: -np.log10(x) + 12, + 'fM': lambda x: -np.log10(x) + 15, +} + + +# TODO rewrite for swifter.apply +def molar_to_p(labels, units): + assert units in MOLARITY_TO_POTENCY, f"Allowed units: {', '.join(MOLARITY_TO_POTENCY)}." + + unit_converted_labels = [] + for label, unit in (labels, units): + unit_converted_labels.append(MOLARITY_TO_POTENCY[unit](label)) + labels = np.array(unit_converted_labels) + + return labels + + +def label_discretize(labels, thresholds): + # if isinstance(threshold, Number): + # labels = np.where(labels < threshold, 1, 0) + # else: + # labels = np.where(labels < threshold[0], 1, np.where(labels > threshold[1], 0, np.nan)) + if isinstance(thresholds, Number): + labels = 1 - np.digitize(labels, [thresholds]) + else: + labels = np.digitize(labels, np.sort(thresholds)[::-1]) + + return labels + + +def label_transform( + labels, + units: Optional[list[str]], + thresholds: Optional[Union[float, list[Number]]], + discard_intermediate: Optional[bool] +): + f"""Convert labels of all units to p scale (-log10[M]) and binarize them if specified. + :param labels: a sequence of labels, continuous or binary values + :type labels: array_like + :param units: a sequence of label units in {', '.join(MOLARITY_TO_POTENCY)} + :type units: array_like, optional + :param thresholds: discretization threshold(s) for affinity labels, in p scale (-log10[M]). + A single number maps affinities below it to 1 and otherwise to 0. + A tuple of two or more thresholds maps affinities to multiple discrete levels descendingly, assigning values + values below the lowest threshold to the highest level (e.g. 2) and values above the greatest threshold to 0 + :type thresholds: list, float, optional + :param discard_intermediate: whether to discard the intermediate (indeterminate) level if provided an odd + number of thresholds (>=3) + :type discard_intermediate: bool + :return: a numpy array of affinity labels in p scale (-log10[M]) or discrete labels + """ + # # Check if labels are already discrete (ignoring NAs). + # discrete = labels.dropna().isin([0, 1]).all() + # + # if discrete: + # assert discretize, "Cannot train a regression model with discrete labels." + # if thresholds: + # warn("Ignoring 'threshold' because 'Y' (labels) in the data table is already binary.") + # if units: + # warn("Ignoring 'units' because 'Y' (labels) in the data table is already binary.") + # labels = labels + if units: + labels = molar_to_p(labels, units) + + if thresholds: + labels = label_discretize(labels, thresholds) + if discard_intermediate: + assert len(thresholds) % 2 == 1 and len(thresholds) >= 3, \ + "Must give an odd number of (at least 3) thresholds to discard the intermediate level." + intermediate_level = len(thresholds) // 2 + # Make the intermediate-level labels NaN (which will be filtered out later) + labels[labels == intermediate_level] = np.nan + # Reduce all levels above the intermediate level by 1 + labels[labels > intermediate_level] -= 1 + + return labels + diff --git a/deepscreen/data/utils/sampler.py b/deepscreen/data/utils/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..0431c132fc8d8a54f877bb4b1e8fcade92717986 --- /dev/null +++ b/deepscreen/data/utils/sampler.py @@ -0,0 +1,90 @@ +from typing import Mapping, Iterable + +from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler + + +class SafeBatchSampler(BatchSampler): + """ + A safe `batch_sampler` that skips samples with `None` values, supports shuffling, and keep a fixed batch size. + + Args: + data_source (Dataset): The dataset to sample from. + batch_size (int): The size of each batch. + drop_last (bool): Whether to drop the last batch if its size is smaller than `batch_size`. Defaults to `False`. + shuffle (bool, optional): Whether to shuffle the data before sampling. Defaults to `True`. + + Example: + >>> dataloader = DataLoader(dataset, batch_sampler=SafeBatchSampler(dataset, batch_size, drop_last, shuffle)) + """ + + def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool, sampler=None): + if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ + batch_size <= 0: + raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}") + if not isinstance(drop_last, bool): + raise ValueError(f"drop_last should be a boolean value, but got drop_last={drop_last}") + if sampler: + pass + elif shuffle: + sampler = RandomSampler(data_source) # type: ignore[arg-type] + else: + sampler = SequentialSampler(data_source) # type: ignore[arg-type] + + super().__init__(sampler, batch_size, drop_last) + self.data_source = data_source + + # def __iter__(self): + # batch = [] + # for idx in self.sampler: + # sample = self.data_source[idx] + # # if isinstance(sample, list | tuple): + # # pass + # # elif isinstance(sample, dict): + # # sample = sample.values() + # # elif isinstance(sample, Series): + # # sample = sample.values + # # else: + # # sample = [sample] + # if isinstance(sample, (Iterable, Mapping)) and not isinstance(sample, str): + # if isinstance(sample, Mapping): + # sample = list(sample.values()) + # else: + # sample = [sample] + # + # if all(v is not None for v in sample): + # batch.append(idx) + # if len(batch) == self.batch_size: + # yield batch + # batch = [] + # + # if len(batch) > 0 and not self.drop_last: + # yield batch + # + # if not batch: + # raise StopIteration + + def __iter__(self): + batch = [0] * self.batch_size + idx_in_batch = 0 + for idx in self.sampler: + sample = self.data_source[idx] + if isinstance(sample, (Iterable, Mapping)) and not isinstance(sample, str): + if isinstance(sample, Mapping): + sample = sample.values() + else: + sample = [sample] + + if all(v is not None for v in sample): + batch[idx_in_batch] = idx + idx_in_batch += 1 + if idx_in_batch == self.batch_size: + yield batch + idx_in_batch = 0 + batch = [0] * self.batch_size + + if idx_in_batch > 0 and not self.drop_last: + yield batch[:idx_in_batch] + + if not any(batch): + # raise StopIteration + return diff --git a/deepscreen/data/utils/split.py b/deepscreen/data/utils/split.py new file mode 100644 index 0000000000000000000000000000000000000000..b8f6d643831af267d306035dff5e45b9930b23c9 --- /dev/null +++ b/deepscreen/data/utils/split.py @@ -0,0 +1,72 @@ +import torch +from typing import List, Union +from torch.utils import data + +# FIXME Test and fix these split functions later + + +def random_split(dataset: data.Dataset, lengths, seed): + return data.random_split(dataset, lengths, generator=torch.Generator().manual_seed(seed)) + + +def cold_start(dataset: data.Dataset, frac: List[float], entities: Union[str, List[str]]): + """Create cold-start splits for PyTorch datasets. + + Args: + dataset (Dataset): PyTorch dataset object. + frac (list): A list of train/valid/test fractions. + entities (Union[str, List[str]]): Either a single "cold" entity or a list of "cold" entities + on which the split is done. + + Returns: + dict: A dictionary of splitted datasets, where keys are 'train', 'valid', and 'test', + and values correspond to each dataset. + """ + if isinstance(entities, str): + entities = [entities] + + train_frac, val_frac, test_frac = frac + + # Collect unique instances for each entity + entity_instances = {} + for entity in entities: + entity_instances[entity] = list(set([getattr(sample, entity) for sample in dataset])) + + # Sample instances belonging to the test datasets + test_entity_instances = [ + torch.randperm(len(entity_instances[entity]))[:int(len(entity_instances[entity]) * test_frac)] + for entity in entities + ] + + # Select samples where all entities are in the test set + test_indices = [] + for i, sample in enumerate(dataset): + if all([getattr(sample, entity) in entity_instances[entity][test_entity_instances[j]] for j, entity in enumerate(entities)]): + test_indices.append(i) + + if len(test_indices) == 0: + raise ValueError('No test samples found. Try increasing the test frac or a less stringent splitting strategy.') + + # Proceed with validation data + train_val_indices = list(set(range(len(dataset))) - set(test_indices)) + + val_entity_instances = [ + torch.randperm(len(entity_instances[entity]))[:int(len(entity_instances[entity]) * val_frac / (1 - test_frac))] + for entity in entities + ] + + val_indices = [] + for i in train_val_indices: + if all([getattr(dataset[i], entity) in entity_instances[entity][val_entity_instances[j]] for j, entity in enumerate(entities)]): + val_indices.append(i) + + if len(val_indices) == 0: + raise ValueError('No validation samples found. Try increasing the test frac or a less stringent splitting strategy.') + + train_indices = list(set(train_val_indices) - set(val_indices)) + + train_dataset = torch.utils.data.Subset(dataset, train_indices) + val_dataset = torch.utils.data.Subset(dataset, val_indices) + test_dataset = torch.utils.data.Subset(dataset, test_indices) + + return {'train': train_dataset, 'valid': val_dataset, 'test': test_dataset} diff --git a/deepscreen/data/utils/transform.py b/deepscreen/data/utils/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..e146b880da2a30cd1abcc17490ec4f778506ed24 --- /dev/null +++ b/deepscreen/data/utils/transform.py @@ -0,0 +1,8 @@ +from sklearn.preprocessing import StandardScaler, MinMaxScaler, MaxAbsScaler, RobustScaler + + +def scale_transform(features, scaler: type[StandardScaler, MinMaxScaler, MaxAbsScaler, RobustScaler]): + scaler = scaler() + features = scaler.fit_transform(features) + features = scaler.transform(features) + return features diff --git a/deepscreen/gui/__init__.py b/deepscreen/gui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deepscreen/gui/test.py b/deepscreen/gui/test.py new file mode 100644 index 0000000000000000000000000000000000000000..39003516a29659cc37f7f97f09133e66f0711093 --- /dev/null +++ b/deepscreen/gui/test.py @@ -0,0 +1,114 @@ +from pathlib import Path + +import gradio as gr + +# Use this in a notebook +root = Path.cwd() + + +drug_encoder_list = [f.stem for f in root.parent.joinpath("configs/model/drug_encoder").iterdir() if f.suffix == ".yaml"] + +drug_featurizer_list = [f.stem for f in root.parent.joinpath("configs/model/drug_featurizer").iterdir() if f.suffix == ".yaml"] + +protein_encoder_list = [f.stem for f in root.parent.joinpath("configs/model/protein_encoder").iterdir() if f.suffix == ".yaml"] + +protein_featurizer_list = [f.stem for f in root.parent.joinpath("configs/model/protein_featurizer").iterdir() if f.suffix == ".yaml"] + +classifier_list = [f.stem for f in root.parent.joinpath("configs/model/classifier").iterdir() if f.suffix == ".yaml"] + +preset_list = [f.stem for f in root.parent.joinpath("configs/model/preset").iterdir() if f.suffix == ".yaml"] + + +from typing import Optional + +def drug_target_interaction( + binary: bool, + drug_encoder, + drug_featurizer, + protein_encoder, + protein_featurizer, + classifier, + preset,) -> Optional[float]: + + + return 1 + +def drug_encoder( + binary: bool, + drug_encoder, + drug_featurizer, + protein_encoder, + protein_featurizer, + classifier, + preset,): + + return + +def protein_encoder( + binary: bool, + drug_encoder, + drug_featurizer, + protein_encoder, + protein_featurizer, + classifier, + preset,): + + return + +# demo = gr.Interface( +# fn=drug_target_interaction, +# inputs=[ +# gr.Radio(["True", "False"]), +# gr.Dropdown(drug_encoder_list), +# gr.Dropdown(drug_featurizer_list), +# gr.Dropdown(protein_encoder_list), +# gr.Dropdown(protein_featurizer_list), +# gr.Dropdown(classifier_list), +# gr.Dropdown(preset_list), +# ], +# outputs=["number"], +# show_error=True, +# +# ) +# +# demo.launch() + + +from omegaconf import DictConfig, OmegaConf + +type_to_component_map = {list: gr.Text, int: gr.Number, float: gr.Number} + + +def get_config_choices(config_path: str): + return [f.stem for f in Path("../../configs/", config_path).iterdir() if f.suffix == ".yaml"] + + +def create_blocks_from_config(cfg: DictConfig): + with gr.Blocks() as blocks: + for key, value in cfg.items(): + if type(value) in [int, float]: + component = gr.Number(value=value, label=key, interactive=True) + if type(value) in [dict, DictConfig]: + with gr.Tab(label=key): + component = create_blocks_from_config(value) + else: + component = gr.Text(value=value, label=key, interactive=True) + return blocks + + +def create_interface_from_config(fn: callable, cfg: DictConfig): + inputs = [] + + for key, value in OmegaConf.to_object(cfg).items(): + component = type_to_component_map.get(type(value), gr.Text) + inputs.append(component(value=value, label=key, interactive=True)) + + interface = gr.Interface(fn=fn, inputs=inputs, outputs="label") + + return interface + + +import hydra + +with hydra.initialize(version_base=None, config_path="../../configs/"): + cfg = hydra.compose("train") \ No newline at end of file diff --git a/deepscreen/models/__init__.py b/deepscreen/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deepscreen/models/__pycache__/__init__.cpython-311.pyc b/deepscreen/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b1d9db43c68b99c204ad7f1f685a1821ff20f5b Binary files /dev/null and b/deepscreen/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/models/__pycache__/dti.cpython-311.pyc b/deepscreen/models/__pycache__/dti.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99cd5e5052e7efb9eb25683194d07ac7c9b825ac Binary files /dev/null and b/deepscreen/models/__pycache__/dti.cpython-311.pyc differ diff --git a/deepscreen/models/components/__init__.py b/deepscreen/models/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deepscreen/models/components/__pycache__/__init__.cpython-311.pyc b/deepscreen/models/components/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..435a4b856aea9f911f80936aa927c55342eebfe6 Binary files /dev/null and b/deepscreen/models/components/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/models/components/__pycache__/cnn.cpython-311.pyc b/deepscreen/models/components/__pycache__/cnn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5972fc842191d7cc2b1775fa18472c427269e7a Binary files /dev/null and b/deepscreen/models/components/__pycache__/cnn.cpython-311.pyc differ diff --git a/deepscreen/models/components/cnn.py b/deepscreen/models/components/cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..dcaba9b0f18902b2d6503c4ec1e08170cbd2a1de --- /dev/null +++ b/deepscreen/models/components/cnn.py @@ -0,0 +1,42 @@ +from torch import nn, rand +from torch.autograd import Variable + + +class CNN(nn.Sequential): + def __init__( + self, + filters: list[int], + kernels: list[int], + max_sequence_length: int, + in_channels: int, + out_channels: int + ): + super().__init__() + num_layer = len(filters) + channels = [in_channels] + filters + self.conv = nn.ModuleList([nn.Conv1d(in_channels=channels[i], + out_channels=channels[i+1], + kernel_size=kernels[i]) + for i in range(num_layer)]) + n_size = self._get_conv_output((in_channels, max_sequence_length)) + self.fc1 = nn.Linear(n_size, out_channels) + + def _forward_features(self, x): + for layer in self.conv: + x = nn.functional.relu(layer(x)) + x = nn.functional.adaptive_max_pool1d(x, output_size=1) + return x + + def _get_conv_output(self, shape): + bs = 1 + input_feat = Variable(rand(bs, *shape)) + output_feat = self._forward_features(input_feat) + n_size = output_feat.data.view(bs, -1).size(1) + return n_size + + def forward(self, v): + v = self._forward_features(v.float()) + v = v.view(v.size(0), -1) + v = self.fc1(v) + return v + diff --git a/deepscreen/models/components/concat_mlp.py b/deepscreen/models/components/concat_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..481d09dd8dc05210138f91cf07484fdba1c8e1da --- /dev/null +++ b/deepscreen/models/components/concat_mlp.py @@ -0,0 +1,11 @@ +from torch import cat + +from deepscreen.models.components.mlp import MLP + + +class ConcatMLP(MLP): + def forward(self, *inputs): + x = cat([*inputs], 1) + for module in self: + x = module(x) + return x diff --git a/deepscreen/models/components/gat.py b/deepscreen/models/components/gat.py new file mode 100644 index 0000000000000000000000000000000000000000..fa8b439601178f280e44f6914b881f56879bd128 --- /dev/null +++ b/deepscreen/models/components/gat.py @@ -0,0 +1,39 @@ +from torch import nn +import torch.nn.functional as F +from torch_geometric.nn import GATConv +from torch_geometric.nn import global_max_pool as gmp + + +class GAT(nn.Module): + r""" + From `GraphDTA `_ (Nguyen et al., 2020), + based on `Graph Attention Network `_ (Veličković et al., 2018). + """ + def __init__( + self, + num_features: int, + out_channels: int, + dropout: float + ): + super().__init__() + + self.dropout = dropout + self.gcn1 = GATConv(num_features, num_features, heads=10, dropout=dropout) + self.gcn2 = GATConv(num_features * 10, out_channels, dropout=dropout) + self.fc_g1 = nn.Linear(out_channels, out_channels) + self.relu = nn.ReLU() + + def forward(self, data): + # graph input feed-forward + x, edge_index, batch = data.x, data.edge_index, data.batch + + x = F.dropout(x, p=self.dropout, training=self.training) + x = F.elu(self.gcn1(x, edge_index)) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.gcn2(x, edge_index) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + x = self.fc_g1(x) + x = self.relu(x) + + return x diff --git a/deepscreen/models/components/gat_gcn.py b/deepscreen/models/components/gat_gcn.py new file mode 100644 index 0000000000000000000000000000000000000000..9c11bd1f0c0830a40f040eb71d0764557fdc13cc --- /dev/null +++ b/deepscreen/models/components/gat_gcn.py @@ -0,0 +1,40 @@ +from torch import cat, nn +from torch_geometric.nn import GCNConv, GATConv +from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp + + +class GATGCN(nn.Module): + r""" + From `GraphDTA `_ (Nguyen et al., 2020), + based on `Graph Attention Network `_ (Veličković et al., 2018) + and `Graph Convolutional Network `_ (Kipf and Welling, 2017). + """ + def __init__( + self, + num_features: int, + out_channels: int, + dropout: float + ): + super().__init__() + + self.conv1 = GATConv(num_features, num_features, heads=10) + self.conv2 = GCNConv(num_features*10, num_features*10) + self.fc_g1 = nn.Linear(num_features*10*2, 1500) + self.fc_g2 = nn.Linear(1500, out_channels) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(dropout) + + def forward(self, data): + x, edge_index, batch = data.x, data.edge_index, data.batch + # print('x shape = ', x.shape) + x = self.conv1(x, edge_index) + x = self.relu(x) + x = self.conv2(x, edge_index) + x = self.relu(x) + # apply global max pooling (gmp) and global mean pooling (gap) + x = cat([gmp(x, batch), gap(x, batch)], dim=1) + x = self.relu(self.fc_g1(x)) + x = self.dropout(x) + x = self.fc_g2(x) + + return x diff --git a/deepscreen/models/components/gcn.py b/deepscreen/models/components/gcn.py new file mode 100644 index 0000000000000000000000000000000000000000..ec7b497179edae3c838a96489cb5db8355c35df0 --- /dev/null +++ b/deepscreen/models/components/gcn.py @@ -0,0 +1,47 @@ +from torch import nn +from torch_geometric.nn import GCNConv, global_max_pool + + +class GCN(nn.Module): + """ + From `GraphDTA `_ (Nguyen et al., 2020), + based on `Graph Convolutional Network `_ (Kipf and Welling, 2017). + """ + def __init__( + self, + num_features: int, + out_channels: int, + dropout: float + ): + super().__init__() + + self.conv1 = GCNConv(num_features, num_features) + self.conv2 = GCNConv(num_features, num_features*2) + self.conv3 = GCNConv(num_features*2, num_features * 4) + self.fc_g1 = nn.Linear(num_features*4, 1024) + self.fc_g2 = nn.Linear(1024, out_channels) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(dropout) + + def forward(self, data): + # get graph input + x, edge_index, batch = data.x, data.edge_index, data.batch + + x = self.conv1(x, edge_index) + x = self.relu(x) + + x = self.conv2(x, edge_index) + x = self.relu(x) + + x = self.conv3(x, edge_index) + x = self.relu(x) + x = global_max_pool(x, batch) # global max pooling + + # flatten + x = self.relu(self.fc_g1(x)) + x = self.dropout(x) + x = self.fc_g2(x) + x = self.dropout(x) + + return x + \ No newline at end of file diff --git a/deepscreen/models/components/gcn_attn.py b/deepscreen/models/components/gcn_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e44b360b3a314a1e83412961b62d93f14ea1e7 --- /dev/null +++ b/deepscreen/models/components/gcn_attn.py @@ -0,0 +1,66 @@ +import torch +from torch import nn +import torch.nn.functional as F +from torch_geometric.nn import GCNConv, global_max_pool as gmp + + +class AttentionGCN(nn.Module): + """ + From `GraphDTA `_ (Nguyen et al., 2020), + based on `Graph Convolutional Network `_ (Kipf and Welling, 2017). + """ + def __init__( + self, + num_features: int, + out_channels: int, + dropout: float + ): + super().__init__() + + self.conv1 = GCNConv(num_features, num_features) + self.conv2 = GCNConv(num_features, num_features*2) + self.conv3 = GCNConv(num_features*2, num_features * 4) + self.fc_g1 = nn.Linear(num_features*4, 1024) + self.fc_g2 = nn.Linear(1024, out_channels) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(dropout) + + def forward(self, data): + # get graph input + x, edge_index, batch = data.x, data.edge_index, data.batch + + x = self.conv1(x, edge_index) + x = self.relu(x) + + x = self.conv2(x, edge_index) + x = self.relu(x) + + x = self.conv3(x, edge_index) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + + # flatten + x = self.relu(self.fc_g1(x)) + x = self.dropout(x) + x = self.fc_g2(x) + x = self.dropout(x) + + return x + + +class Pocket_BCELoss(nn.Module): + def __init__(self): + super().__init__() + self.criterion = nn.BCELoss(reduce=False) + + def forward(self, pred, label, seq_mask): + loss_all = self.criterion(pred, label) + loss = torch.sum(torch.masked_select(loss_all, seq_mask)) + return loss + + def protein_pred_module(self, prot_feature, seq_mask): + protein_emb = nn.Linear(self.hidden_size1, self.hidden_size1) + p_feature = F.leaky_relu(protein_emb(prot_feature), 0.1) + pocket_pred = torch.sigmoid(torch.masked_select(p_feature, seq_mask)) + + return pocket_pred diff --git a/deepscreen/models/components/gin.py b/deepscreen/models/components/gin.py new file mode 100644 index 0000000000000000000000000000000000000000..16754a60295e37e28ebadcdc7b7afa1f9884e930 --- /dev/null +++ b/deepscreen/models/components/gin.py @@ -0,0 +1,63 @@ +from torch import cat, nn +import torch.nn.functional as F +from torch.nn import Sequential, Linear, ReLU +from torch_geometric.nn import GINConv, global_add_pool + + +class GIN(nn.Module): + r""" + From `GraphDTA `_ (Nguyen et al., 2020), + based on `Graph Isomorphism Network `_ (Xu et al., 2019) + """ + def __init__( + self, + num_features: int, + out_channels: int, + dropout: float + ): + super().__init__() + + dim = 32 + self.dropout = dropout + self.relu = nn.ReLU() + + nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim)) + self.conv1 = GINConv(nn1) + self.bn1 = nn.BatchNorm1d(dim) + + nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv2 = GINConv(nn2) + self.bn2 = nn.BatchNorm1d(dim) + + nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv3 = GINConv(nn3) + self.bn3 = nn.BatchNorm1d(dim) + + nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv4 = GINConv(nn4) + self.bn4 = nn.BatchNorm1d(dim) + + nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv5 = GINConv(nn5) + self.bn5 = nn.BatchNorm1d(dim) + + self.fc1_xd = Linear(dim, out_channels) + + def forward(self, data): + x, edge_index, batch = data.x, data.edge_index, data.batch + + x = F.relu(self.conv1(x, edge_index)) + x = self.bn1(x) + x = F.relu(self.conv2(x, edge_index)) + x = self.bn2(x) + x = F.relu(self.conv3(x, edge_index)) + x = self.bn3(x) + x = F.relu(self.conv4(x, edge_index)) + x = self.bn4(x) + x = F.relu(self.conv5(x, edge_index)) + x = self.bn5(x) + x = global_add_pool(x, batch) + x = F.relu(self.fc1_xd(x)) + x = F.dropout(x, p=self.dropout, training=self.training) + + return x diff --git a/deepscreen/models/components/lstm.py b/deepscreen/models/components/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..149307438aab4780df718c2ee8b683415562c0a0 --- /dev/null +++ b/deepscreen/models/components/lstm.py @@ -0,0 +1,40 @@ +from torch import nn, zeros, cat + + +class LSTM(nn.Module): + def __init__( + self, + n_samples: int, + hidden_layers: int = 64): + super().__init__() + self.hidden_layers = hidden_layers + # lstm1, lstm2, linear are all layers in the network + self.lstm1 = nn.LSTMCell(1, self.hidden_layers) + self.lstm2 = nn.LSTMCell(self.hidden_layers, self.hidden_layers) + self.linear = nn.Linear(self.hidden_layers, 1) + self.n_samples = n_samples + + def forward(self, y, future_preds=0): + outputs, num_samples = [], y.size(0) + h_t = zeros(self.n_samples, self.hidden_layers) + c_t = zeros(self.n_samples, self.hidden_layers) + h_t2 = zeros(self.n_samples, self.hidden_layers) + c_t2 = zeros(self.n_samples, self.hidden_layers) + + for time_step in y.split(1, dim=1): + # N, 1 + h_t, c_t = self.lstm1(input_t, (h_t, c_t)) # initial hidden and cell states + h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) # new hidden and cell states + output = self.linear(h_t2) # output from the last FC layer + outputs.append(output) + + for i in range(future_preds): + # this only generates future predictions if we pass in future_preds>0 + # mirrors the code above, using last output/prediction as input + h_t, c_t = self.lstm1(output, (h_t, c_t)) + h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) + output = self.linear(h_t2) + outputs.append(output) + # transform list to tensor + outputs = cat(outputs, dim=1) + return outputs diff --git a/deepscreen/models/components/mlp.py b/deepscreen/models/components/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..ea3dcb2b502104a75525763e418475f627866f65 --- /dev/null +++ b/deepscreen/models/components/mlp.py @@ -0,0 +1,77 @@ +from torch import nn, cat + + +class MLP1(nn.Sequential): + def __init__(self, + input_channels, + hidden_channels: list[int], + out_channels: int, + activation: type[nn.Module] = nn.ReLU, + dropout: float = 0.0): + layers = [] + num_layers = len(hidden_channels) + 1 + dims = [input_channels] + hidden_channels + [out_channels] + for i in range(num_layers): + if i != (num_layers - 1): + layers.append(nn.Linear(dims[i], dims[i+1])) + layers.append(nn.Dropout(dropout)) + layers.append(activation()) + else: + layers.append(nn.Linear(dims[i], dims[i+1])) + + super().__init__(*layers) + + +class MLP2(nn.Sequential): + def __init__(self, + input_channels, + hidden_channels: list[int], + out_channels: int, + dropout: float = 0.0): + super().__init__() + self.dropout = nn.Dropout(dropout) + num_layers = len(hidden_channels) + 1 + dims = [input_channels] + hidden_channels + [out_channels] + self.layers = nn.ModuleList([nn.Linear(dims[i], dims[i+1]) for i in range(num_layers)]) + + def forward(self, x): + for i, layer in enumerate(self.layers): + if i == (len(self.layers) - 1): + x = layer(x) + else: + x = nn.functional.relu(self.dropout(layer(x))) + return x + + +class LazyMLP(nn.Sequential): + def __init__( + self, + out_channels: int, + hidden_channels: list[int], + activation: type[nn.Module] = nn.ReLU, + dropout: float = 0.0 + ): + layers = [] + for hidden_dim in hidden_channels: + layers.append(nn.LazyLinear(out_features=hidden_dim)) + layers.append(nn.Dropout(dropout)) + layers.append(activation()) + layers.append(nn.LazyLinear(out_features=out_channels)) + + super().__init__(*layers) + + +class ConcatMLP(LazyMLP): + def forward(self, *inputs): + x = cat([*inputs], 1) + x = super().forward(x) + return x + + +# class ConcatMLP(MLP1): +# def forward(self, *inputs): +# x = cat([*inputs], 1) +# for module in self: +# x = module(x) +# return x + diff --git a/deepscreen/models/components/transformer.py b/deepscreen/models/components/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e335794b9f093d612e3991660d093f6386ce2234 --- /dev/null +++ b/deepscreen/models/components/transformer.py @@ -0,0 +1,211 @@ +import copy +import math + +import torch +from torch import nn + + +class Transformer(nn.Module): + def __init__(self, + input_dim, + emb_size, + max_position_size, + dropout, + n_layer, + intermediate_size, + num_attention_heads, + attention_probs_dropout, + hidden_dropout, + ): + super().__init__() + self.emb = Embeddings(input_dim, + emb_size, + max_position_size, + dropout) + self.encoder = MultiLayeredEncoder(n_layer, + emb_size, + intermediate_size, + num_attention_heads, + attention_probs_dropout, + hidden_dropout) + + def forward(self, v): + e = v[0].long() + e_mask = v[1].long() + ex_e_mask = e_mask.unsqueeze(1).unsqueeze(2) + ex_e_mask = (1.0 - ex_e_mask) * -10000.0 + + emb = self.emb(e) + encoded_layers = self.encoder(emb.float(), ex_e_mask.float()) + return encoded_layers[:, 0] + + +class LayerNorm(nn.Module): + def __init__(self, hidden_size, variance_epsilon=1e-12): + super(LayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.ones(hidden_size)) + self.beta = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = variance_epsilon + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.gamma * x + self.beta + + +class Embeddings(nn.Module): + """Construct the embeddings from protein/target, position embeddings. + """ + + def __init__(self, vocab_size, hidden_size, max_position_size, dropout_rate): + super(Embeddings, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + self.position_embeddings = nn.Embedding(max_position_size, hidden_size) + + self.LayerNorm = LayerNorm(hidden_size) + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, input_ids): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + embeddings = words_embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class SelfAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): + super(SelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (hidden_size, num_attention_heads)) + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class SelfOutput(nn.Module): + def __init__(self, hidden_size, hidden_dropout_prob): + super(SelfOutput, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.LayerNorm = LayerNorm(hidden_size) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class Attention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): + super(Attention, self).__init__() + self.self = SelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) + self.output = SelfOutput(hidden_size, hidden_dropout_prob) + + def forward(self, input_tensor, attention_mask): + self_output = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class Intermediate(nn.Module): + def __init__(self, hidden_size, intermediate_size): + super(Intermediate, self).__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = nn.functional.relu(hidden_states) + return hidden_states + + +class Output(nn.Module): + def __init__(self, intermediate_size, hidden_size, hidden_dropout_prob): + super(Output, self).__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.LayerNorm = LayerNorm(hidden_size) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class Encoder(nn.Module): + def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, + hidden_dropout_prob): + super(Encoder, self).__init__() + self.attention = Attention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) + self.intermediate = Intermediate(hidden_size, intermediate_size) + self.output = Output(intermediate_size, hidden_size, hidden_dropout_prob) + + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class MultiLayeredEncoder(nn.Module): + def __init__(self, n_layer, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, + hidden_dropout_prob): + super(MultiLayeredEncoder, self).__init__() + layer = Encoder(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, + hidden_dropout_prob) + self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layer)]) + + def forward(self, hidden_states, attention_mask): + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, attention_mask) + return hidden_states diff --git a/deepscreen/models/dti.py b/deepscreen/models/dti.py new file mode 100644 index 0000000000000000000000000000000000000000..86496e51d8bd20c7b13413c223b7d98189ffcdd4 --- /dev/null +++ b/deepscreen/models/dti.py @@ -0,0 +1,167 @@ +from functools import partial +from typing import Optional, Sequence, Dict + +from torch import nn, optim, Tensor +from lightning import LightningModule +from torchmetrics import Metric, MetricCollection + + +class DTILightningModule(LightningModule): + """ + Drug Target Interaction Prediction + + optimizer: a partially or fully initialized instance of class torch.optim.Optimizer + drug_encoder: a fully initialized instance of class torch.nn.Module + protein_encoder: a fully initialized instance of class torch.nn.Module + classifier: a fully initialized instance of class torch.nn.Module + model: a fully initialized instance of class torch.nn.Module + metrics: a list of fully initialized instances of class torchmetrics.Metric + """ + def __init__( + self, + optimizer: optim.Optimizer, + scheduler: Optional[optim.lr_scheduler | Dict], + predictor: nn.Module, + metrics: Optional[Dict[str, Metric]] = (), + out: nn.Module = None, + loss: nn.Module = None, + activation: nn.Module = None, + ): + super().__init__() + + self.predictor = predictor + self.out = out + self.loss = loss + self.activation = activation + + # Automatically averaged over batches: + # Separate metric instances for train, val and test step to ensure a proper reduction over the epoch + metrics = MetricCollection(dict(metrics)) + self.train_metrics = metrics.clone(prefix="train/") + self.val_metrics = metrics.clone(prefix="val/") + self.test_metrics = metrics.clone(prefix="test/") + + # allows access to init params with 'self.hparams' attribute and ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False, + ignore=['predictor', 'out', 'loss', 'activation', 'metrics']) + + def setup(self, stage): + match stage: + case 'fit': + dataloader = self.trainer.datamodule.train_dataloader() + dummy_batch = next(iter(dataloader)) + self.forward(dummy_batch) + case 'validate': + dataloader = self.trainer.datamodule.val_dataloader() + case 'test': + dataloader = self.trainer.datamodule.test_dataloader() + case 'predict': + dataloader = self.trainer.datamodule.predict_dataloader() + + + # for key, value in dummy_batch.items(): + # if isinstance(value, Tensor): + # dummy_batch[key] = value.to(self.device) + + + + def forward(self, batch): + output = self.predictor(batch['X1^'], batch['X2^']) + target = batch.get('Y') + indexes = batch.get('ID^') + preds = None + loss = None + + if isinstance(output, Tensor): + output = self.out(output).squeeze(1) + preds = self.activation(output) + + elif isinstance(output, Sequence): + output = list(output) + # If multi-objective, assume the zeroth element in `output` is main while the rest are auxiliary + output[0] = self.out(output[0]).squeeze(1) + # Downstream metrics evaluation only needs main-objective preds + preds = self.activation(output[0]) + + if target is not None: + loss = self.loss(output, target.float()) + + return preds, target, indexes, loss + + def training_step(self, batch, batch_idx): + preds, target, indexes, loss = self.forward(batch) + self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + self.train_metrics(preds=preds, target=target, indexes=indexes.long()) + self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + + return { + 'N': batch['N'], + 'ID1': batch['ID1'], 'X1': batch['X1'], + 'ID2': batch['ID2'], 'X2': batch['X2'], + 'Y^': preds, 'Y': target, 'loss': loss + } + + def on_train_epoch_end(self): + pass + + def validation_step(self, batch, batch_idx): + preds, target, indexes, loss = self.forward(batch) + + self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + self.val_metrics(preds=preds, target=target, indexes=indexes.long()) + self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + + return { + 'N': batch['N'], + 'ID1': batch['ID1'], 'X1': batch['X1'], + 'ID2': batch['ID2'], 'X2': batch['X2'], + 'Y^': preds, 'Y': target, 'loss': loss + } + + def on_validation_epoch_end(self): + pass + + def test_step(self, batch, batch_idx): + preds, target, indexes, loss = self.forward(batch) + + self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + self.test_metrics(preds=preds, target=target, indexes=indexes.long()) + self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + + # return a dictionary for callbacks like BasePredictionWriter + return { + 'N': batch['N'], + 'ID1': batch['ID1'], 'X1': batch['X1'], + 'ID2': batch['ID2'], 'X2': batch['X2'], + 'Y^': preds, 'Y': target, 'loss': loss + } + + def on_test_epoch_end(self): + pass + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + preds, _, _, _ = self.forward(batch) + # return a dictionary for callbacks like BasePredictionWriter + return { + 'N': batch['N'], + 'ID1': batch['ID1'], 'X1': batch['X1'], + 'ID2': batch['ID2'], 'X2': batch['X2'], + 'Y^': preds + } + + def configure_optimizers(self): + optimizers_config = {'optimizer': self.hparams.optimizer(params=self.parameters())} + if self.hparams.get('scheduler'): + if isinstance(self.hparams.scheduler, partial): + optimizers_config['lr_scheduler'] = { + "scheduler": self.hparams.scheduler(optimizer=optimizers_config['optimizer']), + "monitor": "val/loss", + "interval": "epoch", + "frequency": 1, + } + else: + self.hparams.scheduler['scheduler'] = self.hparams.scheduler['scheduler']( + optimizer=optimizers_config['optimizer'] + ) + optimizers_config['lr_scheduler'] = dict(self.hparams.scheduler) + return optimizers_config diff --git a/deepscreen/models/loss/__init__.py b/deepscreen/models/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deepscreen/models/loss/__pycache__/__init__.cpython-311.pyc b/deepscreen/models/loss/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..714800206038b4a3e2694157e3758817ee237311 Binary files /dev/null and b/deepscreen/models/loss/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/models/loss/__pycache__/multitask_loss.cpython-311.pyc b/deepscreen/models/loss/__pycache__/multitask_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b4bac6d83f54ab4d78c5b1c33aaae6dda906cc7 Binary files /dev/null and b/deepscreen/models/loss/__pycache__/multitask_loss.cpython-311.pyc differ diff --git a/deepscreen/models/loss/multitask_loss.py b/deepscreen/models/loss/multitask_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef284e7c0758f71f88844838b47924050bb99f5 --- /dev/null +++ b/deepscreen/models/loss/multitask_loss.py @@ -0,0 +1,118 @@ +from itertools import zip_longest + +import torch + + +class MultitaskLoss(torch.nn.Module): + """A generic multitask loss class that takes a tuple of loss functions as input""" + def __init__(self, loss_fns, reduction='sum'): + super().__init__() + self.n_tasks = len(loss_fns) # assuming the number of tasks is equal to the number of loss functions + self.loss_fns = loss_fns # store the tuple of loss functions + self.reduction = reduction + + def forward(self, preds, target): + if isinstance(preds, torch.Tensor): + preds = (preds,) + if isinstance(target, torch.Tensor): + target = (target,) + # compute the weighted losses for each task by applying the corresponding loss function and weight + # losses = [weight * loss_fn(p, t) + # for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target)] + losses = [] + for loss_fn, p, t in zip_longest(self.loss_fns, preds, target): + if t is not None: + loss = loss_fn(p, t) + else: + loss = loss_fn(p) + losses.append(loss) + + reduced_loss = None + # apply reduction if specified + if self.reduction == 'sum': + reduced_loss = sum(losses) + elif self.reduction == 'mean': + reduced_loss = sum(losses) / self.n_tasks + # return the tuple of losses or the reduced value + return reduced_loss + + +class MultitaskWeightedLoss(MultitaskLoss): + """A multitask loss class that takes a tuple of loss functions and weights as input""" + + def __init__(self, loss_fns, weights, reduction='sum'): + super().__init__(loss_fns, reduction) + self.weights = weights # store the tuple of weights + + def forward(self, preds, target): + if isinstance(preds, torch.Tensor): + preds = (preds,) + if isinstance(target, torch.Tensor): + target = (target,) + # compute the weighted losses for each task by applying the corresponding loss function and weight + # losses = [weight * loss_fn(p, t) + # for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target)] + losses = [] + for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target): + if t is not None: + loss = weight * loss_fn(p, t) + else: + loss = weight * loss_fn(p) + losses.append(loss) + + reduced_loss = None + # apply reduction if specified + if self.reduction == 'sum': + reduced_loss = sum(losses) + elif self.reduction == 'mean': + reduced_loss = sum(losses) / self.n_tasks + # return the tuple of losses or the reduced value + return reduced_loss + + +class MultitaskUncertaintyLoss(MultitaskLoss): + """ + Modified from https://arxiv.org/abs/1705.07115. + Removed task-specific scale factor for flexibility. + """ + + def __init__(self, loss_fns): + # for loss_fn in loss_fns: + # loss_fn.reduction = 'none' + super().__init__(loss_fns, reduction='none') + self.log_vars = torch.nn.Parameter(torch.zeros(self.n_tasks, requires_grad=True)) + + def forward(self, preds, targets, rescale=True): + losses = super().forward(preds, targets) + stds = torch.exp(self.log_vars / 2) + coeffs = 1 / (stds ** 2) + loss = coeffs * losses + torch.log(stds) + + return loss + + +class MultitaskAutomaticWeightedLoss(MultitaskLoss): + """Automatically weighted multitask loss + + Params: + loss_fns: tuple of loss functions + num: int, the number of losses + x: multitask loss + Examples: + loss1 = 1 + loss2 = 2 + awl = AutomaticWeightedLoss(2) + loss_sum = awl(loss1, loss2) + """ + + def __init__(self, loss_fns): + super().__init__(loss_fns, reduction='none') + self.params = torch.nn.Parameter(torch.ones(self.n_tasks, requires_grad=True)) + + def forward(self, preds, target): + losses = super().forward(preds, target) + loss = sum( + 0.5 / (param ** 2) * loss + torch.log(1 + param ** 2) + for param, loss in zip(self.params, losses) + ) + return loss diff --git a/deepscreen/models/metrics/__init__.py b/deepscreen/models/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deepscreen/models/metrics/bedroc.py b/deepscreen/models/metrics/bedroc.py new file mode 100644 index 0000000000000000000000000000000000000000..07b4fbc61b778407834f1df7b923a815562867a2 --- /dev/null +++ b/deepscreen/models/metrics/bedroc.py @@ -0,0 +1,45 @@ +import torch +from torch import Tensor +from torchmetrics.retrieval.base import RetrievalMetric +from torchmetrics.utilities.checks import _check_retrieval_functional_inputs + +from deepscreen.models.metrics.rie import calc_rie + + +class BEDROC(RetrievalMetric): + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + alpha: float = 80.5, + ): + super().__init__() + self.alpha = alpha + + def _metric(self, preds: Tensor, target: Tensor) -> Tensor: + preds, target = _check_retrieval_functional_inputs(preds, target) + + n_total = target.size(0) + n_actives = target.sum() + + if n_actives == 0: + return torch.tensor(0.0, device=preds.device) + elif n_actives == n_total: + return torch.tensor(1.0, device=preds.device) + + r_a = n_actives / n_total + exp_a = torch.exp(torch.tensor(self.alpha)) + + idx = torch.argsort(preds, descending=True, stable=True) + active_ranks = torch.take(target, idx).nonzero() + 1 + + rie = calc_rie(n_total, active_ranks, r_a, exp_a) + rie_min = (1 - exp_a ** r_a) / (r_a * (1 - exp_a)) + rie_max = (1 - exp_a ** (-r_a)) / (r_a * (1 - exp_a ** (-1))) + + return (rie - rie_min) / (rie_max - rie_min) + + def plot(self, val=None, ax=None): + return self._plot(val, ax) diff --git a/deepscreen/models/metrics/ci.py b/deepscreen/models/metrics/ci.py new file mode 100644 index 0000000000000000000000000000000000000000..ce274d8380bd437fc9ea8397b64d8b93d4574650 --- /dev/null +++ b/deepscreen/models/metrics/ci.py @@ -0,0 +1,39 @@ +import torch +from torchmetrics import Metric +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["ConcordanceIndex.plot"] + + +class ConcordanceIndex(Metric): + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + plot_lower_bound: float = 0.5 + plot_upper_bound: float = 1.0 + + def __init__(self, dist_sync_on_step=False): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.add_state("num_concordant", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("num_valid", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + _check_same_shape(preds, target) + + g = preds.unsqueeze(-1) - preds + g = (g == 0) * 0.5 + (g > 0) + + f = (target.unsqueeze(-1) - target) > 0 + f = torch.tril(f, diagonal=0) + + self.num_concordant += torch.sum(torch.mul(g, f)).long() + self.num_valid += torch.sum(f).long() + + def compute(self): + return torch.where(self.num_valid == 0, 0.0, self.num_concordant / self.num_valid) + + def plot(self, val=None, ax=None): + return self._plot(val, ax) diff --git a/deepscreen/models/metrics/ef.py b/deepscreen/models/metrics/ef.py new file mode 100644 index 0000000000000000000000000000000000000000..d5e6dffd4e31bd1dfdd42b84bddd293d659fa360 --- /dev/null +++ b/deepscreen/models/metrics/ef.py @@ -0,0 +1,34 @@ +import math + +from torch import Tensor, topk +from torchmetrics.retrieval.base import RetrievalMetric +from torchmetrics.utilities.checks import _check_retrieval_functional_inputs + + +class EnrichmentFactor(RetrievalMetric): + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + alpha: float, + ): + super().__init__() + if alpha <= 0 or alpha > 1: + raise ValueError(f"Argument ``alpha`` has to be in interval (0, 1] but got {alpha}") + self.alpha = alpha + + def _metric(self, preds: Tensor, target: Tensor) -> Tensor: + preds, target = _check_retrieval_functional_inputs(preds, target) + + n_total = target.size(0) + n_sampled = math.ceil(n_total * self.alpha) + _, idx = topk(preds, n_sampled) + hits_sampled = target[idx].sum() + hits_total = target.sum() + + return hits_sampled / (hits_total * self.alpha) + + def plot(self, val=None, ax=None): + return self._plot(val, ax) diff --git a/deepscreen/models/metrics/hit_rate.py b/deepscreen/models/metrics/hit_rate.py new file mode 100644 index 0000000000000000000000000000000000000000..942e2ee67be73f969ca2ee15978162451f632be0 --- /dev/null +++ b/deepscreen/models/metrics/hit_rate.py @@ -0,0 +1,36 @@ +import math + +from torch import Tensor, topk +from torchmetrics.retrieval.base import RetrievalMetric +from torchmetrics.utilities.checks import _check_retrieval_functional_inputs + + +class HitRate(RetrievalMetric): + """ + Computes hit rate for virtual screening. + """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + alpha: float = 0.01, + ): + super().__init__() + if alpha <= 0 or alpha > 1: + raise ValueError(f"Argument ``alpha`` has to be in interval (0, 1] but got {alpha}") + self.alpha = alpha + + def _metric(self, preds: Tensor, target: Tensor) -> Tensor: + preds, target = _check_retrieval_functional_inputs(preds, target) + + n_total = target.size(0) + n_sampled = math.ceil(n_total * self.alpha) + _, idx = topk(preds, n_sampled) + hits_sampled = target[idx].sum() + + return hits_sampled / n_sampled + + def plot(self, val=None, ax=None): + return self._plot(val, ax) \ No newline at end of file diff --git a/deepscreen/models/metrics/rie.py b/deepscreen/models/metrics/rie.py new file mode 100644 index 0000000000000000000000000000000000000000..4eeb0c48d7c4a470a9b57a1e58017b0d11aaf35f --- /dev/null +++ b/deepscreen/models/metrics/rie.py @@ -0,0 +1,44 @@ +import torch +from torch import Tensor +from torchmetrics.retrieval.base import RetrievalMetric +from torchmetrics.utilities.checks import _check_retrieval_functional_inputs + + +def calc_rie(n_total, active_ranks, r_a, exp_a): + numerator = (exp_a ** (- active_ranks / n_total)).sum() + denominator = (1 - exp_a ** (-1)) / (exp_a ** (1 / n_total) - 1) + + return numerator / (r_a * denominator) + + +class RIE(RetrievalMetric): + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + alpha: float = 80.5, + ): + super().__init__() + self.alpha = alpha + + def _metric(self, preds: Tensor, target: Tensor) -> Tensor: + preds, target = _check_retrieval_functional_inputs(preds, target) + + n_total = target.size(0) + n_actives = target.sum() + + if n_actives == 0: + return torch.tensor(0.0, device=preds.device) + + r_a = n_actives / n_total + exp_a = torch.exp(torch.tensor(-self.alpha)) + + idx = torch.argsort(preds, descending=True, stable=True) + active_ranks = torch.take(target, idx).nonzero() + 1 + + return calc_rie(n_total, active_ranks, r_a, exp_a) + + def plot(self, val=None, ax=None): + return self._plot(val, ax) diff --git a/deepscreen/models/metrics/sensitivity.py b/deepscreen/models/metrics/sensitivity.py new file mode 100644 index 0000000000000000000000000000000000000000..50797e3a3eae55f7ad3e4833fef6e0f132799479 --- /dev/null +++ b/deepscreen/models/metrics/sensitivity.py @@ -0,0 +1,337 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Sequence, Union + +from torch import Tensor +from torchmetrics.utilities.compute import _safe_divide, _adjust_weights_safe_divide +from typing_extensions import Literal + +from torchmetrics.classification.base import _ClassificationTaskWrapper +from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores +from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["BinarySensitivity.plot", "MulticlassSensitivity.plot", "MultilabelSensitivity.plot"] + + +class BinarySensitivity(BinaryStatScores): + r"""Compute `Sensitivity`_ for binary tasks. + + .. math:: \text{Sensitivity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives + respectively. The metric is only proper defined when :math:`\text{TN} + \text{FP} \neq 0`. If this case is + encountered a score of 0 is returned. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point + tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per + element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``bs`` (:class:`~torch.Tensor`): If ``multidim_average`` is set to ``global``, the metric returns a scalar value. + If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value + per sample. + + Args: + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + """ + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + def compute(self) -> Tensor: + """Compute metric.""" + tp, fp, tn, fn = self._final_state() + return _sensitivity_reduce(tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + """ + return self._plot(val, ax) + + +class MulticlassSensitivity(MulticlassStatScores): + r"""Compute `Sensitivity`_ for multiclass tasks. + + .. math:: \text{Sensitivity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives + respectively. The metric is only proper defined when :math:`\text{TN} + \text{FP} \neq 0`. If this case is + encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be + affected in turn. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` or float tensor of shape ``(N, C, ..)``. + If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert + probabilities/logits into an int tensor. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``mcs`` (:class:`~torch.Tensor`): The returned shape depends on the ``average`` and ``multidim_average`` + arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction + + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + """ + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + plot_legend_name: str = "Class" + + def compute(self) -> Tensor: + """Compute metric.""" + tp, fp, tn, fn = self._final_state() + return _sensitivity_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + """ + return self._plot(val, ax) + + +class MultilabelSensitivity(MultilabelStatScores): + r"""Compute `Sensitivity`_ for multilabel tasks. + + .. math:: \text{Sensitivity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives + respectively. The metric is only proper defined when :math:`\text{TN} + \text{FP} \neq 0`. If this case is + encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be + affected in turn. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, C, ...)``. If preds is a floating + point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid + per element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)`` + + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``mls`` (:class:`~torch.Tensor`): The returned shape depends on the ``average`` and ``multidim_average`` + arguments: + + - If ``multidim_average`` is set to ``global`` + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise`` + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Args: + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction + + multidim_average: Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + """ + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + plot_legend_name: str = "Label" + + def compute(self) -> Tensor: + """Compute metric.""" + tp, fp, tn, fn = self._final_state() + return _sensitivity_reduce( + tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True + ) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + """ + return self._plot(val, ax) + + +class Sensitivity(_ClassificationTaskWrapper): + r"""Compute `Sensitivity`_. + + .. math:: \text{Sensitivity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives + respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is + encountered for any class/label, the metric for that class/label will be set to 0 and the overall metric may + therefore be affected in turn. + + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :class:`~torchmetrics.classification.BinarySensitivity`, :class:`~torchmetrics.classification.MulticlassSensitivity` + and :class:`~torchmetrics.classification.MultilabelSensitivity` for the specific details of each argument influence + and examples. + + Legacy Example: + """ + + def __new__( # type: ignore[misc] + cls, + task: Literal["binary", "multiclass", "multilabel"], + threshold: float = 0.5, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + """Initialize task metric.""" + task = ClassificationTask.from_str(task) + assert multidim_average is not None # noqa: S101 # needed for mypy + kwargs.update( + {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} + ) + if task == ClassificationTask.BINARY: + return BinarySensitivity(threshold, **kwargs) + if task == ClassificationTask.MULTICLASS: + if not isinstance(num_classes, int): + raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") + if not isinstance(top_k, int): + raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") + return MulticlassSensitivity(num_classes, top_k, average, **kwargs) + if task == ClassificationTask.MULTILABEL: + if not isinstance(num_labels, int): + raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") + return MultilabelSensitivity(num_labels, threshold, average, **kwargs) + raise ValueError(f"Task {task} not supported!") + + +def _sensitivity_reduce( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], + multidim_average: Literal["global", "samplewise"] = "global", + multilabel: bool = False, +) -> Tensor: + if average == "binary": + return _safe_divide(tp, tp + fn) + if average == "micro": + tp = tp.sum(dim=0 if multidim_average == "global" else 1) + fn = fn.sum(dim=0 if multidim_average == "global" else 1) + return _safe_divide(tp, tp + fn) + + sensitivity_score = _safe_divide(tp, tp + fn) + return _adjust_weights_safe_divide(sensitivity_score, average, multilabel, tp, fp, fn) diff --git a/deepscreen/models/predictors/__init__.py b/deepscreen/models/predictors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deepscreen/models/predictors/__pycache__/__init__.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dea185f266123f5e3a864b8a5afcd28cda1a5461 Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/deep_dta.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/deep_dta.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92668cd7a4ec3ef9e064cedef83991555658c034 Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/deep_dta.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/drug_vqa.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/drug_vqa.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bf397d6219cf9ae3a2c4397a539f58af649716b Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/drug_vqa.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/deep_conv_dti.py b/deepscreen/models/predictors/deep_conv_dti.py new file mode 100644 index 0000000000000000000000000000000000000000..cbc150e180a21ed0a515878cfddf27b6f2ddd106 --- /dev/null +++ b/deepscreen/models/predictors/deep_conv_dti.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn + + +class DeepConvDTI(nn.Module): + def __init__(self, dropout=0.2, drug_layers=(1024, 512), protein_windows=(10, 15, 20, 25), n_filters=64, + decay=0.0, fc_layers=None, convolution=True, activation=nn.ReLU(), protein_layers=None): + super().__init__() + self.dropout = dropout + self.drug_layers = drug_layers + self.protein_windows = protein_windows + self.filters = n_filters + self.decay = decay + self.fc_layers = fc_layers + self.convolution = convolution + self.activation = activation # Use any nn.Module as the activation function + self.protein_layers = protein_layers + + # Define the drug branch of the model + self.drug_branch = [] + for layer_size in drug_layers: + self.drug_branch += [ + nn.LazyLinear(layer_size), + nn.BatchNorm1d(layer_size), + activation, + nn.Dropout(dropout) + ] + self.drug_branch = nn.Sequential(*self.drug_branch) + + # Define the protein branch of the model + if convolution: + # Use embedding and convolution layers for protein sequences + self.protein_embedding = nn.Embedding(26, 20) + # Use a list of parallel convolution and pooling layers with different window sizes + self.protein_convs = nn.ModuleList() + for window_size in protein_windows: + conv = nn.Sequential( + nn.Conv1d(20, n_filters, window_size, padding="same"), + nn.BatchNorm1d(n_filters), + activation, + nn.AdaptiveMaxPool1d(1) + ) + self.protein_convs.append(conv) + + if protein_layers: + self.protein_branch = [] + for layer_size in protein_layers: + self.protein_branch += [ + nn.LazyLinear(layer_size), + nn.BatchNorm1d(layer_size), + activation, + nn.Dropout(dropout) + ] + self.protein_branch = nn.Sequential(*self.protein_branch) + + # Define the final branch of the model that combines the drug and protein branches + self.final_branch = [] + if fc_layers: + # Add additional dense layers for the final branch + for layer_size in fc_layers: + self.final_branch += [ + nn.LazyLinear(layer_size), + nn.BatchNorm1d(layer_size), + activation + ] + self.final_branch = nn.Sequential(*self.final_branch) + + def forward(self, input_d, input_p): + # Forward pass of the drug branch + output_d = self.drug_branch(input_d.float()) + + # Forward pass of the protein branch + if self.convolution: + # Embed the protein sequence and transpose the dimensions + output_p = self.protein_embedding(input_p) + output_p = output_p.transpose(1, 2) + # Apply the parallel convolution and pooling layers + conv_outputs = [] + for conv in self.protein_convs: + conv_output = conv(output_p).squeeze(-1) + conv_outputs.append(conv_output) + # Concatenate the convolution outputs + output_p = torch.cat(conv_outputs, dim=1) + else: + output_p = input_p + + if self.protein_layers: + # Apply the additional dense layers to the protein branch + output_p = self.protein_branch(output_p) + + # Concatenate the drug and protein outputs + output_t = torch.cat([output_d, output_p], dim=1) + # Apply the final dense layers + output_t = self.final_branch(output_t) + return output_t diff --git a/deepscreen/models/predictors/deep_dta.py b/deepscreen/models/predictors/deep_dta.py new file mode 100644 index 0000000000000000000000000000000000000000..ecc663ec3cdf3d18b01269ee161fbe32e6374391 --- /dev/null +++ b/deepscreen/models/predictors/deep_dta.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn + +# TODO this is an easy model; refactor it to be customized by config file only + + +class DeepDTA(nn.Module): + """ + From DeepDTA + """ + def __init__( + self, + drug_cnn: nn.Module, + protein_cnn: nn.Module, + num_features_drug: int, + num_features_protein: int, + embed_dim: int, + ): + super().__init__() + self.drug_cnn = drug_cnn + self.protein_cnn = protein_cnn + self.fc = nn.Sequential(nn.LazyLinear(1024), nn.ReLU(), nn.Dropout(0.1), + nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.1)) + + # protein sequence encoder (1d conv) + self.drug_embedding = nn.Embedding(num_features_drug, embed_dim) + self.protein_embedding = nn.Embedding(num_features_protein, embed_dim) + + def forward(self, v_d, v_p): + v_d = self.drug_embedding(v_d.long()) + v_d = self.drug_cnn(v_d) + + v_p = self.protein_embedding(v_p.long()) + v_p = self.protein_cnn(v_p) + + v_f = torch.cat([v_d, v_p], 1) + v_f = self.fc(v_f) + + return v_f diff --git a/deepscreen/models/predictors/drug_ban.py b/deepscreen/models/predictors/drug_ban.py new file mode 100644 index 0000000000000000000000000000000000000000..3d639940752f06b1922b3004fd5c4e0732dfcc44 --- /dev/null +++ b/deepscreen/models/predictors/drug_ban.py @@ -0,0 +1,328 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch +import math +from dgllife.model.gnn import GCN +from torch.nn.utils.weight_norm import weight_norm + + +class DrugBAN(nn.Module): + def __init__( + self, + drug_in_feats, + drug_embedding, + drug_hidden_feats, + protein_emb_dim, + num_filters, + kernel_size, + mlp_in_dim, + mlp_hidden_dim, + mlp_out_dim, + drug_padding, + protein_padding, + ban_heads, + ): + super().__init__() + self.drug_extractor = MolecularGCN(in_feats=drug_in_feats, dim_embedding=drug_embedding, + padding=drug_padding, + hidden_feats=drug_hidden_feats) + self.protein_extractor = ProteinCNN(protein_emb_dim, num_filters, kernel_size, protein_padding) + + self.bcn = weight_norm( + BANLayer(v_dim=drug_hidden_feats[-1], q_dim=num_filters[-1], h_dim=mlp_in_dim, h_out=ban_heads), + name='h_mat', dim=None) + self.mlp_classifier = MLPDecoder(mlp_in_dim, mlp_hidden_dim, mlp_out_dim) + + def forward(self, bg_d, v_p): + v_d = self.drug_extractor(bg_d) + v_p = self.protein_extractor(v_p) + f, att = self.bcn(v_d, v_p) + score = self.mlp_classifier(f) + # if mode == "train": + # return v_d, v_p, f, score + # elif mode == "eval": + # return v_d, v_p, score, att + return score + + +class MolecularGCN(nn.Module): + def __init__(self, in_feats, dim_embedding=128, padding=True, hidden_feats=None, activation=None): + super().__init__() + self.init_transform = nn.Linear(in_feats, dim_embedding, bias=False) + if padding: + with torch.no_grad(): + self.init_transform.weight[-1].fill_(0) + self.gnn = GCN(in_feats=dim_embedding, hidden_feats=hidden_feats, activation=activation) + self.output_feats = hidden_feats[-1] + + def forward(self, batch_graph): + node_feats = batch_graph.ndata.pop('h') + node_feats = self.init_transform(node_feats) + node_feats = self.gnn(batch_graph, node_feats) + batch_size = batch_graph.batch_size + node_feats = node_feats.view(batch_size, -1, self.output_feats) + return node_feats + + +class ProteinCNN(nn.Module): + def __init__(self, embedding_dim, num_filters, kernel_size, padding=True): + super().__init__() + if padding: + self.embedding = nn.Embedding(26, embedding_dim, padding_idx=0) + else: + self.embedding = nn.Embedding(26, embedding_dim) + in_ch = [embedding_dim] + num_filters + self.in_ch = in_ch[-1] + kernels = kernel_size + self.conv1 = nn.Conv1d(in_channels=in_ch[0], out_channels=in_ch[1], kernel_size=kernels[0]) + self.bn1 = nn.BatchNorm1d(in_ch[1]) + self.conv2 = nn.Conv1d(in_channels=in_ch[1], out_channels=in_ch[2], kernel_size=kernels[1]) + self.bn2 = nn.BatchNorm1d(in_ch[2]) + self.conv3 = nn.Conv1d(in_channels=in_ch[2], out_channels=in_ch[3], kernel_size=kernels[2]) + self.bn3 = nn.BatchNorm1d(in_ch[3]) + + def forward(self, v): + v = self.embedding(v.long()) + v = v.transpose(2, 1) + v = self.bn1(F.relu(self.conv1(v))) + v = self.bn2(F.relu(self.conv2(v))) + v = self.bn3(F.relu(self.conv3(v))) + v = v.view(v.size(0), v.size(2), -1) + return v + + +class MLPDecoder(nn.Module): + def __init__(self, in_dim, hidden_dim, out_dim): + super().__init__() + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.bn1 = nn.BatchNorm1d(hidden_dim) + self.fc2 = nn.Linear(hidden_dim, hidden_dim) + self.bn2 = nn.BatchNorm1d(hidden_dim) + self.fc3 = nn.Linear(hidden_dim, out_dim) + self.bn3 = nn.BatchNorm1d(out_dim) + # self.fc4 = nn.Linear(out_dim, binary) + + def forward(self, x): + x = self.bn1(F.relu(self.fc1(x))) + x = self.bn2(F.relu(self.fc2(x))) + x = self.bn3(F.relu(self.fc3(x))) + # x = self.fc4(x) + return x + + +# noinspection PyTypeChecker +class SimpleClassifier(nn.Module): + def __init__(self, in_dim, hid_dim, out_dim, dropout): + super().__init__() + layers = [ + weight_norm(nn.Linear(in_dim, hid_dim), dim=None), + nn.ReLU(), + nn.Dropout(dropout, inplace=True), + weight_norm(nn.Linear(hid_dim, out_dim), dim=None) + ] + self.main = nn.Sequential(*layers) + + def forward(self, x): + logits = self.main(x) + return logits + + +class RandomLayer(nn.Module): + def __init__(self, input_dim_list, output_dim=256): + super().__init__() + self.input_num = len(input_dim_list) + self.output_dim = output_dim + self.random_matrix = [torch.randn(input_dim_list[i], output_dim) for i in range(self.input_num)] + + def forward(self, input_list): + return_list = [torch.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)] + return_tensor = return_list[0] / math.pow(float(self.output_dim), 1.0 / len(return_list)) + for single in return_list[1:]: + return_tensor = torch.mul(return_tensor, single) + return return_tensor + + def cuda(self, *args): + super(RandomLayer, self).cuda(*args) + self.random_matrix = [val.cuda(*args) for val in self.random_matrix] + + +# noinspection PyTypeChecker +class BANLayer(nn.Module): + def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=0.2, k=3): + super().__init__() + + self.c = 32 + self.k = k + self.v_dim = v_dim + self.q_dim = q_dim + self.h_dim = h_dim + self.h_out = h_out + + self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout) + self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout) + # self.dropout = nn.Dropout(dropout[1]) + if 1 < k: + self.p_net = nn.AvgPool1d(self.k, stride=self.k) + + if h_out <= self.c: + self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_()) + self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_()) + else: + self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None) + + self.bn = nn.BatchNorm1d(h_dim) + + def attention_pooling(self, v, q, att_map): + fusion_logits = torch.einsum('bvk,bvq,bqk->bk', (v, att_map, q)) + if 1 < self.k: + fusion_logits = fusion_logits.unsqueeze(1) # b x 1 x d + fusion_logits = self.p_net(fusion_logits).squeeze(1) * self.k # sum-pooling + return fusion_logits + + def forward(self, v, q, softmax=False): + v_num = v.size(1) + q_num = q.size(1) + if self.h_out <= self.c: + v_ = self.v_net(v) + q_ = self.q_net(q) + att_maps = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias + else: + v_ = self.v_net(v).transpose(1, 2).unsqueeze(3) + q_ = self.q_net(q).transpose(1, 2).unsqueeze(2) + d_ = torch.matmul(v_, q_) # b x h_dim x v x q + att_maps = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out + att_maps = att_maps.transpose(2, 3).transpose(1, 2) # b x h_out x v x q + if softmax: + p = nn.functional.softmax(att_maps.view(-1, self.h_out, v_num * q_num), 2) + att_maps = p.view(-1, self.h_out, v_num, q_num) + logits = self.attention_pooling(v_, q_, att_maps[:, 0, :, :]) + for i in range(1, self.h_out): + logits_i = self.attention_pooling(v_, q_, att_maps[:, i, :, :]) + logits += logits_i + logits = self.bn(logits) + return logits, att_maps + + +# noinspection PyTypeChecker +class FCNet(nn.Module): + """Simple class for non-linear fully connect network + Modified from https://github.com/jnhwkim/ban-vqa/blob/master/fc.py + """ + + def __init__(self, dims, act='ReLU', dropout=0.0): + super().__init__() + + layers = [] + for i in range(len(dims) - 2): + in_dim = dims[i] + out_dim = dims[i + 1] + if 0 < dropout: + layers.append(nn.Dropout(dropout)) + layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) + if '' != act: + layers.append(getattr(nn, act)()) + if 0 < dropout: + layers.append(nn.Dropout(dropout)) + layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None)) + if '' != act: + layers.append(getattr(nn, act)()) + + self.main = nn.Sequential(*layers) + + def forward(self, x): + return self.main(x) + + +class BCNet(nn.Module): + """Simple class for non-linear bilinear connect network + Modified from https://github.com/jnhwkim/ban-vqa/blob/master/bc.py + """ + + # noinspection PyTypeChecker + def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=(0.2, 0.5), k=3): + super().__init__() + + self.c = 32 + self.k = k + self.v_dim = v_dim + self.q_dim = q_dim + self.h_dim = h_dim + self.h_out = h_out + + self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout[0]) + self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout[0]) + self.dropout = nn.Dropout(dropout[1]) # attention + if 1 < k: + self.p_net = nn.AvgPool1d(self.k, stride=self.k) + + if h_out is None: + pass + elif h_out <= self.c: + self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_()) + self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_()) + else: + self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None) + + def forward(self, v, q): + if self.h_out is None: + v_ = self.v_net(v) + q_ = self.q_net(q) + logits = torch.einsum('bvk,bqk->bvqk', (v_, q_)) + return logits + + # low-rank bilinear pooling using einsum + elif self.h_out <= self.c: + v_ = self.dropout(self.v_net(v)) + q_ = self.q_net(q) + logits = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias + return logits # b x h_out x v x q + + # batch outer product, linear projection + # memory efficient but slow computation + else: + v_ = self.dropout(self.v_net(v)).transpose(1, 2).unsqueeze(3) + q_ = self.q_net(q).transpose(1, 2).unsqueeze(2) + d_ = torch.matmul(v_, q_) # b x h_dim x v x q + logits = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out + return logits.transpose(2, 3).transpose(1, 2) # b x h_out x v x q + + def forward_with_weights(self, v, q, w): + v_ = self.v_net(v) # b x v x d + q_ = self.q_net(q) # b x q x d + logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_)) + if 1 < self.k: + logits = logits.unsqueeze(1) # b x 1 x d + logits = self.p_net(logits).squeeze(1) * self.k # sum-pooling + return logits + + +def drug_featurizer(smiles, max_drug_nodes=290): + from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer + + from deepscreen.utils import get_logger + log = get_logger(__name__) + + try: + v_d = smiles_to_bigraph(smiles=smiles, + node_featurizer=CanonicalAtomFeaturizer(), + edge_featurizer=CanonicalBondFeaturizer(self_loop=True), + add_self_loop=True) + if v_d is None: + return None + actual_node_feats = v_d.ndata.pop('h') + num_actual_nodes = actual_node_feats.shape[0] + num_virtual_nodes = max_drug_nodes - num_actual_nodes + virtual_node_bit = torch.zeros([num_actual_nodes, 1]) + actual_node_feats = torch.cat((actual_node_feats, virtual_node_bit), 1) + v_d.ndata['h'] = actual_node_feats + virtual_node_feat = torch.cat( + (torch.zeros(num_virtual_nodes, 74), torch.ones(num_virtual_nodes, 1)), 1 + ) + v_d.add_nodes(num_virtual_nodes, {"h": virtual_node_feat}) + v_d = v_d.add_self_loop() + return v_d + + except Exception as e: + log.warning(f"Failed to featurize SMILES ({smiles}) to graph due to {str(e)}") + return None + diff --git a/deepscreen/models/predictors/drug_vqa.py b/deepscreen/models/predictors/drug_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..3d4781ff28038000cbe376eb43a40cdee3808ede --- /dev/null +++ b/deepscreen/models/predictors/drug_vqa.py @@ -0,0 +1,232 @@ +from math import floor +import re +from typing import Literal + +import numpy as np +import torch.nn as nn +import torch +import torch.nn.functional as F + + +def conv(in_channels, out_channels, kernel_size, conv_dim, stride=1): + conv_layer = None + match conv_dim: + case 1: + conv_layer = nn.Conv1d + case 2: + conv_layer = nn.Conv2d + case 3: + conv_layer = nn.Conv3d + return conv_layer(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=floor(kernel_size / 2), bias=False) + + +def batch_norm(out_channels, conv_dim): + bn_layer = None + match conv_dim: + case 1: + bn_layer = nn.BatchNorm1d + case 2: + bn_layer = nn.BatchNorm2d + case 3: + bn_layer = nn.BatchNorm3d + return bn_layer(out_channels) + + +def conv3x3(in_channels, out_channels, stride=1): + return nn.Conv2d(in_channels, out_channels, kernel_size=3, + stride=stride, padding=1, bias=False) + + +def conv5x5(in_channels, out_channels, stride=1): + return nn.Conv2d(in_channels, out_channels, kernel_size=5, + stride=stride, padding=2, bias=False) + + +def conv1x1(in_channels, out_channels, stride=1): + return nn.Conv2d(in_channels, out_channels, kernel_size=1, + stride=stride, padding=0, bias=False) + + +# Residual block +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels, conv_dim, stride=1, downsample=None): + super().__init__() + # self.conv1 = conv5x5(in_channels, out_channels, stride) + self.conv1 = conv(in_channels, out_channels, kernel_size=5, conv_dim=conv_dim, stride=stride) + self.bn1 = batch_norm(out_channels, conv_dim=conv_dim) + self.elu = nn.ELU(inplace=True) + # self.conv2 = conv3x3(out_channels, out_channels) + self.conv2 = conv(out_channels, out_channels, kernel_size=3, conv_dim=conv_dim, stride=stride) + self.bn2 = batch_norm(out_channels, conv_dim=conv_dim) + self.downsample = downsample + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.elu(out) + out = self.conv2(out) + out = self.bn2(out) + if self.downsample: + residual = self.downsample(x) + out += residual + out = self.elu(out) + return out + + +class DrugVQA(nn.Module): + """ + The class is an implementation of the DrugVQA model including regularization and without pruning. + Slight modifications have been done for speedup + """ + + def __init__( + self, + conv_dim: Literal[1, 2, 3], + lstm_hid_dim: int, + d_a: int, + r: int, + n_chars_smi: int, + n_chars_seq: int, + dropout: float, + in_channels: int, + cnn_channels: int, + cnn_layers: int, + emb_dim: int, + dense_hid: int, + ): + """ + lstm_hid_dim: {int} hidden dimension for lstm + d_a : {int} hidden dimension for the dense layer + r : {int} attention-hops or attention heads + n_chars_smi : {int} voc size of smiles + n_chars_seq : {int} voc size of protein sequence + dropout : {float} + in_channels : {int} channels of CNN block input + cnn_channels: {int} channels of CNN block + cnn_layers : {int} num of layers of each CNN block + emb_dim : {int} embeddings dimension + dense_hid : {int} hidden dim for the output dense + """ + super().__init__() + self.conv_dim = conv_dim + self.lstm_hid_dim = lstm_hid_dim + self.r = r + self.in_channels = in_channels + # rnn + self.embeddings = nn.Embedding(n_chars_smi, emb_dim) + # self.seq_embed = nn.Embedding(n_chars_seq, emb_dim) + self.lstm = nn.LSTM(emb_dim, self.lstm_hid_dim, 2, batch_first=True, bidirectional=True, + dropout=dropout) + self.linear_first = nn.Linear(2 * self.lstm_hid_dim, d_a) + self.linear_second = nn.Linear(d_a, r) + self.linear_first_seq = nn.Linear(cnn_channels, d_a) + self.linear_second_seq = nn.Linear(d_a, self.r) + + # cnn + # self.conv = conv3x3(1, self.in_channels) + self.conv = conv(1, self.in_channels, kernel_size=3, conv_dim=conv_dim) + self.bn = batch_norm(in_channels, conv_dim=conv_dim) + self.elu = nn.ELU(inplace=False) + self.layer1 = self.make_layer(cnn_channels, cnn_layers) + self.layer2 = self.make_layer(cnn_channels, cnn_layers) + + self.linear_final_step = nn.Linear(self.lstm_hid_dim * 2 + d_a, dense_hid) + # self.linear_final = nn.Linear(dense_hid, n_classes) + self.softmax = nn.Softmax(dim=1) + + # @staticmethod + # def softmax(input, axis=1): + # """ + # Softmax applied to axis=n + # Args: + # input: {Tensor,Variable} input on which softmax is to be applied + # axis : {int} axis on which softmax is to be applied + # + # Returns: + # softmaxed tensors + # """ + # input_size = input.size() + # trans_input = input.transpose(axis, len(input_size) - 1) + # trans_size = trans_input.size() + # input_2d = trans_input.contiguous().view(-1, trans_size[-1]) + # soft_max_2d = F.softmax(input_2d) + # soft_max_nd = soft_max_2d.view(*trans_size) + # return soft_max_nd.transpose(axis, len(input_size) - 1) + + def make_layer(self, out_channels, blocks, stride=1): + downsample = None + if (stride != 1) or (self.in_channels != out_channels): + downsample = nn.Sequential( + # conv3x3(self.in_channels, out_channels, stride=stride), + conv(self.in_channels, out_channels, kernel_size=3, conv_dim=self.conv_dim, stride=stride), + batch_norm(out_channels, conv_dim=self.conv_dim) + ) + layers = [ResidualBlock(self.in_channels, out_channels, + conv_dim=self.conv_dim, stride=stride, downsample=downsample)] + self.in_channels = out_channels + for i in range(1, blocks): + layers.append(ResidualBlock(out_channels, out_channels, conv_dim=self.conv_dim)) + return nn.Sequential(*layers) + + def forward(self, enc_drug, enc_protein): + enc_drug, _ = enc_drug + enc_protein, _ = enc_protein + smile_embed = self.embeddings(enc_drug.long()) + # self.hidden_state = tuple(hidden_state.to(smile_embed).detach() for hidden_state in self.hidden_state) + outputs, hidden_state = self.lstm(smile_embed) + sentence_att = F.tanh(self.linear_first(outputs)) + sentence_att = self.linear_second(sentence_att) + sentence_att = self.softmax(sentence_att) + sentence_att = sentence_att.transpose(1, 2) + sentence_embed = sentence_att @ outputs + avg_sentence_embed = torch.sum(sentence_embed, 1) / self.r # multi head + + pic = self.conv(enc_protein.float().unsqueeze(1)) + pic = self.bn(pic) + pic = self.elu(pic) + pic = self.layer1(pic) + pic = self.layer2(pic) + pic_emb = torch.mean(pic, 2).unsqueeze(2) + pic_emb = pic_emb.permute(0, 2, 1) + seq_att = F.tanh(self.linear_first_seq(pic_emb)) + seq_att = self.linear_second_seq(seq_att) + seq_att = self.softmax(seq_att) + seq_att = seq_att.transpose(1, 2) + seq_embed = seq_att @ pic_emb + avg_seq_embed = torch.sum(seq_embed, 1) / self.r + + sscomplex = torch.cat([avg_sentence_embed, avg_seq_embed], dim=1) + sscomplex = F.relu(self.linear_final_step(sscomplex)) + + # if not bool(self.type): + # output = F.sigmoid(self.linear_final(sscomplex)) + # return output, seq_att + # else: + # return F.log_softmax(self.linear_final(sscomplex)), seq_att + + return sscomplex, seq_att + + +class AttentionL2Regularization(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, seq_att): + batch_size = seq_att.size(0) + identity = torch.eye(seq_att.size(1), device=seq_att.device) + identity = identity.unsqueeze(0).expand(batch_size, seq_att.size(1), seq_att.size(1)) + loss = torch.mean(self.l2_matrix_norm(seq_att @ seq_att.transpose(1, 2) - identity)) + return loss + + @staticmethod + def l2_matrix_norm(m): + """ + m = ||A * A_T - I|| + Missing from the original DrugVQA GitHub source code. + Opting to use the faster Frobenius norm rather than the induced L2 matrix norm (spectral norm) + proposed in the original research, because the goal is to minimize the difference between + the attention matrix and the identity matrix. + """ + return torch.linalg.norm(m, ord='fro', dim=(1, 2)) diff --git a/deepscreen/models/predictors/graph_dta.py b/deepscreen/models/predictors/graph_dta.py new file mode 100644 index 0000000000000000000000000000000000000000..ebfd3887456167e8789c62310867eb32fb2083a3 --- /dev/null +++ b/deepscreen/models/predictors/graph_dta.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn + +from lightning import LightningModule + + +class GraphDTA(LightningModule): + """ + From GraphDTA (Nguyen et al., 2020; https://doi.org/10.1093/bioinformatics/btaa921). + """ + def __init__( + self, + gnn: nn.Module, + num_features_protein: int, + n_filters: int, + embed_dim: int, + output_dim: int, + dropout: float + ): + super().__init__() + self.gnn = gnn + + # protein sequence encoder (1d conv) + self.embedding_xt = nn.Embedding(num_features_protein, embed_dim) + self.conv_xt = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) + self.fc1_xt = nn.Linear(32 * 121, output_dim) + + # combined layers + self.fc1 = nn.Linear(256, 1024) + self.fc2 = nn.Linear(1024, 512) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(dropout) + + # protein input feedforward + def conv_forward_xt(self, v_p): + v_p = self.embedding_xt(v_p.long()) + v_p = self.conv_xt(v_p) + # flatten + v_p = v_p.view(-1, 32 * 121) + v_p = self.fc1_xt(v_p) + return v_p + + def forward(self, v_d, v_p): + v_d = self.gnn(v_d) + v_p = self.conv_forward_xt(v_p) + + # concat + v_f = torch.cat((v_d, v_p), 1) + # dense layers + v_f = self.fc1(v_f) + v_f = self.relu(v_f) + v_f = self.dropout(v_f) + v_f = self.fc2(v_f) + v_f = self.relu(v_f) + v_f = self.dropout(v_f) + # v_f = self.out(v_f) + return v_f diff --git a/deepscreen/models/predictors/hyper_attention_dti.py b/deepscreen/models/predictors/hyper_attention_dti.py new file mode 100644 index 0000000000000000000000000000000000000000..6b16665cfab263a8273c303166e020b92460bc1c --- /dev/null +++ b/deepscreen/models/predictors/hyper_attention_dti.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn + +# TODO this is an easy model; refactor it to be customized by config file only + + +class HyperAttentionDTI(nn.Module): + def __init__( + self, + protein_kernel=(4, 8, 12), + drug_kernel=(4, 6, 8), + conv=40, + char_dim=64, + protein_max_len=1000, + drug_max_len=100 + ): + super().__init__() + + self.drug_embed = nn.Embedding(63, char_dim, padding_idx=0) + self.drug_cnn = nn.Sequential( + nn.Conv1d(in_channels=char_dim, out_channels=conv, kernel_size=drug_kernel[0]), + nn.ReLU(), + nn.Conv1d(in_channels=conv, out_channels=conv * 2, kernel_size=drug_kernel[1]), + nn.ReLU(), + nn.Conv1d(in_channels=conv * 2, out_channels=conv * 4, kernel_size=drug_kernel[2]), + nn.ReLU(), + ) + self.drug_max_pool = nn.MaxPool1d( + drug_max_len - drug_kernel[0] - drug_kernel[1] - drug_kernel[2] + 3) + + self.protein_embed = nn.Embedding(26, char_dim, padding_idx=0) + self.protein_cnn = nn.Sequential( + nn.Conv1d(in_channels=char_dim, out_channels=conv, kernel_size=protein_kernel[0]), + nn.ReLU(), + nn.Conv1d(in_channels=conv, out_channels=conv * 2, kernel_size=protein_kernel[1]), + nn.ReLU(), + nn.Conv1d(in_channels=conv * 2, out_channels=conv * 4, kernel_size=protein_kernel[2]), + nn.ReLU(), + ) + self.protein_max_pool = nn.MaxPool1d( + protein_max_len - protein_kernel[0] - protein_kernel[1] - protein_kernel[2] + 3) + + self.attention_layer = nn.Linear(conv * 4, conv * 4) + self.protein_attention_layer = nn.Linear(conv * 4, conv * 4) + self.drug_attention_layer = nn.Linear(conv * 4, conv * 4) + + self.dropout1 = nn.Dropout(0.1) + self.dropout2 = nn.Dropout(0.1) + self.dropout3 = nn.Dropout(0.1) + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + self.leaky_relu = nn.LeakyReLU() + self.fc1 = nn.Linear(conv * 8, 1024) + self.fc2 = nn.Linear(1024, 1024) + self.fc3 = nn.Linear(1024, 512) + # self.out = nn.Linear(512, 1) + + def forward(self, drug, protein): + drugembed = self.drug_embed(drug.long()) + proteinembed = self.protein_embed(protein.long()) + drugembed = drugembed.permute(0, 2, 1) + proteinembed = proteinembed.permute(0, 2, 1) + + drug_conv = self.drug_cnn(drugembed) + protein_conv = self.protein_cnn(proteinembed) + + drug_att = self.drug_attention_layer(drug_conv.permute(0, 2, 1)) + protein_att = self.protein_attention_layer(protein_conv.permute(0, 2, 1)) + + d_att_layers = torch.unsqueeze(drug_att, 2).repeat(1, 1, protein_conv.shape[-1], 1) # repeat along protein size + p_att_layers = torch.unsqueeze(protein_att, 1).repeat(1, drug_conv.shape[-1], 1, 1) # repeat along drug size + atten_matrix = self.attention_layer(self.relu(d_att_layers + p_att_layers)) + compound_atte = torch.mean(atten_matrix, 2) + protein_atte = torch.mean(atten_matrix, 1) + compound_atte = self.sigmoid(compound_atte.permute(0, 2, 1)) + protein_atte = self.sigmoid(protein_atte.permute(0, 2, 1)) + + drug_conv = drug_conv * 0.5 + drug_conv * compound_atte + protein_conv = protein_conv * 0.5 + protein_conv * protein_atte + + drug_conv = self.drug_max_pool(drug_conv).squeeze(2) + protein_conv = self.protein_max_pool(protein_conv).squeeze(2) + + preds = torch.cat([drug_conv, protein_conv], dim=1) + preds = self.dropout1(preds) + preds = self.leaky_relu(self.fc1(preds)) + preds = self.dropout2(preds) + preds = self.leaky_relu(self.fc2(preds)) + preds = self.dropout3(preds) + preds = self.leaky_relu(self.fc3(preds)) + # preds = self.out(preds) + + return preds diff --git a/deepscreen/models/predictors/m_graph_dta.py b/deepscreen/models/predictors/m_graph_dta.py new file mode 100644 index 0000000000000000000000000000000000000000..ebe849647f7536748e8c8352d29ea20931d7ea1c --- /dev/null +++ b/deepscreen/models/predictors/m_graph_dta.py @@ -0,0 +1,266 @@ +""" +MGraphDTA: Deep Multiscale Graph Neural Network for Explainable Drug-target binding affinity Prediction +""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from rdkit import Chem +from torch.nn.modules.batchnorm import _BatchNorm +import torch_geometric.nn as gnn +from torch import Tensor +from collections import OrderedDict + +from deepscreen.data.featurizers.categorical import one_of_k_encoding, one_of_k_encoding_unk + + +class MGraphDTA(nn.Module): + def __init__(self, block_num, vocab_protein_size, embedding_size=128, filter_num=32): + super().__init__() + self.protein_encoder = TargetRepresentation(block_num, vocab_protein_size, embedding_size) + self.ligand_encoder = GraphDenseNet(num_input_features=87, + out_dim=filter_num * 3, + block_config=[8, 8, 8], + bn_sizes=[2, 2, 2]) + + self.classifier = nn.Sequential( + nn.Linear(filter_num * 3 * 2, 1024), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(1024, 1024), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(1024, 256), + nn.ReLU(), + nn.Dropout(0.1) + ) + + def forward(self, emb_drug, emb_protein): + protein_x = self.protein_encoder(emb_protein) + ligand_x = self.ligand_encoder(emb_drug) + + x = torch.cat([protein_x, ligand_x], dim=-1) + x = self.classifier(x) + + return x + + +class Conv1dReLU(nn.Module): + """ + kernel_size=3, stride=1, padding=1 + kernel_size=5, stride=1, padding=2 + kernel_size=7, stride=1, padding=3 + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): + super().__init__() + self.inc = nn.Sequential( + nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, + padding=padding), + nn.ReLU() + ) + + def forward(self, x): + return self.inc(x) + + +class LinearReLU(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super().__init__() + self.inc = nn.Sequential( + nn.Linear(in_features=in_features, out_features=out_features, bias=bias), + nn.ReLU() + ) + + def forward(self, x): + return self.inc(x) + + +class StackCNN(nn.Module): + def __init__(self, layer_num, in_channels, out_channels, kernel_size, stride=1, padding=0): + super().__init__() + + self.inc = nn.Sequential(OrderedDict([('conv_layer0', + Conv1dReLU(in_channels, out_channels, kernel_size=kernel_size, + stride=stride, padding=padding))])) + for layer_idx in range(layer_num - 1): + self.inc.add_module('conv_layer%d' % (layer_idx + 1), + Conv1dReLU(out_channels, out_channels, kernel_size=kernel_size, stride=stride, + padding=padding)) + + self.inc.add_module('pool_layer', nn.AdaptiveMaxPool1d(1)) + + def forward(self, x): + return self.inc(x).squeeze(-1) + + +class TargetRepresentation(nn.Module): + def __init__(self, block_num, vocab_size, embedding_num): + super().__init__() + self.embed = nn.Embedding(vocab_size, embedding_num, padding_idx=0) + self.block_list = nn.ModuleList() + for block_idx in range(block_num): + self.block_list.append( + StackCNN(block_idx + 1, embedding_num, 96, 3) + ) + + self.linear = nn.Linear(block_num * 96, 96) + + def forward(self, x): + x = self.embed(x).permute(0, 2, 1) + feats = [block(x) for block in self.block_list] + x = torch.cat(feats, -1) + x = self.linear(x) + + return x + + +class NodeLevelBatchNorm(_BatchNorm): + r""" + Applies Batch Normalization over a batch of graph data. + Shape: + - Input: [batch_nodes_dim, node_feature_dim] + - Output: [batch_nodes_dim, node_feature_dim] + batch_nodes_dim: all nodes of a batch graph + """ + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, + track_running_stats=True): + super(NodeLevelBatchNorm, self).__init__( + num_features, eps, momentum, affine, track_running_stats) + + def _check_input_dim(self, input): + if input.dim() != 2: + raise ValueError('expected 2D input (got {}D input)' + .format(input.dim())) + + def forward(self, input): + self._check_input_dim(input) + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + if self.training and self.track_running_stats: + if self.num_batches_tracked is not None: + self.num_batches_tracked = self.num_batches_tracked + 1 + if self.momentum is None: + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: + exponential_average_factor = self.momentum + + return torch.functional.F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training or not self.track_running_stats, + exponential_average_factor, self.eps) + + def extra_repr(self): + return 'num_features={num_features}, eps={eps}, ' \ + 'affine={affine}'.format(**self.__dict__) + + +class GraphConvBn(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = gnn.GraphConv(in_channels, out_channels) + self.norm = NodeLevelBatchNorm(out_channels) + + def forward(self, data): + x, edge_index, batch = data.x, data.edge_index, data.batch + data.x = F.relu(self.norm(self.conv(x, edge_index))) + + return data + + +class DenseLayer(nn.Module): + def __init__(self, num_input_features, growth_rate=32, bn_size=4): + super().__init__() + self.conv1 = GraphConvBn(num_input_features, int(growth_rate * bn_size)) + self.conv2 = GraphConvBn(int(growth_rate * bn_size), growth_rate) + + def bn_function(self, data): + concated_features = torch.cat(data.x, 1) + data.x = concated_features + + data = self.conv1(data) + + return data + + def forward(self, data): + if isinstance(data.x, Tensor): + data.x = [data.x] + + data = self.bn_function(data) + data = self.conv2(data) + + return data + + +class DenseBlock(nn.ModuleDict): + def __init__(self, num_layers, num_input_features, growth_rate=32, bn_size=4): + super().__init__() + for i in range(num_layers): + layer = DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size) + self.add_module('layer%d' % (i + 1), layer) + + def forward(self, data): + features = [data.x] + for name, layer in self.items(): + data = layer(data) + features.append(data.x) + data.x = features + + data.x = torch.cat(data.x, 1) + + return data + + +class GraphDenseNet(nn.Module): + def __init__(self, num_input_features, out_dim, growth_rate=32, block_config=(3, 3, 3, 3), bn_sizes=(2, 3, 4, 4)): + super().__init__() + self.features = nn.Sequential(OrderedDict([('conv0', GraphConvBn(num_input_features, 32))])) + num_input_features = 32 + + for i, num_layers in enumerate(block_config): + block = DenseBlock( + num_layers, num_input_features, growth_rate=growth_rate, bn_size=bn_sizes[i] + ) + self.features.add_module('block%d' % (i + 1), block) + num_input_features += int(num_layers * growth_rate) + + trans = GraphConvBn(num_input_features, num_input_features // 2) + self.features.add_module("transition%d" % (i + 1), trans) + num_input_features = num_input_features // 2 + + self.classifier = nn.Linear(num_input_features, out_dim) + + def forward(self, data): + data = self.features(data) + x = gnn.global_mean_pool(data.x, data.batch) + x = self.classifier(x) + + return x + + +def atom_features(atom): + encoding = one_of_k_encoding_unk(atom.GetSymbol(), + ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As', + 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', + 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', + 'Pb', 'Unknown']) + encoding += one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + one_of_k_encoding_unk( + atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + encoding += one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + encoding += one_of_k_encoding_unk(atom.GetHybridization(), [ + Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, + Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, + Chem.rdchem.HybridizationType.SP3D2, 'other']) + encoding += [atom.GetIsAromatic()] + + try: + encoding += one_of_k_encoding_unk( + atom.GetProp('_CIPCode'), + ['R', 'S']) + [atom.HasProp('_ChiralityPossible')] + except: + encoding += [0, 0] + [atom.HasProp('_ChiralityPossible')] + + return np.array(encoding) diff --git a/deepscreen/models/predictors/mol_trans.py b/deepscreen/models/predictors/mol_trans.py new file mode 100644 index 0000000000000000000000000000000000000000..224b24a7234a469513fb798173adc048ace6ad05 --- /dev/null +++ b/deepscreen/models/predictors/mol_trans.py @@ -0,0 +1,301 @@ +import math +import copy + +import torch +from torch import nn +import torch.nn.functional as F + + +class MolTrans(nn.Module): + """ + Interaction Network with 2D interaction map + """ + def __init__( + self, + input_dim_drug: 23532, + input_dim_target: 16693, + max_drug_seq, + max_protein_seq, + emb_size: 384, + dropout_rate: 0.1, + # DenseNet + scale_down_ratio: 0.25, + growth_rate: 20, + transition_rate: 0.5, + num_dense_blocks: 4, + kernal_dense_size: 3, + # Encoder + intermediate_size: 1536, + num_attention_heads: 12, + attention_probs_dropout_prob: 0.1, + hidden_dropout_prob: 0.1, + # flatten_dim: 78192, + # batch_size + ): + super().__init__() + self.max_d = max_drug_seq + self.max_p = max_protein_seq + self.emb_size = emb_size + self.dropout_rate = dropout_rate + + # densenet + self.scale_down_ratio = scale_down_ratio + self.growth_rate = growth_rate + self.transition_rate = transition_rate + self.num_dense_blocks = num_dense_blocks + self.kernal_dense_size = kernal_dense_size + # self.batch_size = batch_size + self.input_dim_drug = input_dim_drug + self.input_dim_target = input_dim_target + self.n_layer = 2 + + # encoder + self.hidden_size = emb_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.hidden_dropout_prob = hidden_dropout_prob + + # self.flatten_dim = flatten_dim + + # specialized embedding with positional one + self.demb = Embeddings(self.input_dim_drug, self.emb_size, self.max_d, self.dropout_rate) + self.pemb = Embeddings(self.input_dim_target, self.emb_size, self.max_p, self.dropout_rate) + + self.d_encoder = EncoderMultipleLayers(self.n_layer, self.hidden_size, self.intermediate_size, + self.num_attention_heads, self.attention_probs_dropout_prob, + self.hidden_dropout_prob) + self.p_encoder = EncoderMultipleLayers(self.n_layer, self.hidden_size, self.intermediate_size, + self.num_attention_heads, self.attention_probs_dropout_prob, + self.hidden_dropout_prob) + + self.icnn = nn.Conv2d(1, 3, 3, padding=0) + + self.decoder = nn.Sequential( + # nn.Linear(self.flatten_dim, 512), + nn.LazyLinear(512), + nn.ReLU(True), + + nn.BatchNorm1d(512), + nn.Linear(512, 64), + nn.ReLU(True), + + nn.BatchNorm1d(64), + nn.Linear(64, 32), + nn.ReLU(True), + + # # output layer + # nn.Linear(32, 1) + ) + + def forward(self, v_d, v_p): + d, d_mask = v_d + p, p_mask = v_p + ex_d_mask = d_mask.unsqueeze(1).unsqueeze(2) + ex_p_mask = p_mask.unsqueeze(1).unsqueeze(2) + + ex_d_mask = (1.0 - ex_d_mask) * -10000.0 + ex_p_mask = (1.0 - ex_p_mask) * -10000.0 + + d_emb = self.demb(d) # batch_size x seq_length x embed_size + p_emb = self.pemb(p) + + batch_size = d_emb.size(0) + + # set output_all_encoded_layers be false, to obtain the last layer hidden states only. + d_encoded_layers = self.d_encoder(d_emb.float(), ex_d_mask.float()) + # print(d_encoded_layers.shape) + p_encoded_layers = self.p_encoder(p_emb.float(), ex_p_mask.float()) + # print(p_encoded_layers.shape) + + # repeat to have the same tensor size for aggregation + d_aug = torch.unsqueeze(d_encoded_layers, 2).repeat(1, 1, self.max_p, 1) # repeat along protein size + p_aug = torch.unsqueeze(p_encoded_layers, 1).repeat(1, self.max_d, 1, 1) # repeat along drug size + + i = d_aug * p_aug # interaction + i_v = i.view(batch_size, -1, self.max_d, self.max_p) + # batch_size x embed size x max_drug_seq_len x max_protein_seq_len + i_v = torch.sum(i_v, dim=1) + i_v = torch.unsqueeze(i_v, 1) + + i_v = F.dropout(i_v, p=self.dropout_rate) + + i = self.icnn(i_v).view(batch_size, -1) + score = self.decoder(i) + return score + + +class LayerNorm(nn.Module): + def __init__(self, hidden_size, variance_epsilon=1e-12): + super(LayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.ones(hidden_size)) + self.beta = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = variance_epsilon + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.gamma * x + self.beta + + +class Embeddings(nn.Module): + """Construct the embeddings from protein/target, position embeddings. + """ + + def __init__(self, vocab_size, hidden_size, max_position_size, dropout_rate): + super(Embeddings, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + self.position_embeddings = nn.Embedding(max_position_size, hidden_size) + + self.LayerNorm = LayerNorm(hidden_size) + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, input_ids): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + embeddings = words_embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class SelfAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): + super(SelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (hidden_size, num_attention_heads)) + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class SelfOutput(nn.Module): + def __init__(self, hidden_size, hidden_dropout_prob): + super(SelfOutput, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.LayerNorm = LayerNorm(hidden_size) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class Attention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): + super(Attention, self).__init__() + self.self = SelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) + self.output = SelfOutput(hidden_size, hidden_dropout_prob) + + def forward(self, input_tensor, attention_mask): + self_output = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class Intermediate(nn.Module): + def __init__(self, hidden_size, intermediate_size): + super(Intermediate, self).__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = F.relu(hidden_states) + return hidden_states + + +class Output(nn.Module): + def __init__(self, intermediate_size, hidden_size, hidden_dropout_prob): + super(Output, self).__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.LayerNorm = LayerNorm(hidden_size) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class Encoder(nn.Module): + def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, + hidden_dropout_prob): + super(Encoder, self).__init__() + self.attention = Attention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) + self.intermediate = Intermediate(hidden_size, intermediate_size) + self.output = Output(intermediate_size, hidden_size, hidden_dropout_prob) + + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class EncoderMultipleLayers(nn.Module): + def __init__(self, n_layer, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, + hidden_dropout_prob): + super().__init__() + layer = Encoder(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, + hidden_dropout_prob) + self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layer)]) + + def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): + all_encoder_layers = [] + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, attention_mask) + # if output_all_encoded_layers: + # all_encoder_layers.append(hidden_states) + # if not output_all_encoded_layers: + # all_encoder_layers.append(hidden_states) + return hidden_states diff --git a/deepscreen/models/predictors/transformer_cpi.py b/deepscreen/models/predictors/transformer_cpi.py new file mode 100644 index 0000000000000000000000000000000000000000..e4fbeacfd643004342a1a1e6af23f5a685af12f6 --- /dev/null +++ b/deepscreen/models/predictors/transformer_cpi.py @@ -0,0 +1,268 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +class TransformerCPI(nn.Module): + def __init__(self, protein_dim, hidden_dim, n_layers, kernel_size, dropout, n_heads, pf_dim, atom_dim=34): + super().__init__() + + self.encoder = Encoder(protein_dim, hidden_dim, n_layers, kernel_size, dropout) + self.decoder = Decoder(atom_dim, hidden_dim, n_layers, n_heads, pf_dim, dropout) + self.weight = nn.Parameter(torch.FloatTensor(atom_dim, atom_dim)) + self.init_weight() + + def init_weight(self): + stdv = 1. / math.sqrt(self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + + def gcn(self, input, adj): + # input =[batch,num_node, atom_dim] + # adj = [batch,num_node, num_node] + support = torch.matmul(input, self.weight) + # support =[batch,num_node,atom_dim] + output = torch.bmm(adj.float(), support.float()) + # output = [batch,num_node,atom_dim] + return output + + def forward(self, compound, protein): + compound, adj = compound + compound, compound_lengths = compound + adj, _ = adj + protein, protein_lengths = protein + # compound = [batch,atom_num, atom_dim] + # adj = [batch,atom_num, atom_num] + # protein = [batch,protein len, 100] + compound_mask = torch.arange(compound.size(1), device=compound.device) >= compound_lengths.unsqueeze(1) + protein_mask = torch.arange(protein.size(1), device=protein.device) >= protein_lengths.unsqueeze(1) + compound_mask = compound_mask.unsqueeze(1).unsqueeze(3) + protein_mask = protein_mask.unsqueeze(1).unsqueeze(2) + + compound = self.gcn(compound.float(), adj) + # compound = torch.unsqueeze(compound, dim=0) + # compound = [batch size=1 ,atom_num, atom_dim] + + # protein = torch.unsqueeze(protein, dim=0) + # protein =[ batch size=1,protein len, protein_dim] + enc_src = self.encoder(protein) + # enc_src = [batch size, protein len, hid dim] + + out = self.decoder(compound, enc_src, compound_mask, protein_mask) + # out = [batch size, 2] + # out = torch.squeeze(out, dim=0) + return out + + +class SelfAttention(nn.Module): + def __init__(self, hidden_dim, n_heads, dropout): + super().__init__() + + self.hidden_dim = hidden_dim + self.n_heads = n_heads + + assert hidden_dim % n_heads == 0 + + self.w_q = nn.Linear(hidden_dim, hidden_dim) + self.w_k = nn.Linear(hidden_dim, hidden_dim) + self.w_v = nn.Linear(hidden_dim, hidden_dim) + + self.fc = nn.Linear(hidden_dim, hidden_dim) + + self.do = nn.Dropout(dropout) + + self.scale = (hidden_dim // n_heads) ** 0.5 + + def forward(self, query, key, value, mask=None): + bsz = query.shape[0] + + # query = key = value [batch size, sent len, hid dim] + + q = self.w_q(query) + k = self.w_k(key) + v = self.w_v(value) + + # q, k, v = [batch size, sent len, hid dim] + + q = q.view(bsz, -1, self.n_heads, self.hidden_dim // self.n_heads).permute(0, 2, 1, 3) + k = k.view(bsz, -1, self.n_heads, self.hidden_dim // self.n_heads).permute(0, 2, 1, 3) + v = v.view(bsz, -1, self.n_heads, self.hidden_dim // self.n_heads).permute(0, 2, 1, 3) + + # k, v = [batch size, n heads, sent len_K, hid dim // n heads] + # q = [batch size, n heads, sent len_q, hid dim // n heads] + energy = torch.matmul(q, k.permute(0, 1, 3, 2)) / self.scale + + # energy = [batch size, n heads, sent len_Q, sent len_K] + if mask is not None: + energy = energy.masked_fill(mask == 0, -1e10) + + attention = self.do(F.softmax(energy, dim=-1)) + + # attention = [batch size, n heads, sent len_Q, sent len_K] + + x = torch.matmul(attention, v) + + # x = [batch size, n heads, sent len_Q, hid dim // n heads] + + x = x.permute(0, 2, 1, 3).contiguous() + + # x = [batch size, sent len_Q, n heads, hid dim // n heads] + + x = x.view(bsz, -1, self.n_heads * (self.hidden_dim // self.n_heads)) + + # x = [batch size, src sent len_Q, hid dim] + + x = self.fc(x) + + # x = [batch size, sent len_Q, hid dim] + + return x + + +class Encoder(nn.Module): + """protein feature extraction.""" + def __init__(self, protein_dim, hidden_dim, n_layers, kernel_size, dropout): + super().__init__() + + assert kernel_size % 2 == 1, "Kernel size must be odd (for now)" + + self.input_dim = protein_dim + self.hidden_dim = hidden_dim + self.kernel_size = kernel_size + self.dropout = dropout + self.n_layers = n_layers + # self.pos_embedding = nn.Embedding(1000, hidden_dim) + self.scale = 0.5 ** 0.5 + self.convs = nn.ModuleList( + [nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size, padding=(kernel_size - 1) // 2) for _ in + range(self.n_layers)]) # convolutional layers + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(self.input_dim, self.hidden_dim) + self.gn = nn.GroupNorm(8, hidden_dim * 2) + self.ln = nn.LayerNorm(hidden_dim) + + def forward(self, protein): + # pos = torch.arange(0, protein.shape[1]).unsqueeze(0).repeat(protein.shape[0], 1) + # protein = protein + self.pos_embedding(pos) + # protein = [batch size, protein len,protein_dim] + conv_input = self.fc(protein.float()) + # conv_input=[batch size,protein len,hid dim] + # permute for convolutional layer + conv_input = conv_input.permute(0, 2, 1) + # conv_input = [batch size, hid dim, protein len] + for i, conv in enumerate(self.convs): + # pass through convolutional layer + conved = conv(self.dropout(conv_input)) + # conved = [batch size, 2*hid dim, protein len] + + # pass through GLU activation function + conved = F.glu(conved, dim=1) + # conved = [batch size, hid dim, protein len] + + # apply residual connection / high way + conved = (conved + conv_input) * self.scale + # conved = [batch size, hid dim, protein len] + + # set conv_input to conved for next loop iteration + conv_input = conved + + conved = conved.permute(0, 2, 1) + # conved = [batch size,protein len,hid dim] + conved = self.ln(conved) + return conved + + +class PositionwiseFeedforward(nn.Module): + def __init__(self, hidden_dim, pf_dim, dropout): + super().__init__() + + self.hidden_dim = hidden_dim + self.pf_dim = pf_dim + + self.fc_1 = nn.Conv1d(hidden_dim, pf_dim, 1) # convolution neural units + self.fc_2 = nn.Conv1d(pf_dim, hidden_dim, 1) # convolution neural units + + self.do = nn.Dropout(dropout) + + def forward(self, x): + # x = [batch size, sent len, hid dim] + x = x.permute(0, 2, 1) # x = [batch size, hid dim, sent len] + x = self.do(F.relu(self.fc_1(x))) # x = [batch size, pf dim, sent len] + x = self.fc_2(x) # x = [batch size, hid dim, sent len] + x = x.permute(0, 2, 1) # x = [batch size, sent len, hid dim] + + return x + + +class DecoderLayer(nn.Module): + def __init__(self, hidden_dim, n_heads, pf_dim, dropout, + self_attention=SelfAttention, + positionwise_feedforward=PositionwiseFeedforward): + super().__init__() + self.ln = nn.LayerNorm(hidden_dim) + self.sa = self_attention(hidden_dim, n_heads, dropout) + self.ea = self_attention(hidden_dim, n_heads, dropout) + self.pf = positionwise_feedforward(hidden_dim, pf_dim, dropout) + self.do = nn.Dropout(dropout) + + def forward(self, trg, src, trg_mask=None, src_mask=None): + # trg = [batch_size, compound len, atom_dim] + # src = [batch_size, protein len, hidden_dim] # encoder output + # trg_mask = [batch size, compound sent len] + # src_mask = [batch size, protein len] + trg = self.ln(trg + self.do(self.sa(trg, trg, trg, trg_mask))) + trg = self.ln(trg + self.do(self.ea(trg, src, src, src_mask))) + trg = self.ln(trg + self.do(self.pf(trg))) + + return trg + + +class Decoder(nn.Module): + """ compound feature extraction.""" + + def __init__(self, atom_dim, hidden_dim, n_layers, n_heads, pf_dim, dropout, + decoder_layer=DecoderLayer, + self_attention=SelfAttention, + positionwise_feedforward=PositionwiseFeedforward): + super().__init__() + self.ln = nn.LayerNorm(hidden_dim) + self.output_dim = atom_dim + self.hidden_dim = hidden_dim + self.n_layers = n_layers + self.n_heads = n_heads + self.pf_dim = pf_dim + self.decoder_layer = decoder_layer + self.self_attention = self_attention + self.positionwise_feedforward = positionwise_feedforward + self.dropout = dropout + self.sa = self_attention(hidden_dim, n_heads, dropout) + self.layers = nn.ModuleList( + [decoder_layer(hidden_dim, n_heads, pf_dim, dropout, self_attention, positionwise_feedforward) + for _ in range(n_layers)]) + self.ft = nn.Linear(atom_dim, hidden_dim) + self.do = nn.Dropout(dropout) + self.fc_1 = nn.Linear(hidden_dim, 256) + # self.fc_2 = nn.Linear(256, 2) + self.gn = nn.GroupNorm(8, 256) + + def forward(self, trg, src, trg_mask=None, src_mask=None): + # trg = [batch_size, compound len, atom_dim] + # src = [batch_size, protein len, hidden_dim] # encoder output + trg = self.ft(trg) # trg = [batch size, compound len, hid dim] + + for layer in self.layers: + trg = layer(trg, src, trg_mask, src_mask) # trg = [batch size, compound len, hid dim] + """Use norm to determine which atom is significant. """ + norm = torch.norm(trg, dim=2) # norm = [batch size,compound len] + norm = F.softmax(norm, dim=1) # norm = [batch size,compound len] + # trg = torch.squeeze(trg,dim=0) + # norm = torch.squeeze(norm,dim=0) + sum = torch.zeros((trg.shape[0], self.hidden_dim), device=trg.device) + for i in range(norm.shape[0]): + for j in range(norm.shape[1]): + v = trg[i, j,] + v = v * norm[i, j] + sum[i,] += v # sum = [batch size,hidden_dim] + label = F.relu(self.fc_1(sum)) + # label = self.fc_2(label) + return label diff --git a/deepscreen/models/predictors/transformer_cpi_2.py b/deepscreen/models/predictors/transformer_cpi_2.py new file mode 100644 index 0000000000000000000000000000000000000000..fd4f0f15483a3ee65f5e444eb870ca6f1f0dd198 --- /dev/null +++ b/deepscreen/models/predictors/transformer_cpi_2.py @@ -0,0 +1,102 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TransformerCPI2(nn.Module): + def __init__(self, encoder, decoder, atom_dim=34): + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.fc_1 = nn.Linear(atom_dim, atom_dim) + self.fc_2 = nn.Linear(atom_dim, 768) + + # def gcn(self, compound, adj): + # # input = [batch, num_node, atom_dim] + # # adj = [batch, num_node, num_node] + # support = self.fc_1(compound) # support = [batch, num_node, atom_dim] + # output = torch.bmm(adj, support) # output = [batch, num_node, atom_dim] + # return output + + def forward(self, compound, protein): + # atom_feat = [batch_size, atom_num, atom_dim] + # adj_mat = [batch_size, atom_num, atom_num] + # enc_protein = [batch_size, protein_len, 768] + compound, adj = compound + adj, _ = adj + compound, compound_lengths = compound + protein, protein_lengths = protein + + # Add a global/master node to the compound + batch_size, num_node, _ = compound.shape + # Add the global node + compound = F.pad(compound, (0, 0, 1, 0), value=0) + # Add an identity matrix to each adjacency matrix to represent self-connections + adj = adj + torch.eye(num_node, device=compound.device).unsqueeze(0).expand(batch_size, -1, -1) + # Add global edges + adj = F.pad(adj, (1, 0, 1, 0), value=1) + + compound_mask = torch.arange(compound.size(1), device=compound.device) >= (compound_lengths + 1).unsqueeze(1) + protein_mask = torch.arange(protein.size(1), device=protein.device) >= protein_lengths.unsqueeze(1) + + compound = self.gcn(compound.float(), adj) # compound = [batch_size, atom_num, atom_dim] + compound = F.relu(self.fc_2(compound)) # compound = [batch, compound_len, 768] + enc_src, src_mask = self.encoder(protein, protein_mask) # enc_src = [protein_len,batch , hid_dim] + out = self.decoder(compound, enc_src, compound_mask, src_mask) # out = [batch_size, 2] + + return out + + def gcn(self, compound, adj): + support = self.fc_1(compound) # support = [batch, num_node, atom_dim] + output = torch.bmm(adj, support) # output = [batch, num_node, atom_dim] + return output + + +class Encoder(nn.Module): + """protein feature extraction""" + def __init__(self, pretrain, n_layers): + super().__init__() + self.pretrain = pretrain + self.hid_dim = 768 + self.n_layers = n_layers + self.encoder_layer = nn.TransformerEncoderLayer( + d_model=self.hid_dim, nhead=8, dim_feedforward=self.hid_dim * 4, dropout=0.1 + ) + self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=self.n_layers) + + def forward(self, protein, mask): + # protein = [batch_size, protein_len] + # mask = [batch_size, protein_len] 0 for true positions, 1 for mask positions + with torch.no_grad(): + protein = self.pretrain(protein.long(), mask.long())[0] + protein = protein.permute(1, 0, 2).contiguous() # protein = [protein_len, batch_size, 768] + protein = self.encoder(protein, src_key_padding_mask=mask) # protein = [protein_len, batch_size, 768] + return protein, mask + + +class Decoder(nn.Module): + """compound feature extraction""" + def __init__(self, n_layers, dropout): + super().__init__() + self.hid_dim = 768 + self.n_layers = n_layers + self.decoder_layer = nn.TransformerDecoderLayer( + d_model=self.hid_dim, nhead=8, dim_feedforward=self.hid_dim * 4, dropout=0.1 + ) + self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=self.n_layers) + self.fc_1 = nn.Linear(768, 256) + self.fc_2 = nn.Linear(256, 2) + self.dropout = nn.Dropout(dropout) + + def forward(self, tgt, src, tgt_mask=None, src_mask=None): + # tgt = [batch_size, compound len, hid_dim] + # src = [protein_len, batch_size, hid_dim] # encoder output + tgt = tgt.permute(1, 0, 2).contiguous() # tgt = [compound_len, batch_size, hid_dim] + # tgt_mask = tgt_mask == 1 + tgt = self.decoder(tgt, src, tgt_key_padding_mask=tgt_mask, memory_key_padding_mask=src_mask) + # tgt = [compound_len, batch_size, hid_dim] + tgt = tgt.permute(1, 0, 2).contiguous() # tgt = [batch_size, compound_len, hid_dim] + x = tgt[:, 0, :] + label = F.relu(self.fc_1(x)) + # label = self.fc_2(label) + return label diff --git a/deepscreen/predict.py b/deepscreen/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..1669c1d343017bb686483dc052d24f8a59a96508 --- /dev/null +++ b/deepscreen/predict.py @@ -0,0 +1,74 @@ +from typing import List, Tuple + +import hydra +from omegaconf import DictConfig +from lightning import LightningDataModule, LightningModule, Trainer, Callback + +from deepscreen.utils.hydra import checkpoint_rerun_config +from deepscreen.utils import get_logger, job_wrapper, instantiate_callbacks + +log = get_logger(__name__) + + +# def fix_dict_config(cfg: DictConfig): +# """fix all vars in the cfg config +# this is an in-place operation""" +# keys = list(cfg.keys()) +# for k in keys: +# if type(cfg[k]) is DictConfig: +# fix_dict_config(cfg[k]) +# else: +# setattr(cfg, k, getattr(cfg, k)) + + +@job_wrapper(extra_utils=True) +def predict(cfg: DictConfig) -> Tuple[list, dict]: + """Predict given checkpoint on a data predict set. + + This method is wrapped in optional @job_wrapper decorator which applies extra utilities + before and after the call. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + + Returns: + Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + """ + log.info(f"Instantiating data <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks.") + callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=False, callbacks=callbacks) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "trainer": trainer, + } + + log.info("Start predicting.") + + predictions = trainer.predict(model=model, datamodule=datamodule, + ckpt_path=cfg.ckpt_path, return_predictions=True) + + return predictions, object_dict + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="predict.yaml") +def main(cfg: DictConfig): + assert cfg.ckpt_path, "Checkpoint path (`ckpt_path`) must be specified for predicting." + cfg = checkpoint_rerun_config(cfg) + predictions, _ = predict(cfg) + return predictions + + +if __name__ == "__main__": + main() diff --git a/deepscreen/utils/__init__.py b/deepscreen/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f105b4f422de89a9ce402c650fae4aa262e4838 --- /dev/null +++ b/deepscreen/utils/__init__.py @@ -0,0 +1,8 @@ +from deepscreen.utils.logging import get_logger, log_hyperparameters +from deepscreen.utils.instantiators import instantiate_callbacks, instantiate_loggers +from deepscreen.utils.rich import enforce_tags, print_config_tree +from deepscreen.utils.utils import extras, job_wrapper + + +def passthrough(x): + return x diff --git a/deepscreen/utils/__pycache__/__init__.cpython-311.pyc b/deepscreen/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..239dc7ee2dcfbe59ce2de4ee73a8c5194cce0a1b Binary files /dev/null and b/deepscreen/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/utils/__pycache__/hydra.cpython-311.pyc b/deepscreen/utils/__pycache__/hydra.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64cab8f0a3c32ac7b85595e8ad5cc24bebf7349d Binary files /dev/null and b/deepscreen/utils/__pycache__/hydra.cpython-311.pyc differ diff --git a/deepscreen/utils/__pycache__/instantiators.cpython-311.pyc b/deepscreen/utils/__pycache__/instantiators.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cf3d20399e00565dfe5eb8d8805728b69985f45 Binary files /dev/null and b/deepscreen/utils/__pycache__/instantiators.cpython-311.pyc differ diff --git a/deepscreen/utils/__pycache__/logging.cpython-311.pyc b/deepscreen/utils/__pycache__/logging.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbe194675c0f37bddec1318e5db8d883c354bd78 Binary files /dev/null and b/deepscreen/utils/__pycache__/logging.cpython-311.pyc differ diff --git a/deepscreen/utils/__pycache__/rich.cpython-311.pyc b/deepscreen/utils/__pycache__/rich.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b33367488a99a56c57b65f042720f7ca7789bf23 Binary files /dev/null and b/deepscreen/utils/__pycache__/rich.cpython-311.pyc differ diff --git a/deepscreen/utils/__pycache__/utils.cpython-311.pyc b/deepscreen/utils/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d568c9ce81de084efabdba8f1bb6850564cba25 Binary files /dev/null and b/deepscreen/utils/__pycache__/utils.cpython-311.pyc differ diff --git a/deepscreen/utils/hydra.py b/deepscreen/utils/hydra.py new file mode 100644 index 0000000000000000000000000000000000000000..1a33602de5267a02db3c30101dd287799ba99943 --- /dev/null +++ b/deepscreen/utils/hydra.py @@ -0,0 +1,195 @@ +from datetime import timedelta +from pathlib import Path +import re +from time import time +from typing import Any, Tuple + +import pandas as pd +from hydra import TaskFunction +from hydra.core.hydra_config import HydraConfig +from hydra.core.utils import _save_config +from hydra.experimental.callbacks import Callback +from hydra.types import RunMode +from omegaconf import DictConfig, OmegaConf + +from deepscreen.utils import get_logger + +log = get_logger(__name__) + + +class CSVExperimentSummary(Callback): + """On multirun end, aggregate the results from each job's metrics.csv and save them in metrics_summary.csv.""" + + def __init__(self, filename: str = 'experiment_summary.csv', prefix: str | Tuple[str] = 'test/'): + self.filename = filename + self.prefix = prefix if isinstance(prefix, str) else tuple(prefix) + self.input_experiment_summary = None + self.time = {} + + def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None: + if config.hydra.get('overrides') and config.hydra.overrides.get('task'): + for i, override in enumerate(config.hydra.overrides.task): + if override.startswith("ckpt_path"): + ckpt_path = override.split('=', 1)[1] + if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')): + config.hydra.overrides.task[i] = self.parse_ckpt_path_from_experiment_summary(ckpt_path) + break + if config.hydra.sweeper.get('params'): + if config.hydra.sweeper.params.get('ckpt_path'): + ckpt_path = str(config.hydra.sweeper.params.ckpt_path).strip("'\"") + if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')): + config.hydra.sweeper.params.ckpt_path = self.parse_ckpt_path_from_experiment_summary(ckpt_path) + + def on_job_start(self, config: DictConfig, *, task_function: TaskFunction, **kwargs: Any) -> None: + self.time['start'] = time() + + def on_job_end(self, config: DictConfig, job_return, **kwargs: Any) -> None: + # Skip callback if job is DDP subprocess + if "ddp" in job_return.hydra_cfg.hydra.job.name: + return + + try: + self.time['end'] = time() + if config.hydra.mode == RunMode.RUN: + summary_file_path = Path(config.hydra.run.dir) / self.filename + elif config.hydra.mode == RunMode.MULTIRUN: + summary_file_path = Path(config.hydra.sweep.dir) / self.filename + else: + raise RuntimeError('Invalid Hydra `RunMode`.') + + if summary_file_path.is_file(): + summary_df = pd.read_csv(summary_file_path) + else: + summary_df = pd.DataFrame() + + # Add job and override info + info_dict = {} + if job_return.overrides: + info_dict = dict(override.split('=', 1) for override in job_return.overrides) + info_dict['job_status'] = job_return.status.name + info_dict['job_id'] = job_return.hydra_cfg.hydra.job.id + info_dict['wall_time'] = str(timedelta(self.time['end'] - self.time['start'])) + + # Add checkpoint info + if info_dict.get('ckpt_path'): + info_dict['ckpt_path'] = str(info_dict['ckpt_path']).strip("'\"") + + ckpt_path = str(job_return.cfg.ckpt_path).strip("'\"") + if Path(ckpt_path).is_file(): + if info_dict.get('ckpt_path') and ckpt_path != info_dict['ckpt_path']: + info_dict['previous_ckpt_path'] = info_dict['ckpt_path'] + info_dict['ckpt_path'] = ckpt_path + info_dict['best_epoch'] = int(re.search(r'epoch_(\d+)', info_dict['ckpt_path']).group(1)) + + # Add metrics info + metrics_df = pd.DataFrame() + if config.get('logger'): + output_dir = Path(config.hydra.runtime.output_dir).resolve() + csv_metrics_path = output_dir / config.logger.csv.name / "metrics.csv" + if csv_metrics_path.is_file(): + log.info(f"Summarizing metrics with prefix `{self.prefix}` from {csv_metrics_path}") + metrics_df = pd.read_csv(csv_metrics_path) + # Find rows where 'test/' columns are not null and reset its epoch to the best model epoch + test_columns = [col for col in metrics_df.columns if col.startswith('test/')] + if test_columns: + mask = metrics_df[test_columns].notna().any(axis=1) + metrics_df.loc[mask, 'epoch'] = info_dict['best_epoch'] + # Group and filter by best epoch + metrics_df = metrics_df.groupby('epoch').first() + metrics_df = metrics_df[metrics_df.index == info_dict['best_epoch']] + else: + log.info(f"No metrics.csv found in {output_dir}") + + if metrics_df.empty: + metrics_df = pd.DataFrame(data=info_dict, index=[0]) + else: + metrics_df = metrics_df.assign(**info_dict) + metrics_df.index = [0] + + # Add extra info from the input batch experiment summary + if self.input_experiment_summary is not None and 'ckpt_path' in metrics_df.columns: + orig_meta = self.input_experiment_summary[ + self.input_experiment_summary['ckpt_path'] == metrics_df['ckpt_path'][0] + ].head(1) + if not orig_meta.empty: + orig_meta.index = [0] + metrics_df = metrics_df.combine_first(orig_meta) + + summary_df = pd.concat([summary_df, metrics_df]) + + # Drop empty columns + summary_df.dropna(inplace=True, axis=1, how='all') + summary_df.to_csv(summary_file_path, index=False, mode='w') + log.info(f"Experiment summary saved to {summary_file_path}") + except Exception as e: + log.exception("Unable to save the experiment summary due to an error.", exc_info=e) + + def parse_ckpt_path_from_experiment_summary(self, ckpt_path): + try: + self.input_experiment_summary = pd.read_csv( + ckpt_path, usecols=lambda col: not col.startswith(self.prefix) + ) + self.input_experiment_summary['ckpt_path'] = self.input_experiment_summary['ckpt_path'].apply( + lambda x: x.strip("'\"") + ) + ckpt_list = list(set(self.input_experiment_summary['ckpt_path'])) + parsed_ckpt_path = ','.join([f"'{ckpt}'" for ckpt in ckpt_list]) + return parsed_ckpt_path + + except Exception as e: + log.exception( + f'Error in parsing checkpoint paths from experiment_summary file ({ckpt_path}).', + exc_info=e + ) + + +def checkpoint_rerun_config(config: DictConfig): + hydra_cfg = HydraConfig.get() + + if hydra_cfg.output_subdir is not None: + ckpt_cfg_path = Path(config.ckpt_path).parents[1] / hydra_cfg.output_subdir / 'config.yaml' + hydra_output = Path(hydra_cfg.runtime.output_dir) / hydra_cfg.output_subdir + + if ckpt_cfg_path.is_file(): + log.info(f"Found config file for the checkpoint at {str(ckpt_cfg_path)}; " + f"merging config overrides with checkpoint config...") + ckpt_cfg = OmegaConf.load(ckpt_cfg_path) + + # Merge checkpoint config with test config by overriding specified nodes. + # ckpt_cfg = OmegaConf.masked_copy(ckpt_cfg, ['model', 'data', 'trainer', 'task']) + # ckpt_cfg.data = OmegaConf.masked_copy(ckpt_cfg.data, [ + # key for key in ckpt_cfg.data.keys() if key not in ['data_file', 'split', 'train_val_test_split'] + # ]) + # + # config = OmegaConf.merge(ckpt_cfg, config) + + # config = OmegaConf.masked_copy(config, + # [key for key in config if key not in + # ['task']]) + # config.data = OmegaConf.masked_copy(config.data, + # [key for key in config.data if key not in + # ['drug_featurizer', 'protein_featurizer', 'collator']]) + # config.model = OmegaConf.masked_copy(config.model, + # [key for key in config.model if key not in + # ['predictor']]) + # + # config = OmegaConf.merge(ckpt_cfg, config) + + ckpt_cfg = OmegaConf.masked_copy(ckpt_cfg, ['model', 'data', 'task', 'seed']) + ckpt_cfg.data = OmegaConf.masked_copy(ckpt_cfg.data, [ + key for key in ckpt_cfg.data.keys() if key not in ['data_file', 'split', 'train_val_test_split'] + ]) + ckpt_override_keys = ['task', 'data.drug_featurizer', 'data.protein_featurizer', 'data.collator', + 'model.predictor', 'model.out', 'model.loss', 'model.activation', 'model.metrics'] + + for key in ckpt_override_keys: + OmegaConf.update(config, key, OmegaConf.select(ckpt_cfg, key), force_add=True) + + config = OmegaConf.merge(ckpt_cfg, config) + + # OmegaConf.set_readonly(hydra_cfg, False) + # hydra_cfg.job.override_dirname += f"ckpt={str(Path(*Path(config.ckpt_path).parts[-4:]))}" + _save_config(config, "config.yaml", hydra_output) + + return config + diff --git a/deepscreen/utils/instantiators.py b/deepscreen/utils/instantiators.py new file mode 100644 index 0000000000000000000000000000000000000000..f8afaade147cd034730ae24c2f84a1348a6a73ac --- /dev/null +++ b/deepscreen/utils/instantiators.py @@ -0,0 +1,50 @@ +from typing import List + +import hydra +from lightning import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from deepscreen.utils import get_logger + +log = get_logger(__name__) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config.""" + + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config.""" + + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping.") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/deepscreen/utils/lightning.py b/deepscreen/utils/lightning.py new file mode 100644 index 0000000000000000000000000000000000000000..5b226aeab10cc2e741a29929a401f7071e245113 --- /dev/null +++ b/deepscreen/utils/lightning.py @@ -0,0 +1,49 @@ +from pathlib import Path +from typing import Literal + +from lightning.pytorch.callbacks import BasePredictionWriter +import pandas as pd +import torch + +from deepscreen.utils import get_logger + +log = get_logger(__name__) + + +class CSVPredictionWriter(BasePredictionWriter): + def __init__(self, output_dir, write_interval: Literal["batch", "epoch"] = "batch"): + super().__init__(write_interval) + self.output_file = Path(output_dir, "predictions.csv") + + def setup(self, trainer, pl_module, stage: str): + log.info(f"Saving predictions every {self.interval.value} for job `{stage}`.") + + def write_on_batch_end(self, trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx): + output_df = self.outputs_to_dataframe(outputs) + output_df.to_csv(self.output_file, + mode='a', + index=False, + header=not self.output_file.is_file()) + + def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): + output_df = pd.concat([self.outputs_to_dataframe(outputs) for outputs in predictions]) + output_df.to_csv(self.output_file, + mode='w', + index=False, + header=True) + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int, dataloader_idx: int = 0): + self.write_on_batch_end(trainer, pl_module, outputs, None, batch, batch_idx, dataloader_idx) + + def teardown(self, trainer, pl_module, stage: str): + log.info(f'Predictions saved to {self.output_file}') + + @staticmethod + def outputs_to_dataframe(prediction): + for key, value in prediction.items(): + if isinstance(value, torch.Tensor): + prediction[key] = value.tolist() + else: + prediction[key] = list(value) + prediction_df = pd.DataFrame(prediction) + return prediction_df diff --git a/deepscreen/utils/logging.py b/deepscreen/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..e71bf747a8a41331eb71ef6b72f06a7d147d0e66 --- /dev/null +++ b/deepscreen/utils/logging.py @@ -0,0 +1,62 @@ +import logging + +from lightning.pytorch.utilities import rank_zero_only +from lightning.pytorch.utilities.model_summary import ModelSummary + + +def get_logger(name=__name__) -> logging.Logger: + """Initializes multi-GPU-friendly python command line logger.""" + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + for level in logging_levels: + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger + + +log = get_logger(__name__) + + +@rank_zero_only +def log_hyperparameters(object_dict: dict) -> None: + """Controls which config parts are saved by lightning loggers. + + Additionally, saves: + - Number of model parameters + """ + + hparams = {} + + cfg = object_dict["cfg"] + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging.") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + model_summary = ModelSummary(model) + hparams["model/params/total"] = model_summary.total_parameters + hparams["model/params/trainable"] = model_summary.trainable_parameters + hparams["model/params/non_trainable"] = model_summary.total_parameters - model_summary.trainable_parameters + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["job_name"] = cfg.get("job_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/deepscreen/utils/rich.py b/deepscreen/utils/rich.py new file mode 100644 index 0000000000000000000000000000000000000000..2d57ee33a917ebd95133a21665234f19fb4d854b --- /dev/null +++ b/deepscreen/utils/rich.py @@ -0,0 +1,105 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, OmegaConf, open_dict +from lightning.pytorch.utilities import rank_zero_only +from rich.prompt import Prompt + +from deepscreen.utils import get_logger + +log = get_logger(__name__) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + print_order (Sequence[str], optional): Determines in what order config components are printed. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + save_to_file (bool, optional): Whether to export config to the hydra output folder. + """ + + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + queue.append(field) if field in cfg else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing." + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config.""" + + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags.") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) + + +# if __name__ == "__main__": +# from hydra import compose, initialize +# +# with initialize(version_base="1.2", config_path="../../configs"): +# cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) +# print_config_tree(cfg, resolve=False, save_to_file=False) diff --git a/deepscreen/utils/utils.py b/deepscreen/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6855160479ddf99316ee7fd8e51d2bb4e6ffb860 --- /dev/null +++ b/deepscreen/utils/utils.py @@ -0,0 +1,206 @@ +# import time +# from pathlib import Path +# from typing import Any, Dict, List +# +# import hydra +# from pytorch_lightning import Callback +# from pytorch_lightning.loggers import Logger +# from pytorch_lightning.utilities import rank_zero_only + +import warnings +from importlib.util import find_spec +from typing import Callable + +from omegaconf import DictConfig + +from deepscreen.utils import get_logger, enforce_tags, print_config_tree + +log = get_logger(__name__) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before a job is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + """ + + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + print_config_tree(cfg, resolve=True, save_to_file=True) + + +def job_wrapper(extra_utils: bool) -> Callable: + """Optional decorator that controls the failure behavior and extra utilities when executing a job function. + + This wrapper can be used to: + - make sure loggers are closed even if the job function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.job_wrapper(extra_utils) + def train(cfg: DictConfig) -> Tuple[dict, dict]: + + . + + return metric_dict, object_dict + ``` + """ + def decorator(job_func): + def wrapped_func(cfg: DictConfig): + # execute the job + try: + # apply extra utilities + if extra_utils: + extras(cfg) + metric_dict, object_dict = job_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or cause out-of-memory errors + # so when using hparam search plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + return wrapped_func + return decorator + +# @rank_zero_only +# def save_file(path, content) -> None: +# """Save file in rank zero mode (only on one process in multi-GPU setup).""" +# with open(path, "w+") as file: +# file.write(content) +# +# +# def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: +# """Instantiates callbacks from config.""" +# callbacks: List[Callback] = [] +# +# if not callbacks_cfg: +# log.warning("Callbacks config is empty.") +# return callbacks +# +# if not isinstance(callbacks_cfg, DictConfig): +# raise TypeError("Callbacks config must be a DictConfig!") +# +# for _, cb_conf in callbacks_cfg.items(): +# if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: +# log.info(f"Instantiating callback <{cb_conf._target_}>") +# callbacks.append(hydra.utils.instantiate(cb_conf)) +# +# return callbacks +# +# +# def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: +# """Instantiates loggers from config.""" +# logger: List[Logger] = [] +# +# if not logger_cfg: +# log.warning("Logger config is empty.") +# return logger +# +# if not isinstance(logger_cfg, DictConfig): +# raise TypeError("Logger config must be a DictConfig!") +# +# for _, lg_conf in logger_cfg.items(): +# if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: +# log.info(f"Instantiating logger <{lg_conf._target_}>") +# logger.append(hydra.utils.instantiate(lg_conf)) +# +# return logger +# +# +# @rank_zero_only +# def log_hyperparameters(object_dict: Dict[str, Any]) -> None: +# """Controls which config parts are saved by lightning loggers. +# +# Additionally saves: +# - Number of model parameters +# """ +# +# hparams = {} +# +# cfg = object_dict["cfg"] +# model = object_dict["model"] +# trainer = object_dict["trainer"] +# +# if not trainer.logger: +# log.warning("Logger not found! Skipping hyperparameter logging.") +# return +# +# hparams["model"] = cfg["model"] +# +# # TODO Accommodation for LazyModule +# # save number of model parameters +# hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) +# hparams["model/params/trainable"] = sum( +# p.numel() for p in model.parameters() if p.requires_grad +# ) +# hparams["model/params/non_trainable"] = sum( +# p.numel() for p in model.parameters() if not p.requires_grad +# ) +# +# hparams["data"] = cfg["data"] +# hparams["trainer"] = cfg["trainer"] +# +# hparams["callbacks"] = cfg.get("callbacks") +# hparams["extras"] = cfg.get("extras") +# +# hparams["job_name"] = cfg.get("job_name") +# hparams["tags"] = cfg.get("tags") +# hparams["ckpt_path"] = cfg.get("ckpt_path") +# hparams["seed"] = cfg.get("seed") +# +# # send hparams to all loggers +# trainer.logger.log_hyperparams(hparams) + + +# def close_loggers() -> None: +# """Makes sure all loggers closed properly (prevents logging failure during multirun).""" +# +# log.info("Closing loggers.") +# +# if find_spec("wandb"): # if wandb is installed +# import wandb +# +# if wandb.run: +# log.info("Closing wandb!") +# wandb.finish()