diff --git a/configs/data/dti_data.yaml b/configs/data/dti_data.yaml index 5b65659c15f8fac1a229e16586ffb226293b4f14..11c7e42238e7ca608436851ccf5f934a583f995f 100644 --- a/configs/data/dti_data.yaml +++ b/configs/data/dti_data.yaml @@ -1,7 +1,7 @@ _target_: deepscreen.data.dti.DTIDataModule defaults: - - split: null + - split: none - drug_featurizer: none # ??? - protein_featurizer: none # ??? - collator: default @@ -13,8 +13,8 @@ data_dir: ${paths.data_dir} data_file: null train_val_test_split: null -batch_size: ??? +batch_size: 2 num_workers: 0 pin_memory: false - -#train: ${train} \ No newline at end of file +query: X2 +#train: ${train} diff --git a/configs/data/protein_featurizer/word2vec.yaml b/configs/data/protein_featurizer/word2vec.yaml index 7330385dec21adfb415ef0b67f21bcd97dd6152f..1054e7bf765bb08afc3daaa0692a14fc1e97126d 100644 --- a/configs/data/protein_featurizer/word2vec.yaml +++ b/configs/data/protein_featurizer/word2vec.yaml @@ -3,4 +3,4 @@ _partial_: true model: _target_: gensim.models.Word2Vec.load - fname: ${paths.resource_dir}/models/word2vec_30.model \ No newline at end of file + fname: ${paths.resource_dir}/models/word2vec_30.model diff --git a/configs/model/decoder/concat_mlp.yaml b/configs/model/decoder/concat_mlp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..17b92c9493a45d51b3463dc069a73295fc22eb4f --- /dev/null +++ b/configs/model/decoder/concat_mlp.yaml @@ -0,0 +1,6 @@ +_target_: deepscreen.models.components.mlp.ConcatMLP + +input_channels: ${eval:${model.drug_encoder.out_channels}+${model.protein_encoder.out_channels}} +out_channels: 512 +hidden_channels: [1024,1024] +dropout: 0.1 \ No newline at end of file diff --git a/configs/model/decoder/mlp_deepdta.yaml b/configs/model/decoder/mlp_deepdta.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b6ee28076224a373a6879ece9296491dccfd280c --- /dev/null +++ b/configs/model/decoder/mlp_deepdta.yaml @@ -0,0 +1,6 @@ +_target_: deepscreen.models.components.mlp.MLP2 + +input_channels: ${eval:${model.drug_encoder.out_channels}+${model.protein_encoder.out_channels}} +out_channels: 1 +hidden_channels: [1024,1024,512] +dropout: 0.1 \ No newline at end of file diff --git a/configs/model/decoder/mlp_lazy.yaml b/configs/model/decoder/mlp_lazy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..832863817e37fed6e5ef54eba4de99341310c4dc --- /dev/null +++ b/configs/model/decoder/mlp_lazy.yaml @@ -0,0 +1,5 @@ +_target_: deepscreen.models.components.mlp.LazyMLP + +out_channels: 1 +hidden_channels: [1024,1024,512] +dropout: 0.1 \ No newline at end of file diff --git a/configs/model/drug_encoder/cnn.yaml b/configs/model/drug_encoder/cnn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..453ef99dec6c2a5db821181e72eef24c9faab966 --- /dev/null +++ b/configs/model/drug_encoder/cnn.yaml @@ -0,0 +1,9 @@ +_target_: deepscreen.models.components.cnn.CNN + +max_sequence_length: ${data.drug_featurizer.max_sequence_length} +filters: [32, 64, 96] +kernels: [4, 6, 8] +in_channels: ${data.drug_featurizer.in_channels} +out_channels: 256 + +# TODO refactor the in_channels argument pipeline to be more reasonable \ No newline at end of file diff --git a/configs/model/drug_encoder/cnn_deepdta.yaml b/configs/model/drug_encoder/cnn_deepdta.yaml new file mode 100644 index 0000000000000000000000000000000000000000..97bf3a4870e224b2bc7a5eab946d38b57c279d26 --- /dev/null +++ b/configs/model/drug_encoder/cnn_deepdta.yaml @@ -0,0 +1,7 @@ +_target_: deepscreen.models.components.cnn_deepdta.CNN_DeepDTA + +max_sequence_length: ${data.drug_featurizer.max_sequence_length} +filters: [32, 64, 96] +kernels: [4, 6, 8] +in_channels: ${data.drug_featurizer.in_channels} +out_channels: 128 \ No newline at end of file diff --git a/configs/model/drug_encoder/gat.yaml b/configs/model/drug_encoder/gat.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3b8dddd5ab5fac2b3cefff8191c01d3ab393c8a5 --- /dev/null +++ b/configs/model/drug_encoder/gat.yaml @@ -0,0 +1,5 @@ +_target_: deepscreen.models.components.gat.GAT + +num_features: 78 +out_channels: 128 +dropout: 0.2 \ No newline at end of file diff --git a/configs/model/drug_encoder/gcn.yaml b/configs/model/drug_encoder/gcn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e2da337b324610297bd76b6783340d03bd681a8 --- /dev/null +++ b/configs/model/drug_encoder/gcn.yaml @@ -0,0 +1,5 @@ +_target_: deepscreen.models.components.gcn.GCN + +num_features: 78 +out_channels: 128 +dropout: 0.2 \ No newline at end of file diff --git a/configs/model/drug_encoder/gin.yaml b/configs/model/drug_encoder/gin.yaml new file mode 100644 index 0000000000000000000000000000000000000000..caf5820c158ff7956b2440d25f7b5f901936f683 --- /dev/null +++ b/configs/model/drug_encoder/gin.yaml @@ -0,0 +1,5 @@ +_target_: deepscreen.models.components.gin.GIN + +num_features: 78 +out_channels: 128 +dropout: 0.2 \ No newline at end of file diff --git a/configs/model/drug_encoder/lstm.yaml b/configs/model/drug_encoder/lstm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/model/drug_encoder/transformer.yaml b/configs/model/drug_encoder/transformer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5eee1571675bcb2d2d539dbaa69ea0268c7b3908 --- /dev/null +++ b/configs/model/drug_encoder/transformer.yaml @@ -0,0 +1,11 @@ +_target_: deepscreen.models.components.transformer + +input_dim: 1024 +emb_size: 128 +max_position_size: 50 +dropout: 0.1 +n_layer: 8 +intermediate_size: 512 +num_attention_heads: 8 +attention_probs_dropout: 0.1 +hidden_dropout: 0.1 \ No newline at end of file diff --git a/configs/model/dti_model.yaml b/configs/model/dti_model.yaml index 72b1c9ee2b75a2f6ffce9a33b8d42cafd9cc67f0..77c21ca5b43e97f08ea2db523d8dd28fe9f4e757 100644 --- a/configs/model/dti_model.yaml +++ b/configs/model/dti_model.yaml @@ -5,7 +5,7 @@ defaults: - optimizer: adam - scheduler: default - predictor: none - - metrics: dti_metrics + - metrics: null out: ${task.out} loss: ${task.loss} diff --git a/configs/model/metrics/accuracy.yaml b/configs/model/metrics/accuracy.yaml index 80a9d3f8a6837571b70d0cfb70ba7c6f59ee1c3e..ba1ec9ceca32ded612fd2fec9dc33da44f00e0da 100644 --- a/configs/model/metrics/accuracy.yaml +++ b/configs/model/metrics/accuracy.yaml @@ -1,4 +1,4 @@ -accuracy: +Accuracy: _target_: torchmetrics.Accuracy task: ${task.task} num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/auprc.yaml b/configs/model/metrics/auprc.yaml index b9de03c65fdeec91e8b39f38dd773d9956e2e8ce..45ade5a5d5527f8e54ce36b7312b65acc0ccee06 100644 --- a/configs/model/metrics/auprc.yaml +++ b/configs/model/metrics/auprc.yaml @@ -1,4 +1,4 @@ -auprc: +AUPRC: _target_: torchmetrics.AveragePrecision task: ${task.task} num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/auroc.yaml b/configs/model/metrics/auroc.yaml index a4bcdbbd885ae5ba05120e5568ca7d3a323f213f..cef0d8dc6727da3e7f697ae77564ad090c67dc0c 100644 --- a/configs/model/metrics/auroc.yaml +++ b/configs/model/metrics/auroc.yaml @@ -1,4 +1,4 @@ -auroc: +AUROC: _target_: torchmetrics.AUROC task: ${task.task} num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/bedroc.yaml b/configs/model/metrics/bedroc.yaml index 86c68585882faefde5ee711dfe647d406a06dd09..11da21880310df7f5cf8e39fa325a4454eaa96dc 100644 --- a/configs/model/metrics/bedroc.yaml +++ b/configs/model/metrics/bedroc.yaml @@ -1,3 +1,3 @@ -bedroc: +BEDROC: _target_: deepscreen.models.metrics.bedroc.BEDROC alpha: 80.5 \ No newline at end of file diff --git a/configs/model/metrics/ci.yaml b/configs/model/metrics/ci.yaml index 634ea1906b5d0307e2e0342beec04c28e66af308..8ed012f5e90f08ba9d0b4a15dca486335422253b 100644 --- a/configs/model/metrics/ci.yaml +++ b/configs/model/metrics/ci.yaml @@ -1,2 +1,2 @@ -# FIXME: implement concordance index -_target_: \ No newline at end of file +CI: + _target_: deepscreen.models.metrics.ci.ConcordanceIndex diff --git a/configs/model/metrics/concordance_index.yaml b/configs/model/metrics/concordance_index.yaml new file mode 100644 index 0000000000000000000000000000000000000000..634ea1906b5d0307e2e0342beec04c28e66af308 --- /dev/null +++ b/configs/model/metrics/concordance_index.yaml @@ -0,0 +1,2 @@ +# FIXME: implement concordance index +_target_: \ No newline at end of file diff --git a/configs/model/metrics/dta_metrics.yaml b/configs/model/metrics/dta_metrics.yaml index 2b0f108cb9df18dfcceae6af9fd94789685a54b2..1dd0be02dd08d3bd1458ce3e69511313a7429e47 100644 --- a/configs/model/metrics/dta_metrics.yaml +++ b/configs/model/metrics/dta_metrics.yaml @@ -1,2 +1,4 @@ defaults: - - mean_squared_error + - mse + - pearson + - ci \ No newline at end of file diff --git a/configs/model/metrics/dti_case_study.yaml b/configs/model/metrics/dti_case_study.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4ee25f2a9cdb1607268111f0f9b5e06ef8c5c714 --- /dev/null +++ b/configs/model/metrics/dti_case_study.yaml @@ -0,0 +1,18 @@ +# train/test with many metrics at once + +defaults: + - auroc + - auprc + - specificity + - sensitivity + - precision + - recall + - f1_score + - ef + - bedroc + - hit_rate + +# Common virtual screening metrics: +# - ef +# - bedroc +# - hit_rate diff --git a/configs/model/metrics/dti_metrics.yaml b/configs/model/metrics/dti_metrics.yaml index 93ca0c40a0d77ffb095ab91d52863667bf5746cc..3e1d1460b08f522449a8de91bef568ec0ae78490 100644 --- a/configs/model/metrics/dti_metrics.yaml +++ b/configs/model/metrics/dti_metrics.yaml @@ -1,4 +1,4 @@ -# train with many loggers at once +# train/test with many metrics at once defaults: - auroc @@ -8,6 +8,7 @@ defaults: - precision - recall - f1_score + # Common virtual screening metrics: # - ef # - bedroc diff --git a/configs/model/metrics/ef.yaml b/configs/model/metrics/ef.yaml index 82553b414da98c55508db9830d82b12af3db786d..8f7ca459b275f9bbd36e560b9973c16ff6547def 100644 --- a/configs/model/metrics/ef.yaml +++ b/configs/model/metrics/ef.yaml @@ -1,7 +1,23 @@ -ef1: - _target_: deepscreen.models.metrics.ef.EF +EF1: + _target_: deepscreen.models.metrics.ef.EnrichmentFactor alpha: 0.01 -ef5: - _target_: deepscreen.models.metrics.ef.EF - alpha: 0.05 \ No newline at end of file +EF2: + _target_: deepscreen.models.metrics.ef.EnrichmentFactor + alpha: 0.02 + +EF5: + _target_: deepscreen.models.metrics.ef.EnrichmentFactor + alpha: 0.05 + +EF10: + _target_: deepscreen.models.metrics.ef.EnrichmentFactor + alpha: 0.10 + +EF15: + _target_: deepscreen.models.metrics.ef.EnrichmentFactor + alpha: 0.15 + +EF20: + _target_: deepscreen.models.metrics.ef.EnrichmentFactor + alpha: 0.20 diff --git a/configs/model/metrics/f1_score.yaml b/configs/model/metrics/f1_score.yaml index abfb6e4ca37a9dad399aeb3b1244d958d542e238..e80737bb56fcc7627a469a79fd6e81b91dd0bed5 100644 --- a/configs/model/metrics/f1_score.yaml +++ b/configs/model/metrics/f1_score.yaml @@ -1,4 +1,4 @@ -f1_score: +F1: _target_: torchmetrics.F1Score task: ${task.task} num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/hit_rate.yaml b/configs/model/metrics/hit_rate.yaml index 70976774eb365fdf8b8c1f97bd3ad19d6cb64cb8..05776b14828dd5805f9811cfb985ed7c60315de2 100644 --- a/configs/model/metrics/hit_rate.yaml +++ b/configs/model/metrics/hit_rate.yaml @@ -1,3 +1,24 @@ -hit_rate: +HR0_01: + _target_: deepscreen.models.metrics.hit_rate.HitRate + alpha: 0.01 + +HR0_02: + _target_: deepscreen.models.metrics.hit_rate.HitRate + alpha: 0.02 + +HR0_05: _target_: deepscreen.models.metrics.hit_rate.HitRate alpha: 0.05 + +HR0_10: + _target_: deepscreen.models.metrics.hit_rate.HitRate + alpha: 0.10 + +HR0_15: + _target_: deepscreen.models.metrics.hit_rate.HitRate + alpha: 0.15 + +HR0_20: + _target_: deepscreen.models.metrics.hit_rate.HitRate + alpha: 0.20 + diff --git a/configs/model/metrics/ir_hit_rate.yaml b/configs/model/metrics/ir_hit_rate.yaml new file mode 100644 index 0000000000000000000000000000000000000000..77ee1db11a3fe8aa48952c714a4a43ebb3fee8bf --- /dev/null +++ b/configs/model/metrics/ir_hit_rate.yaml @@ -0,0 +1,3 @@ +RetrievalHitRate: + _target_: torchmetrics.retrieval.RetrievalHitRate + top_k: 100 diff --git a/configs/model/metrics/mean_squared_error.yaml b/configs/model/metrics/mean_squared_error.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d9a18c60d43210b16878ff479b1cfa3168788cf --- /dev/null +++ b/configs/model/metrics/mean_squared_error.yaml @@ -0,0 +1,2 @@ +mean_squared_error: + _target_: torchmetrics.MeanSquaredError \ No newline at end of file diff --git a/configs/model/metrics/mse.yaml b/configs/model/metrics/mse.yaml index 0d9a18c60d43210b16878ff479b1cfa3168788cf..d60197f52d2f301d8a050c56f6f1673cd2355d69 100644 --- a/configs/model/metrics/mse.yaml +++ b/configs/model/metrics/mse.yaml @@ -1,2 +1,2 @@ -mean_squared_error: +Mean squared error: _target_: torchmetrics.MeanSquaredError \ No newline at end of file diff --git a/configs/model/metrics/prc.yaml b/configs/model/metrics/prc.yaml index 75e3ee320d5b9a32a9acdb55d564fbacab975088..14ad48ec598db6a439dec9e805be0dce7f60ef77 100644 --- a/configs/model/metrics/prc.yaml +++ b/configs/model/metrics/prc.yaml @@ -1,4 +1,4 @@ -prc: +PR curve: _target_: torchmetrics.PrecisionRecallCurve task: ${task.task} num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/precision.yaml b/configs/model/metrics/precision.yaml index 4b8212b1999c10b627de5859e034260a39022608..428a33e3ab16d756bd5da98f50561f136058f90e 100644 --- a/configs/model/metrics/precision.yaml +++ b/configs/model/metrics/precision.yaml @@ -1,4 +1,4 @@ -precision: +Precision: _target_: torchmetrics.Precision task: ${task.task} num_classes: ${task.num_classes} diff --git a/configs/model/metrics/recall.yaml b/configs/model/metrics/recall.yaml index eadad752ad1c4e1ac137580a289d4d32916ca975..06b783015eda77ab2c8ff5889d269f5998ff2d19 100644 --- a/configs/model/metrics/recall.yaml +++ b/configs/model/metrics/recall.yaml @@ -1,4 +1,4 @@ -recall: +Recall: _target_: torchmetrics.Recall task: ${task.task} num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/roc.yaml b/configs/model/metrics/roc.yaml index 91968a6f42e3d399f39e587785eaddabda523a23..5c661013ad0d584527585b792a6011f4bcc44e5b 100644 --- a/configs/model/metrics/roc.yaml +++ b/configs/model/metrics/roc.yaml @@ -1,4 +1,4 @@ -roc: +ROC curve: _target_: torchmetrics.ROC task: ${task.task} num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/sensitivity.yaml b/configs/model/metrics/sensitivity.yaml index 49568b4512c2b75ebfce98d8ee03e2b1148966cc..f0b3a87d140a848770c7f8654f12336eb9f3481e 100644 --- a/configs/model/metrics/sensitivity.yaml +++ b/configs/model/metrics/sensitivity.yaml @@ -1,4 +1,4 @@ -sensitivity: +Sensitivity: _target_: deepscreen.models.metrics.sensitivity.Sensitivity task: ${task.task} num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/specificity.yaml b/configs/model/metrics/specificity.yaml index 5b161be947876081ce9b6b070ed7201874f5cbc6..ea9ece52149cdc82ee180cbbc3bdc6960124556e 100644 --- a/configs/model/metrics/specificity.yaml +++ b/configs/model/metrics/specificity.yaml @@ -1,4 +1,4 @@ -specificity: +Specificity: _target_: torchmetrics.Specificity task: ${task.task} num_classes: ${task.num_classes} \ No newline at end of file diff --git a/configs/model/metrics/ww_dti_metrics.yaml b/configs/model/metrics/ww_dti_metrics.yaml index c2685e108dfa2476c8eb67ef7f41d41dd9910390..a7073849baec744cad4182ce8be3187edcdd3140 100644 --- a/configs/model/metrics/ww_dti_metrics.yaml +++ b/configs/model/metrics/ww_dti_metrics.yaml @@ -190,4 +190,4 @@ F1Score0_95: _target_: torchmetrics.F1Score task: ${task.task} num_classes: ${task.num_classes} - threshold: 0.95 \ No newline at end of file + threshold: 0.95 diff --git a/configs/model/predictor/drug_vqa.yaml b/configs/model/predictor/drug_vqa.yaml index 3976e440232c5bf123c21d48b499b9ac459fdf87..019eca2058a1867fdbd5493ebbab8b42c74b82e2 100644 --- a/configs/model/predictor/drug_vqa.yaml +++ b/configs/model/predictor/drug_vqa.yaml @@ -5,7 +5,7 @@ lstm_hid_dim: 64 d_a: 32 r: 10 n_chars_smi: 577 -n_chars_seq: 21 +n_chars_seq: 26 dropout: 0.2 in_channels: 8 cnn_channels: 32 @@ -13,3 +13,4 @@ cnn_layers: 4 emb_dim: 30 dense_hid: 64 + diff --git a/configs/model/protein_encoder/cnn.yaml b/configs/model/protein_encoder/cnn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9363e5e6130a5292d173b90eba8430c12edef3d1 --- /dev/null +++ b/configs/model/protein_encoder/cnn.yaml @@ -0,0 +1,7 @@ +_target_: deepscreen.models.components.cnn.CNN + +max_sequence_length: ${data.protein_featurizer.max_sequence_length} +filters: [32, 64, 96] +kernels: [4, 8, 12] +in_channels: ${data.protein_featurizer.in_channels} +out_channels: 256 \ No newline at end of file diff --git a/configs/model/protein_encoder/cnn_deepdta.yaml b/configs/model/protein_encoder/cnn_deepdta.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8ac5f6d9064695cdfe2739946c1e55d55aa588d5 --- /dev/null +++ b/configs/model/protein_encoder/cnn_deepdta.yaml @@ -0,0 +1,7 @@ +_target_: deepscreen.models.components.cnn_deepdta.CNN_DeepDTA + +max_sequence_length: ${data.protein_featurizer.max_sequence_length} +filters: [32, 64, 96] +kernels: [4, 8, 12] +in_channels: ${data.protein_featurizer.in_channels} +out_channels: 128 \ No newline at end of file diff --git a/configs/model/protein_encoder/lstm.yaml b/configs/model/protein_encoder/lstm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/model/protein_encoder/tape_bert.yaml b/configs/model/protein_encoder/tape_bert.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6a64b46c70850a5d796a19615ba553e69cb26a39 --- /dev/null +++ b/configs/model/protein_encoder/tape_bert.yaml @@ -0,0 +1,3 @@ +_target_: tape.ProteinBertModel.from_pretrained + +pretrained_model_name_or_path: bert-base \ No newline at end of file diff --git a/configs/model/protein_encoder/transformer.yaml b/configs/model/protein_encoder/transformer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4fb7c9761bc098da84773f56bc949ce7c7d34c5f --- /dev/null +++ b/configs/model/protein_encoder/transformer.yaml @@ -0,0 +1,12 @@ +_target_: deepscreen.models.components.transformer + +input_dim: 8420 +emb_size: 64 +max_position_size: 545 50 +dropout: 0.1 +n_layer: 2 +intermediate_size: 256 +num_attention_heads: 4 +attention_probs_dropout: 0.1 +hidden_dropout: 0.1 + diff --git a/configs/preset/bacpi.yaml b/configs/preset/bacpi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9dbea130fd0f6b6fe07f009c340f33f6f8f6c558 --- /dev/null +++ b/configs/preset/bacpi.yaml @@ -0,0 +1,37 @@ +# @package _global_ +model: + predictor: + _target_: deepscreen.models.predictors.bacpi.BACPI + + n_atom: 20480 + n_amino: 8448 + comp_dim: 80 + prot_dim: 80 + latent_dim: 80 + gat_dim: 50 + num_head: 3 + dropout: 0.1 + alpha: 0.1 + window: 5 + layer_cnn: 3 + optimizer: + lr: 5e-4 + +data: + batch_size: 16 + + collator: + automatic_padding: True + + drug_featurizer: + _target_: deepscreen.models.predictors.bacpi.drug_featurizer + _partial_: true + radius: 2 + + protein_featurizer: + _target_: deepscreen.models.predictors.bacpi.split_sequence + _partial_: true + ngram: 3 +# collator: +# _target_: deepscreen.models.predictors.transformer_cpi_2.pack +# _partial_: true diff --git a/configs/preset/coa_dti_pro.yaml b/configs/preset/coa_dti_pro.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9ad501a6176a630b8edb97653607ef223946367d --- /dev/null +++ b/configs/preset/coa_dti_pro.yaml @@ -0,0 +1,28 @@ +# @package _global_ +defaults: + - override /data/protein_featurizer: none + +model: + predictor: + _target_: deepscreen.models.predictors.coa_dti_pro.CoaDTIPro + + n_fingerprint: 20480 + n_word: 26 + dim: 512 + layer_output: 3 + layer_coa: 1 + nhead: 8 + dropout: 0.1 + co_attention: 'inter' + gcn_pooling: False + + esm_model_and_alphabet: + _target_: esm.pretrained.load_model_and_alphabet + model_name: resources/models/esm/esm1_t6_43M_UR50S.pt + +data: + drug_featurizer: + _target_: deepscreen.models.predictors.coa_dti_pro.drug_featurizer + _partial_: true + radius: 2 + batch_size: 1 diff --git a/configs/preset/deep_dtaf.yaml b/configs/preset/deep_dtaf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a2f2a8d143d28834b4c94e300c463bc776e23adf --- /dev/null +++ b/configs/preset/deep_dtaf.yaml @@ -0,0 +1,19 @@ +# @package _global_ +defaults: + - override /data/drug_featurizer: label + - override /data/protein_featurizer: label + - override /model/predictor: deep_dta + +data: + drug_featurizer: + charset: {'Z', 'Y', 'H', '[', 'O', ']', '5', 'M', 'K', '.', '9', 'e', + '(', 'l', 'U', 'V', 'L', 'B', 'y', 'm', 'd', 'h', 'T', 'A', + 'W', 'b', 'i', 'D', 'R', '8', '/', 's', '#', 'u', '+', '@', + 'n', '%', 'F', 'r', 't', 'I', 'S', '6', 'P', 'G', 'f', ')', + '-', '\\', 'C', 'E', 'o', '3', '2', '1', '=', 'g', 'c', 'N', + '7', '4', 'a', '0'] + batch_size: 512 + +model: + predictor: + smi_charset_len: ${eval:'len(${data.protein_featurizer.charset})+1'} diff --git a/configs/preset/drug_ban.yaml b/configs/preset/drug_ban.yaml index 537c9d79276383542e956ff4ef27178c4c358fc4..3d6b00d0824b34e3f602892fb07ae789ced8fc41 100644 --- a/configs/preset/drug_ban.yaml +++ b/configs/preset/drug_ban.yaml @@ -25,4 +25,4 @@ data: _partial_: true max_drug_nodes: 330 - batch_size: 512 + batch_size: 256 diff --git a/configs/preset/m_graph_dta.yaml b/configs/preset/m_graph_dta.yaml index 6f6a43a78d98ac5721efd99f7ca27c0b594f8f26..8609b499b94168c7c5da2b5779b310be5c769020 100644 --- a/configs/preset/m_graph_dta.yaml +++ b/configs/preset/m_graph_dta.yaml @@ -16,4 +16,7 @@ data: atom_features: _target_: deepscreen.models.predictors.m_graph_dta.atom_features _partial_: true - batch_size: 512 \ No newline at end of file + batch_size: 512 + +trainer: + precision: 'bf16' \ No newline at end of file diff --git a/configs/preset/mol_trans.yaml b/configs/preset/mol_trans.yaml index ed7fcf5ce9d43b0e047b9720cb544faf707fb64f..a4701094a69ef03e6a54d56dec51fbd83b211d91 100644 --- a/configs/preset/mol_trans.yaml +++ b/configs/preset/mol_trans.yaml @@ -36,4 +36,4 @@ model: #flatten_dim: 293412 optimizer: - lr: 1e-6 \ No newline at end of file + lr: 1e-6 diff --git a/configs/preset/monn.yaml b/configs/preset/monn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..117031aaf9f51f3f935c0a679af40883474b7880 --- /dev/null +++ b/configs/preset/monn.yaml @@ -0,0 +1,17 @@ +# @package _global_ +defaults: + - dti_experiment + # TODO MONN featurizers not fully implemented yet + - override /data/drug_featurizer: label + - override /data/protein_featurizer: label + - override /model/predictor: monn + - override /task: binary + - _self_ + +model: + loss: + _target_: deepscreen.models.loss.multitask_loss.MultitaskWeightedLoss + loss_fns: + - _target_: ${model.loss} + - _target_: deepscreen.models.predictors.monn.MaskedBCELoss + weights: [1, 0.1] diff --git a/configs/preset/transformer_cpi.yaml b/configs/preset/transformer_cpi.yaml index c218ea5a4c51454e9fa4eeb45f97150ac2b99aae..c37697c20ba6d83b55d7e7ee197b369180d47bf4 100644 --- a/configs/preset/transformer_cpi.yaml +++ b/configs/preset/transformer_cpi.yaml @@ -16,6 +16,7 @@ model: atom_dim: 34 data: - batch_size: 16 + batch_size: 128 collator: automatic_padding: True + diff --git a/configs/preset/transformer_cpi_2.yaml b/configs/preset/transformer_cpi_2.yaml index e1c224f88a6de1266ee4d836c15ff5e88b06f119..db33a01104222d8b220c6783a15919eb1baea7c7 100644 --- a/configs/preset/transformer_cpi_2.yaml +++ b/configs/preset/transformer_cpi_2.yaml @@ -4,7 +4,11 @@ defaults: - override /data/protein_featurizer: tokenizer model: + optimizer: + lr: 0.00001 + predictor: + _target_: deepscreen.models.predictors.transformer_cpi_2.TransformerCPI2 encoder: diff --git a/configs/task/default.yaml b/configs/task/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2e280bad0a70512b7be7e73ea731821b50904cdb --- /dev/null +++ b/configs/task/default.yaml @@ -0,0 +1,9 @@ +task: null +num_classes: null + +out: + _target_: torch.nn.Identity +loss: + _target_: torch.nn.Identity +activation: + _target_: torch.nn.Identity diff --git a/configs/trainer/ddp.yaml b/configs/trainer/ddp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3c9778447171ba5f5282e1087aa61be2b4763622 --- /dev/null +++ b/configs/trainer/ddp.yaml @@ -0,0 +1,10 @@ +defaults: + - default + +strategy: ddp + +accelerator: gpu +devices: 4 +num_nodes: 1 +sync_batchnorm: True +precision: 16-mixed diff --git a/configs/trainer/ddp_sim.yaml b/configs/trainer/ddp_sim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8404419e5c295654967d0dfb73a7366e75be2f1f --- /dev/null +++ b/configs/trainer/ddp_sim.yaml @@ -0,0 +1,7 @@ +defaults: + - default + +# simulate DDP on CPU, useful for debugging +accelerator: cpu +devices: 2 +strategy: ddp_spawn diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml index b53d8f577f8d1895d9ce12003c1a6593a2bce06d..5a00b4536dec28a739772e2cc798ab46b2dc2824 100644 --- a/configs/trainer/default.yaml +++ b/configs/trainer/default.yaml @@ -5,7 +5,6 @@ default_root_dir: ${paths.output_dir} min_epochs: 1 max_epochs: 50 -accelerator: auto precision: 32 gradient_clip_val: 0.5 diff --git a/configs/trainer/mps.yaml b/configs/trainer/mps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ecf6d5cc3a34ca127c5510f4a18e989561e38e4 --- /dev/null +++ b/configs/trainer/mps.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: mps +devices: 1 diff --git a/configs/webserver_inference.yaml b/configs/webserver_inference.yaml index 5a7e0e012abff2c0a0612f1a092846aec7c149f4..85ab0acd3dd0d1ecdcd48f31ff6c5e1b734aeb2c 100644 --- a/configs/webserver_inference.yaml +++ b/configs/webserver_inference.yaml @@ -28,7 +28,4 @@ paths: work_dir: null data: - num_workers: 8 - -trainer: - precision: 32 \ No newline at end of file + num_workers: 8 \ No newline at end of file diff --git a/deepscreen/__pycache__/__init__.cpython-311.pyc b/deepscreen/__pycache__/__init__.cpython-311.pyc index f716a0fe682e10bbd19b657e449841053938c805..ee671c098ed59a04efe355d0772dfcd7606b811f 100644 Binary files a/deepscreen/__pycache__/__init__.cpython-311.pyc and b/deepscreen/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/__pycache__/__init__.cpython-39.pyc b/deepscreen/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02a981c48c8d2d301461ee7e340000abccac5a95 Binary files /dev/null and b/deepscreen/__pycache__/__init__.cpython-39.pyc differ diff --git a/deepscreen/__pycache__/predict.cpython-311.pyc b/deepscreen/__pycache__/predict.cpython-311.pyc index 36ecc3f738005c62fa82d2d50947c96604afb506..51c1c7b3ed1a1ae7fa6f3121ac094ae24e4d9072 100644 Binary files a/deepscreen/__pycache__/predict.cpython-311.pyc and b/deepscreen/__pycache__/predict.cpython-311.pyc differ diff --git a/deepscreen/__pycache__/test.cpython-311.pyc b/deepscreen/__pycache__/test.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45585613ecb563341da0d50d184b0acc3ab640cb Binary files /dev/null and b/deepscreen/__pycache__/test.cpython-311.pyc differ diff --git a/deepscreen/__pycache__/train.cpython-311.pyc b/deepscreen/__pycache__/train.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91cea93e08cc598498f9439730e3e50ae10903c3 Binary files /dev/null and b/deepscreen/__pycache__/train.cpython-311.pyc differ diff --git a/deepscreen/__pycache__/train.cpython-39.pyc b/deepscreen/__pycache__/train.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..041a052af12765f27cac7794c9226efff9ef5dfc Binary files /dev/null and b/deepscreen/__pycache__/train.cpython-39.pyc differ diff --git a/deepscreen/data/__pycache__/__init__.cpython-311.pyc b/deepscreen/data/__pycache__/__init__.cpython-311.pyc index e0c09d97afa3c2a4034a852cc250d2fa20bdc490..b71e96c96119cf4998e731ade778d0572911ba2a 100644 Binary files a/deepscreen/data/__pycache__/__init__.cpython-311.pyc and b/deepscreen/data/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/data/__pycache__/__init__.cpython-39.pyc b/deepscreen/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d60e7b76e8d6ca276bf0853c0f937ddd83ae15d3 Binary files /dev/null and b/deepscreen/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/deepscreen/data/__pycache__/dti.cpython-311.pyc b/deepscreen/data/__pycache__/dti.cpython-311.pyc index a91110329d8403a13aa5195248050b47c3e6bba1..f59119302483207f348da69fc051d41e97e80db5 100644 Binary files a/deepscreen/data/__pycache__/dti.cpython-311.pyc and b/deepscreen/data/__pycache__/dti.cpython-311.pyc differ diff --git a/deepscreen/data/__pycache__/dti_datamodule.cpython-311.pyc b/deepscreen/data/__pycache__/dti_datamodule.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..868b4f30b00539c0675f19112a5dbd11c3381f43 Binary files /dev/null and b/deepscreen/data/__pycache__/dti_datamodule.cpython-311.pyc differ diff --git a/deepscreen/data/dti.py b/deepscreen/data/dti.py index 38aacc5bffe7906e77feb570998d3943e3ee6453..71142a1bef25196b4364718e918bf9d40168cee0 100644 --- a/deepscreen/data/dti.py +++ b/deepscreen/data/dti.py @@ -170,7 +170,7 @@ class DTIDataset(Dataset): desc="Validating SMILES...").apply(validate_seq_str, regex=SMILES_PAT) if not df['X1_ERR'].isna().all(): raise Exception(f"Encountered invalid SMILES:\n{df[~df['X1_ERR'].isna()][['X1', 'X1_ERR']]}") - df['X1^'] = df['X1'].apply(rdkit_canonicalize) # swifter + df['X1^'] = df['X1'].swifter.apply(rdkit_canonicalize) # swifter log.info("Validating FASTA (`X2`)...") df['X2'] = df['X2'].str.upper() @@ -252,6 +252,7 @@ class DTIDataModule(LightningDataModule): split: Optional[callable] = None, thresholds: Optional[Union[Number, Sequence[Number]]] = None, discard_intermediate: Optional[bool] = False, + query: Optional[str] = 'X2', num_workers: int = 0, pin_memory: bool = False, ): @@ -270,7 +271,8 @@ class DTIDataModule(LightningDataModule): drug_featurizer=drug_featurizer, protein_featurizer=protein_featurizer, thresholds=thresholds, - discard_intermediate=discard_intermediate + discard_intermediate=discard_intermediate, + query=query ) # this line allows to access init params with 'self.hparams' ensures init params will be stored in ckpt @@ -423,3 +425,4 @@ class DTIDataModule(LightningDataModule): def load_state_dict(self, state_dict: Dict[str, Any]): """Things to do when loading checkpoint.""" pass + diff --git a/deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc b/deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc index 02522f17ba285baea39eddce0b53111f20d16a8e..1f63fdd1e08a1cb2ded973b38c2e8ee8aac6860f 100644 Binary files a/deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc and b/deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc b/deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc index 2f5e8e6f2243f6bfd75e95d81ed3d5547a1662ba..3e567bcef77e1d7400b0d31c69b493a111d0dcbd 100644 Binary files a/deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc and b/deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/__pycache__/fcs.cpython-311.pyc b/deepscreen/data/featurizers/__pycache__/fcs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2e63c8b199633eba1a680d1f859de393f6eaba9 Binary files /dev/null and b/deepscreen/data/featurizers/__pycache__/fcs.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/__pycache__/graph.cpython-311.pyc b/deepscreen/data/featurizers/__pycache__/graph.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99e04ca03b10a3354af0f5b9c23c4ea246a800e2 Binary files /dev/null and b/deepscreen/data/featurizers/__pycache__/graph.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/__pycache__/token.cpython-311.pyc b/deepscreen/data/featurizers/__pycache__/token.cpython-311.pyc index 77f1a92dead63d52aa40a7e2a6add8f6ed5dcf66..b8f92e88b7bc00a3ce6f08815c10f2182bd24012 100644 Binary files a/deepscreen/data/featurizers/__pycache__/token.cpython-311.pyc and b/deepscreen/data/featurizers/__pycache__/token.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/fingerprint/__pycache__/__init__.cpython-311.pyc b/deepscreen/data/featurizers/fingerprint/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..336627658c14827c0018be74975cce400d966059 Binary files /dev/null and b/deepscreen/data/featurizers/fingerprint/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/fingerprint/__pycache__/atompairs.cpython-311.pyc b/deepscreen/data/featurizers/fingerprint/__pycache__/atompairs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ad5c8444560c3083ad01b53121f120504d4b657 Binary files /dev/null and b/deepscreen/data/featurizers/fingerprint/__pycache__/atompairs.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/fingerprint/__pycache__/avalonfp.cpython-311.pyc b/deepscreen/data/featurizers/fingerprint/__pycache__/avalonfp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c7ad090e65bb9c20de83ca780f307b26a15b1de Binary files /dev/null and b/deepscreen/data/featurizers/fingerprint/__pycache__/avalonfp.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/fingerprint/__pycache__/estatefp.cpython-311.pyc b/deepscreen/data/featurizers/fingerprint/__pycache__/estatefp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3aa4edc7768b6fca28bce76630542768cfb21ee1 Binary files /dev/null and b/deepscreen/data/featurizers/fingerprint/__pycache__/estatefp.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/fingerprint/__pycache__/maccskeys.cpython-311.pyc b/deepscreen/data/featurizers/fingerprint/__pycache__/maccskeys.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e12d3befc7f50a7ca65e2ee9940938fc8faebcc Binary files /dev/null and b/deepscreen/data/featurizers/fingerprint/__pycache__/maccskeys.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/fingerprint/__pycache__/map4.cpython-311.pyc b/deepscreen/data/featurizers/fingerprint/__pycache__/map4.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd4e82b3d7e56a804c6f3a8b7593e6ddcfca9fae Binary files /dev/null and b/deepscreen/data/featurizers/fingerprint/__pycache__/map4.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/fingerprint/__pycache__/mhfp6.cpython-311.pyc b/deepscreen/data/featurizers/fingerprint/__pycache__/mhfp6.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1049912fa1b3561c0707d25b87a9427c160bfc1a Binary files /dev/null and b/deepscreen/data/featurizers/fingerprint/__pycache__/mhfp6.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/fingerprint/__pycache__/morganfp.cpython-311.pyc b/deepscreen/data/featurizers/fingerprint/__pycache__/morganfp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fc32cfd5465e85e4fa6528159bb1582c14bea3a Binary files /dev/null and b/deepscreen/data/featurizers/fingerprint/__pycache__/morganfp.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/fingerprint/__pycache__/pharmErGfp.cpython-311.pyc b/deepscreen/data/featurizers/fingerprint/__pycache__/pharmErGfp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0da340c2af5698b668f11523c46dcbe03f761310 Binary files /dev/null and b/deepscreen/data/featurizers/fingerprint/__pycache__/pharmErGfp.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/fingerprint/__pycache__/pharmPointfp.cpython-311.pyc b/deepscreen/data/featurizers/fingerprint/__pycache__/pharmPointfp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32ca23511a45d306e0f8c88c5d1dadaf92b43ceb Binary files /dev/null and b/deepscreen/data/featurizers/fingerprint/__pycache__/pharmPointfp.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/fingerprint/__pycache__/pubchemfp.cpython-311.pyc b/deepscreen/data/featurizers/fingerprint/__pycache__/pubchemfp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f28f5c7c1f652586a22ad48396a49cf085e81ad Binary files /dev/null and b/deepscreen/data/featurizers/fingerprint/__pycache__/pubchemfp.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/fingerprint/__pycache__/rdkitfp.cpython-311.pyc b/deepscreen/data/featurizers/fingerprint/__pycache__/rdkitfp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ff678d66c4c4dcbf08546593bc13edce40881db Binary files /dev/null and b/deepscreen/data/featurizers/fingerprint/__pycache__/rdkitfp.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/fingerprint/__pycache__/torsions.cpython-311.pyc b/deepscreen/data/featurizers/fingerprint/__pycache__/torsions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12fdfd9a6dc3741436369652f8577b7b2277a93b Binary files /dev/null and b/deepscreen/data/featurizers/fingerprint/__pycache__/torsions.cpython-311.pyc differ diff --git a/deepscreen/data/featurizers/token.py b/deepscreen/data/featurizers/token.py index ebc049ab2594fabbf8ab25de6fa042b58011cb12..6d28ad98be28d5e502d861d62968404d894f3338 100644 --- a/deepscreen/data/featurizers/token.py +++ b/deepscreen/data/featurizers/token.py @@ -297,3 +297,4 @@ def load_vocab(vocab_file): token = token.rstrip("\n") vocab[token] = index return vocab + diff --git a/deepscreen/data/utils/__pycache__/__init__.cpython-311.pyc b/deepscreen/data/utils/__pycache__/__init__.cpython-311.pyc index a06039ab6cbf8c6d7949afa9668d9c62bc356e16..f6be68dc842e23768e317b6d17571d4fed2a2acd 100644 Binary files a/deepscreen/data/utils/__pycache__/__init__.cpython-311.pyc and b/deepscreen/data/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/data/utils/__pycache__/__init__.cpython-39.pyc b/deepscreen/data/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb299012746ca8f73ad627eb37c987105258a97c Binary files /dev/null and b/deepscreen/data/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/deepscreen/data/utils/__pycache__/collator.cpython-311.pyc b/deepscreen/data/utils/__pycache__/collator.cpython-311.pyc index cb17a8a4426c690b4a6c56a3abbe993b459bbcd4..fd147e35c555bfc4ff3be6708e1aee9e9b003ca6 100644 Binary files a/deepscreen/data/utils/__pycache__/collator.cpython-311.pyc and b/deepscreen/data/utils/__pycache__/collator.cpython-311.pyc differ diff --git a/deepscreen/data/utils/__pycache__/collator.cpython-39.pyc b/deepscreen/data/utils/__pycache__/collator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3b4061d4419d0002ab4981b1ed0a307e0b0ea7f Binary files /dev/null and b/deepscreen/data/utils/__pycache__/collator.cpython-39.pyc differ diff --git a/deepscreen/data/utils/__pycache__/label.cpython-311.pyc b/deepscreen/data/utils/__pycache__/label.cpython-311.pyc index f866faaa1d09e2a3173e514924ecfcee4751ea28..4cb4f8e1993066ff91340f3f3ad109ee367f2e0f 100644 Binary files a/deepscreen/data/utils/__pycache__/label.cpython-311.pyc and b/deepscreen/data/utils/__pycache__/label.cpython-311.pyc differ diff --git a/deepscreen/data/utils/__pycache__/sampler.cpython-311.pyc b/deepscreen/data/utils/__pycache__/sampler.cpython-311.pyc index 724502a7093c85b8d03eb232cff1682881b5a39a..092161f64dfe2d1a4d3ed2a5420adffc15c9973a 100644 Binary files a/deepscreen/data/utils/__pycache__/sampler.cpython-311.pyc and b/deepscreen/data/utils/__pycache__/sampler.cpython-311.pyc differ diff --git a/deepscreen/data/utils/__pycache__/split.cpython-311.pyc b/deepscreen/data/utils/__pycache__/split.cpython-311.pyc index dc2484c2d6fc079bb396f19ac239768da834d54b..7ab833e2ae76e67c740596d5d93f0009fe7c6b93 100644 Binary files a/deepscreen/data/utils/__pycache__/split.cpython-311.pyc and b/deepscreen/data/utils/__pycache__/split.cpython-311.pyc differ diff --git a/deepscreen/data/utils/sampler.py b/deepscreen/data/utils/sampler.py index 0431c132fc8d8a54f877bb4b1e8fcade92717986..90e28aa84763e8cefd334237a9af785ccc6b6d5d 100644 --- a/deepscreen/data/utils/sampler.py +++ b/deepscreen/data/utils/sampler.py @@ -16,7 +16,6 @@ class SafeBatchSampler(BatchSampler): Example: >>> dataloader = DataLoader(dataset, batch_sampler=SafeBatchSampler(dataset, batch_size, drop_last, shuffle)) """ - def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool, sampler=None): if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ batch_size <= 0: @@ -85,6 +84,8 @@ class SafeBatchSampler(BatchSampler): if idx_in_batch > 0 and not self.drop_last: yield batch[:idx_in_batch] - if not any(batch): +# if not any(batch): # raise StopIteration - return +# return + def __len__(self): + float("inf") diff --git a/deepscreen/models/__pycache__/__init__.cpython-311.pyc b/deepscreen/models/__pycache__/__init__.cpython-311.pyc index 1b1d9db43c68b99c204ad7f1f685a1821ff20f5b..4b54ec8753bd519d60297627b83221c1ee85f44f 100644 Binary files a/deepscreen/models/__pycache__/__init__.cpython-311.pyc and b/deepscreen/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/models/__pycache__/__init__.cpython-39.pyc b/deepscreen/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d12a658003d6c602a700eb04c6425d38cbf99adb Binary files /dev/null and b/deepscreen/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/deepscreen/models/__pycache__/dti.cpython-311.pyc b/deepscreen/models/__pycache__/dti.cpython-311.pyc index 99cd5e5052e7efb9eb25683194d07ac7c9b825ac..bc405379496dbaf7ad2131561b1182059a43f123 100644 Binary files a/deepscreen/models/__pycache__/dti.cpython-311.pyc and b/deepscreen/models/__pycache__/dti.cpython-311.pyc differ diff --git a/deepscreen/models/components/__pycache__/__init__.cpython-311.pyc b/deepscreen/models/components/__pycache__/__init__.cpython-311.pyc index 435a4b856aea9f911f80936aa927c55342eebfe6..679bae723b9904659c5ad7a2aadf3c6837c297ab 100644 Binary files a/deepscreen/models/components/__pycache__/__init__.cpython-311.pyc and b/deepscreen/models/components/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/models/components/__pycache__/cnn.cpython-311.pyc b/deepscreen/models/components/__pycache__/cnn.cpython-311.pyc index d5972fc842191d7cc2b1775fa18472c427269e7a..b5fdd837d84a853f256ffdf0b9693ddbe5a7db94 100644 Binary files a/deepscreen/models/components/__pycache__/cnn.cpython-311.pyc and b/deepscreen/models/components/__pycache__/cnn.cpython-311.pyc differ diff --git a/deepscreen/models/components/__pycache__/gat.cpython-311.pyc b/deepscreen/models/components/__pycache__/gat.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f38c7d62d0ac7a404a0f4229fd1b4aac1e717a99 Binary files /dev/null and b/deepscreen/models/components/__pycache__/gat.cpython-311.pyc differ diff --git a/deepscreen/models/dti.py b/deepscreen/models/dti.py index 65ef25841820b614c022561b5ec1effdf731aaef..a692208a0ccd9c2f79646d8b729d216cbfd2082e 100644 --- a/deepscreen/models/dti.py +++ b/deepscreen/models/dti.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Optional, Sequence, Dict, Any +from typing import Optional, Sequence, Dict from torch import nn, optim, Tensor from lightning import LightningModule @@ -17,8 +17,6 @@ class DTILightningModule(LightningModule): model: a fully initialized instance of class torch.nn.Module metrics: a list of fully initialized instances of class torchmetrics.Metric """ - extra_return_keys = ['ID1', 'X1', 'ID2', 'X2', 'N'] - def __init__( self, optimizer: optim.Optimizer, @@ -51,19 +49,20 @@ class DTILightningModule(LightningModule): match stage: case 'fit': dataloader = self.trainer.datamodule.train_dataloader() - dummy_batch = next(iter(dataloader)) - self.forward(dummy_batch) - # case 'validate': - # dataloader = self.trainer.datamodule.val_dataloader() - # case 'test': - # dataloader = self.trainer.datamodule.test_dataloader() - # case 'predict': - # dataloader = self.trainer.datamodule.predict_dataloader() + case 'validate': + dataloader = self.trainer.datamodule.val_dataloader() + case 'test': + dataloader = self.trainer.datamodule.test_dataloader() + case 'predict': + dataloader = self.trainer.datamodule.predict_dataloader() + dummy_batch = next(iter(dataloader)) # for key, value in dummy_batch.items(): # if isinstance(value, Tensor): # dummy_batch[key] = value.to(self.device) + self.forward(dummy_batch) + def forward(self, batch): output = self.predictor(batch['X1^'], batch['X2^']) target = batch.get('Y') @@ -93,18 +92,13 @@ class DTILightningModule(LightningModule): self.train_metrics(preds=preds, target=target, indexes=indexes.long()) self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) - return_dict = { - 'Y^': preds, - 'Y': target, - 'loss': loss + return { + 'N': batch['N'], + 'ID1': batch['ID1'], 'X1': batch['X1'], + 'ID2': batch['ID2'], 'X2': batch['X2'], + 'Y^': preds, 'Y': target, 'loss': loss } - for key in self.extra_return_keys: - if key in batch: - return_dict[key] = batch[key] - - return return_dict - def on_train_epoch_end(self): pass @@ -115,18 +109,13 @@ class DTILightningModule(LightningModule): self.val_metrics(preds=preds, target=target, indexes=indexes.long()) self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) - return_dict = { - 'Y^': preds, - 'Y': target, - 'loss': loss + return { + 'N': batch['N'], + 'ID1': batch['ID1'], 'X1': batch['X1'], + 'ID2': batch['ID2'], 'X2': batch['X2'], + 'Y^': preds, 'Y': target, 'loss': loss } - for key in self.extra_return_keys: - if key in batch: - return_dict[key] = batch[key] - - return return_dict - def on_validation_epoch_end(self): pass @@ -137,34 +126,27 @@ class DTILightningModule(LightningModule): self.test_metrics(preds=preds, target=target, indexes=indexes.long()) self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) - return_dict = { - 'Y^': preds, - 'Y': target, - 'loss': loss + # return a dictionary for callbacks like BasePredictionWriter + return { + 'N': batch['N'], + 'ID1': batch['ID1'], 'X1': batch['X1'], + 'ID2': batch['ID2'], 'X2': batch['X2'], + 'Y^': preds, 'Y': target, 'loss': loss } - for key in self.extra_return_keys: - if key in batch: - return_dict[key] = batch[key] - - return return_dict - def on_test_epoch_end(self): pass def predict_step(self, batch, batch_idx, dataloader_idx=0): preds, _, _, _ = self.forward(batch) # return a dictionary for callbacks like BasePredictionWriter - return_dict = { - 'Y^': preds, + return { + 'N': batch['N'], + 'ID1': batch['ID1'], 'X1': batch['X1'], + 'ID2': batch['ID2'], 'X2': batch['X2'], + 'Y^': preds } - for key in self.extra_return_keys: - if key in batch: - return_dict[key] = batch[key] - - return return_dict - def configure_optimizers(self): optimizers_config = {'optimizer': self.hparams.optimizer(params=self.parameters())} if self.hparams.get('scheduler'): diff --git a/deepscreen/models/loss/__pycache__/__init__.cpython-311.pyc b/deepscreen/models/loss/__pycache__/__init__.cpython-311.pyc index 714800206038b4a3e2694157e3758817ee237311..91fbac487b1a4c0ce09722f3bd0d4b662dad91a3 100644 Binary files a/deepscreen/models/loss/__pycache__/__init__.cpython-311.pyc and b/deepscreen/models/loss/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/models/loss/__pycache__/multitask_loss.cpython-311.pyc b/deepscreen/models/loss/__pycache__/multitask_loss.cpython-311.pyc index 9b4bac6d83f54ab4d78c5b1c33aaae6dda906cc7..57e6d7c76f1d55bc52cf52d7e1650921036cd7b7 100644 Binary files a/deepscreen/models/loss/__pycache__/multitask_loss.cpython-311.pyc and b/deepscreen/models/loss/__pycache__/multitask_loss.cpython-311.pyc differ diff --git a/deepscreen/models/metrics/__pycache__/__init__.cpython-311.pyc b/deepscreen/models/metrics/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..467c4244298a9fd777b38c94d2b34f796795bd64 Binary files /dev/null and b/deepscreen/models/metrics/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/models/metrics/__pycache__/__init__.cpython-39.pyc b/deepscreen/models/metrics/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9752eb38be03873b4573f658cbcf13a7aae83b2a Binary files /dev/null and b/deepscreen/models/metrics/__pycache__/__init__.cpython-39.pyc differ diff --git a/deepscreen/models/metrics/__pycache__/bedroc.cpython-311.pyc b/deepscreen/models/metrics/__pycache__/bedroc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4f1f02387658916c4aba609b9b86dc3f636ff58 Binary files /dev/null and b/deepscreen/models/metrics/__pycache__/bedroc.cpython-311.pyc differ diff --git a/deepscreen/models/metrics/__pycache__/bedroc.cpython-39.pyc b/deepscreen/models/metrics/__pycache__/bedroc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb91be739ff900518d316d09ac5b805ca7fbfc08 Binary files /dev/null and b/deepscreen/models/metrics/__pycache__/bedroc.cpython-39.pyc differ diff --git a/deepscreen/models/metrics/__pycache__/ci.cpython-311.pyc b/deepscreen/models/metrics/__pycache__/ci.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40d7afc81e34abde0d7e8c1c197918767855461a Binary files /dev/null and b/deepscreen/models/metrics/__pycache__/ci.cpython-311.pyc differ diff --git a/deepscreen/models/metrics/__pycache__/ef.cpython-311.pyc b/deepscreen/models/metrics/__pycache__/ef.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a98a5cc05047841369158bab896ef1ba37861fab Binary files /dev/null and b/deepscreen/models/metrics/__pycache__/ef.cpython-311.pyc differ diff --git a/deepscreen/models/metrics/__pycache__/hit_rate.cpython-311.pyc b/deepscreen/models/metrics/__pycache__/hit_rate.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2820ba83889a1db862e5c2335e063d7e3ba1f74a Binary files /dev/null and b/deepscreen/models/metrics/__pycache__/hit_rate.cpython-311.pyc differ diff --git a/deepscreen/models/metrics/__pycache__/hit_rate.cpython-39.pyc b/deepscreen/models/metrics/__pycache__/hit_rate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8f1ee441c6f730bdb23ca7067778eb06273ec2b Binary files /dev/null and b/deepscreen/models/metrics/__pycache__/hit_rate.cpython-39.pyc differ diff --git a/deepscreen/models/metrics/__pycache__/rie.cpython-311.pyc b/deepscreen/models/metrics/__pycache__/rie.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f84edf5790e52fae86a291391d9d237a18a18831 Binary files /dev/null and b/deepscreen/models/metrics/__pycache__/rie.cpython-311.pyc differ diff --git a/deepscreen/models/metrics/__pycache__/rie.cpython-39.pyc b/deepscreen/models/metrics/__pycache__/rie.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51b89362c40985fb4f7cc01d7505fe1906254218 Binary files /dev/null and b/deepscreen/models/metrics/__pycache__/rie.cpython-39.pyc differ diff --git a/deepscreen/models/metrics/__pycache__/sensitivity.cpython-311.pyc b/deepscreen/models/metrics/__pycache__/sensitivity.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec9ee954b398dd3d7cd81d48e00d947f81656885 Binary files /dev/null and b/deepscreen/models/metrics/__pycache__/sensitivity.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/__init__.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/__init__.cpython-311.pyc index dea185f266123f5e3a864b8a5afcd28cda1a5461..259f8f58c054ab0589dfe920ffbba4c125d78643 100644 Binary files a/deepscreen/models/predictors/__pycache__/__init__.cpython-311.pyc and b/deepscreen/models/predictors/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/__init__.cpython-39.pyc b/deepscreen/models/predictors/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1af73737ba1fdbf2380e64358a7e2e451ec2b05b Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/__init__.cpython-39.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/bacpi.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/bacpi.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14d40371c8bbb6ce4b5e68fb2fce2c69df1651ed Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/bacpi.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/coa_dti_pro.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/coa_dti_pro.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b702c800945cb1f93a392c3c5b18cb51c03f0c8 Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/coa_dti_pro.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/deep_conv_dti.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/deep_conv_dti.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcf5de9268c84d289b6f552f37a93190dd31c41d Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/deep_conv_dti.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/deep_dta.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/deep_dta.cpython-311.pyc index 92668cd7a4ec3ef9e064cedef83991555658c034..bb85827d74b2c8bc67d7ec01a152153c1a53b0ad 100644 Binary files a/deepscreen/models/predictors/__pycache__/deep_dta.cpython-311.pyc and b/deepscreen/models/predictors/__pycache__/deep_dta.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/deep_dta.cpython-39.pyc b/deepscreen/models/predictors/__pycache__/deep_dta.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3340d70417964f7af79aae204ec5d00bcb8ddd25 Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/deep_dta.cpython-39.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/drug_ban.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/drug_ban.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..157c10059b33a563794123e5ad1f63af0a0eba34 Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/drug_ban.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/drug_vqa.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/drug_vqa.cpython-311.pyc index 4bf397d6219cf9ae3a2c4397a539f58af649716b..3f8fc2f03a46a699c45b453156ce52ee199c4a93 100644 Binary files a/deepscreen/models/predictors/__pycache__/drug_vqa.cpython-311.pyc and b/deepscreen/models/predictors/__pycache__/drug_vqa.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/graph_dta.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/graph_dta.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa4cebae544b174b329833f51c46b5f6121d965d Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/graph_dta.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/graph_dta.cpython-39.pyc b/deepscreen/models/predictors/__pycache__/graph_dta.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6606729143cd3104865df360d7e493ab4bb33938 Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/graph_dta.cpython-39.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/hyper_attention_dti.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/hyper_attention_dti.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc8ddbf9e700c199a39807cd11135e6506521f6c Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/hyper_attention_dti.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/hyper_attention_dti.cpython-39.pyc b/deepscreen/models/predictors/__pycache__/hyper_attention_dti.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..870c32e9526d2e5e2ab1655f743837d5af47b1f7 Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/hyper_attention_dti.cpython-39.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/m_graph_dta.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/m_graph_dta.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9aacea2164b6b7a381c3eb7ae11c7bb504f53c30 Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/m_graph_dta.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/mol_trans.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/mol_trans.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d8272e9ebe90a3bff55ab090f2535c128ccfeda Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/mol_trans.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/mol_trans.cpython-39.pyc b/deepscreen/models/predictors/__pycache__/mol_trans.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dde5393f979e8a5550a18f69eea7e14859367afe Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/mol_trans.cpython-39.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/transformer_cpi.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/transformer_cpi.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..987efb023cdfc4b631340996f2a268db6b2e3da3 Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/transformer_cpi.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/__pycache__/transformer_cpi_2.cpython-311.pyc b/deepscreen/models/predictors/__pycache__/transformer_cpi_2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66e5e2b12de5c8071662083ebe8e45687852522c Binary files /dev/null and b/deepscreen/models/predictors/__pycache__/transformer_cpi_2.cpython-311.pyc differ diff --git a/deepscreen/models/predictors/ark_dta.py b/deepscreen/models/predictors/ark_dta.py new file mode 100644 index 0000000000000000000000000000000000000000..5894a1e0be3375efe7e0e690970eac8c2bf3f7f1 --- /dev/null +++ b/deepscreen/models/predictors/ark_dta.py @@ -0,0 +1,829 @@ +import torch +from torch import nn + + +class ArkDTA(nn.Module): + def __init__(self, args): + super(Net, self).__init__() + self.layer = nn.ModuleDict() + analysis_mode = args.analysis_mode + h = args.arkdta_hidden_dim + d = args.hp_dropout_rate + esm = args.arkdta_esm_model + esm_freeze = args.arkdta_esm_freeze + E = args.arkdta_ecfpvec_dim + L = args.arkdta_sab_depth + A = args.arkdta_attention_option + K = args.arkdta_num_heads + assert 'ARKMAB' in args.arkdta_residue_addon + + self.layer['prot_encoder'] = FastaESM(h, esm, esm_freeze, analysis_mode) + self.layer['comp_encoder'] = EcfpConverter(h, L, E, analysis_mode) + self.layer['intg_arkmab'] = load_residue_addon(args) + self.layer['intg_pooling'] = load_complex_decoder(args) + self.layer['ba_predictor'] = AffinityMLP(h) + self.layer['dt_predictor'] = InteractionMLP(h) + + def load_auxiliary_materials(self, **kwargs): + return_batch = kwargs['return_batch'] + + b = kwargs['atomresi_adj'].size(0) + x, y, z = kwargs['encoder_attention'].size() + logits0 = kwargs['encoder_attention'].view(x // b, b, y, z).mean(0)[:, :, :-1].sum(2).unsqueeze( + 2) # actual compsub + logits1 = kwargs['encoder_attention'].view(x // b, b, y, z).mean(0)[:, :, -1].unsqueeze(2) # inactive site + return_batch['task/es_pred'] = torch.cat([logits1, logits0], 2) + return_batch['task/es_true'] = (kwargs['atomresi_adj'].sum(1) > 0.).long().squeeze(1) + return_batch['mask/es_resi'] = (kwargs['atomresi_masks'].sum(1) > 0.).float().squeeze(1) + + return return_batch + + def forward(self, batch): + return_batch = dict() + residue_features, residue_masks, residue_fastas = batch[0], batch[1], batch[2] + ecfp_words, ecfp_masks = batch[3], batch[4] + atomresi_adj, atomresi_masks = batch[5], batch[6] + bav, dti, cids = batch[7], batch[8], batch[-1] + + # Protein Encoder Module + residue_features = self.layer['prot_encoder'](X=residue_features, + fastas=residue_fastas, + masks=residue_masks) + residue_masks = residue_features[1] + residue_temps = residue_features[2] + protein_features = residue_features[3] + residue_features = residue_features[0] + return_batch['temp/lm_related'] = residue_temps * 0. + + # Ligand Encoder Module + cstruct_features = self.layer['comp_encoder'](ecfp_words=ecfp_words, + ecfp_masks=ecfp_masks) + cstruct_masks = cstruct_features[1] + cstruct_features = cstruct_features[0] + + # Protein-Ligand Integration Module (ARK-MAB) + residue_results = self.layer['intg_arkmab'](residue_features=residue_features, residue_masks=residue_masks, + ligelem_features=cstruct_features, ligelem_masks=cstruct_masks) + residue_features, residue_masks, attention_weights = residue_results + del residue_results; + torch.cuda.empty_cache() + + # Protein-Ligand Integration Module (Pooling Layer) + complex_results = self.layer['intg_pooling'](residue_features=residue_features, + residue_masks=residue_masks, + attention_weights=attention_weights, + protein_features=protein_features) + binding_complex, _, _, _ = complex_results + del complex_results; + torch.cuda.empty_cache() + + # Drug-Target Outcome Predictor + bav_predicted = self.layer['ba_predictor'](binding_complex=binding_complex) + dti_predicted = self.layer['dt_predictor'](binding_complex=binding_complex) + + return_batch['task/ba_pred'] = bav_predicted.view(-1) + return_batch['task/dt_pred'] = dti_predicted.view(-1) + return_batch['task/ba_true'] = bav.view(-1) + return_batch['task/dt_true'] = dti.view(-1) + return_batch['meta/cid'] = cids + + # Additional Materials for Calculating Auxiliary Loss + return_batch = self.load_auxiliary_materials(return_batch=return_batch, + atomresi_adj=atomresi_adj, + atomresi_masks=atomresi_masks, + encoder_attention=attention_weights) + + return return_batch + + @torch.no_grad() + def infer(self, batch): + return_batch = dict() + residue_features, residue_masks, residue_fastas = batch[0], batch[1], batch[2] + ecfp_words, ecfp_masks = batch[3], batch[4] + bav, dti, cids = batch[7], batch[8], batch[-1] + + # Protein Encoder Module + residue_features = self.layer['prot_encoder'](X=residue_features, + fastas=residue_fastas, + masks=residue_masks) + residue_masks = residue_features[1] + residue_temps = residue_features[2] + protein_features = residue_features[3] + residue_features = residue_features[0] + return_batch['temp/lm_related'] = residue_temps * 0. + + # Ligand Encoder Module + cstruct_features = self.layer['comp_encoder'](ecfp_words=ecfp_words, + ecfp_masks=ecfp_masks) + cstruct_masks = cstruct_features[1] + cstruct_features = cstruct_features[0] + + # Protein-Ligand Integration Module (ARK-MAB) + residue_results = self.layer['intg_arkmab'](residue_features=residue_features, residue_masks=residue_masks, + ligelem_features=cstruct_features, ligelem_masks=cstruct_masks) + residue_features, residue_masks, attention_weights = residue_results + del residue_results; + torch.cuda.empty_cache() + + # Protein-Ligand Integration Module (Pooling Layer) + complex_results = self.layer['intg_pooling'](residue_features=residue_features, + residue_masks=residue_masks, + attention_weights=attention_weights, + protein_features=protein_features) + binding_complex, _, _, _ = complex_results + del complex_results; + torch.cuda.empty_cache() + + # Drug-Target Outcome Predictor + bav_predicted = self.layer['ba_predictor'](binding_complex=binding_complex) + dti_predicted = self.layer['dt_predictor'](binding_complex=binding_complex) + + return_batch['task/ba_pred'] = bav_predicted.view(-1) + return_batch['task/dt_pred'] = dti_predicted.view(-1) + return_batch['task/ba_true'] = bav.view(-1) + return_batch['task/dt_true'] = dti.view(-1) + return_batch['meta/cid'] = cids + + return return_batch + + +class GraphDenseSequential(nn.Sequential): + def __init__(self, *args): + super(GraphDenseSequential, self).__init__(*args) + + def forward(self, X, adj, mask): + for module in self._modules.values(): + try: + X = module(X, adj, mask) + except BaseException: + X = module(X) + + return X + + +class MaskedGlobalPooling(nn.Module): + def __init__(self, pooling='max'): + super(MaskedGlobalPooling, self).__init__() + self.pooling = pooling + + def forward(self, x, adj, masks): + if x.dim() == 2: + x = x.unsqueeze(0) + # print(x, adj, masks) + masks = masks.unsqueeze(2).repeat(1, 1, x.size(2)) + if self.pooling == 'max': + x[masks == 0] = -99999.99999 + x = x.max(1)[0] + elif self.pooling == 'add': + x = x.sum(1) + else: + print('Not Implemented') + + return x + + +class MaskedMean(nn.Module): + def __init__(self): + super(MaskedMean, self).__init__() + + def forward(self, X, m): + if isinstance(m, torch.Tensor): + X = X * m.unsqueeze(2) + + return X.mean(1) + + +class MaskedMax(nn.Module): + def __init__(self): + super(MaskedMax, self).__init__() + + def forward(self, X, m): + if isinstance(m, torch.Tensor): + X = X * m.unsqueeze(2) + + return torch.max(X, 1)[0] + + +class MaskedSum(nn.Module): + def __init__(self): + super(MaskedSum, self).__init__() + + def forward(self, X, m): + if isinstance(m, torch.Tensor): + X = X * m.unsqueeze(2) + + return X.sum(1) + + +class MaskedScaledAverage(nn.Module): + def __init__(self): + super(MaskedScaledAverage, self).__init__() + + def forward(self, X, m): + if isinstance(m, torch.Tensor): + X = X * m.unsqueeze(2) + + return X.sum(1) / (m.sum(1) ** 0.5).unsqueeze(1) + + +class Decoder(nn.Module): + def __init__(self, analysis_mode): + super(Decoder, self).__init__() + self.output_representations = [] + self.query_representations = [] + self.kvpair_representations = [] + self.attention_weights = [] + + if analysis_mode: self.register_forward_hook(store_decoder_representations) + + def show(self): + print("Number of Saved Numpy Arrays: ", len(self.representations)) + for i, representation in enumerate(self.representations): + print(f"Shape of {i}th Numpy Array: ", representation.shape) + + return self.representations + + def flush(self): + del self.representations + self.representations = [] + + def release_qk(self): + + return None + + def forward(self, **kwargs): + + return kwargs['X'], kwargs['X'], kwargs['residue_features'], None + + +class DecoderPMA_Residue(Decoder): + def __init__(self, h: int, num_heads: int, num_seeds: int, attn_option: str, analysis_mode: bool): + super(DecoderPMA_Residue, self).__init__(analysis_mode) + # Aggregate the Residues into Residue Regions + pma_args = (h, num_seeds, num_heads, RFF(h), attn_option, False, analysis_mode, False) + self.decoder = PoolingMultiheadAttention(*pma_args) + # Model Region-Region Interaction through Set Attention + sab_depth = 0 if num_seeds < 4 else int((num_seeds // 2) ** 0.5) + sab_args = (h, num_heads, RFF(h), attn_option, False, analysis_mode, True) + self.pairwise = nn.ModuleList([SetAttentionBlock(*sab_args) for _ in range(sab_depth)]) + # Concat, then reduce into h-dimensional Set Representation + self.aggregate = nn.Linear(h * num_seeds, h) + + self.apply(initialization) + + def forward(self, **kwargs): + residue_features = kwargs['residue_features'] + residue_masks = kwargs['residue_masks'] + + output, attention = self.decoder(residue_features, residue_masks) + for sab in self.pairwise: output, _ = sab(output) + b, n, d = output.shape + output = self.aggregate(output.view(b, n * d)) + + return output, None, residue_features, attention + + +class AffinityMLP(nn.Module): + def __init__(self, h: int): + super(AffinityMLP, self).__init__() + self.mlp = nn.Sequential(nn.Linear(h, h), nn.Dropout(0.1), nn.LeakyReLU(), nn.Linear(h, 1)) + + self.apply(initialization) + + def forward(self, **kwargs): + ''' + X: batch size x 1 x H + ''' + X = kwargs['binding_complex'] + X = X.squeeze(1) if X.dim() == 3 else X + yhat = self.mlp(X) + + return yhat + + +class InteractionMLP(nn.Module): + def __init__(self, h: int): + super(InteractionMLP, self).__init__() + self.mlp = nn.Sequential(nn.Linear(h, h), nn.Dropout(0.1), nn.LeakyReLU(), nn.Linear(h, 1), nn.Sigmoid()) + + self.apply(initialization) + + def forward(self, **kwargs): + ''' + X: batch size x 1 x H + ''' + X = kwargs['binding_complex'] + X = X.squeeze(1) if X.dim() == 3 else X + yhat = self.mlp(X) + + return yhat + + +class LigelemEncoder(nn.Module): + def __init__(self): + super(LigelemEncoder, self).__init__() + self.representations = [] + + def show(self): + print("Number of Saved Numpy Arrays: ", len(self.representations)) + for i, representation in enumerate(self.representations): + print(f"Shape of {i}th Numpy Array: ", representation.shape) + + return self.representations + + def flush(self): + del self.representations + self.representations = [] + + +class EcfpConverter(LigelemEncoder): + def __init__(self, h: int, sab_depth: int, ecfp_dim: int, analysis_mode: bool): + super(EcfpConverter, self).__init__() + K = 4 # number of attention heads + self.ecfp_embeddings = nn.Embedding(ecfp_dim + 1, h, padding_idx=ecfp_dim) + self.encoder = nn.ModuleList([]) + sab_args = (h, K, RFF(h), 'general_dot', False, analysis_mode, True) + self.encoder = nn.ModuleList([SetAttentionBlock(*sab_args) for _ in range(sab_depth)]) + + self.representations = [] + if analysis_mode: self.register_forward_hook(store_elemwise_representations) + self.apply(initialization) + + def forward(self, **kwargs): + ''' + X : (b x d) + ''' + ecfp_words = kwargs['ecfp_words'] + ecfp_masks = kwargs['ecfp_masks'] + ecfp_words = self.ecfp_embeddings(ecfp_words) + + for sab in self.encoder: + ecfp_words, _ = sab(ecfp_words, ecfp_masks) + + return [ecfp_words, ecfp_masks] + + +class ResidueAddOn(nn.Module): + def __init__(self): + super(ResidueAddOn, self).__init__() + + self.representations = [] + + def show(self): + print("Number of Saved Numpy Arrays: ", len(self.representations)) + for i, representation in enumerate(self.representations): + print(f"Shape of {i}th Numpy Array: ", representation.shape) + + return self.representations + + def flush(self): + del self.representations + self.representations = [] + + def forward(self, **kwargs): + X, Xm = kwargs['X'], kwargs['Xm'] + + return X, Xm + + +class ARKMAB(ResidueAddOn): + def __init__(self, h: int, num_heads: int, attn_option: str, analysis_mode: bool): + super(ARKMAB, self).__init__() + pmx_args = (h, num_heads, RFF(h), attn_option, False, analysis_mode, False) + self.pmx = PoolingMultiheadCrossAttention(*pmx_args) + self.inactive = nn.Parameter(torch.randn(1, 1, h)) + self.fillmask = nn.Parameter(torch.ones(1, 1), requires_grad=False) + + self.representations = [] + if analysis_mode: pass + self.apply(initialization) + + def forward(self, **kwargs): + ''' + X: batch size x residues x H + Xm: batch size x residues x H + Y: batch size x ecfpsubs x H + Ym: batch size x ecfpsubs x H + ''' + X, Xm = kwargs['residue_features'], kwargs['residue_masks'] + Y, Ym = kwargs['ligelem_features'], kwargs['ligelem_masks'] + pseudo_substructure = self.inactive.repeat(X.size(0), 1, 1) + pseudo_masks = self.fillmask.repeat(X.size(0), 1) + + Y = torch.cat([Y, pseudo_substructure], 1) + Ym = torch.cat([Ym, pseudo_masks], 1) + + X, attention = self.pmx(Y=Y, Ym=Ym, X=X, Xm=Xm) + + return X, Xm, attention + + +class ResidueEncoder(nn.Module): + def __init__(self): + super(ResidueEncoder, self).__init__() + self.representations = [] + + def show(self): + print("Number of Saved Numpy Arrays: ", len(self.representations)) + for i, representation in enumerate(self.representations): + print(f"Shape of {i}th Numpy Array: ", representation.shape) + + return self.representations + + def flush(self): + del self.representations + self.representations = [] + + +class AminoAcidSeqCNN(ResidueEncoder): + def __init__(self, h: int, d: float, cnn_depth: int, kernel_size: int, analysis_mode: bool): + super(AminoAcidSeqCNN, self).__init__() + self.encoder = nn.ModuleList([nn.Sequential(nn.Linear(21, h), # Warning + nn.Dropout(d), + nn.LeakyReLU(), + nn.Linear(h, h))]) + for _ in range(cnn_depth): + self.encoder.append(nn.Conv1d(h, h, kernel_size, 1, (kernel_size - 1) // 2)) + + self.representations = [] + if analysis_mode: self.register_forward_hook(store_representations) + self.apply(initialization) + + def forward(self, **kwargs): + X = kwargs['aaseqs'] + for i, module in enumerate(self.encoder): + if i == 1: X = X.transpose(1, 2) + X = module(X) + X = X.transpose(1, 2) + + return X + + +class FastaESM(ResidueEncoder): + def __init__(self, h: int, esm_model: str, esm_freeze: bool, analysis_mode: bool): + super(FastaESM, self).__init__() + self.esm_version = 2 if 'esm2' in esm_model else 1 + if esm_model == 'esm1b_t33_650M_UR505': + self.esm, alphabet = esm.pretrained.esm1b_t33_650M_UR50S() + self.layer_idx, self.emb_dim = 33, 1024 + elif esm_model == 'esm1_t12_85M_UR505': + self.esm, alphabet = esm.pretrained.esm1_t12_85M_UR50S() + self.layer_idx, self.emb_dim = 12, 768 + elif esm_model == 'esm2_t6_8M_UR50D': + self.esm, alphabet = esm.pretrained.esm2_t6_8M_UR50D() + self.layer_idx, self.emb_dim = 6, 320 + elif esm_model == 'esm2_t12_35M_UR50D': + self.esm, alphabet = esm.pretrained.esm2_t12_35M_UR50D() + self.layer_idx, self.emb_dim = 12, 480 + elif esm_model == 'esm2_t30_150M_UR50D': + self.esm, alphabet = esm.pretrained.esm2_t30_150M_UR50D() + self.layer_idx, self.emb_dim = 30, 640 + else: + raise + self.batch_converter = alphabet.get_batch_converter() + if esm_freeze == 'True': + for p in self.esm.parameters(): + p.requires_grad = False + assert h == self.emb_dim, f"The hidden dimension should be set to {self.emb_dim}, not {h}" + self.representations = [] + if analysis_mode: self.register_forward_hook(store_elemwise_representations) + + def esm1_pooling(self, embeddings, masks): + + return embeddings[:, 1:, :].sum(1) / masks[:, 1:].sum(1).view(-1, 1) + + def esm2_pooling(self, embeddings, masks): + + return embeddings[:, 1:-1, :].sum(1) / masks[:, 1:-1].sum(1).view(-1, 1) + + def forward(self, **kwargs): + fastas = kwargs['fastas'] + _, _, tokenized = self.batch_converter(fastas) + tokenized = tokenized.cuda() + if self.esm_version == 2: + masks = torch.where(tokenized > 1, 1, 0).float() + else: + masks = torch.where((tokenized > 1) & (tokenized != 32), 1, 0).float() + + embeddings = self.esm(tokenized, repr_layers=[self.layer_idx], return_contacts=True) + logits = embeddings["logits"].sum() + contacts = embeddings["contacts"].sum() + attentions = embeddings["attentions"].sum() + embeddings = embeddings["representations"][self.layer_idx] + + assert masks.size(0) == embeddings.size( + 0), f"Batch sizes of masks {masks.size(0)} and {embeddings.size(0)} do not match." + assert masks.size(1) == embeddings.size( + 1), f"Lengths of masks {masks.size(1)} and {embeddings.size(1)} do not match." + + if self.esm_version == 2: + return [embeddings[:, 1:-1, :], masks[:, 1:-1], logits + contacts + attentions, + self.esm2_pooling(embeddings, masks)] + else: + return [embeddings[:, 1:, :], masks[:, 1:], logits + contacts + attentions, + self.esm1_pooling(embeddings, masks)] + + +class DotProduct(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, queries, keys): + return torch.bmm(queries, keys.transpose(1, 2)) + + +class ScaledDotProduct(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, queries, keys): + return torch.bmm(queries, keys.transpose(1, 2)) / (queries.size(2) ** 0.5) + + +class GeneralDotProduct(nn.Module): + def __init__(self, hidden_dim): + super().__init__() + self.W = nn.Parameter(torch.randn(hidden_dim, hidden_dim)) + torch.nn.init.orthogonal_(self.W) + + def forward(self, queries, keys): + return torch.bmm(queries @ self.W, keys.transpose(1, 2)) + + +class ConcatDotProduct(nn.Module): + def __init__(self, hidden_dim): + super().__init__() + raise + + def forward(self, queries, keys): + return + + +class Additive(nn.Module): + def __init__(self, hidden_dim): + super().__init__() + self.U = nn.Parameter(torch.randn(hidden_dim, hidden_dim)) + self.T = nn.Parameter(torch.randn(hidden_dim, hidden_dim)) + self.b = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1)) + self.W = nn.Sequential(nn.Tanh(), nn.Linear(hidden_dim, 1)) + torch.nn.init.orthogonal_(self.U) + torch.nn.init.orthogonal_(self.T) + + def forward(self, queries, keys): + return self.W(queries.unsqueeze(1) @ self.U + keys.unsqueeze(2) @ self.T + self.b).squeeze(-1).transpose(1, 2) + + +class Attention(nn.Module): + def __init__(self, similarity, hidden_dim=1024, store_qk=False): + super().__init__() + self.softmax = nn.Softmax(dim=2) + self.attention_maps = [] + self.store_qk = store_qk + self.query_vectors, self.key_vectors = None, None + + assert similarity in ['dot', 'scaled_dot', 'general_dot', 'concat_dot', 'additive'] + if similarity == 'dot': + self.similarity = DotProduct() + elif similarity == 'scaled_dot': + self.similarity = ScaledDotProduct() + elif similarity == 'general_dot': + self.similarity = GeneralDotProduct(hidden_dim) + elif similarity == 'concat_dot': + self.similarity = ConcatDotProduct(hidden_dim) + elif similarity == 'additive': + self.similarity = Additive(hidden_dim) + else: + raise + + def release_qk(self): + Q, K = self.query_vectors, self.key_vectors + self.query_vectors, self.key_vectors = None, None + torch.cuda.empty_cache() + + return Q, K + + def forward(self, queries, keys, qmasks=None, kmasks=None): + if self.store_qk: + self.query_vectors = queries + self.key_vectors = keys + + if torch.is_tensor(qmasks) and not torch.is_tensor(kmasks): + dim0, dim1 = qmasks.size(0), keys.size(1) + kmasks = torch.ones(dim0, dim1).cuda() + + elif not torch.is_tensor(qmasks) and torch.is_tensor(kmasks): + dim0, dim1 = kmasks.size(0), queries.size(1) + qmasks = torch.ones(dim0, dim1).cuda() + else: + pass + + attention = self.similarity(queries, keys) + if torch.is_tensor(qmasks) and torch.is_tensor(kmasks): + qmasks = qmasks.repeat(queries.size(0) // qmasks.size(0), 1).unsqueeze(2) + kmasks = kmasks.repeat(keys.size(0) // kmasks.size(0), 1).unsqueeze(2) + attnmasks = torch.bmm(qmasks, kmasks.transpose(1, 2)) + attention = torch.clip(attention, min=-10, max=10) + attention = attention.exp() + attention = attention * attnmasks + attention = attention / (attention.sum(2).unsqueeze(2) + 1e-5) + else: + attention = self.softmax(attention) + + return attention + + +@torch.no_grad() +def save_attention_maps(self, input, output): + self.attention_maps.append(output.data.detach().cpu().numpy()) + + +class MultiheadAttention(nn.Module): + def __init__(self, d, h, sim='dot', analysis=False, store_qk=False): + super().__init__() + assert d % h == 0, f"{d} dimension, {h} heads" + self.h = h + p = d // h + + self.project_queries = nn.Linear(d, d) + self.project_keys = nn.Linear(d, d) + self.project_values = nn.Linear(d, d) + self.concatenation = nn.Linear(d, d) + self.attention = Attention(sim, p, store_qk) + + if analysis: + self.attention.register_forward_hook(save_attention_maps) + + def release_qk(self): + Q, K = self.attention.release_qk() + + Qb = Q.size(0) // self.h + Qn, Qd = Q.size(1), Q.size(2) + + Kb = K.size(0) // self.h + Kn, Kd = K.size(1), K.size(2) + + Q = Q.view(self.h, Qb, Qn, Qd) + K = K.view(self.h, Kb, Kn, Kd) + + Q = Q.permute(1, 2, 0, 3).contiguous().view(Qb, Qn, Qd * self.h) + K = K.permute(1, 2, 0, 3).contiguous().view(Kb, Kn, Kd * self.h) + + return Q, K + + def forward(self, queries, keys, values, qmasks=None, kmasks=None): + h = self.h + b, n, d = queries.size() + _, m, _ = keys.size() + p = d // h + + queries = self.project_queries(queries) # shape [b, n, d] + keys = self.project_keys(keys) # shape [b, m, d] + values = self.project_values(values) # shape [b, m, d] + + queries = queries.view(b, n, h, p) + keys = keys.view(b, m, h, p) + values = values.view(b, m, h, p) + + queries = queries.permute(2, 0, 1, 3).contiguous().view(h * b, n, p) + keys = keys.permute(2, 0, 1, 3).contiguous().view(h * b, m, p) + values = values.permute(2, 0, 1, 3).contiguous().view(h * b, m, p) + + attn_w = self.attention(queries, keys, qmasks, kmasks) # shape [h * b, n, p] + output = torch.bmm(attn_w, values) + output = output.view(h, b, n, p) + output = output.permute(1, 2, 0, 3).contiguous().view(b, n, d) + output = self.concatenation(output) # shape [b, n, d] + + return output, attn_w + + +class MultiheadAttentionExpanded(nn.Module): + def __init__(self, d, h, sim='dot', analysis=False): + super().__init__() + self.project_queries = nn.ModuleList([nn.Linear(d, d) for _ in range(h)]) + self.project_keys = nn.ModuleList([nn.Linear(d, d) for _ in range(h)]) + self.project_values = nn.ModuleList([nn.Linear(d, d) for _ in range(h)]) + self.concatenation = nn.Linear(h * d, d) + self.attention = Attention(sim, d) + + if analysis: + self.attention.register_forward_hook(save_attention_maps) + + def forward(self, queries, keys, values, qmasks=None, kmasks=None): + output = [] + for Wq, Wk, Wv in zip(self.project_queries, self.project_keys, self.project_values): + Pq, Pk, Pv = Wq(queries), Wk(keys), Wv(values) + output.append(torch.bmm(self.attention(Pq, Pk, qmasks, kmasks), Pv)) + + output = self.concatenation(torch.cat(output, 1)) + + return output + + +class EmptyModule(nn.Module): + def __init__(self, args): + super().__init__() + + def forward(self, x): + return 0. + + +class RFF(nn.Module): + def __init__(self, h): + super().__init__() + self.rff = nn.Sequential(nn.Linear(h, h), nn.ReLU(), nn.Linear(h, h), nn.ReLU(), nn.Linear(h, h), nn.ReLU()) + + def forward(self, x): + return self.rff(x) + + +class MultiheadAttentionBlock(nn.Module): + def __init__(self, d, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False): + super().__init__() + self.multihead = MultiheadAttention(d, h, similarity, analysis, + store_qk) if not full_head else MultiheadAttentionExpanded(d, h, similarity, + analysis) + self.layer_norm1 = nn.LayerNorm(d) + self.layer_norm2 = nn.LayerNorm(d) + self.rff = rff + + def release_qk(self): + Q, K = self.multihead.release_qk() + + return Q, K + + def forward(self, x, y, xm=None, ym=None, layer_norm=True): + h, a = self.multihead(x, y, y, xm, ym) + if layer_norm: + h = self.layer_norm1(x + h) + return self.layer_norm2(h + self.rff(h)), a + else: + h = x + h + return h + self.rff(h), a + + +class SetAttentionBlock(nn.Module): + def __init__(self, d, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False): + super().__init__() + self.mab = MultiheadAttentionBlock(d, h, rff, similarity, full_head, analysis, store_qk) + + def release_qk(self): + Q, K = self.mab.release_qk() + + return Q, K + + def forward(self, x, m=None, ln=True): + return self.mab(x, x, m, m, ln) + + +class InducedSetAttentionBlock(nn.Module): + def __init__(self, d, m, h, rff1, rff2, similarity='dot', full_head=False, analysis=False, store_qk=False): + super().__init__() + self.mab1 = MultiheadAttentionBlock(d, h, rff1, similarity, full_head, analysis, store_qk) + self.mab2 = MultiheadAttentionBlock(d, h, rff2, similarity, full_head, analysis, store_qk) + self.inducing_points = nn.Parameter(torch.randn(1, m, d)) + + def release_qk(self): + raise NotImplemented + + def forward(self, x, m=None, ln=True): + b = x.size(0) + p = self.inducing_points + p = p.repeat([b, 1, 1]) # shape [b, m, d] + h = self.mab1(p, x, None, m, ln) # shape [b, m, d] + + return self.mab2(x, h, m, None, ln) + + +class PoolingMultiheadAttention(nn.Module): + def __init__(self, d, k, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False): + super().__init__() + self.mab = MultiheadAttentionBlock(d, h, rff, similarity, full_head, analysis, store_qk) + self.seed_vectors = nn.Parameter(torch.randn(1, k, d)) + torch.nn.init.xavier_uniform_(self.seed_vectors) + + def release_qk(self): + Q, K = self.mab.release_qk() + + return Q, K + + def forward(self, z, m=None, ln=True): + b = z.size(0) + s = self.seed_vectors + s = s.repeat([b, 1, 1]) # random seed vector: shape [b, k, d] + + return self.mab(s, z, None, m, ln) + + +class PoolingMultiheadCrossAttention(nn.Module): + def __init__(self, d, h, rff, similarity='dot', full_head=False, analysis=False, store_qk=False): + super().__init__() + self.mab = MultiheadAttentionBlock(d, h, rff, similarity, full_head, analysis, store_qk) + + def release_qk(self): + Q, K = self.mab.release_qk() + + return Q, K + + def forward(self, X, Y, Xm=None, Ym=None, ln=True): + return self.mab(X, Y, Xm, Ym, ln) diff --git a/deepscreen/models/predictors/bacpi.py b/deepscreen/models/predictors/bacpi.py new file mode 100644 index 0000000000000000000000000000000000000000..ff732fa9a784048e09aba8fba2723c3832dd5f8c --- /dev/null +++ b/deepscreen/models/predictors/bacpi.py @@ -0,0 +1,284 @@ +from collections import defaultdict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from rdkit import Chem +from rdkit.Chem import AllChem + + +class BACPI(nn.Module): + def __init__( + self, + n_atom, + n_amino, + comp_dim, + prot_dim, + gat_dim, + num_head, + dropout, + alpha, + window, + layer_cnn, + latent_dim, + ): + super().__init__() + self.embedding_layer_atom = nn.Embedding(n_atom + 1, comp_dim) + self.embedding_layer_amino = nn.Embedding(n_amino + 1, prot_dim) + + self.dropout = dropout + self.alpha = alpha + self.layer_cnn = layer_cnn + + self.gat_layers = [GATLayer(comp_dim, gat_dim, dropout=dropout, alpha=alpha, concat=True) + for _ in range(num_head)] + for i, layer in enumerate(self.gat_layers): + self.add_module('gat_layer_{}'.format(i), layer) + self.gat_out = GATLayer(gat_dim * num_head, comp_dim, dropout=dropout, alpha=alpha, concat=False) + self.W_comp = nn.Linear(comp_dim, latent_dim) + + self.conv_layers = nn.ModuleList([nn.Conv2d(in_channels=1, out_channels=1, kernel_size=2 * window + 1, + stride=1, padding=window) for _ in range(layer_cnn)]) + self.W_prot = nn.Linear(prot_dim, latent_dim) + + self.fp0 = nn.Parameter(torch.empty(size=(1024, latent_dim))) + nn.init.xavier_uniform_(self.fp0, gain=1.414) + self.fp1 = nn.Parameter(torch.empty(size=(latent_dim, latent_dim))) + nn.init.xavier_uniform_(self.fp1, gain=1.414) + + self.bidat_num = 4 + + self.U = nn.ParameterList([ + nn.Parameter(torch.empty(size=(latent_dim, latent_dim))) for _ in range(self.bidat_num) + ]) + for i in range(self.bidat_num): + nn.init.xavier_uniform_(self.U[i], gain=1.414) + + self.transform_c2p = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) + self.transform_p2c = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) + + self.bihidden_c = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) + self.bihidden_p = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) + self.biatt_c = nn.ModuleList([nn.Linear(latent_dim * 2, 1) for _ in range(self.bidat_num)]) + self.biatt_p = nn.ModuleList([nn.Linear(latent_dim * 2, 1) for _ in range(self.bidat_num)]) + + self.comb_c = nn.Linear(latent_dim * self.bidat_num, latent_dim) + self.comb_p = nn.Linear(latent_dim * self.bidat_num, latent_dim) + + def comp_gat(self, atoms, atoms_mask, adj): + atoms_vector = self.embedding_layer_atom(atoms) + atoms_multi_head = torch.cat([gat(atoms_vector, adj) for gat in self.gat_layers], dim=2) + atoms_vector = F.elu(self.gat_out(atoms_multi_head, adj)) + atoms_vector = F.leaky_relu(self.W_comp(atoms_vector), self.alpha) + return atoms_vector + + def prot_cnn(self, amino, amino_mask): + amino_vector = self.embedding_layer_amino(amino) + amino_vector = torch.unsqueeze(amino_vector, 1) + for i in range(self.layer_cnn): + amino_vector = F.leaky_relu(self.conv_layers[i](amino_vector), self.alpha) + amino_vector = torch.squeeze(amino_vector, 1) + amino_vector = F.leaky_relu(self.W_prot(amino_vector), self.alpha) + return amino_vector + + def mask_softmax(self, a, mask, dim=-1): + a_max = torch.max(a, dim, keepdim=True)[0] + a_exp = torch.exp(a - a_max) + a_exp = a_exp * mask + a_softmax = a_exp / (torch.sum(a_exp, dim, keepdim=True) + 1e-6) + return a_softmax + + def bidirectional_attention_prediction(self, atoms_vector, atoms_mask, fps, amino_vector, amino_mask): + b = atoms_vector.shape[0] + for i in range(self.bidat_num): + A = torch.tanh(torch.matmul(torch.matmul(atoms_vector, self.U[i]), amino_vector.transpose(1, 2))) + A = A * torch.matmul(atoms_mask.view(b, -1, 1).float(), amino_mask.view(b, 1, -1).float()) + + atoms_trans = torch.matmul(A, torch.tanh(self.transform_p2c[i](amino_vector))) + amino_trans = torch.matmul(A.transpose(1, 2), torch.tanh(self.transform_c2p[i](atoms_vector))) + + atoms_tmp = torch.cat([torch.tanh(self.bihidden_c[i](atoms_vector)), atoms_trans], dim=2) + amino_tmp = torch.cat([torch.tanh(self.bihidden_p[i](amino_vector)), amino_trans], dim=2) + + atoms_att = self.mask_softmax(self.biatt_c[i](atoms_tmp).view(b, -1), atoms_mask.view(b, -1).float()) + amino_att = self.mask_softmax(self.biatt_p[i](amino_tmp).view(b, -1), amino_mask.view(b, -1).float()) + + cf = torch.sum(atoms_vector * atoms_att.view(b, -1, 1), dim=1) + pf = torch.sum(amino_vector * amino_att.view(b, -1, 1), dim=1) + + if i == 0: + cat_cf = cf + cat_pf = pf + else: + cat_cf = torch.cat([cat_cf.view(b, -1), cf.view(b, -1)], dim=1) + cat_pf = torch.cat([cat_pf.view(b, -1), pf.view(b, -1)], dim=1) + + cf_final = torch.cat([self.comb_c(cat_cf).view(b, -1), fps.view(b, -1)], dim=1) + pf_final = self.comb_p(cat_pf) + cf_pf = F.leaky_relu( + torch.matmul( + cf_final.view(b, -1, 1), pf_final.view(b, 1, -1) + ).view(b, -1), 0.1 + ) + + return cf_pf + + def forward(self, compound, protein): + atom, adj, fp = compound + + atom, atom_lengths = atom + adj, _ = adj + fp, _ = fp + amino, amino_lengths = protein + + atom_mask = torch.arange(atom.size(1), device=atom.device) >= atom_lengths.unsqueeze(1) + amino_mask = torch.arange(amino.size(1), device=amino.device) >= amino_lengths.unsqueeze(1) + + atoms_vector = self.comp_gat(atom, atom_mask, adj) + amino_vector = self.prot_cnn(amino, amino_mask) + + super_feature = F.leaky_relu(torch.matmul(fp.float(), self.fp0), 0.1) + super_feature = F.leaky_relu(torch.matmul(super_feature, self.fp1), 0.1) + + prediction = self.bidirectional_attention_prediction( + atoms_vector, atom_mask, super_feature, amino_vector, amino_mask) + + return prediction + + +class GATLayer(nn.Module): + def __init__(self, in_features, out_features, dropout=0.5, alpha=0.2, concat=True): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.dropout = dropout + self.alpha = alpha + self.concat = concat + + self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) + nn.init.xavier_uniform_(self.W.data, gain=1.414) + self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1))) + nn.init.xavier_uniform_(self.a.data, gain=1.414) + + def forward(self, h, adj): + Wh = torch.matmul(h, self.W) + a_input = self._prepare_attentional_mechanism_input(Wh) + e = F.leaky_relu(torch.matmul(a_input, self.a).squeeze(3), self.alpha) + + zero_vec = -9e15 * torch.ones_like(e) + attention = torch.where(adj > 0, e, zero_vec) + attention = F.softmax(attention, dim=2) + # attention = F.dropout(attention, self.dropout, training=self.training) + h_prime = torch.bmm(attention, Wh) + + return F.elu(h_prime) if self.concat else h_prime + + def _prepare_attentional_mechanism_input(self, Wh): + b = Wh.size()[0] + N = Wh.size()[1] + + Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=1) + Wh_repeated_alternating = Wh.repeat_interleave(N, dim=0).view(b, N * N, self.out_features) + all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=2) + + return all_combinations_matrix.view(b, N, N, 2 * self.out_features) + + +atom_dict = defaultdict(lambda: len(atom_dict)) +bond_dict = defaultdict(lambda: len(bond_dict)) +fingerprint_dict = defaultdict(lambda: len(fingerprint_dict)) +edge_dict = defaultdict(lambda: len(edge_dict)) +word_dict = defaultdict(lambda: len(word_dict)) + + +def create_atoms(mol): + atoms = [a.GetSymbol() for a in mol.GetAtoms()] + for a in mol.GetAromaticAtoms(): + i = a.GetIdx() + atoms[i] = (atoms[i], 'aromatic') + atoms = [atom_dict[a] for a in atoms] + return np.array(atoms) + + +def create_ijbonddict(mol): + i_jbond_dict = defaultdict(lambda: []) + for b in mol.GetBonds(): + i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() + bond = bond_dict[str(b.GetBondType())] + i_jbond_dict[i].append((j, bond)) + i_jbond_dict[j].append((i, bond)) + + atoms_set = set(range(mol.GetNumAtoms())) + isolate_atoms = atoms_set - set(i_jbond_dict.keys()) + bond = bond_dict['nan'] + for a in isolate_atoms: + i_jbond_dict[a].append((a, bond)) + + return i_jbond_dict + + +def atom_features(atoms, i_jbond_dict, radius): + if (len(atoms) == 1) or (radius == 0): + fingerprints = [fingerprint_dict[a] for a in atoms] + else: + nodes = atoms + i_jedge_dict = i_jbond_dict + for _ in range(radius): + fingerprints = [] + for i, j_edge in i_jedge_dict.items(): + neighbors = [(nodes[j], edge) for j, edge in j_edge] + fingerprint = (nodes[i], tuple(sorted(neighbors))) + fingerprints.append(fingerprint_dict[fingerprint]) + + nodes = fingerprints + _i_jedge_dict = defaultdict(lambda: []) + for i, j_edge in i_jedge_dict.items(): + for j, edge in j_edge: + both_side = tuple(sorted((nodes[i], nodes[j]))) + edge = edge_dict[(both_side, edge)] + _i_jedge_dict[i].append((j, edge)) + i_jedge_dict = _i_jedge_dict + + return np.array(fingerprints) + + +def create_adjacency(mol): + adjacency = Chem.GetAdjacencyMatrix(mol) + adjacency = np.array(adjacency) + adjacency += np.eye(adjacency.shape[0], dtype=int) + return adjacency + + +def get_fingerprints(mol): + fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024, useChirality=True) + return np.array(fp) + + +def split_sequence(sequence, ngram=3): + sequence = '-' + sequence + '=' + words = [word_dict[sequence[i:i + ngram]] + for i in range(len(sequence) - ngram + 1)] + return np.array(words) + + +def drug_featurizer(smiles, radius=2): + from deepscreen.utils import get_logger + log = get_logger(__name__) + try: + mol = Chem.MolFromSmiles(smiles) + if not mol: + return None + mol = Chem.AddHs(mol) + atoms = create_atoms(mol) + i_jbond_dict = create_ijbonddict(mol) + + compound = atom_features(atoms, i_jbond_dict, radius) + adjacency = create_adjacency(mol) + fp = get_fingerprints(mol) + + return compound, adjacency, fp + + except Exception as e: + log.warning(f"Failed to featurize SMILES ({smiles}) to graph due to {str(e)}") + return None diff --git a/deepscreen/models/predictors/coa_dti_pro.py b/deepscreen/models/predictors/coa_dti_pro.py new file mode 100644 index 0000000000000000000000000000000000000000..343f0eb3b485392deb5230b0712d80ab2c2961fe --- /dev/null +++ b/deepscreen/models/predictors/coa_dti_pro.py @@ -0,0 +1,386 @@ +import math +from collections import defaultdict +from typing import Literal + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from rdkit import Chem +from scipy.sparse import coo_matrix +from torch_geometric.data import Data +from torch_geometric.nn.pool.topk_pool import TopKPooling +from torch_geometric.nn.glob import global_mean_pool as gap, global_max_pool as gmp +from torch_geometric.utils import add_self_loops, remove_self_loops +from torch_geometric.nn.conv.message_passing import MessagePassing + + +class CoaDTIPro(nn.Module): + def __init__(self, + esm_model_and_alphabet, n_fingerprint, dim, n_word, layer_output, layer_coa, nhead=8, dropout=0.1, + co_attention: Literal['stack', 'encoder', 'inter'] = 'inter', gcn_pooling=False, ): + super().__init__() + self.co_attention = co_attention + self.layer_output = layer_output + self.layer_coa = layer_coa + self.embed_word = nn.Embedding(n_word, dim) + self.gnn = GNN(n_fingerprint, gcn_pooling) + self.esm_model, self.alphabet = esm_model_and_alphabet + self.batch_converter = self.alphabet.get_batch_converter() + + self.W_attention = nn.Linear(dim, dim) + + self.W_out = nn.Sequential( + nn.Linear(2 * dim, dim), + nn.Linear(dim, 128), + nn.Linear(128, 64) + ) + + self.coa_layers = CoAttention(dim, nhead, dropout, layer_coa, co_attention) + self.lin = nn.Linear(768, 512) # bert1024 esm768 + self.W_interaction = nn.Linear(64, 2) + + def attention_cnn(self, x, xs, layer): + """The attention mechanism is applied to the last layer of CNN.""" + xs = torch.unsqueeze(torch.unsqueeze(xs, 0), 0) + for i in range(layer): + xs = torch.relu(self.W_cnn[i](xs)) + xs = torch.squeeze(torch.squeeze(xs, 0), 0) + + h = torch.relu(self.W_attention(x)) + hs = torch.relu(self.W_attention(xs)) + weights = torch.tanh(F.linear(h, hs)) + ys = torch.t(weights) * hs + + return torch.unsqueeze(torch.mean(ys, 0), 0) + + def forward(self, inputs, proteins): + """Compound vector with GNN.""" + compound_vector = self.gnn(inputs) + compound_vector = torch.unsqueeze(compound_vector, 0) # sequence-like GNN ouput + + _, _, proteins = self.batch_converter([(None, protein) for protein in proteins]) + with torch.no_grad(): + results = self.esm_model(proteins.to(compound_vector.device), repr_layers=[6]) + token_representations = results["representations"][6] + + protein_vector = token_representations[:, 1:, :] + protein_vector = self.lin(torch.squeeze(protein_vector, 1)) + + protein_vector, compound_vector = self.coa_layers(protein_vector, compound_vector) + + protein_vector = protein_vector.mean(dim=1) + compound_vector = compound_vector.mean(dim=1) + """Concatenate the above two vectors and output the interaction.""" + cat_vector = torch.cat((compound_vector, protein_vector), 1) + cat_vector = torch.tanh(self.W_out(cat_vector)) + interaction = self.W_interaction(cat_vector) + return interaction + + +class CoAttention(nn.Module): + def __init__(self, dim, nhead, dropout, layer_coa, co_attention): + super().__init__() + self.co_attention = co_attention + if self.co_attention == 'encoder': + self.coa_layers = EncoderCrossAtt(dim, nhead, dropout, layer_coa) + elif self.co_attention == 'stack': + self.coa_layers = nn.ModuleList([StackCrossAtt(dim, nhead, dropout) for _ in range(layer_coa)]) + elif self.co_attention == 'inter': + self.coa_layers = nn.ModuleList([InterCrossAtt(dim, nhead, dropout) for _ in range(layer_coa)]) + + def forward(self, protein_vector, compound_vector): + # x and y are the input tensors for the two modalities + # edge_index_x and edge_index_y are the edge indices for the graph data + if self.co_attention == 'encoder': + return self.coa_layers(protein_vector, compound_vector) + else: + # loop over the sequential layers and pass the arguments + for layer in self.coa_layers: + protein_vector, compound_vector = layer(protein_vector, compound_vector) + return protein_vector, compound_vector + + +class EncoderCrossAtt(nn.Module): + def __init__(self, dim, nhead, dropout, layers): + super().__init__() + # self.encoder_layers = nn.ModuleList([SEA(dim, dropout) for _ in range(layers)]) + self.encoder_layers = nn.ModuleList([SA(dim, nhead, dropout) for _ in range(layers)]) + self.decoder_sa = nn.ModuleList([SA(dim, nhead, dropout) for _ in range(layers)]) + self.decoder_coa = nn.ModuleList([DPA(dim, nhead, dropout) for _ in range(layers)]) + self.layer_coa = layers + + def forward(self, protein_vector, compound_vector): + for i in range(self.layer_coa): + compound_vector = self.encoder_layers[i](compound_vector, None) # self-attention + for i in range(self.layer_coa): + protein_vector = self.decoder_sa[i](protein_vector, None) + protein_vector = self.decoder_coa[i](protein_vector, compound_vector, None)# co-attention + + return protein_vector, compound_vector + + +class InterCrossAtt(nn.Module): + def __init__(self, dim, nhead, dropout): + super().__init__() + self.sca = SA(dim, nhead, dropout) + self.spa = SA(dim, nhead, dropout) + self.coa_pc = DPA(dim, nhead, dropout) + self.coa_cp = DPA(dim, nhead, dropout) + + def forward(self, protein_vector, compound_vector): + compound_vector = self.sca(compound_vector, None) # self-attention + protein_vector = self.spa(protein_vector, None) # self-attention + compound_covector = self.coa_pc(compound_vector, protein_vector, None) # co-attention + protein_covector = self.coa_cp(protein_vector, compound_vector, None) # co-attention + + return protein_covector, compound_covector + + +class StackCrossAtt(nn.Module): + def __init__(self, dim, nhead, dropout): + super().__init__() + self.sca = SA(dim, nhead, dropout) + self.spa = SA(dim, nhead, dropout) + self.coa_cp = DPA(dim, nhead, dropout) + + def forward(self, protein_vector, compound_vector): + compound_vector = self.sca(compound_vector, None) # self-attention + protein_vector = self.spa(protein_vector, None) # self-attention + protein_covector = self.coa_cp(protein_vector, compound_vector, None) # co-attention + + return protein_covector, compound_vector + + +class MHAtt(nn.Module): + def __init__(self, hid_dim, n_heads, dropout): + super().__init__() + + self.linear_v = nn.Linear(hid_dim, hid_dim) + self.linear_k = nn.Linear(hid_dim, hid_dim) + self.linear_q = nn.Linear(hid_dim, hid_dim) + self.linear_merge = nn.Linear(hid_dim, hid_dim) + self.hid_dim = hid_dim + self.dropout = dropout + self.nhead = n_heads + + self.dropout = nn.Dropout(dropout) + self.hidden_size_head = int(self.hid_dim / self.nhead) + + def forward(self, v, k, q, mask): + n_batches = q.size(0) + v = self.linear_v(v).view(n_batches, -1, self.nhead, self.hidden_size_head).transpose(1, 2) + k = self.linear_k(k).view(n_batches, -1, self.nhead, self.hidden_size_head).transpose(1, 2) + q = self.linear_q(q).view(n_batches, -1, self.nhead, self.hidden_size_head).transpose(1, 2) + + atted = self.att(v, k, q, mask) + atted = atted.transpose(1, 2).contiguous().view(n_batches, -1, self.hid_dim) + + atted = self.linear_merge(atted) + + return atted + + def att(self, value, key, query, mask): + d_k = query.size(-1) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) + + if mask is not None: + scores = scores.masked_fill(mask, -1e9) + + att_map = F.softmax(scores, dim=-1) + att_map = self.dropout(att_map) + + return torch.matmul(att_map, value) + + +class DPA(nn.Module): + def __init__(self, hid_dim, n_heads, dropout): + super().__init__() + + self.mhatt1 = MHAtt(hid_dim, n_heads, dropout) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(hid_dim) + + def forward(self, x, y, y_mask=None): + x = self.norm1(x + self.dropout1(self.mhatt1(y, y, x, y_mask))) + return x + + +class SA(nn.Module): + def __init__(self, hid_dim, n_heads, dropout): + super().__init__() + + self.mhatt1 = MHAtt(hid_dim, n_heads, dropout) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(hid_dim) + + def forward(self, x, mask=None): + x = self.norm1(x + self.dropout1(self.mhatt1(x, x, x, mask))) + return x + + +class SAGEConv(MessagePassing): + def __init__(self, in_channels, out_channels): + super().__init__(aggr='max') # "Max" aggregation. + self.lin = torch.nn.Linear(in_channels, out_channels) + self.act = torch.nn.ReLU() + self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False) + self.update_act = torch.nn.ReLU() + + def forward(self, x, edge_index): + # x has shape [N, in_channels] + # edge_index has shape [2, E] + edge_index, _ = remove_self_loops(edge_index) + edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) + + return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x) + + def message(self, x_j): + # x_j has shape [E, in_channels] + x_j = self.lin(x_j) + x_j = self.act(x_j) + + return x_j + + def update(self, aggr_out, x): + # aggr_out has shape [N, out_channels] + new_embedding = torch.cat([aggr_out, x], dim=1) + + new_embedding = self.update_lin(new_embedding) + new_embedding = self.update_act(new_embedding) + + return new_embedding + + +class GNN(nn.Module): + def __init__(self, n_fingerprint, pooling, embed_dim=128): + super().__init__() + self.pooling = pooling + self.embed_fingerprint = nn.Embedding(num_embeddings=n_fingerprint, embedding_dim=embed_dim) + self.conv1 = SAGEConv(embed_dim, 128) + self.pool1 = TopKPooling(128, ratio=0.8) + self.conv2 = SAGEConv(128, 128) + self.pool2 = TopKPooling(128, ratio=0.8) + self.conv3 = SAGEConv(128, 128) + self.pool3 = TopKPooling(128, ratio=0.8) + self.linp1 = torch.nn.Linear(256, 128) + self.linp2 = torch.nn.Linear(128, 512) + + self.lin = torch.nn.Linear(128, 512) + self.bn1 = torch.nn.BatchNorm1d(128) + self.bn2 = torch.nn.BatchNorm1d(64) + self.act1 = torch.nn.ReLU() + self.act2 = torch.nn.ReLU() + + def forward(self, data): + # x, edge_index, batch = data.x, data.edge_index, data.batch + x, edge_index, batch = data.x, data.edge_index, data.batch + x = self.embed_fingerprint(x) + x = x.squeeze(1) + x = F.relu(self.conv1(x, edge_index)) + + if self.pooling: + x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch) + x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) + + x = F.relu(self.conv2(x, edge_index)) + + x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch) + x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) + + x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch) + x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) + + x = x1 + x2 + x3 + x = self.linp1(x) + x = self.act1(x) + x = self.linp2(x) + + else: + x = F.relu(self.conv2(x, edge_index)) + x = self.lin(x) + + return x + + +atom_dict = defaultdict(lambda: len(atom_dict)) # 51 bindingdb: 26 +bond_dict = defaultdict(lambda: len(bond_dict)) # 4 bindingdb: 4 +fingerprint_dict = defaultdict(lambda: len(fingerprint_dict)) # 6341 bindingdb: 20366 +edge_dict = defaultdict(lambda: len(edge_dict)) # 17536 bindingdb: 77916 +word_dict = defaultdict(lambda: len(word_dict)) # 22 bindingdb: 21 + + +def drug_featurizer(smiles, radius=2): + mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) + atoms = create_atoms(mol) + i_jbond_dict = create_ijbonddict(mol) + fingerprints = extract_fingerprints(atoms, i_jbond_dict, radius) + adjacency = coo_matrix(Chem.GetAdjacencyMatrix(mol)) + adjacency = coo_matrix(adjacency) + edge_index = np.array([adjacency.row, adjacency.col]) + + return Data(x=torch.LongTensor(fingerprints).unsqueeze(1), edge_index=torch.LongTensor(edge_index)) + + +def create_atoms(mol): + """Create a list of atom (e.g., hydrogen and oxygen) IDs + considering the aromaticity.""" + # GetSymbol: obtain the symbol of the atom + atoms = [a.GetSymbol() for a in mol.GetAtoms()] + for a in mol.GetAromaticAtoms(): + i = a.GetIdx() + atoms[i] = (atoms[i], 'aromatic') + # turn it into index + atoms = [atom_dict[a] for a in atoms] + + return np.array(atoms) + + +def create_ijbonddict(mol): + """Create a dictionary, which each key is a node ID + and each value is the tuples of its neighboring node + and bond (e.g., single and double) IDs.""" + i_jbond_dict = defaultdict(lambda: []) + for b in mol.GetBonds(): + i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() + bond = bond_dict[str(b.GetBondType())] + i_jbond_dict[i].append((j, bond)) + i_jbond_dict[j].append((i, bond)) + return i_jbond_dict + + +def extract_fingerprints(atoms, i_jbond_dict, radius=2): + """Extract the r-radius subgraphs (i.e., fingerprints) + from a molecular graph using Weisfeiler-Lehman algorithm.""" + fingerprints = None + + if (len(atoms) == 1) or (radius == 0): + fingerprints = [fingerprint_dict[a] for a in atoms] + + else: + nodes = atoms + i_jedge_dict = i_jbond_dict + + for _ in range(radius): + + """Update each node ID considering its neighboring nodes and edges + (i.e., r-radius subgraphs or fingerprints).""" + fingerprints = [] + for i, j_edge in i_jedge_dict.items(): + neighbors = [(nodes[j], edge) for j, edge in j_edge] + fingerprint = (nodes[i], tuple(sorted(neighbors))) + fingerprints.append(fingerprint_dict[fingerprint]) + nodes = fingerprints + + """Also update each edge ID considering two nodes + on its both sides.""" + _i_jedge_dict = defaultdict(lambda: []) + for i, j_edge in i_jedge_dict.items(): + for j, edge in j_edge: + both_side = tuple(sorted((nodes[i], nodes[j]))) + edge = edge_dict[(both_side, edge)] + _i_jedge_dict[i].append((j, edge)) + i_jedge_dict = _i_jedge_dict + + return np.array(fingerprints) diff --git a/deepscreen/models/predictors/custom.py b/deepscreen/models/predictors/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..7f5d4d714c420588e858dd3f9f0f94c40d8eed2c --- /dev/null +++ b/deepscreen/models/predictors/custom.py @@ -0,0 +1,21 @@ +from torch import nn + + +class CustomPredictor(nn.Module): + def __init__( + self, + drug_encoder: nn.Module, + protein_encoder: nn.Module, + decoder: nn.Module, + ): + super().__init__() + self.drug_encoder = drug_encoder + self.protein_encoder = protein_encoder + self.decoder = decoder + + def forward(self, enc_drug, enc_protein): + enc_drug = self.drug_encoder(enc_drug) + enc_protein = self.protein_encoder(enc_protein) + preds = self.decoder(enc_drug, enc_protein) + + return preds diff --git a/deepscreen/models/predictors/deep_dtaf.py b/deepscreen/models/predictors/deep_dtaf.py new file mode 100644 index 0000000000000000000000000000000000000000..eb30b040cba4f0b3054734790eb43bec49d9e9f2 --- /dev/null +++ b/deepscreen/models/predictors/deep_dtaf.py @@ -0,0 +1,190 @@ +import torch +import torch.nn as nn + +PT_FEATURE_SIZE = 40 + + +class DeepDTAF(nn.Module): + def __init__(self, smi_charset_len): + super().__init__() + + smi_embed_size = 128 + seq_embed_size = 128 + + seq_oc = 128 + pkt_oc = 128 + smi_oc = 128 + + self.smi_embed = nn.Embedding(smi_charset_len, smi_embed_size) + + self.seq_embed = nn.Linear(PT_FEATURE_SIZE, seq_embed_size) # (N, *, H_{in}) -> (N, *, H_{out}) + + conv_seq = [] + ic = seq_embed_size + for oc in [32, 64, 64, seq_oc]: + conv_seq.append(DilatedParllelResidualBlockA(ic, oc)) + ic = oc + conv_seq.append(nn.AdaptiveMaxPool1d(1)) # (N, oc) + conv_seq.append(Squeeze()) + self.conv_seq = nn.Sequential(*conv_seq) + + # (N, H=32, L) + conv_pkt = [] + ic = seq_embed_size + for oc in [32, 64, pkt_oc]: + conv_pkt.append(nn.Conv1d(ic, oc, 3)) # (N,C,L) + conv_pkt.append(nn.BatchNorm1d(oc)) + conv_pkt.append(nn.PReLU()) + ic = oc + conv_pkt.append(nn.AdaptiveMaxPool1d(1)) + conv_pkt.append(Squeeze()) + self.conv_pkt = nn.Sequential(*conv_pkt) # (N,oc) + + conv_smi = [] + ic = smi_embed_size + for oc in [32, 64, smi_oc]: + conv_smi.append(DilatedParllelResidualBlockB(ic, oc)) + ic = oc + conv_smi.append(nn.AdaptiveMaxPool1d(1)) + conv_smi.append(Squeeze()) + self.conv_smi = nn.Sequential(*conv_smi) # (N,128) + + self.cat_dropout = nn.Dropout(0.2) + + self.classifier = nn.Sequential( + nn.Linear(seq_oc + pkt_oc + smi_oc, 128), + nn.Dropout(0.5), + nn.PReLU(), + nn.Linear(128, 64), + nn.Dropout(0.5), + nn.PReLU(), + # nn.Linear(64, 1), + # nn.PReLU() + ) + + def forward(self, seq, pkt, smi): + # assert seq.shape == (N,L,43) + seq_embed = self.seq_embed(seq) # (N,L,32) + seq_embed = torch.transpose(seq_embed, 1, 2) # (N,32,L) + seq_conv = self.conv_seq(seq_embed) # (N,128) + + # assert pkt.shape == (N,L,43) + pkt_embed = self.seq_embed(pkt) # (N,L,32) + pkt_embed = torch.transpose(pkt_embed, 1, 2) + pkt_conv = self.conv_pkt(pkt_embed) # (N,128) + + # assert smi.shape == (N, L) + smi_embed = self.smi_embed(smi) # (N,L,32) + smi_embed = torch.transpose(smi_embed, 1, 2) + smi_conv = self.conv_smi(smi_embed) # (N,128) + + cat = torch.cat([seq_conv, pkt_conv, smi_conv], dim=1) # (N,128*3) + cat = self.cat_dropout(cat) + + output = self.classifier(cat) + return output + + +class Squeeze(nn.Module): + def forward(self, input: torch.Tensor): + return input.squeeze() + + +class CDilated(nn.Module): + def __init__(self, nIn, nOut, kSize, stride=1, d=1): + super().__init__() + padding = int((kSize - 1) / 2) * d + self.conv = nn.Conv1d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, dilation=d) + + def forward(self, input): + output = self.conv(input) + return output + + +class DilatedParllelResidualBlockA(nn.Module): + def __init__(self, nIn, nOut, add=True): + super().__init__() + n = int(nOut / 5) + n1 = nOut - 4 * n + self.c1 = nn.Conv1d(nIn, n, 1, padding=0) + self.br1 = nn.Sequential(nn.BatchNorm1d(n), nn.PReLU()) + self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0 + self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1 + self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2 + self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3 + self.d16 = CDilated(n, n, 3, 1, 16) # dilation rate of 2^4 + self.br2 = nn.Sequential(nn.BatchNorm1d(nOut), nn.PReLU()) + + if nIn != nOut: + # print(f'{nIn}-{nOut}: add=False') + add = False + self.add = add + + def forward(self, input): + # reduce + output1 = self.c1(input) + output1 = self.br1(output1) + # split and transform + d1 = self.d1(output1) + d2 = self.d2(output1) + d4 = self.d4(output1) + d8 = self.d8(output1) + d16 = self.d16(output1) + + # heirarchical fusion for de-gridding + add1 = d2 + add2 = add1 + d4 + add3 = add2 + d8 + add4 = add3 + d16 + + # merge + combine = torch.cat([d1, add1, add2, add3, add4], 1) + + # if residual version + if self.add: + combine = input + combine + output = self.br2(combine) + return output + + +class DilatedParllelResidualBlockB(nn.Module): + def __init__(self, nIn, nOut, add=True): + super().__init__() + n = int(nOut / 4) + n1 = nOut - 3 * n + self.c1 = nn.Conv1d(nIn, n, 1, padding=0) + self.br1 = nn.Sequential(nn.BatchNorm1d(n), nn.PReLU()) + self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0 + self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1 + self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2 + self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3 + self.br2 = nn.Sequential(nn.BatchNorm1d(nOut), nn.PReLU()) + + if nIn != nOut: + # print(f'{nIn}-{nOut}: add=False') + add = False + self.add = add + + def forward(self, input): + # reduce + output1 = self.c1(input) + output1 = self.br1(output1) + # split and transform + d1 = self.d1(output1) + d2 = self.d2(output1) + d4 = self.d4(output1) + d8 = self.d8(output1) + + # heirarchical fusion for de-gridding + add1 = d2 + add2 = add1 + d4 + add3 = add2 + d8 + + # merge + combine = torch.cat([d1, add1, add2, add3], 1) + + # if residual version + if self.add: + combine = input + combine + output = self.br2(combine) + return output diff --git a/deepscreen/models/predictors/deep_scams.py b/deepscreen/models/predictors/deep_scams.py new file mode 100644 index 0000000000000000000000000000000000000000..ba7f1f597a0b453d205f3d59b129a9d0962fca5c --- /dev/null +++ b/deepscreen/models/predictors/deep_scams.py @@ -0,0 +1,38 @@ +import numpy as np +from torch import nn +from rdkit.Chem import Descriptors, AllChem, MolFromSmiles + +from deepscreen.models.components.mlp import LazyMLP + +DeepSCAMs = LazyMLP( + out_channels=1, + hidden_channels=[100, 1000, 1000], + activation=nn.Tanh, + dropout=0.0 +) + + +def featurizer(smiles, radius=2, n_bits=1024): + descr = Descriptors._descList[0:2] + Descriptors._descList[3:] + calc = [x[1] for x in descr] + try: + mol = MolFromSmiles(smiles) + features = [] + fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=n_bits) + fp_list = [] + fp_list.extend(fp.ToBitString()) + fp_expl = [float(x) for x in fp_list] + ds_n = [] + for d in calc: + v = d(mol) + if v > np.finfo(np.float32).max: + ds_n.append(np.finfo(np.float32).max) + else: + ds_n.append(np.float32(v)) + + features += [fp_expl + list(ds_n)] + except: + log.warning(f'RDKit could not find process SMILES: {smiles}; converted to all 0 features') + features = np.zeros((n_bits,)) + + return features diff --git a/deepscreen/models/predictors/electra_dta.py b/deepscreen/models/predictors/electra_dta.py new file mode 100644 index 0000000000000000000000000000000000000000..b06fe6ab33516a1d866cf74f8ce2c52a88c00991 --- /dev/null +++ b/deepscreen/models/predictors/electra_dta.py @@ -0,0 +1,102 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class ElectraDTA(nn.Module): + def __init__(self, smilen, seq_len, hidden_dim): + super().__init__() + self.drug_input = nn.Linear(smilen, hidden_dim) + self.prot_input = nn.Linear(seq_len, hidden_dim) + self.num_filters = hidden_dim + self.filter_length = 3 + self.n_layers = 4 + self.seblock = True + self.encode_smiles = ConvBlock(hidden_dim, self.seblock, self.num_filters, self.filter_length) + self.encode_prot = ConvBlock(hidden_dim, self.seblock, self.num_filters, self.filter_length) + self.global_pool = nn.AdaptiveMaxPool1d(1) + self.concat = nn.Linear(hidden_dim * 4, hidden_dim * 4) + self.predictions = FCNet(hidden_dim * 4) # you need to define this function + self.interaction_model = nn.Sequential(self.drug_input, self.prot_input, self.encode_smiles, self.encode_prot, + self.global_pool, self.concat, self.predictions) + + def forward(self, x): + x = self.interaction_model(x) + return x + + +class SEBlock(nn.Module): + def __init__(self, channels, r=8): + super().__init__() + self.squeeze = nn.AdaptiveAvgPool1d(1) + self.excitation = nn.Sequential( + nn.Linear(channels, channels // r), + nn.ReLU(), + nn.Linear(channels // r, channels), + nn.Sigmoid() + ) + + def forward(self, x): + out = self.squeeze(x) + out = self.excitation(out) + return x * out + + +class ConvBlock(nn.Module): + def __init__(self, input_channels, seblock, num_filters, filter_length): + super().__init__() + self.conv1 = nn.Conv1d(input_channels, num_filters, filter_length, padding='valid', stride=1) + self.conv2 = nn.Conv1d(num_filters, num_filters * 2, filter_length, padding='valid', stride=1) + self.seblock = seblock + if seblock: + self.se1 = SEBlock(num_filters) + self.se2 = SEBlock(num_filters * 2) + + def forward(self, x): + x = F.relu(self.conv1(x)) + if self.seblock: + x = self.se1(x) + x = F.relu(self.conv2(x)) + if self.seblock: + x = self.se2(x) + return x + + +class Highway(nn.Module): + def __init__(self, dim, n_layers, activation=nn.Tanh(), gate_bias=0): + super(Highway, self).__init__() + self.n_layers = n_layers + self.activation = activation + self.T_gates = nn.ModuleList([nn.Linear(dim, dim) for _ in range(n_layers)]) + self.transforms = nn.ModuleList([nn.Linear(dim, dim) for _ in range(n_layers)]) + self.sigmoid = nn.Sigmoid() + nn.init.constant_(self.linear.bias, gate_bias) + + def forward(self, x): + for i in range(self.n_layers): + T = self.sigmoid(self.gate(x)) + C = 1 - T + H = self.activation(self.linear(x)) + x = T * H + C * x + return x + + +class FCNet(nn.Module): + def __init__(self, input_dim): + super().__init__() + self.n_layers = 4 + self.highway = Highway(input_dim, self.n_layers, gate_bias=-2) + self.fc1 = nn.Linear(input_dim, 1024) + self.fc2 = nn.Linear(1024, 1024) + self.fc3 = nn.Linear(1024, 512) + # self.fc4 = nn.Linear(512, 1) + self.dropout = nn.Dropout(0.4) + + def forward(self, x): + x = self.highway(x) + x = F.relu(self.fc1(x)) + x = self.dropout(x) + x = F.relu(self.fc2(x)) + x = self.dropout(x) + x = F.relu(self.fc3(x)) + # x = self.fc4(x) + return x diff --git a/deepscreen/models/predictors/graph_dta.py b/deepscreen/models/predictors/graph_dta.py index 7e6151279b07a661d04198e55aeca7c93d9743e1..9932c2415842607cf40d17d37803ab4870b406d4 100644 --- a/deepscreen/models/predictors/graph_dta.py +++ b/deepscreen/models/predictors/graph_dta.py @@ -22,7 +22,7 @@ class GraphDTA(LightningModule): # protein sequence encoder (1d conv) self.embedding_xt = nn.Embedding(num_features_protein, embed_dim) - self.conv_xt = nn.LazyConv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) + self.conv_xt = nn.LazyConv1d(out_channels=n_filters, kernel_size=8) self.fc1_xt = nn.Linear(32 * 121, output_dim) # combined layers diff --git a/deepscreen/models/predictors/monn.py b/deepscreen/models/predictors/monn.py new file mode 100644 index 0000000000000000000000000000000000000000..491af435c89ba8ae3f099a57f30d7aa7108d8ee7 --- /dev/null +++ b/deepscreen/models/predictors/monn.py @@ -0,0 +1,283 @@ +import torch +from torch import nn +import torch.nn.functional as F + +# some predefined parameters +elem_list = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', + 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', + 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh', 'Tc', 'Ba', 'Bi', 'Hf', 'Mo', 'U', + 'Sm', 'Os', 'Ir', 'Ce', 'Gd', 'Ga', 'Cs', 'unknown'] +atom_fdim = len(elem_list) + 6 + 6 + 6 + 1 +bond_fdim = 6 +max_nb = 6 + + +class MONN(nn.Module): + # init_A, init_B, init_W = loading_emb(measure) + # net = Net(init_A, init_B, init_W, params) + def __init__(self, init_atom_features, init_bond_features, init_word_features, params): + super().__init__() + self.init_atom_features = init_atom_features + self.init_bond_features = init_bond_features + self.init_word_features = init_word_features + + """hyper part""" + GNN_depth, inner_CNN_depth, DMA_depth, k_head, kernel_size, hidden_size1, hidden_size2 = params + self.GNN_depth = GNN_depth + self.inner_CNN_depth = inner_CNN_depth + self.DMA_depth = DMA_depth + self.k_head = k_head + self.kernel_size = kernel_size + self.hidden_size1 = hidden_size1 + self.hidden_size2 = hidden_size2 + + """GraphConv Module""" + self.vertex_embedding = nn.Linear(atom_fdim, + self.hidden_size1) # first transform vertex features into hidden representations + + # GWM parameters + self.W_a_main = nn.ModuleList( + [nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.k_head)]) for i in + range(self.GNN_depth)]) + self.W_a_super = nn.ModuleList( + [nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.k_head)]) for i in + range(self.GNN_depth)]) + self.W_main = nn.ModuleList( + [nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.k_head)]) for i in + range(self.GNN_depth)]) + self.W_bmm = nn.ModuleList( + [nn.ModuleList([nn.Linear(self.hidden_size1, 1) for i in range(self.k_head)]) for i in + range(self.GNN_depth)]) + + self.W_super = nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) + self.W_main_to_super = nn.ModuleList( + [nn.Linear(self.hidden_size1 * self.k_head, self.hidden_size1) for i in range(self.GNN_depth)]) + self.W_super_to_main = nn.ModuleList( + [nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) + + self.W_zm1 = nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) + self.W_zm2 = nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) + self.W_zs1 = nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) + self.W_zs2 = nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) + self.GRU_main = nn.GRUCell(self.hidden_size1, self.hidden_size1) + self.GRU_super = nn.GRUCell(self.hidden_size1, self.hidden_size1) + + # WLN parameters + self.label_U2 = nn.ModuleList([nn.Linear(self.hidden_size1 + bond_fdim, self.hidden_size1) for i in + range(self.GNN_depth)]) # assume no edge feature transformation + self.label_U1 = nn.ModuleList( + [nn.Linear(self.hidden_size1 * 2, self.hidden_size1) for i in range(self.GNN_depth)]) + + """CNN-RNN Module""" + # CNN parameters + self.embed_seq = nn.Embedding(len(self.init_word_features), 20, padding_idx=0) + self.embed_seq.weight = nn.Parameter(self.init_word_features) + self.embed_seq.weight.requires_grad = False + + self.conv_first = nn.Conv1d(20, self.hidden_size1, kernel_size=self.kernel_size, + padding=(self.kernel_size - 1) / 2) + self.conv_last = nn.Conv1d(self.hidden_size1, self.hidden_size1, kernel_size=self.kernel_size, + padding=(self.kernel_size - 1) / 2) + + self.plain_CNN = nn.ModuleList([]) + for i in range(self.inner_CNN_depth): + self.plain_CNN.append(nn.Conv1d(self.hidden_size1, self.hidden_size1, kernel_size=self.kernel_size, + padding=(self.kernel_size - 1) / 2)) + + """Affinity Prediction Module""" + self.super_final = nn.Linear(self.hidden_size1, self.hidden_size2) + self.c_final = nn.Linear(self.hidden_size1, self.hidden_size2) + self.p_final = nn.Linear(self.hidden_size1, self.hidden_size2) + + # DMA parameters + self.mc0 = nn.Linear(hidden_size2, hidden_size2) + self.mp0 = nn.Linear(hidden_size2, hidden_size2) + + self.mc1 = nn.ModuleList([nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) + self.mp1 = nn.ModuleList([nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) + + self.hc0 = nn.ModuleList([nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) + self.hp0 = nn.ModuleList([nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) + self.hc1 = nn.ModuleList([nn.Linear(self.hidden_size2, 1) for i in range(self.DMA_depth)]) + self.hp1 = nn.ModuleList([nn.Linear(self.hidden_size2, 1) for i in range(self.DMA_depth)]) + + self.c_to_p_transform = nn.ModuleList( + [nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) + self.p_to_c_transform = nn.ModuleList( + [nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) + + self.GRU_dma = nn.GRUCell(self.hidden_size2, self.hidden_size2) + # Output layer + self.W_out = nn.Linear(self.hidden_size2 * self.hidden_size2 * 2, 1) + + """Pairwise Interaction Prediction Module""" + self.pairwise_compound = nn.Linear(self.hidden_size1, self.hidden_size1) + self.pairwise_protein = nn.Linear(self.hidden_size1, self.hidden_size1) + + def mask_softmax(self, a, mask, dim=-1): + a_max = torch.max(a, dim, keepdim=True)[0] + a_exp = torch.exp(a - a_max) + a_exp = a_exp * mask + a_softmax = a_exp / (torch.sum(a_exp, dim, keepdim=True) + 1e-6) + return a_softmax + + def graph_conv_module(self, batch_size, vertex_mask, vertex, edge, atom_adj, bond_adj, nbs_mask): + n_vertex = vertex_mask.size(1) + + # initial features + vertex_initial = torch.index_select(self.init_atom_features, 0, vertex.view(-1)) + vertex_initial = vertex_initial.view(batch_size, -1, atom_fdim) + edge_initial = torch.index_select(self.init_bond_features, 0, edge.view(-1)) + edge_initial = edge_initial.view(batch_size, -1, bond_fdim) + + vertex_feature = F.leaky_relu(self.vertex_embedding(vertex_initial), 0.1) + super_feature = torch.sum(vertex_feature * vertex_mask.view(batch_size, -1, 1), dim=1, keepdim=True) + + for GWM_iter in range(self.GNN_depth): + # prepare main node features + for k in range(self.k_head): + a_main = torch.tanh(self.W_a_main[GWM_iter][k](vertex_feature)) + a_super = torch.tanh(self.W_a_super[GWM_iter][k](super_feature)) + a = self.W_bmm[GWM_iter][k](a_main * super_feature) + attn = self.mask_softmax(a.view(batch_size, -1), vertex_mask).view(batch_size, -1, 1) + k_main_to_super = torch.bmm(attn.transpose(1, 2), self.W_main[GWM_iter][k](vertex_feature)) + if k == 0: + m_main_to_super = k_main_to_super + else: + m_main_to_super = torch.cat([m_main_to_super, k_main_to_super], dim=-1) # concat k-head + main_to_super = torch.tanh(self.W_main_to_super[GWM_iter](m_main_to_super)) + main_self = self.wln_unit(batch_size, vertex_mask, vertex_feature, edge_initial, atom_adj, bond_adj, + nbs_mask, GWM_iter) + + super_to_main = torch.tanh(self.W_super_to_main[GWM_iter](super_feature)) + super_self = torch.tanh(self.W_super[GWM_iter](super_feature)) + # warp gate and GRU for update main node features, use main_self and super_to_main + z_main = torch.sigmoid(self.W_zm1[GWM_iter](main_self) + self.W_zm2[GWM_iter](super_to_main)) + hidden_main = (1 - z_main) * main_self + z_main * super_to_main + vertex_feature = self.GRU_main(hidden_main.view(-1, self.hidden_size1), + vertex_feature.view(-1, self.hidden_size1)) + vertex_feature = vertex_feature.view(batch_size, n_vertex, self.hidden_size1) + # warp gate and GRU for update super node features + z_supper = torch.sigmoid(self.W_zs1[GWM_iter](super_self) + self.W_zs2[GWM_iter](main_to_super)) + hidden_super = (1 - z_supper) * super_self + z_supper * main_to_super + super_feature = self.GRU_super(hidden_super.view(batch_size, self.hidden_size1), + super_feature.view(batch_size, self.hidden_size1)) + super_feature = super_feature.view(batch_size, 1, self.hidden_size1) + + return vertex_feature, super_feature + + def wln_unit(self, batch_size, vertex_mask, vertex_features, edge_initial, atom_adj, bond_adj, nbs_mask, GNN_iter): + n_vertex = vertex_mask.size(1) + n_nbs = nbs_mask.size(2) + + vertex_mask = vertex_mask.view(batch_size, n_vertex, 1) + nbs_mask = nbs_mask.view(batch_size, n_vertex, n_nbs, 1) + + vertex_nei = torch.index_select(vertex_features.view(-1, self.hidden_size1), 0, atom_adj).view(batch_size, + n_vertex, n_nbs, + self.hidden_size1) + edge_nei = torch.index_select(edge_initial.view(-1, bond_fdim), 0, bond_adj).view(batch_size, n_vertex, n_nbs, + bond_fdim) + + # Weisfeiler Lehman relabelling + l_nei = torch.cat((vertex_nei, edge_nei), -1) + nei_label = F.leaky_relu(self.label_U2[GNN_iter](l_nei), 0.1) + nei_label = torch.sum(nei_label * nbs_mask, dim=-2) + new_label = torch.cat((vertex_features, nei_label), 2) + new_label = self.label_U1[GNN_iter](new_label) + vertex_features = F.leaky_relu(new_label, 0.1) + + return vertex_features + + def cnn_module(self, sequence): + ebd = self.embed_seq(sequence) + ebd = ebd.transpose(1, 2) + x = F.leaky_relu(self.conv_first(ebd), 0.1) + + for i in range(self.inner_CNN_depth): + x = self.plain_CNN[i](x) + x = F.leaky_relu(x, 0.1) + + x = F.leaky_relu(self.conv_last(x), 0.1) + H = x.transpose(1, 2) + # H, hidden = self.rnn(H) + + return H + + def pairwise_pred_module(self, batch_size, comp_feature, prot_feature, vertex_mask, seq_mask): + pairwise_c_feature = F.leaky_relu(self.pairwise_compound(comp_feature), 0.1) + pairwise_p_feature = F.leaky_relu(self.pairwise_protein(prot_feature), 0.1) + pairwise_pred = torch.matmul(pairwise_c_feature, pairwise_p_feature.transpose(1, 2)) + # TODO: difference between the pairwise_mask here and in the data? + pairwise_mask = torch.matmul(vertex_mask.view(batch_size, -1, 1), seq_mask.view(batch_size, 1, -1)) + pairwise_pred = pairwise_pred * pairwise_mask + return pairwise_pred + + def affinity_pred_module(self, batch_size, comp_feature, prot_feature, super_feature, vertex_mask, seq_mask, + pairwise_pred): + comp_feature = F.leaky_relu(self.c_final(comp_feature), 0.1) + prot_feature = F.leaky_relu(self.p_final(prot_feature), 0.1) + super_feature = F.leaky_relu(self.super_final(super_feature.view(batch_size, -1)), 0.1) + + cf, pf = self.dma_gru(batch_size, comp_feature, vertex_mask, prot_feature, seq_mask, pairwise_pred) + + cf = torch.cat([cf.view(batch_size, -1), super_feature.view(batch_size, -1)], dim=1) + kroneck = F.leaky_relu( + torch.matmul(cf.view(batch_size, -1, 1), pf.view(batch_size, 1, -1)).view(batch_size, -1), 0.1) + + affinity_pred = self.W_out(kroneck) + return affinity_pred + + def dma_gru(self, batch_size, comp_feats, vertex_mask, prot_feats, seq_mask, pairwise_pred): + vertex_mask = vertex_mask.view(batch_size, -1, 1) + seq_mask = seq_mask.view(batch_size, -1, 1) + + cf = torch.Tensor() + pf = torch.Tensor() + + c0 = torch.sum(comp_feats * vertex_mask, dim=1) / torch.sum(vertex_mask, dim=1) + p0 = torch.sum(prot_feats * seq_mask, dim=1) / torch.sum(seq_mask, dim=1) + + m = c0 * p0 + for DMA_iter in range(self.DMA_depth): + c_to_p = torch.matmul(pairwise_pred.transpose(1, 2), + F.tanh(self.c_to_p_transform[DMA_iter](comp_feats))) # batch * n_residue * hidden + p_to_c = torch.matmul(pairwise_pred, + F.tanh(self.p_to_c_transform[DMA_iter](prot_feats))) # batch * n_vertex * hidden + + c_tmp = F.tanh(self.hc0[DMA_iter](comp_feats)) * F.tanh(self.mc1[DMA_iter](m)).view(batch_size, 1, + -1) * p_to_c + p_tmp = F.tanh(self.hp0[DMA_iter](prot_feats)) * F.tanh(self.mp1[DMA_iter](m)).view(batch_size, 1, + -1) * c_to_p + + c_att = self.mask_softmax(self.hc1[DMA_iter](c_tmp).view(batch_size, -1), vertex_mask.view(batch_size, -1)) + p_att = self.mask_softmax(self.hp1[DMA_iter](p_tmp).view(batch_size, -1), seq_mask.view(batch_size, -1)) + + cf = torch.sum(comp_feats * c_att.view(batch_size, -1, 1), dim=1) + pf = torch.sum(prot_feats * p_att.view(batch_size, -1, 1), dim=1) + + m = self.GRU_dma(m, cf * pf) + + return cf, pf + + def forward(self, enc_drug, enc_protein): + vertex_mask, vertex, edge, atom_adj, bond_adj, nbs_mask = enc_drug + vertex, vertex_mask = vertex + edge, _ = edge + atom_adj, _ = atom_adj + bond_adj, _ = bond_adj + nbs_mask, _ = enc_drug + + seq_mask, sequence = enc_protein + + batch_size = vertex.size(0) + + atom_feature, super_feature = self.graph_conv_module(batch_size, vertex_mask, vertex, edge, atom_adj, bond_adj, + nbs_mask) + prot_feature = self.cnn_module(sequence) + + pairwise_pred = self.pairwise_pred_module(batch_size, atom_feature, prot_feature, vertex_mask, seq_mask) + affinity_pred = self.affinity_pred_module(batch_size, atom_feature, prot_feature, super_feature, vertex_mask, + seq_mask, pairwise_pred) + + return affinity_pred # , pairwise_pred diff --git a/deepscreen/models/predictors/multi_entity.py b/deepscreen/models/predictors/multi_entity.py new file mode 100644 index 0000000000000000000000000000000000000000..0ced64056001bc776aeb836036b7c72ef5546ed0 --- /dev/null +++ b/deepscreen/models/predictors/multi_entity.py @@ -0,0 +1,46 @@ +from itertools import zip_longest +from typing import Sequence, Dict, Union + +import torch +from lightning_utilities.core.rank_zero import rank_zero_warn +from torch import nn + + +class MultiEntityInteraction(nn.Module): + def __init__( + self, + encoders: Union[nn.Module, Sequence[nn.Module], Dict[str, nn.Module]], + decoders: Union[nn.Module, Sequence[nn.Module], Dict[str, nn.Module]], + ): + super().__init__() + + # Add new encoders to MultiEntityInteraction. + if isinstance(encoders, nn.Module): + # set compatible with original type expectations + encoders = [encoders] + elif isinstance(encoders, Sequence): + # Check all values are encoders + for i, encoder in enumerate(encoders): + if not isinstance(encoder, nn.Module): + raise ValueError( + f"Value {encoder} at index {i} is not an instance of `nn.Module`." + ) + elif isinstance(encoders, dict): + # Check all values are encoders + for k, encoder in encoders.items(): + if not isinstance(encoder, nn.Module): + raise ValueError( + f"Value {encoder} at key {k} is not an instance of `nn.Module`." + ) + else: + raise ValueError( + "Unknown input to MultiEntityInteraction. Expected, `nn.Module`, or `dict`/`sequence` of the" + f" previous, but got {encoders}" + ) + self.encoders = encoders + self.decoders = decoders + + def forward(self, inputs): + preds = [encoder(x) for encoder, x in zip_longest(self.encoders, inputs)] + + return preds diff --git a/deepscreen/models/predictors/transformer_cpi.py b/deepscreen/models/predictors/transformer_cpi.py index e4fbeacfd643004342a1a1e6af23f5a685af12f6..c2e715ac90787834df859019cdbf1cf0f21d9f0d 100644 --- a/deepscreen/models/predictors/transformer_cpi.py +++ b/deepscreen/models/predictors/transformer_cpi.py @@ -266,3 +266,4 @@ class Decoder(nn.Module): label = F.relu(self.fc_1(sum)) # label = self.fc_2(label) return label + diff --git a/deepscreen/models/predictors/transformer_cpi_2.py b/deepscreen/models/predictors/transformer_cpi_2.py index fd4f0f15483a3ee65f5e444eb870ca6f1f0dd198..223483ba7a6f6a302b6f2a8779b314f784dbe23b 100644 --- a/deepscreen/models/predictors/transformer_cpi_2.py +++ b/deepscreen/models/predictors/transformer_cpi_2.py @@ -23,8 +23,9 @@ class TransformerCPI2(nn.Module): # adj_mat = [batch_size, atom_num, atom_num] # enc_protein = [batch_size, protein_len, 768] compound, adj = compound - adj, _ = adj + compound, compound_lengths = compound + adj, _ = adj protein, protein_lengths = protein # Add a global/master node to the compound @@ -85,7 +86,7 @@ class Decoder(nn.Module): ) self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=self.n_layers) self.fc_1 = nn.Linear(768, 256) - self.fc_2 = nn.Linear(256, 2) + # self.fc_2 = nn.Linear(256, 2) self.dropout = nn.Dropout(dropout) def forward(self, tgt, src, tgt_mask=None, src_mask=None): diff --git a/deepscreen/models/predictors/transformer_cpi_2.py.bak b/deepscreen/models/predictors/transformer_cpi_2.py.bak new file mode 100644 index 0000000000000000000000000000000000000000..05026247f968b8034e8ae481418c8809e5045425 --- /dev/null +++ b/deepscreen/models/predictors/transformer_cpi_2.py.bak @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Encoder(nn.Module): + """protein feature extraction""" + + def __init__(self, pretrain, n_layers): + super().__init__() + self.pretrain = pretrain + self.hid_dim = 768 + self.n_layers = n_layers + self.encoder_layer = nn.TransformerEncoderLayer( + d_model=self.hid_dim, nhead=8, dim_feedforward=self.hid_dim * 4, dropout=0.1 + ) + self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=self.n_layers) + + def forward(self, protein, mask): + # protein = [batch_size, protein_len] + # mask = [batch_size, protein_len] 0 for true positions, 1 for mask positions + with torch.no_grad(): + protein = self.pretrain(protein.long(), mask.long())[0] + protein = protein.permute(1, 0, 2).contiguous() # protein = [protein_len, batch_size, 768] + protein = self.encoder(protein, src_key_padding_mask=mask) # protein = [protein_len, batch_size, 768] + return protein, mask + + +class Decoder(nn.Module): + """compound feature extraction""" + + def __init__(self, n_layers, dropout): + super().__init__() + self.hid_dim = 768 + self.n_layers = n_layers + self.decoder_layer = nn.TransformerDecoderLayer( + d_model=self.hid_dim, nhead=8, dim_feedforward=self.hid_dim * 4, dropout=0.1 + ) + self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=self.n_layers) + self.fc_1 = nn.Linear(768, 256) + self.fc_2 = nn.Linear(256, 2) + self.dropout = nn.Dropout(dropout) + + def forward(self, tgt, src, tgt_mask=None, src_mask=None): + # tgt = [batch_size, compound len, hid_dim] + # src = [protein_len, batch_size, hid_dim] # encoder output + tgt = tgt.permute(1, 0, 2).contiguous() # tgt = [compound_len, batch_size, hid_dim] + # tgt_mask = tgt_mask == 1 + tgt = self.decoder( + tgt, src, tgt_key_padding_mask=tgt_mask, memory_key_padding_mask=src_mask + ) # tgt = [compound_len, batch_size, hid_dim] + tgt = tgt.permute(1, 0, 2).contiguous() # tgt = [batch_size, compound_len, hid_dim] + x = tgt[:, 0, :] + label = F.relu(self.fc_1(x)) + label = self.fc_2(label) + return label + + +class TransformerCPI2(nn.Module): + def __init__(self, encoder, decoder, atom_dim=34): + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.fc_1 = nn.Linear(atom_dim, atom_dim) + self.fc_2 = nn.Linear(atom_dim, 768) + + # def gcn(self, compound, adj): + # # input = [batch, num_node, atom_dim] + # # adj = [batch, num_node, num_node] + # support = self.fc_1(compound) # support = [batch, num_node, atom_dim] + # output = torch.bmm(adj, support) # output = [batch, num_node, atom_dim] + # return output + + def gcn(self, compound, adj): + batch_size, num_node, atom_dim = compound.shape + # Add a global node with padding at top + compound = F.pad(compound, (0, 0, 1, 0), value=0) + # Add an identity matrix to each adjacency matrix to represent self-connections + adj = adj + torch.eye(num_node, device=compound.device).unsqueeze(0).expand(batch_size, -1, -1) + # Add global edges with padding at left and top + adj = F.pad(adj, (1, 0, 1, 0), value=1) + + support = self.fc_1(compound) # support = [batch, num_node, atom_dim] + output = torch.bmm(adj, support) # output = [batch, num_node, atom_dim] + + # TODO support RNN packing in collate function + + return output + + def forward(self, enc_drug, enc_protein): + # atom_feat = [batch_size, atom_num, atom_dim] + # adj_mat = [batch_size, atom_num, atom_num] + # enc_protein = [batch_size, protein_len, 768] + compound, adj, atom_num = enc_drug + protein, protein_num = enc_protein + compound_max_len = compound.shape[1] + protein_max_len = protein.shape[1] + device = compound.device + + compound_mask, protein_mask = self.make_masks(atom_num, protein_num, compound_max_len, protein_max_len, device) + + compound = self.gcn(compound, adj) # compound = [batch_size, atom_num, atom_dim] + compound = F.relu(self.fc_2(compound)) # compound = [batch, compound_len, 768] + enc_src, src_mask = self.encoder(protein, protein_mask) # enc_src = [protein_len,batch , hid_dim] + out = self.decoder(compound, enc_src, compound_mask, src_mask) # out = [batch_size, 2] + + return out + + @staticmethod + def make_masks(atom_num, protein_num, compound_max_len, protein_max_len, device): + N = len(atom_num) # batch size + compound_mask = torch.ones((N, compound_max_len), device=device) + protein_mask = torch.ones((N, protein_max_len), device=device) + for i in range(N): + compound_mask[i, :atom_num[i]] = 0 + protein_mask[i, :protein_num[i]] = 0 + return compound_mask, protein_mask + + +def pack(batch): + N = len(batch) + atoms, adjs, proteins = zip(*[(torch.Tensor(sample['X1'][0]), + torch.Tensor(sample['X1'][1]), + torch.Tensor(sample['X2'])) + for sample in batch]) + + atoms_len = 0 + atom_num = [] + for atom in atoms: + atom_num.append(atom.shape[0] + 1) + if atom.shape[0] >= atoms_len: + atoms_len = atom.shape[0] + atoms_len += 1 + + proteins_len = 0 + protein_num = [] + for protein in proteins: + protein_num.append(protein.shape[0]) + if protein.shape[0] >= proteins_len: + proteins_len = protein.shape[0] + + atoms_padded = torch.zeros((N, atoms_len, 34)) + for i, atom in enumerate(atoms): + a_len = atom.shape[0] + atoms_padded[i, 1:a_len + 1, :] = atom + + adjs_padded = torch.zeros((N, atoms_len, atoms_len)) + for i, adj in enumerate(adjs): + adjs_padded[i, 0, :] = 1 + adjs_padded[i, :, 0] = 1 + a_len = adj.shape[0] + adj = adj + torch.eye(a_len) + adjs_padded[i, 1:a_len + 1, 1:a_len + 1] = adj + + proteins_padded = torch.zeros((N, proteins_len)) + for i, protein in enumerate(proteins): + a_len = protein.shape[0] + proteins_padded[i, :a_len] = protein + + return { + 'N': torch.Tensor([sample['N'] for sample in batch]), + 'X1': (atoms_padded, adjs_padded, atom_num), + 'ID1': [sample['ID1'] for sample in batch], + 'X2': (proteins_padded, protein_num), + 'ID2': [sample['ID2'] for sample in batch], + 'Y': torch.Tensor([sample['Y'] for sample in batch]), + 'IDX': torch.Tensor([sample['IDX'] for sample in batch]), + } diff --git a/deepscreen/predict.py b/deepscreen/predict.py index 1669c1d343017bb686483dc052d24f8a59a96508..0820c55a40166983d1120807217983f34d126bef 100644 --- a/deepscreen/predict.py +++ b/deepscreen/predict.py @@ -55,9 +55,7 @@ def predict(cfg: DictConfig) -> Tuple[list, dict]: } log.info("Start predicting.") - - predictions = trainer.predict(model=model, datamodule=datamodule, - ckpt_path=cfg.ckpt_path, return_predictions=True) + predictions = trainer.predict(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path, return_predictions=True) return predictions, object_dict diff --git a/deepscreen/utils/__pycache__/__init__.cpython-311.pyc b/deepscreen/utils/__pycache__/__init__.cpython-311.pyc index 239dc7ee2dcfbe59ce2de4ee73a8c5194cce0a1b..e2c1aea4e93f598409187cc41804f9e797571bcf 100644 Binary files a/deepscreen/utils/__pycache__/__init__.cpython-311.pyc and b/deepscreen/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/deepscreen/utils/__pycache__/__init__.cpython-39.pyc b/deepscreen/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb48a5f8dca848a15291e601f1d9d1ed4037dae6 Binary files /dev/null and b/deepscreen/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/deepscreen/utils/__pycache__/hydra.cpython-311.pyc b/deepscreen/utils/__pycache__/hydra.cpython-311.pyc index 64cab8f0a3c32ac7b85595e8ad5cc24bebf7349d..8bab331cb43c5ad9c302d3484a6a51084deaec75 100644 Binary files a/deepscreen/utils/__pycache__/hydra.cpython-311.pyc and b/deepscreen/utils/__pycache__/hydra.cpython-311.pyc differ diff --git a/deepscreen/utils/__pycache__/hydra.cpython-39.pyc b/deepscreen/utils/__pycache__/hydra.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a35ff660206a9268068f0f9b73477d2367750e51 Binary files /dev/null and b/deepscreen/utils/__pycache__/hydra.cpython-39.pyc differ diff --git a/deepscreen/utils/__pycache__/instantiators.cpython-311.pyc b/deepscreen/utils/__pycache__/instantiators.cpython-311.pyc index 9cf3d20399e00565dfe5eb8d8805728b69985f45..3eb9521c5eae99f18716f5d6165e0abb72f1262c 100644 Binary files a/deepscreen/utils/__pycache__/instantiators.cpython-311.pyc and b/deepscreen/utils/__pycache__/instantiators.cpython-311.pyc differ diff --git a/deepscreen/utils/__pycache__/instantiators.cpython-39.pyc b/deepscreen/utils/__pycache__/instantiators.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fc3d8c894aafd90e1813bde563c381987841657 Binary files /dev/null and b/deepscreen/utils/__pycache__/instantiators.cpython-39.pyc differ diff --git a/deepscreen/utils/__pycache__/lightning.cpython-311.pyc b/deepscreen/utils/__pycache__/lightning.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c913f4a2f5e9da8cf1b011f2c9ea66b0d9d1c345 Binary files /dev/null and b/deepscreen/utils/__pycache__/lightning.cpython-311.pyc differ diff --git a/deepscreen/utils/__pycache__/logging.cpython-311.pyc b/deepscreen/utils/__pycache__/logging.cpython-311.pyc index dbe194675c0f37bddec1318e5db8d883c354bd78..1bf265f9e74c20ed7c1200f8ff6424a22ff5b8de 100644 Binary files a/deepscreen/utils/__pycache__/logging.cpython-311.pyc and b/deepscreen/utils/__pycache__/logging.cpython-311.pyc differ diff --git a/deepscreen/utils/__pycache__/logging.cpython-39.pyc b/deepscreen/utils/__pycache__/logging.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..598d362b41c76551ea19ce0547d69681710c1a74 Binary files /dev/null and b/deepscreen/utils/__pycache__/logging.cpython-39.pyc differ diff --git a/deepscreen/utils/__pycache__/rich.cpython-311.pyc b/deepscreen/utils/__pycache__/rich.cpython-311.pyc index b33367488a99a56c57b65f042720f7ca7789bf23..4ad255de98093ce7db7e917893ce395c85d28281 100644 Binary files a/deepscreen/utils/__pycache__/rich.cpython-311.pyc and b/deepscreen/utils/__pycache__/rich.cpython-311.pyc differ diff --git a/deepscreen/utils/__pycache__/rich.cpython-39.pyc b/deepscreen/utils/__pycache__/rich.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1803eed05521743da1af12c3b4900362cbce4390 Binary files /dev/null and b/deepscreen/utils/__pycache__/rich.cpython-39.pyc differ diff --git a/deepscreen/utils/__pycache__/utils.cpython-311.pyc b/deepscreen/utils/__pycache__/utils.cpython-311.pyc index 1d568c9ce81de084efabdba8f1bb6850564cba25..38baa912f1adac98765b633efecff777ff49efb2 100644 Binary files a/deepscreen/utils/__pycache__/utils.cpython-311.pyc and b/deepscreen/utils/__pycache__/utils.cpython-311.pyc differ diff --git a/deepscreen/utils/__pycache__/utils.cpython-39.pyc b/deepscreen/utils/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9179d50866f9f021365876487b506ec05cf840df Binary files /dev/null and b/deepscreen/utils/__pycache__/utils.cpython-39.pyc differ diff --git a/deepscreen/utils/hydra.py b/deepscreen/utils/hydra.py index 1a33602de5267a02db3c30101dd287799ba99943..9d32f967a1e10199ff308fba8b6a5f7db24c7f4a 100644 --- a/deepscreen/utils/hydra.py +++ b/deepscreen/utils/hydra.py @@ -1,16 +1,18 @@ -from datetime import timedelta +from datetime import datetime from pathlib import Path import re -from time import time from typing import Any, Tuple import pandas as pd from hydra import TaskFunction from hydra.core.hydra_config import HydraConfig +from hydra.core.override_parser.overrides_parser import OverridesParser from hydra.core.utils import _save_config from hydra.experimental.callbacks import Callback from hydra.types import RunMode +from hydra._internal.config_loader_impl import ConfigLoaderImpl from omegaconf import DictConfig, OmegaConf +from omegaconf.errors import MissingMandatoryValue from deepscreen.utils import get_logger @@ -33,15 +35,16 @@ class CSVExperimentSummary(Callback): ckpt_path = override.split('=', 1)[1] if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')): config.hydra.overrides.task[i] = self.parse_ckpt_path_from_experiment_summary(ckpt_path) + log.info(ckpt_path) break if config.hydra.sweeper.get('params'): if config.hydra.sweeper.params.get('ckpt_path'): ckpt_path = str(config.hydra.sweeper.params.ckpt_path).strip("'\"") if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')): config.hydra.sweeper.params.ckpt_path = self.parse_ckpt_path_from_experiment_summary(ckpt_path) - + log.info(ckpt_path) def on_job_start(self, config: DictConfig, *, task_function: TaskFunction, **kwargs: Any) -> None: - self.time['start'] = time() + self.time['start'] = datetime.now() def on_job_end(self, config: DictConfig, job_return, **kwargs: Any) -> None: # Skip callback if job is DDP subprocess @@ -49,7 +52,7 @@ class CSVExperimentSummary(Callback): return try: - self.time['end'] = time() + self.time['end'] = datetime.now() if config.hydra.mode == RunMode.RUN: summary_file_path = Path(config.hydra.run.dir) / self.filename elif config.hydra.mode == RunMode.MULTIRUN: @@ -68,7 +71,7 @@ class CSVExperimentSummary(Callback): info_dict = dict(override.split('=', 1) for override in job_return.overrides) info_dict['job_status'] = job_return.status.name info_dict['job_id'] = job_return.hydra_cfg.hydra.job.id - info_dict['wall_time'] = str(timedelta(self.time['end'] - self.time['start'])) + info_dict['wall_time'] = str(self.time['end'] - self.time['start']) # Add checkpoint info if info_dict.get('ckpt_path'): @@ -79,7 +82,9 @@ class CSVExperimentSummary(Callback): if info_dict.get('ckpt_path') and ckpt_path != info_dict['ckpt_path']: info_dict['previous_ckpt_path'] = info_dict['ckpt_path'] info_dict['ckpt_path'] = ckpt_path - info_dict['best_epoch'] = int(re.search(r'epoch_(\d+)', info_dict['ckpt_path']).group(1)) + if info_dict.get('ckpt_path'): + if (epoch := re.search(r'epoch_(\d+)', info_dict['ckpt_path'])) is not None: + info_dict['best_epoch'] = int(epoch.group(1)) # Add metrics info metrics_df = pd.DataFrame() @@ -90,13 +95,14 @@ class CSVExperimentSummary(Callback): log.info(f"Summarizing metrics with prefix `{self.prefix}` from {csv_metrics_path}") metrics_df = pd.read_csv(csv_metrics_path) # Find rows where 'test/' columns are not null and reset its epoch to the best model epoch - test_columns = [col for col in metrics_df.columns if col.startswith('test/')] - if test_columns: - mask = metrics_df[test_columns].notna().any(axis=1) - metrics_df.loc[mask, 'epoch'] = info_dict['best_epoch'] - # Group and filter by best epoch - metrics_df = metrics_df.groupby('epoch').first() - metrics_df = metrics_df[metrics_df.index == info_dict['best_epoch']] + if info_dict.get('best_epoch'): + test_columns = [col for col in metrics_df.columns if col.startswith('test/')] + if test_columns: + mask = metrics_df[test_columns].notna().any(axis=1) + metrics_df.loc[mask, 'epoch'] = info_dict['best_epoch'] + # Group and filter by best epoch + metrics_df = metrics_df.groupby('epoch').first() + metrics_df = metrics_df[metrics_df.index == info_dict['best_epoch']] else: log.info(f"No metrics.csv found in {output_dir}") @@ -108,12 +114,14 @@ class CSVExperimentSummary(Callback): # Add extra info from the input batch experiment summary if self.input_experiment_summary is not None and 'ckpt_path' in metrics_df.columns: + log.info(self.input_experiment_summary['ckpt_path']) + log.info(metrics_df['ckpt_path']) orig_meta = self.input_experiment_summary[ self.input_experiment_summary['ckpt_path'] == metrics_df['ckpt_path'][0] ].head(1) if not orig_meta.empty: orig_meta.index = [0] - metrics_df = metrics_df.combine_first(orig_meta) + metrics_df = metrics_df.astype('O').combine_first(orig_meta.astype('O')) summary_df = pd.concat([summary_df, metrics_df]) @@ -125,6 +133,7 @@ class CSVExperimentSummary(Callback): log.exception("Unable to save the experiment summary due to an error.", exc_info=e) def parse_ckpt_path_from_experiment_summary(self, ckpt_path): + log.info(ckpt_path) try: self.input_experiment_summary = pd.read_csv( ckpt_path, usecols=lambda col: not col.startswith(self.prefix) @@ -145,8 +154,9 @@ class CSVExperimentSummary(Callback): def checkpoint_rerun_config(config: DictConfig): hydra_cfg = HydraConfig.get() - - if hydra_cfg.output_subdir is not None: + if not Path(config.ckpt_path).is_file(): + raise FileNotFoundError(f'Not a valid checkpoint file: {config.ckpt_path}') + if hydra_cfg.get('output_subdir'): ckpt_cfg_path = Path(config.ckpt_path).parents[1] / hydra_cfg.output_subdir / 'config.yaml' hydra_output = Path(hydra_cfg.runtime.output_dir) / hydra_cfg.output_subdir @@ -155,41 +165,21 @@ def checkpoint_rerun_config(config: DictConfig): f"merging config overrides with checkpoint config...") ckpt_cfg = OmegaConf.load(ckpt_cfg_path) - # Merge checkpoint config with test config by overriding specified nodes. - # ckpt_cfg = OmegaConf.masked_copy(ckpt_cfg, ['model', 'data', 'trainer', 'task']) - # ckpt_cfg.data = OmegaConf.masked_copy(ckpt_cfg.data, [ - # key for key in ckpt_cfg.data.keys() if key not in ['data_file', 'split', 'train_val_test_split'] - # ]) - # - # config = OmegaConf.merge(ckpt_cfg, config) - - # config = OmegaConf.masked_copy(config, - # [key for key in config if key not in - # ['task']]) - # config.data = OmegaConf.masked_copy(config.data, - # [key for key in config.data if key not in - # ['drug_featurizer', 'protein_featurizer', 'collator']]) - # config.model = OmegaConf.masked_copy(config.model, - # [key for key in config.model if key not in - # ['predictor']]) - # - # config = OmegaConf.merge(ckpt_cfg, config) - - ckpt_cfg = OmegaConf.masked_copy(ckpt_cfg, ['model', 'data', 'task', 'seed']) - ckpt_cfg.data = OmegaConf.masked_copy(ckpt_cfg.data, [ - key for key in ckpt_cfg.data.keys() if key not in ['data_file', 'split', 'train_val_test_split'] - ]) - ckpt_override_keys = ['task', 'data.drug_featurizer', 'data.protein_featurizer', 'data.collator', - 'model.predictor', 'model.out', 'model.loss', 'model.activation', 'model.metrics'] - - for key in ckpt_override_keys: - OmegaConf.update(config, key, OmegaConf.select(ckpt_cfg, key), force_add=True) - - config = OmegaConf.merge(ckpt_cfg, config) - - # OmegaConf.set_readonly(hydra_cfg, False) - # hydra_cfg.job.override_dirname += f"ckpt={str(Path(*Path(config.ckpt_path).parts[-4:]))}" + for key, value in ckpt_cfg.items(): + OmegaConf.update(config, key, value, merge=False, force_add=True) + + # Recompose merged config with overrides + if hydra_cfg.overrides.get('task'): + parser = OverridesParser.create() + parsed_overrides = parser.parse_overrides(overrides=hydra_cfg.overrides.task) + + filtered_overrides = [] + for override in parsed_overrides: + if override.is_force_add() or override.key_or_group.split('.')[0] in config: + filtered_overrides.append(override) + + ConfigLoaderImpl._apply_overrides_to_config(filtered_overrides, config) + _save_config(config, "config.yaml", hydra_output) return config - diff --git a/deepscreen/utils/hydra.py.bak b/deepscreen/utils/hydra.py.bak new file mode 100644 index 0000000000000000000000000000000000000000..5c1ec3b24b5128e2755086279f7ed5b89381132f --- /dev/null +++ b/deepscreen/utils/hydra.py.bak @@ -0,0 +1,182 @@ +from datetime import datetime +from pathlib import Path +import re +from typing import Any, Tuple + +import pandas as pd +from hydra import TaskFunction +from hydra.core.hydra_config import HydraConfig +from hydra.core.override_parser.overrides_parser import OverridesParser +from hydra.core.utils import _save_config +from hydra.experimental.callbacks import Callback +from hydra.types import RunMode +from hydra._internal.config_loader_impl import ConfigLoaderImpl +from omegaconf import DictConfig, OmegaConf +from omegaconf.errors import MissingMandatoryValue + +from deepscreen.utils import get_logger + +log = get_logger(__name__) + + +class CSVExperimentSummary(Callback): + """On multirun end, aggregate the results from each job's metrics.csv and save them in metrics_summary.csv.""" + + def __init__(self, filename: str = 'experiment_summary.csv', prefix: str | Tuple[str] = 'test/'): + self.filename = filename + self.prefix = prefix if isinstance(prefix, str) else tuple(prefix) + self.input_experiment_summary = None + self.time = {} + + def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None: + if config.hydra.get('overrides') and config.hydra.overrides.get('task'): + for i, override in enumerate(config.hydra.overrides.task): + if override.startswith("ckpt_path"): + ckpt_path = override.split('=', 1)[1] + if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')): + config.hydra.overrides.task[i] = self.parse_ckpt_path_from_experiment_summary(ckpt_path) + log.info(ckpt_path) + break + if config.hydra.sweeper.get('params'): + if config.hydra.sweeper.params.get('ckpt_path'): + ckpt_path = str(config.hydra.sweeper.params.ckpt_path).strip("'\"") + if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')): + config.hydra.sweeper.params.ckpt_path = self.parse_ckpt_path_from_experiment_summary(ckpt_path) + log.info(ckpt_path) + def on_job_start(self, config: DictConfig, *, task_function: TaskFunction, **kwargs: Any) -> None: + self.time['start'] = datetime.now() + + def on_job_end(self, config: DictConfig, job_return, **kwargs: Any) -> None: + # Skip callback if job is DDP subprocess + if "ddp" in job_return.hydra_cfg.hydra.job.name: + return + + try: + self.time['end'] = datetime.now() + if config.hydra.mode == RunMode.RUN: + summary_file_path = Path(config.hydra.run.dir) / self.filename + elif config.hydra.mode == RunMode.MULTIRUN: + summary_file_path = Path(config.hydra.sweep.dir) / self.filename + else: + raise RuntimeError('Invalid Hydra `RunMode`.') + + if summary_file_path.is_file(): + summary_df = pd.read_csv(summary_file_path) + else: + summary_df = pd.DataFrame() + + # Add job and override info + info_dict = {} + if job_return.overrides: + info_dict = dict(override.split('=', 1) for override in job_return.overrides) + info_dict['job_status'] = job_return.status.name + info_dict['job_id'] = job_return.hydra_cfg.hydra.job.id + info_dict['wall_time'] = str(self.time['end'] - self.time['start']) + + # Add checkpoint info + if info_dict.get('ckpt_path'): + info_dict['ckpt_path'] = str(info_dict['ckpt_path']).strip("'\"") + + ckpt_path = str(job_return.cfg.ckpt_path).strip("'\"") + if Path(ckpt_path).is_file(): + if info_dict.get('ckpt_path') and ckpt_path != info_dict['ckpt_path']: + info_dict['previous_ckpt_path'] = info_dict['ckpt_path'] + info_dict['ckpt_path'] = ckpt_path + if info_dict.get('ckpt_path'): + info_dict['best_epoch'] = int(re.search(r'epoch_(\d+)', info_dict['ckpt_path']).group(1)) + + # Add metrics info + metrics_df = pd.DataFrame() + if config.get('logger'): + output_dir = Path(config.hydra.runtime.output_dir).resolve() + csv_metrics_path = output_dir / config.logger.csv.name / "metrics.csv" + if csv_metrics_path.is_file(): + log.info(f"Summarizing metrics with prefix `{self.prefix}` from {csv_metrics_path}") + metrics_df = pd.read_csv(csv_metrics_path) + # Find rows where 'test/' columns are not null and reset its epoch to the best model epoch + test_columns = [col for col in metrics_df.columns if col.startswith('test/')] + if test_columns: + mask = metrics_df[test_columns].notna().any(axis=1) + metrics_df.loc[mask, 'epoch'] = info_dict['best_epoch'] + # Group and filter by best epoch + metrics_df = metrics_df.groupby('epoch').first() + metrics_df = metrics_df[metrics_df.index == info_dict['best_epoch']] + else: + log.info(f"No metrics.csv found in {output_dir}") + + if metrics_df.empty: + metrics_df = pd.DataFrame(data=info_dict, index=[0]) + else: + metrics_df = metrics_df.assign(**info_dict) + metrics_df.index = [0] + + # Add extra info from the input batch experiment summary + if self.input_experiment_summary is not None and 'ckpt_path' in metrics_df.columns: + log.info(self.input_experiment_summary['ckpt_path']) + log.info(metrics_df['ckpt_path']) + orig_meta = self.input_experiment_summary[ + self.input_experiment_summary['ckpt_path'] == metrics_df['ckpt_path'][0] + ].head(1) + if not orig_meta.empty: + orig_meta.index = [0] + metrics_df = metrics_df.astype('O').combine_first(orig_meta.astype('O')) + + summary_df = pd.concat([summary_df, metrics_df]) + + # Drop empty columns + summary_df.dropna(inplace=True, axis=1, how='all') + summary_df.to_csv(summary_file_path, index=False, mode='w') + log.info(f"Experiment summary saved to {summary_file_path}") + except Exception as e: + log.exception("Unable to save the experiment summary due to an error.", exc_info=e) + + def parse_ckpt_path_from_experiment_summary(self, ckpt_path): + log.info(ckpt_path) + try: + self.input_experiment_summary = pd.read_csv( + ckpt_path, usecols=lambda col: not col.startswith(self.prefix) + ) + self.input_experiment_summary['ckpt_path'] = self.input_experiment_summary['ckpt_path'].apply( + lambda x: x.strip("'\"") + ) + ckpt_list = list(set(self.input_experiment_summary['ckpt_path'])) + parsed_ckpt_path = ','.join([f"'{ckpt}'" for ckpt in ckpt_list]) + return parsed_ckpt_path + + except Exception as e: + log.exception( + f'Error in parsing checkpoint paths from experiment_summary file ({ckpt_path}).', + exc_info=e + ) + + +def checkpoint_rerun_config(config: DictConfig): + hydra_cfg = HydraConfig.get() + + if hydra_cfg.get('output_subdir'): + ckpt_cfg_path = Path(config.ckpt_path).parents[1] / hydra_cfg.output_subdir / 'config.yaml' + hydra_output = Path(hydra_cfg.runtime.output_dir) / hydra_cfg.output_subdir + + if ckpt_cfg_path.is_file(): + log.info(f"Found config file for the checkpoint at {str(ckpt_cfg_path)}; " + f"merging config overrides with checkpoint config...") + ckpt_cfg = OmegaConf.load(ckpt_cfg_path) + + # Recompose checkpoint config with overrides + + if hydra_cfg.overrides.get('task'): + parser = OverridesParser.create() + parsed_overrides = parser.parse_overrides(overrides=hydra_cfg.overrides.task) + filtered_overrides = [] + for override in parsed_overrides: + if not override.is_force_add(): + OmegaConf.update(ckpt_cfg, override.key_or_group, override.value()) + filtered_overrides.append(override) + log.info(filtered_overrides) + ConfigLoaderImpl._apply_overrides_to_config(filtered_overrides, config) + + _save_config(config, "config.yaml", hydra_output) + + return config + +