Spaces:
Sleeping
Sleeping
| import torch | |
| import argparse | |
| import selfies as sf | |
| from tqdm import tqdm | |
| from transformers import T5EncoderModel | |
| from transformers import set_seed | |
| from src.scripts.mytokenizers import Tokenizer | |
| from src.improved_diffusion import gaussian_diffusion as gd | |
| from src.improved_diffusion import dist_util, logger | |
| from src.improved_diffusion.respace import SpacedDiffusion | |
| from src.improved_diffusion.transformer_model import TransformerNetModel | |
| from src.improved_diffusion.script_util import ( | |
| model_and_diffusion_defaults, | |
| add_dict_to_argparser, | |
| ) | |
| from src.scripts.mydatasets import Lang2molDataset_submission | |
| import streamlit as st | |
| import os | |
| def get_encoder(): | |
| model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol") | |
| model.eval() | |
| return model | |
| def get_tokenizer(): | |
| return Tokenizer() | |
| def get_model(): | |
| model = TransformerNetModel( | |
| in_channels=32, | |
| model_channels=128, | |
| dropout=0.1, | |
| vocab_size=35073, | |
| hidden_size=1024, | |
| num_attention_heads=16, | |
| num_hidden_layers=12, | |
| ) | |
| model.load_state_dict( | |
| torch.load( | |
| os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"), | |
| map_location=torch.device("cpu"), | |
| ) | |
| ) | |
| model.eval() | |
| return model | |
| def get_diffusion(): | |
| return SpacedDiffusion( | |
| use_timesteps=[i for i in range(0, 2000, 10)], | |
| betas=gd.get_named_beta_schedule("sqrt", 2000), | |
| model_mean_type=(gd.ModelMeanType.START_X), | |
| model_var_type=((gd.ModelVarType.FIXED_LARGE)), | |
| loss_type=gd.LossType.E2E_MSE, | |
| rescale_timesteps=True, | |
| model_arch="transformer", | |
| training_mode="e2e", | |
| ) | |
| tokenizer = get_tokenizer() | |
| encoder = get_encoder() | |
| model = get_model() | |
| diffusion = get_diffusion() | |
| st.title("Lang2mol-Diff") | |
| text_input = st.text_area("Enter molecule description") | |
| button = st.button("Submit") | |
| if button: | |
| with st.spinner("Please wait..."): | |
| output = tokenizer( | |
| text_input, | |
| max_length=256, | |
| truncation=True, | |
| padding="max_length", | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ) | |
| caption_state = encoder( | |
| input_ids=output["input_ids"], | |
| attention_mask=output["attention_mask"], | |
| ).last_hidden_state | |
| caption_mask = output["attention_mask"] | |
| outputs = diffusion.p_sample_loop( | |
| model, | |
| (1, 256, 32), | |
| clip_denoised=False, | |
| denoised_fn=None, | |
| model_kwargs={}, | |
| top_p=1.0, | |
| progress=True, | |
| caption=(caption_state, caption_mask), | |
| ) | |
| logits = model.get_logits(torch.tensor(outputs)) | |
| cands = torch.topk(logits, k=1, dim=-1) | |
| outputs = cands.indices | |
| outputs = outputs.squeeze(-1) | |
| outputs = tokenizer.decode(outputs) | |
| result = sf.decoder( | |
| outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "") | |
| ).replace("\t", "") | |
| st.write(result) | |