libokj's picture
Upload 358 files
05ca42f
raw
history blame
5.77 kB
from typing import Optional, Sequence, Literal
from torch import nn, optim
from lightning import LightningModule
from torchmetrics import Metric, MetricCollection, MeanMetric
class DTILightningModule(LightningModule):
"""
Drug Target Interaction Prediction
optimizer: a partially or fully initialized instance of class torch.optim.Optimizer
drug_encoder: a fully initialized instance of class torch.nn.Module
protein_encoder: a fully initialized instance of class torch.nn.Module
classifier: a fully initialized instance of class torch.nn.Module
model: a fully initialized instance of class torch.nn.Module
metrics: a list of fully initialized instances of class torchmetrics.Metric
"""
def __init__(
self,
optimizer: optim.Optimizer,
scheduler: optim.lr_scheduler,
predictor: Optional[nn.Module],
metrics: Optional[dict[Metric]] = (),
out: nn.Module = None,
loss: nn.Module = None,
activation: nn.Module = None,
):
super().__init__()
self.predictor = predictor
self.out = out
self.loss = loss
self.activation = activation
# averaging loss over batches
# use separate metric instances for train, val and test step to ensure a proper reduction over the epoch
self.train_loss = MeanMetric()
self.val_loss = MeanMetric()
self.test_loss = MeanMetric()
metrics = MetricCollection(dict(metrics))
self.train_metrics = metrics.clone(prefix="train/")
self.val_metrics = metrics.clone(prefix="val/")
self.test_metrics = metrics.clone(prefix="test/")
# allows access to init params with 'self.hparams' attribute and ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False,
ignore=['predictor', 'out', 'loss', 'activation', 'metrics'])
def forward(self, enc_drug, enc_protein):
return self.predictor(enc_drug, enc_protein)
def model_step(self, batch):
# common step for training/validation/test
enc_drug = batch['X1']
enc_protein = batch['X2']
target = batch['Y']
indexes = batch['ID']
preds = self.forward(enc_drug, enc_protein)
if isinstance(preds, Sequence):
# the first preds is the main task preds, with the others being auxiliary
preds = list(preds)
# metrics calculation only needs main task preds
preds[0] = self.out(preds[0]).squeeze()
loss = self.loss(preds, target)
preds = self.activation(preds[0])
else:
preds = self.out(preds).squeeze()
loss = self.loss(preds, target)
preds = self.activation(preds)
return loss, preds, target, indexes
def training_step(self, batch, batch_idx):
loss, preds, target, indexes = self.model_step(batch)
self.train_loss(loss)
self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
self.train_metrics(preds=preds, target=target.long(), indexes=indexes.long())
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True)
return loss # {"loss": loss, "preds": preds, "target": target}
def on_train_epoch_end(self):
# `outputs` is a list of dicts returned from `training_step()`
pass
def validation_step(self, batch, batch_idx):
loss, preds, target, indexes = self.model_step(batch)
self.val_loss(loss)
self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
self.val_metrics(preds=preds, target=target.long(), indexes=indexes.long())
self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True)
def on_validation_epoch_end(self):
pass
def test_step(self, batch, batch_idx):
loss, preds, target, indexes = self.model_step(batch)
self.test_loss(loss)
self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
self.test_metrics(preds=preds, target=target.long(), indexes=indexes.long())
self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True)
# return a dictionary for callbacks like BasePredictionWriter
return {"Y_hat": preds, "Y": target, 'ID1': batch['ID1'], 'ID2': batch['ID2']}
def on_test_epoch_end(self):
pass
def predict_step(self, batch, batch_idx, dataloader_idx=0):
enc_drug = batch['X1']
enc_protein = batch['X2']
preds = self.forward(enc_drug, enc_protein)
if isinstance(preds, Sequence):
# the first prediction is the main task prediction, with the others being auxiliary
preds = self.out(preds[0]).squeeze()
else:
preds = self.out(preds).squeeze()
preds = self.activation(preds)
# preds = squeeze(preds, dim=1)
# return a dictionary for callbacks like BasePredictionWriter
return {"Y_hat": preds, 'ID1': batch['ID1'], 'ID2': batch['ID2']}
def configure_optimizers(self):
optimizer = self.hparams.optimizer(params=self.parameters())
if self.hparams.scheduler is not None:
scheduler = self.hparams.scheduler(optimizer=optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val/loss",
"interval": "epoch",
"frequency": 1,
},
}
return {"optimizer": optimizer}