File size: 2,568 Bytes
7a223e7
6adb0d1
 
 
 
 
 
9acb06c
6adb0d1
d0f68bc
6adb0d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67828bb
648d955
 
 
 
 
 
 
 
 
7a223e7
ebf43f7
 
 
 
 
 
 
 
d0f68bc
7124c40
d0f68bc
67828bb
 
d337d91
03c79df
 
67828bb
 
03c79df
67828bb
685777c
 
d337d91
d0f68bc
03c79df
d0f68bc
 
 
692229b
648d955
 
 
 
 
692229b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import streamlit as st
import os
import torch
#import math
import numpy as np
#import matplotlib.pyplot as plt
#import pathlib
#from AtomLenz import *
#from utils_graph import *
from Object_Smiles import Objects_Smiles 

#from robust_detection import wandb_config
from robust_detection import utils
from robust_detection.models.rcnn import RCNN
from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN, COCO_RCNN

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from PIL import Image
import matplotlib.pyplot as plt

def plot_bbox(bbox_XYXY, label):
    xmin, ymin, xmax, ymax =bbox_XYXY
    plt.plot(
        [xmin, xmin, xmax, xmax, xmin],
        [ymin, ymax, ymax, ymin, ymin],
        color=colors[label],
        label=str(label))

model_cls = RCNN
experiment_path_atoms="./models/atoms_model/"
dir_list = os.listdir(experiment_path_atoms)
dir_list = [os.path.join(experiment_path_atoms,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
model_atom.model.roi_heads.score_thresh = 0.65 
data_cls = Objects_Smiles
dataset = data_cls(data_path="./uploads/", batch_size=1)
#    dataset.prepare_data()
st.title("Atom Level Entity Detector")

image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png'])
#st.write('filename is', file_name)
if image_file is not None:
   col1, col2 = st.columns(2)

   image = Image.open(image_file)
   col1.image(image, use_column_width=True)
   if not os.path.exists("uploads/images"):
       os.makedirs("uploads/images")
   with open(os.path.join("uploads/images/","0.png"),"wb") as f: 
        f.write(image_file.getbuffer()) 
   st.success("Saved File")
   dataset.prepare_data()
   trainer = pl.Trainer(logger=False)
   atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
   st.write(atom_preds)
   plt.imshow(image, cmap="gray")
   for bbox, label in zip(atom_preds[0]['boxes'], atom_preds[0]['pred']):
       plot_bbox(bbox, label)
  
   col2.image(plt.show(), use_column_width=True)
#x = st.slider('Select a value')
#st.write(x, 'squared is', x * x)