File size: 2,571 Bytes
7a223e7
6adb0d1
 
 
 
 
 
9acb06c
6adb0d1
d0f68bc
6adb0d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67828bb
648d955
 
 
 
 
 
 
 
 
7a223e7
ebf43f7
 
 
 
 
 
 
 
d0f68bc
7124c40
d0f68bc
67828bb
 
d337d91
03c79df
 
67828bb
 
03c79df
67828bb
685777c
 
d337d91
d0f68bc
03c79df
d0f68bc
 
 
ffed2e4
648d955
5195bea
648d955
 
 
692229b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import streamlit as st
import os
import torch
#import math
import numpy as np
#import matplotlib.pyplot as plt
#import pathlib
#from AtomLenz import *
#from utils_graph import *
from Object_Smiles import Objects_Smiles 

#from robust_detection import wandb_config
from robust_detection import utils
from robust_detection.models.rcnn import RCNN
from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN, COCO_RCNN

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from PIL import Image
import matplotlib.pyplot as plt

def plot_bbox(bbox_XYXY, label):
    xmin, ymin, xmax, ymax =bbox_XYXY
    plt.plot(
        [xmin, xmin, xmax, xmax, xmin],
        [ymin, ymax, ymax, ymin, ymin],
        color=colors[label],
        label=str(label))

model_cls = RCNN
experiment_path_atoms="./models/atoms_model/"
dir_list = os.listdir(experiment_path_atoms)
dir_list = [os.path.join(experiment_path_atoms,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
model_atom.model.roi_heads.score_thresh = 0.65 
data_cls = Objects_Smiles
dataset = data_cls(data_path="./uploads/", batch_size=1)
#    dataset.prepare_data()
st.title("Atom Level Entity Detector")

image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png'])
#st.write('filename is', file_name)
if image_file is not None:
   col1, col2 = st.columns(2)

   image = Image.open(image_file)
   col1.image(image, use_column_width=True)
   if not os.path.exists("uploads/images"):
       os.makedirs("uploads/images")
   with open(os.path.join("uploads/images/","0.png"),"wb") as f: 
        f.write(image_file.getbuffer()) 
   st.success("Saved File")
   dataset.prepare_data()
   trainer = pl.Trainer(logger=False)
   atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
##   st.write(atom_preds)
   plt.imshow(image, cmap="gray")
   for bbox, label in zip(atom_preds[0]['boxes'], atom_preds[0]['preds']):
       plot_bbox(bbox, label)
  
   col2.image(plt.show(), use_column_width=True)
#x = st.slider('Select a value')
#st.write(x, 'squared is', x * x)