atomlenz / app.py
moldenhof's picture
test with menu
5417a1b
raw
history blame
12 kB
import streamlit as st
import os
import torch
#import math
import numpy as np
#import matplotlib.pyplot as plt
#import pathlib
from AtomLenz import *
#from utils_graph import *
from Object_Smiles import Objects_Smiles
#from robust_detection import wandb_config
from robust_detection import utils
from robust_detection.models.rcnn import RCNN
from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN, COCO_RCNN
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from PIL import Image
import matplotlib.pyplot as plt
def main_page(top_n, model_path):
st.markdown(
"""test """
)
#### TRYOUT MENU #####
page_names_to_funcs = {
# "Microscopy images from a molecule": images_from_molecule,
# "Molecules from a microscopy image": molecules_from_image,
"About AtomLenz": main_page,
}
selected_page = st.sidebar.selectbox("What would you like to retrieve?", page_names_to_funcs.keys())
st.sidebar.markdown('')
selected_model = st.sidebar.selectbox(
"Select a AtomLenz model to load",
("AtomLenz trained on synthetic data (default)", "AtomLenz for hand-drawn images", "ChemExpert (not available yet)"))
model_dict = {
"AtomLenz trained on synthetic data (default)" : "atomlenz_default.pt",
"AtomLenz for hand-drawn images" : "atomlenz_handdrawn.pt",
"ChemExpert (not available yet)" : "atomlenz_default.pt"
}
model_file = model_dict[selected_model]
model_path = os.path.join(datapath, model_file)
if model_path.endswith("320).pt"):
image_resolution = 320
else:
image_resolution = 520
#page_names_to_funcs[selected_page](n_objects, model_path)
######################
colors = ["magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum","magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum"]
def plot_bbox(bbox_XYXY, label):
xmin, ymin, xmax, ymax =bbox_XYXY
plt.plot(
[xmin, xmin, xmax, xmax, xmin],
[ymin, ymax, ymax, ymin, ymin],
color=colors[label],
label=str(label))
model_cls = RCNN
experiment_path_atoms="./models/atoms_model/"
dir_list = os.listdir(experiment_path_atoms)
dir_list = [os.path.join(experiment_path_atoms,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
model_atom.model.roi_heads.score_thresh = 0.65
experiment_path_bonds = "./models/bonds_model/"
dir_list = os.listdir(experiment_path_bonds)
dir_list = [os.path.join(experiment_path_bonds,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_bonds = [f for f in dir_list if "ckpt" in f][0]
model_bond = model_cls.load_from_checkpoint(checkpoint_file_bonds)
model_bond.model.roi_heads.score_thresh = 0.65
experiment_path_stereo = "./models/stereos_model/"
dir_list = os.listdir(experiment_path_stereo)
dir_list = [os.path.join(experiment_path_stereo,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_stereo = [f for f in dir_list if "ckpt" in f][0]
model_stereo = model_cls.load_from_checkpoint(checkpoint_file_stereo)
model_stereo.model.roi_heads.score_thresh = 0.65
experiment_path_charges = "./models/charges_model/"
dir_list = os.listdir(experiment_path_charges)
dir_list = [os.path.join(experiment_path_charges,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_charges = [f for f in dir_list if "ckpt" in f][0]
model_charge = model_cls.load_from_checkpoint(checkpoint_file_charges)
model_charge.model.roi_heads.score_thresh = 0.65
data_cls = Objects_Smiles
dataset = data_cls(data_path="./uploads/", batch_size=1)
# dataset.prepare_data()
st.title("Atom Level Entity Detector")
image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png'])
#st.write('filename is', file_name)
if image_file is not None:
#col1, col2 = st.columns(2)
image = Image.open(image_file)
#col1.image(image, use_column_width=True)
st.image(image, use_column_width=True)
col1, col2 = st.columns(2)
if not os.path.exists("uploads/images"):
os.makedirs("uploads/images")
with open(os.path.join("uploads/images/","0.png"),"wb") as f:
f.write(image_file.getbuffer())
#st.success("Saved File")
dataset.prepare_data()
trainer = pl.Trainer(logger=False)
st.toast('Predicting atoms,bonds,charges,..., please wait')
atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
bond_preds = trainer.predict(model_bond, dataset.test_dataloader())
stereo_preds = trainer.predict(model_stereo, dataset.test_dataloader())
charges_preds = trainer.predict(model_charge, dataset.test_dataloader())
st.toast('Done')
#st.write(atom_preds)
plt.imshow(image, cmap="gray")
for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]):
# st.write(bbox)
# st.write(label)
plot_bbox(bbox, label)
plt.axis('off')
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
image_vis = Image.open("example_image.png")
col1.image(image_vis, use_column_width=True)
plt.clf()
plt.imshow(image, cmap="gray")
for bbox, label in zip(bond_preds[0]['boxes'][0], bond_preds[0]['preds'][0]):
# st.write(bbox)
# st.write(label)
plot_bbox(bbox, label)
plt.axis('off')
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
image_vis = Image.open("example_image.png")
col2.image(image_vis, use_column_width=True)
mol_graphs = []
count_bonds_preds = np.zeros(4)
count_atoms_preds = np.zeros(15)
correct=0
correct_objects=0
correct_both=0
predictions=0
tanimoto_dists=[]
predictions_list = []
for image_idx, bonds in enumerate(bond_preds):
count_bonds_preds = np.zeros(8)
count_atoms_preds = np.zeros(18)
atom_boxes = atom_preds[image_idx]['boxes'][0]
atom_labels = atom_preds[image_idx]['preds'][0]
atom_scores = atom_preds[image_idx]['scores'][0]
charge_boxes = charges_preds[image_idx]['boxes'][0]
charge_labels = charges_preds[image_idx]['preds'][0]
charge_mask=torch.where(charge_labels>1)
filtered_ch_labels=charge_labels[charge_mask]
filtered_ch_boxes=charge_boxes[charge_mask]
#import ipdb; ipdb.set_trace()
filtered_bboxes, filtered_labels = iou_filter_bboxes(atom_boxes, atom_labels, atom_scores)
#for atom_label in filtered_labels:
# count_atoms_preds[atom_label] += 1
#import ipdb; ipdb.set_trace()
mol_graph = np.zeros((len(filtered_bboxes),len(filtered_bboxes)))
stereo_atoms = np.zeros(len(filtered_bboxes))
charge_atoms = np.ones(len(filtered_bboxes))
for index,box_atom in enumerate(filtered_bboxes):
for box_charge,label_charge in zip(filtered_ch_boxes,filtered_ch_labels):
if bb_box_intersects(box_atom,box_charge) == 1:
charge_atoms[index]=label_charge
for bond_idx, bond_box in enumerate(bonds['boxes'][0]):
label_bond = bonds['preds'][0][bond_idx]
if label_bond > 1:
try:
count_bonds_preds[label_bond] += 1
except:
count_bonds_preds=count_bonds_preds
#import ipdb; ipdb.set_trace()
result = []
limit = 0
#TODO: values of 50 and 5 should be made dependent of mean size of atom_boxes
while result.count(1) < 2 and limit < 80:
result=[]
bigger_bond_box = [bond_box[0]-limit,bond_box[1]-limit,bond_box[2]+limit,bond_box[3]+limit]
for atom_box in filtered_bboxes:
result.append(bb_box_intersects(atom_box,bigger_bond_box))
limit+=5
indices = [i for i, x in enumerate(result) if x == 1]
if len(indices) == 2:
#import ipdb; ipdb.set_trace()
mol_graph[indices[0],indices[1]]=label_bond
mol_graph[indices[1],indices[0]]=label_bond
if len(indices) > 2:
#we have more then two canidate atoms for one bond, we filter ...
cand_bboxes = filtered_bboxes[indices,:]
cand_indices = dist_filter_bboxes(cand_bboxes)
#import ipdb; ipdb.set_trace()
mol_graph[indices[cand_indices[0]],indices[cand_indices[1]]]=label_bond
mol_graph[indices[cand_indices[1]],indices[cand_indices[0]]]=label_bond
#print("more than 2 indices")
#if len(indices) < 2:
# print("less than 2 indices")
#import ipdb; ipdb.set_trace()
# else:
# result=[]
# for atom_box in filtered_bboxes:
# result.append(bb_box_intersects(atom_box,bond_box))
# indices = [i for i, x in enumerate(result) if x == 1]
# if len(indices) == 1:
# stereo_atoms[indices[0]]=label_bond
stereo_bonds = np.where(mol_graph>4, True, False)
if np.any(stereo_bonds):
stereo_boxes = stereo_preds[image_idx]['boxes'][0]
stereo_labels= stereo_preds[image_idx]['preds'][0]
for stereo_box in stereo_boxes:
result=[]
for atom_box in filtered_bboxes:
result.append(bb_box_intersects(atom_box,stereo_box))
indices = [i for i, x in enumerate(result) if x == 1]
if len(indices) == 1:
stereo_atoms[indices[0]]=1
molecule = dict()
molecule['graph'] = mol_graph
#molecule['atom_labels'] = atom_preds[image_idx]['preds'][0]
molecule['atom_labels'] = filtered_labels
molecule['atom_boxes'] = filtered_bboxes
molecule['stereo_atoms'] = stereo_atoms
molecule['charge_atoms'] = charge_atoms
mol_graphs.append(molecule)
#base_path="./"
#base_path = pathlib.Path(args.data_path)
#image_dir = base_path.joinpath("images")
#smiles_dir = base_path.joinpath("smiles")
#impath = image_dir.joinpath(f"{image_idx}.png")
#smilespath = smiles_dir.joinpath(f"{image_idx}.txt")
save_mol_to_file(molecule,'molfile')
mol = Chem.MolFromMolFile('molfile',sanitize=False)
problematic = 0
try:
problems = Chem.DetectChemistryProblems(mol)
if len(problems) > 0:
mol = solve_mol_problems(mol,problems)
problematic = 1
#import ipdb; ipdb.set_trace()
try:
Chem.SanitizeMol(mol)
except:
problems = Chem.DetectChemistryProblems(mol)
if len(problems) > 0:
mol = solve_mol_problems(mol,problems)
try:
Chem.SanitizeMol(mol)
except:
pass
except:
problematic = 1
try:
pred_smiles = Chem.MolToSmiles(mol)
except:
pred_smiles = ""
problematic = 1
predictions+=1
predictions_list.append([image_idx,pred_smiles,problematic])
#import ipdb; ipdb.set_trace()
file_preds = open('preds_atomlenz','w')
for pred in predictions_list:
print(pred)
#x = st.slider('Select a value')
#st.write(x, 'squared is', x * x)