atomlenz / app.py
moldenhof's picture
implementing app
09fa344
raw
history blame
4.83 kB
import streamlit as st
import os
import torch
#import math
import numpy as np
#import matplotlib.pyplot as plt
#import pathlib
#from AtomLenz import *
#from utils_graph import *
from Object_Smiles import Objects_Smiles
#from robust_detection import wandb_config
from robust_detection import utils
from robust_detection.models.rcnn import RCNN
from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN, COCO_RCNN
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from PIL import Image
import matplotlib.pyplot as plt
colors = ["magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum","magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum"]
def plot_bbox(bbox_XYXY, label):
xmin, ymin, xmax, ymax =bbox_XYXY
plt.plot(
[xmin, xmin, xmax, xmax, xmin],
[ymin, ymax, ymax, ymin, ymin],
color=colors[label],
label=str(label))
model_cls = RCNN
experiment_path_atoms="./models/atoms_model/"
dir_list = os.listdir(experiment_path_atoms)
dir_list = [os.path.join(experiment_path_atoms,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
model_atom.model.roi_heads.score_thresh = 0.65
experiment_path_bonds = "./models/bonds_model/"
dir_list = os.listdir(experiment_path_bonds)
dir_list = [os.path.join(experiment_path_bonds,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_bonds = [f for f in dir_list if "ckpt" in f][0]
model_bond = model_cls.load_from_checkpoint(checkpoint_file_bonds)
model_bond.model.roi_heads.score_thresh = 0.65
experiment_path_stereo = "./models/stereos_model/"
dir_list = os.listdir(experiment_path_stereo)
dir_list = [os.path.join(experiment_path_stereo,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_stereo = [f for f in dir_list if "ckpt" in f][0]
model_stereo = model_cls.load_from_checkpoint(checkpoint_file_stereo)
model_stereo.model.roi_heads.score_thresh = 0.65
experiment_path_charges = "./models/charges_model/"
dir_list = os.listdir(experiment_path_charges)
dir_list = [os.path.join(experiment_path_charges,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_charges = [f for f in dir_list if "ckpt" in f][0]
model_charge = model_cls.load_from_checkpoint(checkpoint_file_charges)
model_charge.model.roi_heads.score_thresh = 0.65
data_cls = Objects_Smiles
dataset = data_cls(data_path="./uploads/", batch_size=1)
# dataset.prepare_data()
st.title("Atom Level Entity Detector")
image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png'])
#st.write('filename is', file_name)
if image_file is not None:
#col1, col2 = st.columns(2)
image = Image.open(image_file)
#col1.image(image, use_column_width=True)
st.image(image, use_column_width=True)
col1, col2 = st.columns(2)
if not os.path.exists("uploads/images"):
os.makedirs("uploads/images")
with open(os.path.join("uploads/images/","0.png"),"wb") as f:
f.write(image_file.getbuffer())
st.success("Saved File")
dataset.prepare_data()
trainer = pl.Trainer(logger=False)
atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
bond_preds = trainer.predict(model_bond, dataset.test_dataloader())
stereo_preds = trainer.predict(model_stereo, dataset.test_dataloader())
charge_preds = trainer.predict(model_charge, dataset.test_dataloader())
#st.write(atom_preds)
plt.imshow(image, cmap="gray")
for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]):
# st.write(bbox)
# st.write(label)
plot_bbox(bbox, label)
plt.axis('off')
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
image_vis = Image.open("example_image.png")
col1.image(image_vis, use_column_width=True)
plt.clf()
plt.imshow(image, cmap="gray")
for bbox, label in zip(bond_preds[0]['boxes'][0], bond_preds[0]['preds'][0]):
# st.write(bbox)
# st.write(label)
plot_bbox(bbox, label)
plt.axis('off')
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
image_vis = Image.open("example_image.png")
col2.image(image_vis, use_column_width=True)
#x = st.slider('Select a value')
#st.write(x, 'squared is', x * x)