File size: 6,102 Bytes
05ca42f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from functools import partial
from typing import Optional, Sequence, Dict

from torch import nn, optim, Tensor
from lightning import LightningModule
from torchmetrics import Metric, MetricCollection


class DTILightningModule(LightningModule):
    """
        Drug Target Interaction Prediction

        optimizer: a partially or fully initialized instance of class torch.optim.Optimizer
        drug_encoder: a fully initialized instance of class torch.nn.Module
        protein_encoder: a fully initialized instance of class torch.nn.Module
        classifier: a fully initialized instance of class torch.nn.Module
        model: a fully initialized instance of class torch.nn.Module
        metrics: a list of fully initialized instances of class torchmetrics.Metric
    """
    def __init__(
            self,
            optimizer: optim.Optimizer,
            scheduler: Optional[optim.lr_scheduler | Dict],
            predictor: nn.Module,
            metrics: Optional[Dict[str, Metric]] = (),
            out: nn.Module = None,
            loss: nn.Module = None,
            activation: nn.Module = None,
    ):
        super().__init__()

        self.predictor = predictor
        self.out = out
        self.loss = loss
        self.activation = activation

        # Automatically averaged over batches:
        # Separate metric instances for train, val and test step to ensure a proper reduction over the epoch
        metrics = MetricCollection(dict(metrics))
        self.train_metrics = metrics.clone(prefix="train/")
        self.val_metrics = metrics.clone(prefix="val/")
        self.test_metrics = metrics.clone(prefix="test/")

        # allows access to init params with 'self.hparams' attribute and ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False,
                                  ignore=['predictor', 'out', 'loss', 'activation', 'metrics'])

    def setup(self, stage):
        match stage:
            case 'fit':
                dataloader = self.trainer.datamodule.train_dataloader()
            case 'validate':
                dataloader = self.trainer.datamodule.val_dataloader()
            case 'test':
                dataloader = self.trainer.datamodule.test_dataloader()
            case 'predict':
                dataloader = self.trainer.datamodule.predict_dataloader()
        dummy_batch = next(iter(dataloader))

        # for key, value in dummy_batch.items():
        #     if isinstance(value, Tensor):
        #         dummy_batch[key] = value.to(self.device)

        self.forward(dummy_batch)

    def forward(self, batch):
        output = self.predictor(batch['X1'], batch['X2'])
        target = batch.get('Y')
        indexes = batch.get('IDX')
        preds = None
        loss = None

        if isinstance(output, Tensor):
            output = self.out(output).squeeze(1)
            preds = self.activation(output)

        elif isinstance(output, Sequence):
            output = list(output)
            # If multi-objective, assume the zeroth element in `output` is main while the rest are auxiliary
            output[0] = self.out(output[0]).squeeze(1)
            # Downstream metrics evaluation only needs main-objective preds
            preds = self.activation(output[0])

        if target is not None:
            loss = self.loss(output, target.float())

        return preds, target, indexes, loss

    def training_step(self, batch, batch_idx):
        preds, target, indexes, loss = self.forward(batch)
        self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.train_metrics(preds=preds, target=target, indexes=indexes.long())
        self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

        return {'loss': loss, 'N': batch['N'], 'ID1': batch['ID1'], 'ID2': batch['ID2'], 'Y^': preds, 'Y': target}

    def on_train_epoch_end(self):
        pass

    def validation_step(self, batch, batch_idx):
        preds, target, indexes, loss = self.forward(batch)

        self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.val_metrics(preds=preds, target=target, indexes=indexes.long())
        self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

    def on_validation_epoch_end(self):
        pass

    def test_step(self, batch, batch_idx):
        preds, target, indexes, loss = self.forward(batch)

        self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.test_metrics(preds=preds, target=target, indexes=indexes.long())
        self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

        # return a dictionary for callbacks like BasePredictionWriter
        return {'N': batch['N'], 'ID1': batch['ID1'], 'ID2': batch['ID2'], 'Y^': preds, 'Y': target}

    def on_test_epoch_end(self):
        pass

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        preds, target, indexes = self.forward(batch)
        # return a dictionary for callbacks like BasePredictionWriter
        return {'N': batch['N'], 'ID1': batch['ID1'], 'ID2': batch['ID2'], 'Y^': preds}

    def configure_optimizers(self):
        optimizers_config = {'optimizer': self.hparams.optimizer(params=self.parameters())}
        if self.hparams.get('scheduler'):
            if isinstance(self.hparams.scheduler, partial):
                optimizers_config['lr_scheduler'] = {
                    "scheduler": self.hparams.scheduler(optimizer=optimizers_config['optimizer']),
                    "monitor": "val/loss",
                    "interval": "epoch",
                    "frequency": 1,
                }
            else:
                self.hparams.scheduler['scheduler'] = self.hparams.scheduler['scheduler'](
                    optimizer=optimizers_config['optimizer']
                )
                optimizers_config['lr_scheduler'] = dict(self.hparams.scheduler)
        return optimizers_config