File size: 5,772 Bytes
05ca42f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from typing import Optional, Sequence, Literal

from torch import nn, optim
from lightning import LightningModule
from torchmetrics import Metric, MetricCollection, MeanMetric


class DTILightningModule(LightningModule):
    """
        Drug Target Interaction Prediction

        optimizer: a partially or fully initialized instance of class torch.optim.Optimizer
        drug_encoder: a fully initialized instance of class torch.nn.Module
        protein_encoder: a fully initialized instance of class torch.nn.Module
        classifier: a fully initialized instance of class torch.nn.Module
        model: a fully initialized instance of class torch.nn.Module
        metrics: a list of fully initialized instances of class torchmetrics.Metric
    """
    def __init__(
            self,
            optimizer: optim.Optimizer,
            scheduler: optim.lr_scheduler,
            predictor: Optional[nn.Module],
            metrics: Optional[dict[Metric]] = (),
            out: nn.Module = None,
            loss: nn.Module = None,
            activation: nn.Module = None,
            ):
        super().__init__()

        self.predictor = predictor
        self.out = out
        self.loss = loss
        self.activation = activation

        # averaging loss over batches
        # use separate metric instances for train, val and test step to ensure a proper reduction over the epoch
        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()
        self.test_loss = MeanMetric()

        metrics = MetricCollection(dict(metrics))
        self.train_metrics = metrics.clone(prefix="train/")
        self.val_metrics = metrics.clone(prefix="val/")
        self.test_metrics = metrics.clone(prefix="test/")

        # allows access to init params with 'self.hparams' attribute and ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False,
                                  ignore=['predictor', 'out', 'loss', 'activation', 'metrics'])

    def forward(self, enc_drug, enc_protein):
        return self.predictor(enc_drug, enc_protein)

    def model_step(self, batch):
        # common step for training/validation/test
        enc_drug = batch['X1']
        enc_protein = batch['X2']
        target = batch['Y']
        indexes = batch['ID']

        preds = self.forward(enc_drug, enc_protein)

        if isinstance(preds, Sequence):
            # the first preds is the main task preds, with the others being auxiliary
            preds = list(preds)
            # metrics calculation only needs main task preds
            preds[0] = self.out(preds[0]).squeeze()
            loss = self.loss(preds, target)
            preds = self.activation(preds[0])
        else:
            preds = self.out(preds).squeeze()
            loss = self.loss(preds, target)
            preds = self.activation(preds)

        return loss, preds, target, indexes

    def training_step(self, batch, batch_idx):
        loss, preds, target, indexes = self.model_step(batch)

        self.train_loss(loss)
        self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.train_metrics(preds=preds, target=target.long(), indexes=indexes.long())
        self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True)

        return loss  # {"loss": loss, "preds": preds, "target": target}

    def on_train_epoch_end(self):
        # `outputs` is a list of dicts returned from `training_step()`
        pass

    def validation_step(self, batch, batch_idx):
        loss, preds, target, indexes = self.model_step(batch)

        self.val_loss(loss)
        self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.val_metrics(preds=preds, target=target.long(), indexes=indexes.long())
        self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True)

    def on_validation_epoch_end(self):
        pass

    def test_step(self, batch, batch_idx):
        loss, preds, target, indexes = self.model_step(batch)

        self.test_loss(loss)
        self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.test_metrics(preds=preds, target=target.long(), indexes=indexes.long())
        self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True)

        # return a dictionary for callbacks like BasePredictionWriter
        return {"Y_hat": preds, "Y": target, 'ID1': batch['ID1'], 'ID2': batch['ID2']}

    def on_test_epoch_end(self):
        pass

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        enc_drug = batch['X1']
        enc_protein = batch['X2']

        preds = self.forward(enc_drug, enc_protein)

        if isinstance(preds, Sequence):
            # the first prediction is the main task prediction, with the others being auxiliary
            preds = self.out(preds[0]).squeeze()
        else:
            preds = self.out(preds).squeeze()

        preds = self.activation(preds)
        # preds = squeeze(preds, dim=1)

        # return a dictionary for callbacks like BasePredictionWriter
        return {"Y_hat": preds, 'ID1': batch['ID1'], 'ID2': batch['ID2']}

    def configure_optimizers(self):
        optimizer = self.hparams.optimizer(params=self.parameters())
        if self.hparams.scheduler is not None:
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": "epoch",
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}