import os import hydra import lightning as L import numpy as np import omegaconf import pandas as pd import rdkit import rich.syntax import rich.tree import torch from tqdm.auto import tqdm import esm import pdb import dataloader import diffusion from models.classifier import muPPIt rdkit.rdBase.DisableLog('rdApp.error') omegaconf.OmegaConf.register_new_resolver( 'cwd', os.getcwd) omegaconf.OmegaConf.register_new_resolver( 'device_count', torch.cuda.device_count) omegaconf.OmegaConf.register_new_resolver( 'eval', eval) omegaconf.OmegaConf.register_new_resolver( 'div_up', lambda x, y: (x + y - 1) // y) omegaconf.OmegaConf.register_new_resolver( 'if_then_else', lambda condition, x, y: x if condition else y ) vhse8_values = { 'A': [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48], 'R': [-1.47, 1.45, 1.24, 1.27, 1.55, 1.47, 1.30, 0.83], 'N': [-0.99, 0.00, 0.69, -0.37, -0.55, 0.85, 0.73, -0.80], 'D': [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56], 'C': [0.18, -1.67, -0.21, 0.00, 1.20, -1.61, -0.19, -0.41], 'Q': [-0.96, 0.12, 0.18, 0.16, 0.09, 0.42, -0.20, -0.41], 'E': [-1.18, 0.40, 0.10, 0.36, -2.16, -0.17, 0.91, 0.36], 'G': [-0.20, -1.53, -2.63, 2.28, -0.53, -1.18, -1.34, 1.10], 'H': [-0.43, -0.25, 0.37, 0.19, 0.51, 1.28, 0.93, 0.65], 'I': [1.27, 0.14, 0.30, -1.80, 0.30, -1.61, -0.16, -0.13], 'L': [1.36, 0.07, 0.26, -0.80, 0.22, -1.37, 0.08, -0.62], 'K': [-1.17, 0.70, 0.80, 1.64, 0.67, 1.63, 0.13, -0.01], 'M': [1.01, -0.53, 0.43, 0.00, 0.23, 0.10, -0.86, -0.68], 'F': [1.52, 0.61, 0.95, -0.16, 0.25, 0.28, -1.33, -0.65], 'P': [0.22, -0.17, -0.50, -0.05, 0.01, -1.34, 0.19, 3.56], 'S': [-0.67, -0.86, -1.07, -0.41, -0.32, 0.27, -0.64, 0.11], 'T': [-0.34, -0.51, -0.55, -1.06, 0.01, -0.01, -0.79, 0.39], 'W': [1.50, 2.06, 1.79, 0.75, 0.75, 0.13, -1.06, -0.85], 'Y': [0.61, 1.60, 1.17, 0.73, 0.53, 0.25, -0.96, -0.52], 'V': [0.76, -0.92, 0.17, -1.91, 0.22, -1.40, -0.24, -0.03], } aa_to_idx = {'A': 5, 'R': 10, 'N': 17, 'D': 13, 'C': 23, 'Q': 16, 'E': 9, 'G': 6, 'H': 21, 'I': 12, 'L': 4, 'K': 15, 'M': 20, 'F': 18, 'P': 14, 'S': 8, 'T': 11, 'W': 22, 'Y': 19, 'V': 7} vhse8_tensor = torch.zeros(24, 8) for aa, values in vhse8_values.items(): aa_index = aa_to_idx[aa] vhse8_tensor[aa_index] = torch.tensor(values) vhse8_tensor.requires_grad = False esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() esm_model.eval() def precompute_embedding(sequence, tokenizer): tokens = tokenizer(sequence, return_tensors='pt')['input_ids'] with torch.no_grad(): embed = esm_model(tokens, repr_layers=[33], return_contacts=False)["representations"][33] vhse8_embed = vhse8_tensor[tokens] return torch.concat([embed, vhse8_embed], dim=-1) @hydra.main(version_base=None, config_path='./configs', config_name='config') def main(config: omegaconf.DictConfig) -> None: # Reproducibility L.seed_everything(config.seed) os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' torch.use_deterministic_algorithms(True) torch.backends.cudnn.benchmark = False # _print_config(config, resolve=True) print(f"Checkpoint: {config.eval.checkpoint_path}") tokenizer = dataloader.get_tokenizer(config) pretrained = diffusion.Diffusion.load_from_checkpoint( config.eval.checkpoint_path, tokenizer=tokenizer, config=config, logger=False) pretrained.eval() muppit = muPPIt(d_node=1288, d_k=32, d_v=32, n_heads=4, lr=None) muppit.load_state_dict(torch.load(config.guidance.classifier_checkpoint_path)) muppit.eval() mut_embed = precompute_embedding(config.eval.mutant, tokenizer) wt_embed = precompute_embedding(config.eval.wildtype, tokenizer) samples = [] for _ in tqdm( range(config.sampling.num_sample_batches), desc='Gen. batches', leave=False): sample = pretrained.sample( wt_embed = wt_embed, mut_embed = mut_embed, classifier_model = muppit ) samples.extend( pretrained.tokenizer.batch_decode(sample)) print('\n') print([sample.replace(' ', '')[5:-5] for sample in samples]) samples = [sample.replace(' ', '')[5:-5] for sample in samples] print('\n') print(samples) if __name__ == '__main__': main()