Spaces:
Sleeping
Sleeping
from typing import Optional, Sequence, Literal | |
from torch import nn, optim | |
from lightning import LightningModule | |
from torchmetrics import Metric, MetricCollection, MeanMetric | |
class DTILightningModule(LightningModule): | |
""" | |
Drug Target Interaction Prediction | |
optimizer: a partially or fully initialized instance of class torch.optim.Optimizer | |
drug_encoder: a fully initialized instance of class torch.nn.Module | |
protein_encoder: a fully initialized instance of class torch.nn.Module | |
classifier: a fully initialized instance of class torch.nn.Module | |
model: a fully initialized instance of class torch.nn.Module | |
metrics: a list of fully initialized instances of class torchmetrics.Metric | |
""" | |
def __init__( | |
self, | |
optimizer: optim.Optimizer, | |
scheduler: optim.lr_scheduler, | |
predictor: Optional[nn.Module], | |
metrics: Optional[dict[Metric]] = (), | |
out: nn.Module = None, | |
loss: nn.Module = None, | |
activation: nn.Module = None, | |
): | |
super().__init__() | |
self.predictor = predictor | |
self.out = out | |
self.loss = loss | |
self.activation = activation | |
# averaging loss over batches | |
# use separate metric instances for train, val and test step to ensure a proper reduction over the epoch | |
self.train_loss = MeanMetric() | |
self.val_loss = MeanMetric() | |
self.test_loss = MeanMetric() | |
metrics = MetricCollection(dict(metrics)) | |
self.train_metrics = metrics.clone(prefix="train/") | |
self.val_metrics = metrics.clone(prefix="val/") | |
self.test_metrics = metrics.clone(prefix="test/") | |
# allows access to init params with 'self.hparams' attribute and ensures init params will be stored in ckpt | |
self.save_hyperparameters(logger=False, | |
ignore=['predictor', 'out', 'loss', 'activation', 'metrics']) | |
def forward(self, enc_drug, enc_protein): | |
return self.predictor(enc_drug, enc_protein) | |
def model_step(self, batch): | |
# common step for training/validation/test | |
enc_drug = batch['X1'] | |
enc_protein = batch['X2'] | |
target = batch['Y'] | |
indexes = batch['ID'] | |
preds = self.forward(enc_drug, enc_protein) | |
if isinstance(preds, Sequence): | |
# the first preds is the main task preds, with the others being auxiliary | |
preds = list(preds) | |
# metrics calculation only needs main task preds | |
preds[0] = self.out(preds[0]).squeeze() | |
loss = self.loss(preds, target) | |
preds = self.activation(preds[0]) | |
else: | |
preds = self.out(preds).squeeze() | |
loss = self.loss(preds, target) | |
preds = self.activation(preds) | |
return loss, preds, target, indexes | |
def training_step(self, batch, batch_idx): | |
loss, preds, target, indexes = self.model_step(batch) | |
self.train_loss(loss) | |
self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True) | |
self.train_metrics(preds=preds, target=target.long(), indexes=indexes.long()) | |
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True) | |
return loss # {"loss": loss, "preds": preds, "target": target} | |
def on_train_epoch_end(self): | |
# `outputs` is a list of dicts returned from `training_step()` | |
pass | |
def validation_step(self, batch, batch_idx): | |
loss, preds, target, indexes = self.model_step(batch) | |
self.val_loss(loss) | |
self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) | |
self.val_metrics(preds=preds, target=target.long(), indexes=indexes.long()) | |
self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True) | |
def on_validation_epoch_end(self): | |
pass | |
def test_step(self, batch, batch_idx): | |
loss, preds, target, indexes = self.model_step(batch) | |
self.test_loss(loss) | |
self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True) | |
self.test_metrics(preds=preds, target=target.long(), indexes=indexes.long()) | |
self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True) | |
# return a dictionary for callbacks like BasePredictionWriter | |
return {"Y_hat": preds, "Y": target, 'ID1': batch['ID1'], 'ID2': batch['ID2']} | |
def on_test_epoch_end(self): | |
pass | |
def predict_step(self, batch, batch_idx, dataloader_idx=0): | |
enc_drug = batch['X1'] | |
enc_protein = batch['X2'] | |
preds = self.forward(enc_drug, enc_protein) | |
if isinstance(preds, Sequence): | |
# the first prediction is the main task prediction, with the others being auxiliary | |
preds = self.out(preds[0]).squeeze() | |
else: | |
preds = self.out(preds).squeeze() | |
preds = self.activation(preds) | |
# preds = squeeze(preds, dim=1) | |
# return a dictionary for callbacks like BasePredictionWriter | |
return {"Y_hat": preds, 'ID1': batch['ID1'], 'ID2': batch['ID2']} | |
def configure_optimizers(self): | |
optimizer = self.hparams.optimizer(params=self.parameters()) | |
if self.hparams.scheduler is not None: | |
scheduler = self.hparams.scheduler(optimizer=optimizer) | |
return { | |
"optimizer": optimizer, | |
"lr_scheduler": { | |
"scheduler": scheduler, | |
"monitor": "val/loss", | |
"interval": "epoch", | |
"frequency": 1, | |
}, | |
} | |
return {"optimizer": optimizer} | |