File size: 4,834 Bytes
7a223e7
6adb0d1
 
 
 
 
 
9acb06c
6adb0d1
d0f68bc
6adb0d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67828bb
648d955
 
4ea7e69
648d955
 
 
 
 
 
 
7a223e7
ebf43f7
 
 
 
 
 
 
 
09fa344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0f68bc
7124c40
d0f68bc
67828bb
 
d337d91
03c79df
 
c4f48f9
67828bb
03c79df
c4f48f9
 
 
685777c
 
d337d91
d0f68bc
03c79df
d0f68bc
 
 
09fa344
 
 
c4f48f9
648d955
c4521a7
c4f48f9
 
648d955
c4f48f9
ff8787b
 
c4f48f9
09fa344
 
 
 
 
 
 
 
 
 
692229b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st
import os
import torch
#import math
import numpy as np
#import matplotlib.pyplot as plt
#import pathlib
#from AtomLenz import *
#from utils_graph import *
from Object_Smiles import Objects_Smiles 

#from robust_detection import wandb_config
from robust_detection import utils
from robust_detection.models.rcnn import RCNN
from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN, COCO_RCNN

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from PIL import Image
import matplotlib.pyplot as plt

colors = ["magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum","magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum"]
def plot_bbox(bbox_XYXY, label):
    xmin, ymin, xmax, ymax =bbox_XYXY
    plt.plot(
        [xmin, xmin, xmax, xmax, xmin],
        [ymin, ymax, ymax, ymin, ymin],
        color=colors[label],
        label=str(label))

model_cls = RCNN
experiment_path_atoms="./models/atoms_model/"
dir_list = os.listdir(experiment_path_atoms)
dir_list = [os.path.join(experiment_path_atoms,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
model_atom.model.roi_heads.score_thresh = 0.65 
experiment_path_bonds = "./models/bonds_model/"
dir_list = os.listdir(experiment_path_bonds)
dir_list = [os.path.join(experiment_path_bonds,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_bonds = [f for f in dir_list if "ckpt" in f][0]
model_bond = model_cls.load_from_checkpoint(checkpoint_file_bonds)
model_bond.model.roi_heads.score_thresh = 0.65
experiment_path_stereo = "./models/stereos_model/"
dir_list = os.listdir(experiment_path_stereo)
dir_list = [os.path.join(experiment_path_stereo,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_stereo = [f for f in dir_list if "ckpt" in f][0]
model_stereo = model_cls.load_from_checkpoint(checkpoint_file_stereo)
model_stereo.model.roi_heads.score_thresh = 0.65
experiment_path_charges = "./models/charges_model/"
dir_list = os.listdir(experiment_path_charges)
dir_list = [os.path.join(experiment_path_charges,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_charges = [f for f in dir_list if "ckpt" in f][0]
model_charge = model_cls.load_from_checkpoint(checkpoint_file_charges)
model_charge.model.roi_heads.score_thresh = 0.65

data_cls = Objects_Smiles
dataset = data_cls(data_path="./uploads/", batch_size=1)
#    dataset.prepare_data()
st.title("Atom Level Entity Detector")

image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png'])
#st.write('filename is', file_name)
if image_file is not None:
   #col1, col2 = st.columns(2)

   image = Image.open(image_file)
   #col1.image(image, use_column_width=True)
   st.image(image, use_column_width=True)
   col1, col2 = st.columns(2)
   if not os.path.exists("uploads/images"):
       os.makedirs("uploads/images")
   with open(os.path.join("uploads/images/","0.png"),"wb") as f: 
        f.write(image_file.getbuffer()) 
   st.success("Saved File")
   dataset.prepare_data()
   trainer = pl.Trainer(logger=False)
   atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
   bond_preds = trainer.predict(model_bond, dataset.test_dataloader())
   stereo_preds = trainer.predict(model_stereo, dataset.test_dataloader())
   charge_preds = trainer.predict(model_charge, dataset.test_dataloader())
   #st.write(atom_preds)
   plt.imshow(image, cmap="gray")
   for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]):
      # st.write(bbox)
      # st.write(label)
       plot_bbox(bbox, label)
   plt.axis('off')
   plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
   image_vis = Image.open("example_image.png")
   col1.image(image_vis, use_column_width=True)
   plt.clf()
   plt.imshow(image, cmap="gray")
   for bbox, label in zip(bond_preds[0]['boxes'][0], bond_preds[0]['preds'][0]):
      # st.write(bbox)
      # st.write(label)
       plot_bbox(bbox, label)
   plt.axis('off')
   plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
   image_vis = Image.open("example_image.png")
   col2.image(image_vis, use_column_width=True)
#x = st.slider('Select a value')
#st.write(x, 'squared is', x * x)