Spaces:
Running
Running
Initial Streamlit Docker app
Browse files- .dockerignore +9 -0
- Dockerfile +46 -0
- README.md +82 -7
- __pycache__/model.cpython-38.pyc +0 -0
- __pycache__/utils.cpython-38.pyc +0 -0
- app.py +129 -0
- best_hybridgnn.pt +3 -0
- model.py +50 -0
- predictions.db +0 -0
- requirements.txt +8 -0
- utils.py +47 -0
.dockerignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*.pyc
|
3 |
+
*.pkl
|
4 |
+
*.sqlite
|
5 |
+
.git
|
6 |
+
*.csv
|
7 |
+
*.db
|
8 |
+
*.log
|
9 |
+
venv/
|
Dockerfile
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dockerfile for Hugging Face Space: Streamlit + RDKit + PyG
|
2 |
+
|
3 |
+
FROM python:3.10-slim
|
4 |
+
|
5 |
+
# system libraries (needed by RDKit / Pillow)
|
6 |
+
RUN apt-get update && \
|
7 |
+
apt-get install -y --no-install-recommends \
|
8 |
+
build-essential \
|
9 |
+
libxrender1 \
|
10 |
+
libxext6 \
|
11 |
+
libsm6 \
|
12 |
+
libx11-6 \
|
13 |
+
libglib2.0-0 \
|
14 |
+
libfreetype6 \
|
15 |
+
libpng-dev \
|
16 |
+
wget && \
|
17 |
+
rm -rf /var/lib/apt/lists/*
|
18 |
+
|
19 |
+
# Python packages
|
20 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
21 |
+
pip install --no-cache-dir \
|
22 |
+
streamlit==1.45.0 \
|
23 |
+
rdkit-pypi==2022.9.5 \
|
24 |
+
pandas==2.2.3 \
|
25 |
+
numpy==1.26.4 \
|
26 |
+
torch==2.2.0 \
|
27 |
+
torch-geometric==2.5.2 \
|
28 |
+
ogb==1.3.6 \
|
29 |
+
pillow==10.3.0
|
30 |
+
|
31 |
+
# working directory & app code
|
32 |
+
WORKDIR /app
|
33 |
+
COPY . .
|
34 |
+
|
35 |
+
# Streamlit configuration for Spaces
|
36 |
+
ENV \
|
37 |
+
STREAMLIT_SERVER_HEADLESS=true \
|
38 |
+
STREAMLIT_SERVER_ADDRESS=0.0.0.0 \
|
39 |
+
STREAMLIT_SERVER_PORT=7860 \
|
40 |
+
STREAMLIT_TELEMETRY_DISABLED=true
|
41 |
+
|
42 |
+
EXPOSE 7860
|
43 |
+
|
44 |
+
# launch
|
45 |
+
CMD ["streamlit", "run", "app.py"]
|
46 |
+
|
README.md
CHANGED
@@ -1,10 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
---
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HOMO–LUMO Gap Predictor
|
2 |
+
|
3 |
+
This web app uses a trained Graph Neural Network (GNN) to predict HOMO–LUMO energy gaps from molecular SMILES strings. Built with [Streamlit](https://streamlit.io), it enables fast single or batch predictions with visualization.
|
4 |
+
|
5 |
+
### Live App
|
6 |
+
|
7 |
+
[Click here to launch the app](https://www.willfillinoncedeployed.com)
|
8 |
+
|
9 |
+
|
10 |
+
---
|
11 |
+
|
12 |
+
## Features
|
13 |
+
|
14 |
+
- Predict HOMO–LUMO gap for one or many molecules
|
15 |
+
- Accepts comma-separated SMILES or CSV uploads
|
16 |
+
- RDKit rendering of molecule structures
|
17 |
+
- Downloadable CSV of predictions
|
18 |
+
- Powered by a trained hybrid GNN model with RDKit descriptors
|
19 |
+
|
20 |
---
|
21 |
+
|
22 |
+
## Usage
|
23 |
+
|
24 |
+
1. **Input Options**:
|
25 |
+
- Type one or more SMILES strings separated by commas
|
26 |
+
- OR upload a `.csv` file with a single column of SMILES
|
27 |
+
|
28 |
+
2. **Example SMILES**: CC(=O)Oc1ccccc1C(=O)O, C1=CC=CC=C1
|
29 |
+
|
30 |
+
3. **CSV Format**:
|
31 |
+
- One column
|
32 |
+
- No header
|
33 |
+
- Each row contains a SMILES string
|
34 |
+
|
35 |
+
4. **Output**:
|
36 |
+
- Predictions displayed in-browser (up to 10 molecules shown)
|
37 |
+
- Full results available for download as CSV
|
38 |
+
|
39 |
+
---
|
40 |
+
|
41 |
+
## Project Structure
|
42 |
+
|
43 |
+
streamlit-app/
|
44 |
+
│
|
45 |
+
├── app.py # Main Streamlit app
|
46 |
+
├── model.py # Hybrid GNN architecture and model loader
|
47 |
+
├── utils.py # RDKit and SMILES processing
|
48 |
+
├── requirements.txt # Python dependencies
|
49 |
+
└── predictions.db # SQLite log of predictions
|
50 |
+
|
51 |
---
|
52 |
|
53 |
+
## Requirements
|
54 |
+
|
55 |
+
To run locally:
|
56 |
+
```
|
57 |
+
pip install -r requirements.txt
|
58 |
+
streamlit run app.py
|
59 |
+
|
60 |
+
```
|
61 |
+
|
62 |
+
|
63 |
+
## Model Info
|
64 |
+
|
65 |
+
The app uses a trained hybrid GNN model combining:
|
66 |
+
|
67 |
+
* AtomEncoder and BondEncoder from OGB
|
68 |
+
* GINEConv layers from PyTorch Geometric
|
69 |
+
* Global mean pooling
|
70 |
+
* RDKit-based physicochemical descriptors
|
71 |
+
|
72 |
+
Trained on the [OGB PCQM4Mv2 dataset](https://ogb.stanford.edu/docs/lsc/pcqm4mv2/), optimized using Optuna
|
73 |
+
|
74 |
+
|
75 |
+
## Author
|
76 |
+
|
77 |
+
Developed by [Matthew Graham](https://github.com/MooseML)
|
78 |
+
For inquiries, collaborations, or ideas, feel free to reach out!
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
__pycache__/model.cpython-38.pyc
ADDED
Binary file (2.16 kB). View file
|
|
__pycache__/utils.cpython-38.pyc
ADDED
Binary file (1.62 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import torch
|
4 |
+
import sqlite3
|
5 |
+
from datetime import datetime
|
6 |
+
from rdkit import Chem
|
7 |
+
from rdkit.Chem import Draw
|
8 |
+
|
9 |
+
from model import load_model
|
10 |
+
from utils import smiles_to_data
|
11 |
+
from torch_geometric.loader import DataLoader
|
12 |
+
|
13 |
+
# Config
|
14 |
+
DEVICE = "cpu"
|
15 |
+
RDKIT_DIM = 6
|
16 |
+
MODEL_PATH = "best_hybridgnn.pt"
|
17 |
+
MAX_DISPLAY = 10
|
18 |
+
|
19 |
+
# Load Model
|
20 |
+
model = load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)
|
21 |
+
|
22 |
+
# SQLite Setup
|
23 |
+
@st.cache_resource
|
24 |
+
def init_db():
|
25 |
+
conn = sqlite3.connect("predictions.db", check_same_thread=False)
|
26 |
+
c = conn.cursor()
|
27 |
+
c.execute("""
|
28 |
+
CREATE TABLE IF NOT EXISTS predictions (
|
29 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
30 |
+
smiles TEXT,
|
31 |
+
prediction REAL,
|
32 |
+
timestamp TEXT
|
33 |
+
)
|
34 |
+
""")
|
35 |
+
conn.commit()
|
36 |
+
return conn
|
37 |
+
|
38 |
+
conn = init_db()
|
39 |
+
cursor = conn.cursor()
|
40 |
+
|
41 |
+
# Streamlit UI
|
42 |
+
st.title("HOMO-LUMO Gap Predictor")
|
43 |
+
st.markdown("""
|
44 |
+
This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph Neural Network (GNN).
|
45 |
+
|
46 |
+
**Instructions:**
|
47 |
+
- Enter a **single SMILES** string or **comma-separated list** in the box below.
|
48 |
+
- Or **upload a CSV file** containing a single column of SMILES strings.
|
49 |
+
- **Note**: If you've uploaded a CSV and want to switch to typing SMILES, please click the “X” next to the uploaded file to clear it.
|
50 |
+
- SMILES format should look like: `CC(=O)Oc1ccccc1C(=O)O` (for aspirin).
|
51 |
+
- The app will display predictions and molecule images (up to 10 shown at once).
|
52 |
+
""")
|
53 |
+
|
54 |
+
# Text Input
|
55 |
+
smiles_input = st.text_area("Enter SMILES string(s)", placeholder="C1=CC=CC=C1, CC(=O)Oc1ccccc1C(=O)O")
|
56 |
+
|
57 |
+
# File Upload
|
58 |
+
uploaded_file = st.file_uploader("...or upload a CSV file", type=["csv"])
|
59 |
+
|
60 |
+
smiles_list = []
|
61 |
+
|
62 |
+
if uploaded_file:
|
63 |
+
try:
|
64 |
+
df = pd.read_csv(uploaded_file)
|
65 |
+
if df.shape[1] != 1:
|
66 |
+
st.error("CSV should have only one column with SMILES strings.")
|
67 |
+
else:
|
68 |
+
smiles_list = df.iloc[:, 0].dropna().astype(str).tolist()
|
69 |
+
st.success(f"{len(smiles_list)} SMILES loaded from file.")
|
70 |
+
except Exception as e:
|
71 |
+
st.error(f"Error reading CSV: {e}")
|
72 |
+
|
73 |
+
elif smiles_input:
|
74 |
+
raw_input = smiles_input.strip().replace("\n", ",")
|
75 |
+
smiles_list = [smi.strip() for smi in raw_input.split(",") if smi.strip()]
|
76 |
+
st.success(f"{len(smiles_list)} SMILES parsed from input.")
|
77 |
+
|
78 |
+
# Run Inference
|
79 |
+
if smiles_list:
|
80 |
+
with st.spinner("Processing molecules..."):
|
81 |
+
data_list = smiles_to_data(smiles_list, device=DEVICE)
|
82 |
+
|
83 |
+
# Filter only valid molecules and keep aligned SMILES
|
84 |
+
valid_pairs = [(smi, data) for smi, data in zip(smiles_list, data_list) if data is not None]
|
85 |
+
|
86 |
+
if not valid_pairs:
|
87 |
+
st.warning("No valid molecules found.")
|
88 |
+
else:
|
89 |
+
valid_smiles, valid_data = zip(*valid_pairs)
|
90 |
+
loader = DataLoader(valid_data, batch_size=64)
|
91 |
+
predictions = []
|
92 |
+
|
93 |
+
for batch in loader:
|
94 |
+
batch = batch.to(DEVICE)
|
95 |
+
with torch.no_grad():
|
96 |
+
pred = model(batch).view(-1).cpu().numpy()
|
97 |
+
predictions.extend(pred.tolist())
|
98 |
+
|
99 |
+
# Display Results
|
100 |
+
st.subheader(f"Predictions (showing up to {MAX_DISPLAY} molecules):")
|
101 |
+
|
102 |
+
for i, (smi, pred) in enumerate(zip(valid_smiles, predictions)):
|
103 |
+
if i >= MAX_DISPLAY:
|
104 |
+
st.info(f"...only showing the first {MAX_DISPLAY} molecules")
|
105 |
+
break
|
106 |
+
|
107 |
+
mol = Chem.MolFromSmiles(smi)
|
108 |
+
if mol:
|
109 |
+
st.image(Draw.MolToImage(mol, size=(250, 250)))
|
110 |
+
st.write(f"**SMILES**: `{smi}`")
|
111 |
+
st.write(f"**Predicted HOMO-LUMO Gap**: `{pred:.4f} eV`")
|
112 |
+
|
113 |
+
# Log to SQLite
|
114 |
+
cursor.execute("INSERT INTO predictions (smiles, prediction, timestamp) VALUES (?, ?, ?)",
|
115 |
+
(smi, pred, str(datetime.now())))
|
116 |
+
conn.commit()
|
117 |
+
|
118 |
+
# Download Results
|
119 |
+
result_df = pd.DataFrame({
|
120 |
+
"SMILES": valid_smiles,
|
121 |
+
"Predicted HOMO-LUMO Gap (eV)": [round(p, 4) for p in predictions]
|
122 |
+
})
|
123 |
+
|
124 |
+
st.download_button(
|
125 |
+
label="Download Predictions as CSV",
|
126 |
+
data=result_df.to_csv(index=False).encode('utf-8'),
|
127 |
+
file_name="homolumo_predictions.csv",
|
128 |
+
mime="text/csv"
|
129 |
+
)
|
best_hybridgnn.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e3cd6a7f4297f6451cf159ac6a5745ae0edde7c6c481308ff407e065eec2828c
|
3 |
+
size 5259810
|
model.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import Linear, Dropout, Module, Sequential
|
3 |
+
from torch_geometric.nn import GINEConv, global_mean_pool
|
4 |
+
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
|
5 |
+
|
6 |
+
class HybridGNN(Module):
|
7 |
+
def __init__(self, gnn_dim, rdkit_dim, hidden_dim, dropout_rate=0.2, activation='ReLU'):
|
8 |
+
super().__init__()
|
9 |
+
act_map = {'Swish': torch.nn.SiLU(), 'ReLU': torch.nn.ReLU()}
|
10 |
+
act_fn = act_map[activation]
|
11 |
+
self.gnn_dim = gnn_dim
|
12 |
+
self.rdkit_dim = rdkit_dim
|
13 |
+
|
14 |
+
self.atom_encoder = AtomEncoder(emb_dim=gnn_dim)
|
15 |
+
self.bond_encoder = BondEncoder(emb_dim=gnn_dim)
|
16 |
+
|
17 |
+
self.conv1 = GINEConv(Sequential(Linear(gnn_dim, gnn_dim), act_fn, Linear(gnn_dim, gnn_dim)))
|
18 |
+
self.conv2 = GINEConv(Sequential(Linear(gnn_dim, gnn_dim), act_fn, Linear(gnn_dim, gnn_dim)))
|
19 |
+
self.pool = global_mean_pool
|
20 |
+
|
21 |
+
self.mlp = Sequential(Linear(gnn_dim + rdkit_dim, hidden_dim), act_fn,
|
22 |
+
Dropout(dropout_rate),
|
23 |
+
Linear(hidden_dim, hidden_dim // 2), act_fn,
|
24 |
+
Dropout(dropout_rate),
|
25 |
+
Linear(hidden_dim // 2, 1))
|
26 |
+
|
27 |
+
def forward(self, data):
|
28 |
+
x = self.atom_encoder(data.x)
|
29 |
+
edge_attr = self.bond_encoder(data.edge_attr)
|
30 |
+
|
31 |
+
x = self.conv1(x, data.edge_index, edge_attr)
|
32 |
+
x = self.conv2(x, data.edge_index, edge_attr)
|
33 |
+
x = self.pool(x, data.batch)
|
34 |
+
|
35 |
+
rdkit_feats = getattr(data, 'rdkit_feats', None)
|
36 |
+
if rdkit_feats is not None:
|
37 |
+
if x.shape[0] != rdkit_feats.shape[0]:
|
38 |
+
raise ValueError(f"Shape mismatch: GNN output ({x.shape}) vs rdkit_feats ({rdkit_feats.shape})")
|
39 |
+
x = torch.cat([x, rdkit_feats], dim=1)
|
40 |
+
else:
|
41 |
+
raise ValueError("RDKit features not found in the data object.")
|
42 |
+
|
43 |
+
return self.mlp(x)
|
44 |
+
|
45 |
+
def load_model(rdkit_dim: int, path: str = "best_hybridgnn.pt", device: str = "cpu"):
|
46 |
+
model = HybridGNN(gnn_dim=512, rdkit_dim=rdkit_dim, hidden_dim=256, dropout_rate=0.29, activation='Swish')
|
47 |
+
model.load_state_dict(torch.load(path, map_location=device))
|
48 |
+
model.to(device)
|
49 |
+
model.eval()
|
50 |
+
return model
|
predictions.db
ADDED
Binary file (24.6 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.45.0
|
2 |
+
rdkit-pypi==2022.9.5
|
3 |
+
pandas==2.2.3
|
4 |
+
numpy==1.26.4
|
5 |
+
torch==2.2.0
|
6 |
+
torch-geometric==2.5.2
|
7 |
+
ogb==1.3.6
|
8 |
+
pillow==10.3.0
|
utils.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from rdkit import Chem
|
4 |
+
from rdkit.Chem import Descriptors
|
5 |
+
from torch_geometric.data import Data
|
6 |
+
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
|
7 |
+
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
|
8 |
+
from ogb.lsc import PCQM4Mv2Evaluator
|
9 |
+
from ogb.utils import smiles2graph
|
10 |
+
from torch_geometric.loader import DataLoader
|
11 |
+
|
12 |
+
def compute_rdkit_features(smiles):
|
13 |
+
mol = Chem.MolFromSmiles(smiles)
|
14 |
+
if mol is None:
|
15 |
+
raise ValueError("Invalid SMILES")
|
16 |
+
return [
|
17 |
+
Descriptors.MolWt(mol),
|
18 |
+
Descriptors.NumRotatableBonds(mol),
|
19 |
+
Descriptors.TPSA(mol),
|
20 |
+
Descriptors.NumHAcceptors(mol),
|
21 |
+
Descriptors.NumHDonors(mol),
|
22 |
+
Descriptors.RingCount(mol)
|
23 |
+
]
|
24 |
+
|
25 |
+
def smiles_to_data(smiles_list, device="cpu"):
|
26 |
+
graph_list = []
|
27 |
+
rdkit_list = []
|
28 |
+
|
29 |
+
for smi in smiles_list:
|
30 |
+
try:
|
31 |
+
graph = smiles2graph(smi)
|
32 |
+
rdkit_feats = compute_rdkit_features(smi)
|
33 |
+
|
34 |
+
data = Data(
|
35 |
+
x=torch.tensor(graph['node_feat'], dtype=torch.long),
|
36 |
+
edge_index=torch.tensor(graph['edge_index'], dtype=torch.long),
|
37 |
+
edge_attr=torch.tensor(graph['edge_feat'], dtype=torch.long),
|
38 |
+
rdkit_feats=torch.tensor(rdkit_feats, dtype=torch.float32).unsqueeze(0),
|
39 |
+
num_nodes=graph['num_nodes']
|
40 |
+
)
|
41 |
+
graph_list.append(data)
|
42 |
+
except Exception as e:
|
43 |
+
print(f"Error with SMILES '{smi}': {e}")
|
44 |
+
continue
|
45 |
+
|
46 |
+
return graph_list
|
47 |
+
|