File size: 12,025 Bytes
7a223e7 6adb0d1 b2c3eed 6adb0d1 d0f68bc 6adb0d1 67828bb 648d955 5417a1b 9d217be 5417a1b 9d217be 5417a1b 9d217be 5417a1b 9d217be 4ea7e69 648d955 7a223e7 ebf43f7 09fa344 d0f68bc 7124c40 d0f68bc 67828bb d337d91 03c79df c4f48f9 67828bb 03c79df c4f48f9 685777c d337d91 d0f68bc 265325a d0f68bc 7493a71 d0f68bc 09fa344 d0723fe 7493a71 c4f48f9 648d955 c4521a7 c4f48f9 648d955 c4f48f9 ff8787b c4f48f9 09fa344 4fd163f b2c3eed 53d17a4 b2c3eed 692229b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
import streamlit as st
import os
import torch
#import math
import numpy as np
#import matplotlib.pyplot as plt
#import pathlib
from AtomLenz import *
#from utils_graph import *
from Object_Smiles import Objects_Smiles
#from robust_detection import wandb_config
from robust_detection import utils
from robust_detection.models.rcnn import RCNN
from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN, COCO_RCNN
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from PIL import Image
import matplotlib.pyplot as plt
def main_page(top_n, model_path):
st.markdown(
"""test """
)
#### TRYOUT MENU #####
page_names_to_funcs = {
# "Microscopy images from a molecule": images_from_molecule,
# "Molecules from a microscopy image": molecules_from_image,
"About AtomLenz": main_page,
}
selected_page = st.sidebar.selectbox("What would you like to retrieve?", page_names_to_funcs.keys())
st.sidebar.markdown('')
selected_model = st.sidebar.selectbox(
"Select a AtomLenz model to load",
("AtomLenz trained on synthetic data (default)", "AtomLenz for hand-drawn images", "ChemExpert (not available yet)"))
model_dict = {
"AtomLenz trained on synthetic data (default)" : "atomlenz_default.pt",
"AtomLenz for hand-drawn images" : "atomlenz_handdrawn.pt",
"ChemExpert (not available yet)" : "atomlenz_default.pt"
}
model_file = model_dict[selected_model]
model_path = os.path.join(datapath, model_file)
if model_path.endswith("320).pt"):
image_resolution = 320
else:
image_resolution = 520
#page_names_to_funcs[selected_page](n_objects, model_path)
######################
colors = ["magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum","magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum"]
def plot_bbox(bbox_XYXY, label):
xmin, ymin, xmax, ymax =bbox_XYXY
plt.plot(
[xmin, xmin, xmax, xmax, xmin],
[ymin, ymax, ymax, ymin, ymin],
color=colors[label],
label=str(label))
model_cls = RCNN
experiment_path_atoms="./models/atoms_model/"
dir_list = os.listdir(experiment_path_atoms)
dir_list = [os.path.join(experiment_path_atoms,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
model_atom.model.roi_heads.score_thresh = 0.65
experiment_path_bonds = "./models/bonds_model/"
dir_list = os.listdir(experiment_path_bonds)
dir_list = [os.path.join(experiment_path_bonds,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_bonds = [f for f in dir_list if "ckpt" in f][0]
model_bond = model_cls.load_from_checkpoint(checkpoint_file_bonds)
model_bond.model.roi_heads.score_thresh = 0.65
experiment_path_stereo = "./models/stereos_model/"
dir_list = os.listdir(experiment_path_stereo)
dir_list = [os.path.join(experiment_path_stereo,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_stereo = [f for f in dir_list if "ckpt" in f][0]
model_stereo = model_cls.load_from_checkpoint(checkpoint_file_stereo)
model_stereo.model.roi_heads.score_thresh = 0.65
experiment_path_charges = "./models/charges_model/"
dir_list = os.listdir(experiment_path_charges)
dir_list = [os.path.join(experiment_path_charges,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_charges = [f for f in dir_list if "ckpt" in f][0]
model_charge = model_cls.load_from_checkpoint(checkpoint_file_charges)
model_charge.model.roi_heads.score_thresh = 0.65
data_cls = Objects_Smiles
dataset = data_cls(data_path="./uploads/", batch_size=1)
# dataset.prepare_data()
st.title("Atom Level Entity Detector")
image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png'])
#st.write('filename is', file_name)
if image_file is not None:
#col1, col2 = st.columns(2)
image = Image.open(image_file)
#col1.image(image, use_column_width=True)
st.image(image, use_column_width=True)
col1, col2 = st.columns(2)
if not os.path.exists("uploads/images"):
os.makedirs("uploads/images")
with open(os.path.join("uploads/images/","0.png"),"wb") as f:
f.write(image_file.getbuffer())
#st.success("Saved File")
dataset.prepare_data()
trainer = pl.Trainer(logger=False)
st.toast('Predicting atoms,bonds,charges,..., please wait')
atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
bond_preds = trainer.predict(model_bond, dataset.test_dataloader())
stereo_preds = trainer.predict(model_stereo, dataset.test_dataloader())
charges_preds = trainer.predict(model_charge, dataset.test_dataloader())
st.toast('Done')
#st.write(atom_preds)
plt.imshow(image, cmap="gray")
for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]):
# st.write(bbox)
# st.write(label)
plot_bbox(bbox, label)
plt.axis('off')
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
image_vis = Image.open("example_image.png")
col1.image(image_vis, use_column_width=True)
plt.clf()
plt.imshow(image, cmap="gray")
for bbox, label in zip(bond_preds[0]['boxes'][0], bond_preds[0]['preds'][0]):
# st.write(bbox)
# st.write(label)
plot_bbox(bbox, label)
plt.axis('off')
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
image_vis = Image.open("example_image.png")
col2.image(image_vis, use_column_width=True)
mol_graphs = []
count_bonds_preds = np.zeros(4)
count_atoms_preds = np.zeros(15)
correct=0
correct_objects=0
correct_both=0
predictions=0
tanimoto_dists=[]
predictions_list = []
for image_idx, bonds in enumerate(bond_preds):
count_bonds_preds = np.zeros(8)
count_atoms_preds = np.zeros(18)
atom_boxes = atom_preds[image_idx]['boxes'][0]
atom_labels = atom_preds[image_idx]['preds'][0]
atom_scores = atom_preds[image_idx]['scores'][0]
charge_boxes = charges_preds[image_idx]['boxes'][0]
charge_labels = charges_preds[image_idx]['preds'][0]
charge_mask=torch.where(charge_labels>1)
filtered_ch_labels=charge_labels[charge_mask]
filtered_ch_boxes=charge_boxes[charge_mask]
#import ipdb; ipdb.set_trace()
filtered_bboxes, filtered_labels = iou_filter_bboxes(atom_boxes, atom_labels, atom_scores)
#for atom_label in filtered_labels:
# count_atoms_preds[atom_label] += 1
#import ipdb; ipdb.set_trace()
mol_graph = np.zeros((len(filtered_bboxes),len(filtered_bboxes)))
stereo_atoms = np.zeros(len(filtered_bboxes))
charge_atoms = np.ones(len(filtered_bboxes))
for index,box_atom in enumerate(filtered_bboxes):
for box_charge,label_charge in zip(filtered_ch_boxes,filtered_ch_labels):
if bb_box_intersects(box_atom,box_charge) == 1:
charge_atoms[index]=label_charge
for bond_idx, bond_box in enumerate(bonds['boxes'][0]):
label_bond = bonds['preds'][0][bond_idx]
if label_bond > 1:
try:
count_bonds_preds[label_bond] += 1
except:
count_bonds_preds=count_bonds_preds
#import ipdb; ipdb.set_trace()
result = []
limit = 0
#TODO: values of 50 and 5 should be made dependent of mean size of atom_boxes
while result.count(1) < 2 and limit < 80:
result=[]
bigger_bond_box = [bond_box[0]-limit,bond_box[1]-limit,bond_box[2]+limit,bond_box[3]+limit]
for atom_box in filtered_bboxes:
result.append(bb_box_intersects(atom_box,bigger_bond_box))
limit+=5
indices = [i for i, x in enumerate(result) if x == 1]
if len(indices) == 2:
#import ipdb; ipdb.set_trace()
mol_graph[indices[0],indices[1]]=label_bond
mol_graph[indices[1],indices[0]]=label_bond
if len(indices) > 2:
#we have more then two canidate atoms for one bond, we filter ...
cand_bboxes = filtered_bboxes[indices,:]
cand_indices = dist_filter_bboxes(cand_bboxes)
#import ipdb; ipdb.set_trace()
mol_graph[indices[cand_indices[0]],indices[cand_indices[1]]]=label_bond
mol_graph[indices[cand_indices[1]],indices[cand_indices[0]]]=label_bond
#print("more than 2 indices")
#if len(indices) < 2:
# print("less than 2 indices")
#import ipdb; ipdb.set_trace()
# else:
# result=[]
# for atom_box in filtered_bboxes:
# result.append(bb_box_intersects(atom_box,bond_box))
# indices = [i for i, x in enumerate(result) if x == 1]
# if len(indices) == 1:
# stereo_atoms[indices[0]]=label_bond
stereo_bonds = np.where(mol_graph>4, True, False)
if np.any(stereo_bonds):
stereo_boxes = stereo_preds[image_idx]['boxes'][0]
stereo_labels= stereo_preds[image_idx]['preds'][0]
for stereo_box in stereo_boxes:
result=[]
for atom_box in filtered_bboxes:
result.append(bb_box_intersects(atom_box,stereo_box))
indices = [i for i, x in enumerate(result) if x == 1]
if len(indices) == 1:
stereo_atoms[indices[0]]=1
molecule = dict()
molecule['graph'] = mol_graph
#molecule['atom_labels'] = atom_preds[image_idx]['preds'][0]
molecule['atom_labels'] = filtered_labels
molecule['atom_boxes'] = filtered_bboxes
molecule['stereo_atoms'] = stereo_atoms
molecule['charge_atoms'] = charge_atoms
mol_graphs.append(molecule)
#base_path="./"
#base_path = pathlib.Path(args.data_path)
#image_dir = base_path.joinpath("images")
#smiles_dir = base_path.joinpath("smiles")
#impath = image_dir.joinpath(f"{image_idx}.png")
#smilespath = smiles_dir.joinpath(f"{image_idx}.txt")
save_mol_to_file(molecule,'molfile')
mol = Chem.MolFromMolFile('molfile',sanitize=False)
problematic = 0
try:
problems = Chem.DetectChemistryProblems(mol)
if len(problems) > 0:
mol = solve_mol_problems(mol,problems)
problematic = 1
#import ipdb; ipdb.set_trace()
try:
Chem.SanitizeMol(mol)
except:
problems = Chem.DetectChemistryProblems(mol)
if len(problems) > 0:
mol = solve_mol_problems(mol,problems)
try:
Chem.SanitizeMol(mol)
except:
pass
except:
problematic = 1
try:
pred_smiles = Chem.MolToSmiles(mol)
except:
pred_smiles = ""
problematic = 1
predictions+=1
predictions_list.append([image_idx,pred_smiles,problematic])
#import ipdb; ipdb.set_trace()
file_preds = open('preds_atomlenz','w')
for pred in predictions_list:
print(pred)
#x = st.slider('Select a value')
#st.write(x, 'squared is', x * x)
|