File size: 2,571 Bytes
7a223e7 6adb0d1 9acb06c 6adb0d1 d0f68bc 6adb0d1 67828bb 648d955 7a223e7 ebf43f7 d0f68bc 7124c40 d0f68bc 67828bb d337d91 03c79df 67828bb 03c79df 67828bb 685777c d337d91 d0f68bc 03c79df d0f68bc ffed2e4 648d955 5195bea 648d955 692229b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import streamlit as st
import os
import torch
#import math
import numpy as np
#import matplotlib.pyplot as plt
#import pathlib
#from AtomLenz import *
#from utils_graph import *
from Object_Smiles import Objects_Smiles
#from robust_detection import wandb_config
from robust_detection import utils
from robust_detection.models.rcnn import RCNN
from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN, COCO_RCNN
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from PIL import Image
import matplotlib.pyplot as plt
def plot_bbox(bbox_XYXY, label):
xmin, ymin, xmax, ymax =bbox_XYXY
plt.plot(
[xmin, xmin, xmax, xmax, xmin],
[ymin, ymax, ymax, ymin, ymin],
color=colors[label],
label=str(label))
model_cls = RCNN
experiment_path_atoms="./models/atoms_model/"
dir_list = os.listdir(experiment_path_atoms)
dir_list = [os.path.join(experiment_path_atoms,f) for f in dir_list]
dir_list.sort(key=os.path.getctime, reverse=True)
checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
model_atom.model.roi_heads.score_thresh = 0.65
data_cls = Objects_Smiles
dataset = data_cls(data_path="./uploads/", batch_size=1)
# dataset.prepare_data()
st.title("Atom Level Entity Detector")
image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png'])
#st.write('filename is', file_name)
if image_file is not None:
col1, col2 = st.columns(2)
image = Image.open(image_file)
col1.image(image, use_column_width=True)
if not os.path.exists("uploads/images"):
os.makedirs("uploads/images")
with open(os.path.join("uploads/images/","0.png"),"wb") as f:
f.write(image_file.getbuffer())
st.success("Saved File")
dataset.prepare_data()
trainer = pl.Trainer(logger=False)
atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
## st.write(atom_preds)
plt.imshow(image, cmap="gray")
for bbox, label in zip(atom_preds[0]['boxes'], atom_preds[0]['preds']):
plot_bbox(bbox, label)
col2.image(plt.show(), use_column_width=True)
#x = st.slider('Select a value')
#st.write(x, 'squared is', x * x)
|