libokj's picture
Update deepscreen/models/dti.py
dda1dbd
from functools import partial
from typing import Optional, Sequence, Dict
from torch import nn, optim, Tensor
from lightning import LightningModule
from torchmetrics import Metric, MetricCollection
class DTILightningModule(LightningModule):
"""
Drug Target Interaction Prediction
optimizer: a partially or fully initialized instance of class torch.optim.Optimizer
drug_encoder: a fully initialized instance of class torch.nn.Module
protein_encoder: a fully initialized instance of class torch.nn.Module
classifier: a fully initialized instance of class torch.nn.Module
model: a fully initialized instance of class torch.nn.Module
metrics: a list of fully initialized instances of class torchmetrics.Metric
"""
extra_return_keys = ['ID1', 'X1', 'ID2', 'X2', 'N']
def __init__(
self,
optimizer: optim.Optimizer,
scheduler: Optional[optim.lr_scheduler | Dict],
predictor: nn.Module,
metrics: Optional[Dict[str, Metric]] = (),
out: nn.Module = None,
loss: nn.Module = None,
activation: nn.Module = None,
):
super().__init__()
self.predictor = predictor
self.out = out
self.loss = loss
self.activation = activation
# Automatically averaged over batches:
# Separate metric instances for train, val and test step to ensure a proper reduction over the epoch
metrics = MetricCollection(dict(metrics))
self.train_metrics = metrics.clone(prefix="train/")
self.val_metrics = metrics.clone(prefix="val/")
self.test_metrics = metrics.clone(prefix="test/")
# allows access to init params with 'self.hparams' attribute and ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False,
ignore=['predictor', 'out', 'loss', 'activation', 'metrics'])
def setup(self, stage):
match stage:
case 'fit':
dataloader = self.trainer.datamodule.train_dataloader()
dummy_batch = next(iter(dataloader))
self.forward(dummy_batch)
# case 'validate':
# dataloader = self.trainer.datamodule.val_dataloader()
# case 'test':
# dataloader = self.trainer.datamodule.test_dataloader()
# case 'predict':
# dataloader = self.trainer.datamodule.predict_dataloader()
# for key, value in dummy_batch.items():
# if isinstance(value, Tensor):
# dummy_batch[key] = value.to(self.device)
def forward(self, batch):
output = self.predictor(batch['X1^'], batch['X2^'])
target = batch.get('Y')
indexes = batch.get('ID^')
preds = None
loss = None
if isinstance(output, Tensor):
output = self.out(output).squeeze(1)
preds = self.activation(output)
elif isinstance(output, Sequence):
output = list(output)
# If multi-objective, assume the zeroth element in `output` is main while the rest are auxiliary
output[0] = self.out(output[0]).squeeze(1)
# Downstream metrics evaluation only needs main-objective preds
preds = self.activation(output[0])
if target is not None:
loss = self.loss(output, target.float())
return preds, target, indexes, loss
def training_step(self, batch, batch_idx):
preds, target, indexes, loss = self.forward(batch)
self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
self.train_metrics(preds=preds, target=target, indexes=indexes.long())
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
return_dict = {
'Y^': preds,
'Y': target,
'loss': loss
}
for key in self.extra_return_keys:
if key in batch:
return_dict[key] = batch[key]
return return_dict
def on_train_epoch_end(self):
pass
def validation_step(self, batch, batch_idx):
preds, target, indexes, loss = self.forward(batch)
self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
self.val_metrics(preds=preds, target=target, indexes=indexes.long())
self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
return_dict = {
'Y^': preds,
'Y': target,
'loss': loss
}
for key in self.extra_return_keys:
if key in batch:
return_dict[key] = batch[key]
return return_dict
def on_validation_epoch_end(self):
pass
def test_step(self, batch, batch_idx):
preds, target, indexes, loss = self.forward(batch)
self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
self.test_metrics(preds=preds, target=target, indexes=indexes.long())
self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
return_dict = {
'Y^': preds,
'Y': target,
'loss': loss
}
for key in self.extra_return_keys:
if key in batch:
return_dict[key] = batch[key]
return return_dict
def on_test_epoch_end(self):
pass
def predict_step(self, batch, batch_idx, dataloader_idx=0):
preds, _, _, _ = self.forward(batch)
# return a dictionary for callbacks like BasePredictionWriter
return_dict = {
'Y^': preds,
}
for key in self.extra_return_keys:
if key in batch:
return_dict[key] = batch[key]
return return_dict
def configure_optimizers(self):
optimizers_config = {'optimizer': self.hparams.optimizer(params=self.parameters())}
if self.hparams.get('scheduler'):
if isinstance(self.hparams.scheduler, partial):
optimizers_config['lr_scheduler'] = {
"scheduler": self.hparams.scheduler(optimizer=optimizers_config['optimizer']),
"monitor": "val/loss",
"interval": "epoch",
"frequency": 1,
}
else:
self.hparams.scheduler['scheduler'] = self.hparams.scheduler['scheduler'](
optimizer=optimizers_config['optimizer']
)
optimizers_config['lr_scheduler'] = dict(self.hparams.scheduler)
return optimizers_config