File size: 4,834 Bytes
7a223e7 6adb0d1 9acb06c 6adb0d1 d0f68bc 6adb0d1 67828bb 648d955 4ea7e69 648d955 7a223e7 ebf43f7 09fa344 d0f68bc 7124c40 d0f68bc 67828bb d337d91 03c79df c4f48f9 67828bb 03c79df c4f48f9 685777c d337d91 d0f68bc 03c79df d0f68bc 09fa344 c4f48f9 648d955 c4521a7 c4f48f9 648d955 c4f48f9 ff8787b c4f48f9 09fa344 692229b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import streamlit as st
import os
import torch
#import math
import numpy as np
#import matplotlib.pyplot as plt
#import pathlib
#from AtomLenz import *
#from utils_graph import *
from Object_Smiles import Objects_Smiles
#from robust_detection import wandb_config
from robust_detection import utils
from robust_detection.models.rcnn import RCNN
from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN, COCO_RCNN
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from PIL import Image
import matplotlib.pyplot as plt
colors = ["magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum","magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum"]
def plot_bbox(bbox_XYXY, label):
xmin, ymin, xmax, ymax =bbox_XYXY
plt.plot(
[xmin, xmin, xmax, xmax, xmin],
[ymin, ymax, ymax, ymin, ymin],
color=colors[label],
label=str(label))
model_cls = RCNN
experiment_path_atoms="./models/atoms_model/"
dir_list = os.listdir(experiment_path_atoms)
dir_list = [os.path.join(experiment_path_atoms,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
model_atom.model.roi_heads.score_thresh = 0.65
experiment_path_bonds = "./models/bonds_model/"
dir_list = os.listdir(experiment_path_bonds)
dir_list = [os.path.join(experiment_path_bonds,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_bonds = [f for f in dir_list if "ckpt" in f][0]
model_bond = model_cls.load_from_checkpoint(checkpoint_file_bonds)
model_bond.model.roi_heads.score_thresh = 0.65
experiment_path_stereo = "./models/stereos_model/"
dir_list = os.listdir(experiment_path_stereo)
dir_list = [os.path.join(experiment_path_stereo,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_stereo = [f for f in dir_list if "ckpt" in f][0]
model_stereo = model_cls.load_from_checkpoint(checkpoint_file_stereo)
model_stereo.model.roi_heads.score_thresh = 0.65
experiment_path_charges = "./models/charges_model/"
dir_list = os.listdir(experiment_path_charges)
dir_list = [os.path.join(experiment_path_charges,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_charges = [f for f in dir_list if "ckpt" in f][0]
model_charge = model_cls.load_from_checkpoint(checkpoint_file_charges)
model_charge.model.roi_heads.score_thresh = 0.65
data_cls = Objects_Smiles
dataset = data_cls(data_path="./uploads/", batch_size=1)
# dataset.prepare_data()
st.title("Atom Level Entity Detector")
image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png'])
#st.write('filename is', file_name)
if image_file is not None:
#col1, col2 = st.columns(2)
image = Image.open(image_file)
#col1.image(image, use_column_width=True)
st.image(image, use_column_width=True)
col1, col2 = st.columns(2)
if not os.path.exists("uploads/images"):
os.makedirs("uploads/images")
with open(os.path.join("uploads/images/","0.png"),"wb") as f:
f.write(image_file.getbuffer())
st.success("Saved File")
dataset.prepare_data()
trainer = pl.Trainer(logger=False)
atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
bond_preds = trainer.predict(model_bond, dataset.test_dataloader())
stereo_preds = trainer.predict(model_stereo, dataset.test_dataloader())
charge_preds = trainer.predict(model_charge, dataset.test_dataloader())
#st.write(atom_preds)
plt.imshow(image, cmap="gray")
for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]):
# st.write(bbox)
# st.write(label)
plot_bbox(bbox, label)
plt.axis('off')
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
image_vis = Image.open("example_image.png")
col1.image(image_vis, use_column_width=True)
plt.clf()
plt.imshow(image, cmap="gray")
for bbox, label in zip(bond_preds[0]['boxes'][0], bond_preds[0]['preds'][0]):
# st.write(bbox)
# st.write(label)
plot_bbox(bbox, label)
plt.axis('off')
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
image_vis = Image.open("example_image.png")
col2.image(image_vis, use_column_width=True)
#x = st.slider('Select a value')
#st.write(x, 'squared is', x * x)
|