Spaces:
Sleeping
Sleeping
File size: 5,174 Bytes
a1efe76 b12a328 a1efe76 0ef141b e3eae4d a1efe76 e3eae4d d072ab4 5331c58 a1efe76 d072ab4 0ef141b a1efe76 0ef141b a1efe76 e3eae4d d072ab4 a1efe76 3c3852b e3eae4d a1efe76 d072ab4 a1efe76 d072ab4 e3eae4d b12a328 5331c58 b12a328 5331c58 b12a328 d072ab4 a1efe76 d072ab4 e3eae4d d072ab4 d96a023 a1efe76 d072ab4 a1efe76 d072ab4 a1efe76 d072ab4 10c6717 d072ab4 a1efe76 d072ab4 10c6717 a1efe76 10c6717 a1efe76 d072ab4 a1efe76 d072ab4 a1efe76 d072ab4 a1efe76 d072ab4 a1efe76 5331c58 a1efe76 d072ab4 a1efe76 d072ab4 a1efe76 d072ab4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os, pathlib, sqlite3, sys, tempfile
from datetime import datetime
from io import StringIO
import pandas as pd
import streamlit as st
import torch
from rdkit import Chem
from rdkit.Chem import Draw
from torch_geometric.loader import DataLoader
from model import load_model
from utils import smiles_to_data
# configuration
DEVICE, RDKIT_DIM, MODEL_PATH, MAX_DISPLAY = "cpu", 6, "best_hybridgnn.pt", 20
# heavy imports already done above; now Streamlit starts
@st.cache_resource
def get_model():
return load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)
model = get_model()
# SQLite (cached) — DB stored in /data or /tmp
DB_DIR = pathlib.Path(os.getenv("DB_DIR", "/tmp"))
DB_DIR.mkdir(parents=True, exist_ok=True)
@st.cache_resource
def init_db():
conn = sqlite3.connect(DB_DIR / "predictions.db", check_same_thread=False)
conn.execute(
"""CREATE TABLE IF NOT EXISTS predictions(
id INTEGER PRIMARY KEY AUTOINCREMENT,
smiles TEXT, prediction REAL, timestamp TEXT)"""
)
conn.commit()
return conn
conn = init_db()
cursor = conn.cursor()
# compact info panel
with st.sidebar.expander("Info & Env", expanded=False):
st.write(f"Python {sys.version.split()[0]}")
st.write(f"Temp dir: `{tempfile.gettempdir()}` "
f"({'writable' if os.access(tempfile.gettempdir(), os.W_OK) else 'read-only'})")
if "csv_bytes" in st.session_state:
st.write(f"Last upload: **{len(st.session_state['csv_bytes'])/1024:.1f} KB**")
# header and instructions (unchanged)
st.title("HOMO-LUMO Gap Predictor")
st.markdown("""
This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph Neural Network (GNN).
**Instructions:**
- Enter a **single SMILES** string or **comma/newline separated list** in the box below.
- Or **upload a CSV file** containing a single column of SMILES strings.
- **Note**: If you've uploaded a CSV and want to switch to typing SMILES, please click the "X" next to the uploaded file to clear it.
- SMILES format should look like: `O=C(C)Oc1ccccc1C(=O)O` (for aspirin).
- The app will display predictions and molecule images (up to 20 shown at once).
""")
# uploader (outside the form)
csv_file = st.file_uploader("CSV with SMILES", type=["csv"])
if csv_file is not None:
st.session_state["csv_bytes"] = csv_file.getvalue() # cache raw bytes
# textarea and button
smiles_list = []
with st.form("main_form"):
smiles_text = st.text_area(
"…or paste SMILES (comma/newline separated)",
placeholder="CC(=O)Oc1ccccc1C(=O)O",
height=120,
)
run = st.form_submit_button("Run Prediction")
# decide which input to use
if run:
if smiles_text.strip(): # user typed → override CSV
smiles_list = [
s.strip() for s in smiles_text.replace("\n", ",").split(",") if s.strip()
]
st.session_state.pop("csv_bytes", None) # forget previous upload
st.success(f"{len(smiles_list)} SMILES parsed from textbox")
elif "csv_bytes" in st.session_state: # CSV path
try:
df = pd.read_csv(
StringIO(st.session_state["csv_bytes"].decode("utf-8")),
comment="#",
)
col = df.columns[0] if df.shape[1] == 1 else next(
(c for c in df.columns if c.lower() == "smiles"), None
)
if col is None:
st.error("CSV needs one column or a 'SMILES' column.")
else:
smiles_list = df[col].dropna().astype(str).tolist()
st.success(f"{len(smiles_list)} SMILES loaded from CSV")
except Exception as e:
st.error(f"CSV error: {e}")
else:
st.warning("No input provided.")
# inference & display
if smiles_list:
data_list = smiles_to_data(smiles_list, device=DEVICE)
valid = [(s, d) for s, d in zip(smiles_list, data_list) if d is not None]
if not valid:
st.warning("No valid molecules.")
else:
vsmi, vdata = zip(*valid)
preds = []
for batch in DataLoader(vdata, batch_size=64):
with torch.no_grad():
preds.extend(model(batch.to(DEVICE)).view(-1).cpu().numpy().tolist())
st.subheader(f"Results (first {MAX_DISPLAY})")
for i, (smi, pred) in enumerate(zip(vsmi, preds)):
if i >= MAX_DISPLAY:
st.info("...Only Displaying 20 Compounds")
break
mol = Chem.MolFromSmiles(smi)
if mol:
st.image(Draw.MolToImage(mol, size=(250, 250)))
st.write(f"`{smi}` → **{pred:.4f} eV**")
cursor.execute(
"INSERT INTO predictions(smiles, prediction, timestamp) VALUES (?,?,?)",
(smi, float(pred), datetime.utcnow().isoformat()),
)
conn.commit()
st.download_button(
"Download CSV",
pd.DataFrame({"SMILES": vsmi, "Gap (eV)": [round(p, 4) for p in preds]})
.to_csv(index=False).encode(),
"homolumo_predictions.csv",
"text/csv",
)
|