File size: 5,174 Bytes
a1efe76
 
 
 
b12a328
a1efe76
0ef141b
e3eae4d
 
a1efe76
 
e3eae4d
 
 
d072ab4
5331c58
a1efe76
d072ab4
0ef141b
a1efe76
 
0ef141b
a1efe76
e3eae4d
d072ab4
a1efe76
 
3c3852b
e3eae4d
 
a1efe76
 
 
 
 
 
 
 
 
 
 
 
d072ab4
a1efe76
 
 
 
 
 
 
d072ab4
e3eae4d
b12a328
 
 
 
5331c58
b12a328
 
5331c58
 
b12a328
 
d072ab4
a1efe76
 
d072ab4
e3eae4d
d072ab4
d96a023
a1efe76
d072ab4
 
 
 
 
a1efe76
 
d072ab4
a1efe76
d072ab4
 
 
 
 
 
 
 
10c6717
d072ab4
 
 
 
 
 
 
a1efe76
d072ab4
10c6717
a1efe76
 
10c6717
a1efe76
 
 
d072ab4
a1efe76
d072ab4
a1efe76
 
 
 
 
d072ab4
a1efe76
 
 
 
 
d072ab4
a1efe76
 
 
 
5331c58
a1efe76
 
 
 
d072ab4
a1efe76
 
 
d072ab4
a1efe76
 
 
d072ab4
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os, pathlib, sqlite3, sys, tempfile
from datetime import datetime
from io import StringIO

import pandas as pd
import streamlit as st
import torch
from rdkit import Chem
from rdkit.Chem import Draw
from torch_geometric.loader import DataLoader

from model import load_model
from utils import smiles_to_data

#  configuration 
DEVICE, RDKIT_DIM, MODEL_PATH, MAX_DISPLAY = "cpu", 6, "best_hybridgnn.pt", 20

#  heavy imports already done above; now Streamlit starts 
@st.cache_resource
def get_model():
    return load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)

model = get_model()

# SQLite (cached) — DB stored in /data or /tmp
DB_DIR = pathlib.Path(os.getenv("DB_DIR", "/tmp"))
DB_DIR.mkdir(parents=True, exist_ok=True)

@st.cache_resource
def init_db():
    conn = sqlite3.connect(DB_DIR / "predictions.db", check_same_thread=False)
    conn.execute(
        """CREATE TABLE IF NOT EXISTS predictions(
               id INTEGER PRIMARY KEY AUTOINCREMENT,
               smiles TEXT, prediction REAL, timestamp TEXT)"""
    )
    conn.commit()
    return conn

conn   = init_db()
cursor = conn.cursor()

#  compact info panel 
with st.sidebar.expander("Info & Env", expanded=False):
    st.write(f"Python {sys.version.split()[0]}")
    st.write(f"Temp dir: `{tempfile.gettempdir()}` "
             f"({'writable' if os.access(tempfile.gettempdir(), os.W_OK) else 'read-only'})")
    if "csv_bytes" in st.session_state:
        st.write(f"Last upload: **{len(st.session_state['csv_bytes'])/1024:.1f} KB**")

#  header and instructions (unchanged) 
st.title("HOMO-LUMO Gap Predictor")
st.markdown("""
This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph Neural Network (GNN).

**Instructions:**
- Enter a **single SMILES** string or **comma/newline separated list** in the box below.
- Or **upload a CSV file** containing a single column of SMILES strings.
- **Note**: If you've uploaded a CSV and want to switch to typing SMILES, please click the "X" next to the uploaded file to clear it.
- SMILES format should look like: `O=C(C)Oc1ccccc1C(=O)O` (for aspirin).
- The app will display predictions and molecule images (up to 20 shown at once).
""")

#  uploader (outside the form) 
csv_file = st.file_uploader("CSV with SMILES", type=["csv"])
if csv_file is not None:
    st.session_state["csv_bytes"] = csv_file.getvalue()  # cache raw bytes

#  textarea and button 
smiles_list = []
with st.form("main_form"):
    smiles_text = st.text_area(
        "…or paste SMILES (comma/newline separated)",
        placeholder="CC(=O)Oc1ccccc1C(=O)O",
        height=120,
    )
    run = st.form_submit_button("Run Prediction")

#  decide which input to use 
if run:
    if smiles_text.strip():                      # user typed → override CSV
        smiles_list = [
            s.strip() for s in smiles_text.replace("\n", ",").split(",") if s.strip()
        ]
        st.session_state.pop("csv_bytes", None)  # forget previous upload
        st.success(f"{len(smiles_list)} SMILES parsed from textbox")

    elif "csv_bytes" in st.session_state:        # CSV path
        try:
            df = pd.read_csv(
                StringIO(st.session_state["csv_bytes"].decode("utf-8")),
                comment="#",
            )
            col = df.columns[0] if df.shape[1] == 1 else next(
                (c for c in df.columns if c.lower() == "smiles"), None
            )
            if col is None:
                st.error("CSV needs one column or a 'SMILES' column.")
            else:
                smiles_list = df[col].dropna().astype(str).tolist()
                st.success(f"{len(smiles_list)} SMILES loaded from CSV")
        except Exception as e:
            st.error(f"CSV error: {e}")

    else:
        st.warning("No input provided.")

#  inference & display 
if smiles_list:
    data_list = smiles_to_data(smiles_list, device=DEVICE)
    valid = [(s, d) for s, d in zip(smiles_list, data_list) if d is not None]

    if not valid:
        st.warning("No valid molecules.")
    else:
        vsmi, vdata = zip(*valid)
        preds = []
        for batch in DataLoader(vdata, batch_size=64):
            with torch.no_grad():
                preds.extend(model(batch.to(DEVICE)).view(-1).cpu().numpy().tolist())

        st.subheader(f"Results (first {MAX_DISPLAY})")
        for i, (smi, pred) in enumerate(zip(vsmi, preds)):
            if i >= MAX_DISPLAY:
                st.info("...Only Displaying 20 Compounds")
                break
            mol = Chem.MolFromSmiles(smi)
            if mol:
                st.image(Draw.MolToImage(mol, size=(250, 250)))
            st.write(f"`{smi}` → **{pred:.4f} eV**")

            cursor.execute(
                "INSERT INTO predictions(smiles, prediction, timestamp) VALUES (?,?,?)",
                (smi, float(pred), datetime.utcnow().isoformat()),
            )
        conn.commit()

        st.download_button(
            "Download CSV",
            pd.DataFrame({"SMILES": vsmi, "Gap (eV)": [round(p, 4) for p in preds]})
              .to_csv(index=False).encode(),
            "homolumo_predictions.csv",
            "text/csv",
        )