|
import streamlit as st |
|
import os |
|
import torch |
|
|
|
import numpy as np |
|
|
|
|
|
from AtomLenz import * |
|
|
|
from Object_Smiles import Objects_Smiles |
|
|
|
|
|
from robust_detection import utils |
|
from robust_detection.models.rcnn import RCNN |
|
from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN, COCO_RCNN |
|
|
|
import pytorch_lightning as pl |
|
from pytorch_lightning.loggers import WandbLogger |
|
from pytorch_lightning.loggers import CSVLogger |
|
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint |
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
from rdkit import Chem |
|
from rdkit.Chem import AllChem |
|
from rdkit import DataStructs |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
|
|
def main_page(top_n, model_path): |
|
st.markdown( |
|
"""test """ |
|
) |
|
|
|
|
|
|
|
page_names_to_funcs = { |
|
|
|
|
|
"About AtomLenz": main_page, |
|
|
|
} |
|
|
|
selected_page = st.sidebar.selectbox("What would you like to retrieve?", page_names_to_funcs.keys()) |
|
st.sidebar.markdown('') |
|
|
|
|
|
selected_model = st.sidebar.selectbox( |
|
"Select a AtomLenz model to load", |
|
("AtomLenz trained on synthetic data (default)", "AtomLenz for hand-drawn images", "ChemExpert (not available yet)")) |
|
|
|
model_dict = { |
|
"AtomLenz trained on synthetic data (default)" : "atomlenz_default.pt", |
|
"AtomLenz for hand-drawn images" : "atomlenz_handdrawn.pt", |
|
"ChemExpert (not available yet)" : "atomlenz_default.pt" |
|
|
|
} |
|
|
|
model_file = model_dict[selected_model] |
|
model_path = os.path.join(datapath, model_file) |
|
|
|
if model_path.endswith("320).pt"): |
|
image_resolution = 320 |
|
else: |
|
image_resolution = 520 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colors = ["magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum","magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum"] |
|
def plot_bbox(bbox_XYXY, label): |
|
xmin, ymin, xmax, ymax =bbox_XYXY |
|
plt.plot( |
|
[xmin, xmin, xmax, xmax, xmin], |
|
[ymin, ymax, ymax, ymin, ymin], |
|
color=colors[label], |
|
label=str(label)) |
|
|
|
model_cls = RCNN |
|
experiment_path_atoms="./models/atoms_model/" |
|
dir_list = os.listdir(experiment_path_atoms) |
|
dir_list = [os.path.join(experiment_path_atoms,f) for f in dir_list] |
|
dir_list.sort(key=os.path.getctime, reverse=True) |
|
checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0] |
|
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms) |
|
model_atom.model.roi_heads.score_thresh = 0.65 |
|
experiment_path_bonds = "./models/bonds_model/" |
|
dir_list = os.listdir(experiment_path_bonds) |
|
dir_list = [os.path.join(experiment_path_bonds,f) for f in dir_list] |
|
dir_list.sort(key=os.path.getctime, reverse=True) |
|
checkpoint_file_bonds = [f for f in dir_list if "ckpt" in f][0] |
|
model_bond = model_cls.load_from_checkpoint(checkpoint_file_bonds) |
|
model_bond.model.roi_heads.score_thresh = 0.65 |
|
experiment_path_stereo = "./models/stereos_model/" |
|
dir_list = os.listdir(experiment_path_stereo) |
|
dir_list = [os.path.join(experiment_path_stereo,f) for f in dir_list] |
|
dir_list.sort(key=os.path.getctime, reverse=True) |
|
checkpoint_file_stereo = [f for f in dir_list if "ckpt" in f][0] |
|
model_stereo = model_cls.load_from_checkpoint(checkpoint_file_stereo) |
|
model_stereo.model.roi_heads.score_thresh = 0.65 |
|
experiment_path_charges = "./models/charges_model/" |
|
dir_list = os.listdir(experiment_path_charges) |
|
dir_list = [os.path.join(experiment_path_charges,f) for f in dir_list] |
|
dir_list.sort(key=os.path.getctime, reverse=True) |
|
checkpoint_file_charges = [f for f in dir_list if "ckpt" in f][0] |
|
model_charge = model_cls.load_from_checkpoint(checkpoint_file_charges) |
|
model_charge.model.roi_heads.score_thresh = 0.65 |
|
|
|
data_cls = Objects_Smiles |
|
dataset = data_cls(data_path="./uploads/", batch_size=1) |
|
|
|
st.title("Atom Level Entity Detector") |
|
|
|
image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png']) |
|
|
|
if image_file is not None: |
|
|
|
|
|
image = Image.open(image_file) |
|
|
|
st.image(image, use_column_width=True) |
|
col1, col2 = st.columns(2) |
|
if not os.path.exists("uploads/images"): |
|
os.makedirs("uploads/images") |
|
with open(os.path.join("uploads/images/","0.png"),"wb") as f: |
|
f.write(image_file.getbuffer()) |
|
|
|
dataset.prepare_data() |
|
trainer = pl.Trainer(logger=False) |
|
st.toast('Predicting atoms,bonds,charges,..., please wait') |
|
atom_preds = trainer.predict(model_atom, dataset.test_dataloader()) |
|
bond_preds = trainer.predict(model_bond, dataset.test_dataloader()) |
|
stereo_preds = trainer.predict(model_stereo, dataset.test_dataloader()) |
|
charges_preds = trainer.predict(model_charge, dataset.test_dataloader()) |
|
st.toast('Done') |
|
|
|
plt.imshow(image, cmap="gray") |
|
for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]): |
|
|
|
|
|
plot_bbox(bbox, label) |
|
plt.axis('off') |
|
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0) |
|
image_vis = Image.open("example_image.png") |
|
col1.image(image_vis, use_column_width=True) |
|
plt.clf() |
|
plt.imshow(image, cmap="gray") |
|
for bbox, label in zip(bond_preds[0]['boxes'][0], bond_preds[0]['preds'][0]): |
|
|
|
|
|
plot_bbox(bbox, label) |
|
plt.axis('off') |
|
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0) |
|
image_vis = Image.open("example_image.png") |
|
col2.image(image_vis, use_column_width=True) |
|
mol_graphs = [] |
|
count_bonds_preds = np.zeros(4) |
|
count_atoms_preds = np.zeros(15) |
|
correct=0 |
|
correct_objects=0 |
|
correct_both=0 |
|
predictions=0 |
|
tanimoto_dists=[] |
|
predictions_list = [] |
|
for image_idx, bonds in enumerate(bond_preds): |
|
count_bonds_preds = np.zeros(8) |
|
count_atoms_preds = np.zeros(18) |
|
atom_boxes = atom_preds[image_idx]['boxes'][0] |
|
atom_labels = atom_preds[image_idx]['preds'][0] |
|
atom_scores = atom_preds[image_idx]['scores'][0] |
|
charge_boxes = charges_preds[image_idx]['boxes'][0] |
|
charge_labels = charges_preds[image_idx]['preds'][0] |
|
charge_mask=torch.where(charge_labels>1) |
|
filtered_ch_labels=charge_labels[charge_mask] |
|
filtered_ch_boxes=charge_boxes[charge_mask] |
|
|
|
filtered_bboxes, filtered_labels = iou_filter_bboxes(atom_boxes, atom_labels, atom_scores) |
|
|
|
|
|
|
|
mol_graph = np.zeros((len(filtered_bboxes),len(filtered_bboxes))) |
|
stereo_atoms = np.zeros(len(filtered_bboxes)) |
|
charge_atoms = np.ones(len(filtered_bboxes)) |
|
for index,box_atom in enumerate(filtered_bboxes): |
|
for box_charge,label_charge in zip(filtered_ch_boxes,filtered_ch_labels): |
|
if bb_box_intersects(box_atom,box_charge) == 1: |
|
charge_atoms[index]=label_charge |
|
|
|
for bond_idx, bond_box in enumerate(bonds['boxes'][0]): |
|
label_bond = bonds['preds'][0][bond_idx] |
|
if label_bond > 1: |
|
try: |
|
count_bonds_preds[label_bond] += 1 |
|
except: |
|
count_bonds_preds=count_bonds_preds |
|
|
|
result = [] |
|
limit = 0 |
|
|
|
while result.count(1) < 2 and limit < 80: |
|
result=[] |
|
bigger_bond_box = [bond_box[0]-limit,bond_box[1]-limit,bond_box[2]+limit,bond_box[3]+limit] |
|
for atom_box in filtered_bboxes: |
|
result.append(bb_box_intersects(atom_box,bigger_bond_box)) |
|
limit+=5 |
|
indices = [i for i, x in enumerate(result) if x == 1] |
|
if len(indices) == 2: |
|
|
|
mol_graph[indices[0],indices[1]]=label_bond |
|
mol_graph[indices[1],indices[0]]=label_bond |
|
if len(indices) > 2: |
|
|
|
cand_bboxes = filtered_bboxes[indices,:] |
|
cand_indices = dist_filter_bboxes(cand_bboxes) |
|
|
|
mol_graph[indices[cand_indices[0]],indices[cand_indices[1]]]=label_bond |
|
mol_graph[indices[cand_indices[1]],indices[cand_indices[0]]]=label_bond |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stereo_bonds = np.where(mol_graph>4, True, False) |
|
if np.any(stereo_bonds): |
|
stereo_boxes = stereo_preds[image_idx]['boxes'][0] |
|
stereo_labels= stereo_preds[image_idx]['preds'][0] |
|
for stereo_box in stereo_boxes: |
|
result=[] |
|
for atom_box in filtered_bboxes: |
|
result.append(bb_box_intersects(atom_box,stereo_box)) |
|
indices = [i for i, x in enumerate(result) if x == 1] |
|
if len(indices) == 1: |
|
stereo_atoms[indices[0]]=1 |
|
|
|
molecule = dict() |
|
molecule['graph'] = mol_graph |
|
|
|
molecule['atom_labels'] = filtered_labels |
|
molecule['atom_boxes'] = filtered_bboxes |
|
molecule['stereo_atoms'] = stereo_atoms |
|
molecule['charge_atoms'] = charge_atoms |
|
mol_graphs.append(molecule) |
|
|
|
|
|
|
|
|
|
|
|
|
|
save_mol_to_file(molecule,'molfile') |
|
mol = Chem.MolFromMolFile('molfile',sanitize=False) |
|
problematic = 0 |
|
try: |
|
problems = Chem.DetectChemistryProblems(mol) |
|
if len(problems) > 0: |
|
mol = solve_mol_problems(mol,problems) |
|
problematic = 1 |
|
|
|
try: |
|
Chem.SanitizeMol(mol) |
|
except: |
|
problems = Chem.DetectChemistryProblems(mol) |
|
if len(problems) > 0: |
|
mol = solve_mol_problems(mol,problems) |
|
try: |
|
Chem.SanitizeMol(mol) |
|
except: |
|
pass |
|
except: |
|
problematic = 1 |
|
try: |
|
pred_smiles = Chem.MolToSmiles(mol) |
|
except: |
|
pred_smiles = "" |
|
problematic = 1 |
|
predictions+=1 |
|
predictions_list.append([image_idx,pred_smiles,problematic]) |
|
|
|
file_preds = open('preds_atomlenz','w') |
|
for pred in predictions_list: |
|
print(pred) |
|
|
|
|
|
|