Spaces:
Running
Running
| """Inference utilities.""" | |
| import logging | |
| import torch | |
| import numpy as np | |
| from paccmann_predictor.models.paccmann import MCA | |
| from pytoda.transforms import Compose | |
| from pytoda.smiles.transforms import ToTensor | |
| from configuration import ( | |
| MODEL_WEIGHTS_URI, | |
| MODEL_PARAMS, | |
| SMILES_LANGUAGE, | |
| SMILES_TRANSFORMS, | |
| ) | |
| logger = logging.getLogger("openapi_server:inference") | |
| # NOTE: to avoid segfaults | |
| torch.set_num_threads(1) | |
| def predict( | |
| smiles: str, gene_expression: np.ndarray, estimate_confidence: bool = False | |
| ) -> dict: | |
| """ | |
| Run PaccMann prediction. | |
| Args: | |
| smiles (str): SMILES representing a compound. | |
| gene_expression (np.ndarray): gene expression data. | |
| estimate_confidence (bool, optional): estimate confidence of the | |
| prediction. Defaults to False. | |
| Returns: | |
| dict: the prediction dictionaty from the model. | |
| """ | |
| logger.debug("running predict.") | |
| logger.debug("gene expression shape: {}.".format(gene_expression.shape)) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.debug("device selected: {}.".format(device)) | |
| logger.debug("loading model for prediction.") | |
| model = MCA(MODEL_PARAMS) | |
| model.load_state_dict(torch.load(MODEL_WEIGHTS_URI, map_location=device)) | |
| model.eval() | |
| if estimate_confidence: | |
| logger.debug("associating SMILES language for confidence estimates.") | |
| model._associate_language(SMILES_LANGUAGE) | |
| logger.debug("model loaded.") | |
| logger.debug("set up the transformation.") | |
| smiles_transform_fn = Compose(SMILES_TRANSFORMS + [ToTensor(device=device)]) | |
| logger.debug("starting the prediction.") | |
| with torch.no_grad(): | |
| _, prediction_dict = model( | |
| smiles_transform_fn(smiles).view(1, -1).repeat(gene_expression.shape[0], 1), | |
| torch.tensor(gene_expression).float(), | |
| confidence=estimate_confidence, | |
| ) | |
| logger.debug("successful prediction.") | |
| return prediction_dict | |