MooseML commited on
Commit
e3eae4d
·
1 Parent(s): 6568ddd

Initial Streamlit Docker app

Browse files
.dockerignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ *.pkl
4
+ *.sqlite
5
+ .git
6
+ *.csv
7
+ *.db
8
+ *.log
9
+ venv/
Dockerfile ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for Hugging Face Space: Streamlit + RDKit + PyG
2
+
3
+ FROM python:3.10-slim
4
+
5
+ # system libraries (needed by RDKit / Pillow)
6
+ RUN apt-get update && \
7
+ apt-get install -y --no-install-recommends \
8
+ build-essential \
9
+ libxrender1 \
10
+ libxext6 \
11
+ libsm6 \
12
+ libx11-6 \
13
+ libglib2.0-0 \
14
+ libfreetype6 \
15
+ libpng-dev \
16
+ wget && \
17
+ rm -rf /var/lib/apt/lists/*
18
+
19
+ # Python packages
20
+ RUN pip install --no-cache-dir --upgrade pip && \
21
+ pip install --no-cache-dir \
22
+ streamlit==1.45.0 \
23
+ rdkit-pypi==2022.9.5 \
24
+ pandas==2.2.3 \
25
+ numpy==1.26.4 \
26
+ torch==2.2.0 \
27
+ torch-geometric==2.5.2 \
28
+ ogb==1.3.6 \
29
+ pillow==10.3.0
30
+
31
+ # working directory & app code
32
+ WORKDIR /app
33
+ COPY . .
34
+
35
+ # Streamlit configuration for Spaces
36
+ ENV \
37
+ STREAMLIT_SERVER_HEADLESS=true \
38
+ STREAMLIT_SERVER_ADDRESS=0.0.0.0 \
39
+ STREAMLIT_SERVER_PORT=7860 \
40
+ STREAMLIT_TELEMETRY_DISABLED=true
41
+
42
+ EXPOSE 7860
43
+
44
+ # launch
45
+ CMD ["streamlit", "run", "app.py"]
46
+
README.md CHANGED
@@ -1,10 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Homo Lumo Gap Predictor
3
- emoji: 🐨
4
- colorFrom: green
5
- colorTo: gray
6
- sdk: docker
7
- pinned: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HOMO–LUMO Gap Predictor
2
+
3
+ This web app uses a trained Graph Neural Network (GNN) to predict HOMO–LUMO energy gaps from molecular SMILES strings. Built with [Streamlit](https://streamlit.io), it enables fast single or batch predictions with visualization.
4
+
5
+ ### Live App
6
+
7
+ [Click here to launch the app](https://www.willfillinoncedeployed.com)
8
+
9
+
10
+ ---
11
+
12
+ ## Features
13
+
14
+ - Predict HOMO–LUMO gap for one or many molecules
15
+ - Accepts comma-separated SMILES or CSV uploads
16
+ - RDKit rendering of molecule structures
17
+ - Downloadable CSV of predictions
18
+ - Powered by a trained hybrid GNN model with RDKit descriptors
19
+
20
  ---
21
+
22
+ ## Usage
23
+
24
+ 1. **Input Options**:
25
+ - Type one or more SMILES strings separated by commas
26
+ - OR upload a `.csv` file with a single column of SMILES
27
+
28
+ 2. **Example SMILES**: CC(=O)Oc1ccccc1C(=O)O, C1=CC=CC=C1
29
+
30
+ 3. **CSV Format**:
31
+ - One column
32
+ - No header
33
+ - Each row contains a SMILES string
34
+
35
+ 4. **Output**:
36
+ - Predictions displayed in-browser (up to 10 molecules shown)
37
+ - Full results available for download as CSV
38
+
39
+ ---
40
+
41
+ ## Project Structure
42
+
43
+ streamlit-app/
44
+
45
+ ├── app.py # Main Streamlit app
46
+ ├── model.py # Hybrid GNN architecture and model loader
47
+ ├── utils.py # RDKit and SMILES processing
48
+ ├── requirements.txt # Python dependencies
49
+ └── predictions.db # SQLite log of predictions
50
+
51
  ---
52
 
53
+ ## Requirements
54
+
55
+ To run locally:
56
+ ```
57
+ pip install -r requirements.txt
58
+ streamlit run app.py
59
+
60
+ ```
61
+
62
+
63
+ ## Model Info
64
+
65
+ The app uses a trained hybrid GNN model combining:
66
+
67
+ * AtomEncoder and BondEncoder from OGB
68
+ * GINEConv layers from PyTorch Geometric
69
+ * Global mean pooling
70
+ * RDKit-based physicochemical descriptors
71
+
72
+ Trained on the [OGB PCQM4Mv2 dataset](https://ogb.stanford.edu/docs/lsc/pcqm4mv2/), optimized using Optuna
73
+
74
+
75
+ ## Author
76
+
77
+ Developed by [Matthew Graham](https://github.com/MooseML)
78
+ For inquiries, collaborations, or ideas, feel free to reach out!
79
+
80
+
81
+
82
+
83
+
84
+
85
+
__pycache__/model.cpython-38.pyc ADDED
Binary file (2.16 kB). View file
 
__pycache__/utils.cpython-38.pyc ADDED
Binary file (1.62 kB). View file
 
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import torch
4
+ import sqlite3
5
+ from datetime import datetime
6
+ from rdkit import Chem
7
+ from rdkit.Chem import Draw
8
+
9
+ from model import load_model
10
+ from utils import smiles_to_data
11
+ from torch_geometric.loader import DataLoader
12
+
13
+ # Config
14
+ DEVICE = "cpu"
15
+ RDKIT_DIM = 6
16
+ MODEL_PATH = "best_hybridgnn.pt"
17
+ MAX_DISPLAY = 10
18
+
19
+ # Load Model
20
+ model = load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)
21
+
22
+ # SQLite Setup
23
+ @st.cache_resource
24
+ def init_db():
25
+ conn = sqlite3.connect("predictions.db", check_same_thread=False)
26
+ c = conn.cursor()
27
+ c.execute("""
28
+ CREATE TABLE IF NOT EXISTS predictions (
29
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
30
+ smiles TEXT,
31
+ prediction REAL,
32
+ timestamp TEXT
33
+ )
34
+ """)
35
+ conn.commit()
36
+ return conn
37
+
38
+ conn = init_db()
39
+ cursor = conn.cursor()
40
+
41
+ # Streamlit UI
42
+ st.title("HOMO-LUMO Gap Predictor")
43
+ st.markdown("""
44
+ This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph Neural Network (GNN).
45
+
46
+ **Instructions:**
47
+ - Enter a **single SMILES** string or **comma-separated list** in the box below.
48
+ - Or **upload a CSV file** containing a single column of SMILES strings.
49
+ - **Note**: If you've uploaded a CSV and want to switch to typing SMILES, please click the “X” next to the uploaded file to clear it.
50
+ - SMILES format should look like: `CC(=O)Oc1ccccc1C(=O)O` (for aspirin).
51
+ - The app will display predictions and molecule images (up to 10 shown at once).
52
+ """)
53
+
54
+ # Text Input
55
+ smiles_input = st.text_area("Enter SMILES string(s)", placeholder="C1=CC=CC=C1, CC(=O)Oc1ccccc1C(=O)O")
56
+
57
+ # File Upload
58
+ uploaded_file = st.file_uploader("...or upload a CSV file", type=["csv"])
59
+
60
+ smiles_list = []
61
+
62
+ if uploaded_file:
63
+ try:
64
+ df = pd.read_csv(uploaded_file)
65
+ if df.shape[1] != 1:
66
+ st.error("CSV should have only one column with SMILES strings.")
67
+ else:
68
+ smiles_list = df.iloc[:, 0].dropna().astype(str).tolist()
69
+ st.success(f"{len(smiles_list)} SMILES loaded from file.")
70
+ except Exception as e:
71
+ st.error(f"Error reading CSV: {e}")
72
+
73
+ elif smiles_input:
74
+ raw_input = smiles_input.strip().replace("\n", ",")
75
+ smiles_list = [smi.strip() for smi in raw_input.split(",") if smi.strip()]
76
+ st.success(f"{len(smiles_list)} SMILES parsed from input.")
77
+
78
+ # Run Inference
79
+ if smiles_list:
80
+ with st.spinner("Processing molecules..."):
81
+ data_list = smiles_to_data(smiles_list, device=DEVICE)
82
+
83
+ # Filter only valid molecules and keep aligned SMILES
84
+ valid_pairs = [(smi, data) for smi, data in zip(smiles_list, data_list) if data is not None]
85
+
86
+ if not valid_pairs:
87
+ st.warning("No valid molecules found.")
88
+ else:
89
+ valid_smiles, valid_data = zip(*valid_pairs)
90
+ loader = DataLoader(valid_data, batch_size=64)
91
+ predictions = []
92
+
93
+ for batch in loader:
94
+ batch = batch.to(DEVICE)
95
+ with torch.no_grad():
96
+ pred = model(batch).view(-1).cpu().numpy()
97
+ predictions.extend(pred.tolist())
98
+
99
+ # Display Results
100
+ st.subheader(f"Predictions (showing up to {MAX_DISPLAY} molecules):")
101
+
102
+ for i, (smi, pred) in enumerate(zip(valid_smiles, predictions)):
103
+ if i >= MAX_DISPLAY:
104
+ st.info(f"...only showing the first {MAX_DISPLAY} molecules")
105
+ break
106
+
107
+ mol = Chem.MolFromSmiles(smi)
108
+ if mol:
109
+ st.image(Draw.MolToImage(mol, size=(250, 250)))
110
+ st.write(f"**SMILES**: `{smi}`")
111
+ st.write(f"**Predicted HOMO-LUMO Gap**: `{pred:.4f} eV`")
112
+
113
+ # Log to SQLite
114
+ cursor.execute("INSERT INTO predictions (smiles, prediction, timestamp) VALUES (?, ?, ?)",
115
+ (smi, pred, str(datetime.now())))
116
+ conn.commit()
117
+
118
+ # Download Results
119
+ result_df = pd.DataFrame({
120
+ "SMILES": valid_smiles,
121
+ "Predicted HOMO-LUMO Gap (eV)": [round(p, 4) for p in predictions]
122
+ })
123
+
124
+ st.download_button(
125
+ label="Download Predictions as CSV",
126
+ data=result_df.to_csv(index=False).encode('utf-8'),
127
+ file_name="homolumo_predictions.csv",
128
+ mime="text/csv"
129
+ )
best_hybridgnn.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3cd6a7f4297f6451cf159ac6a5745ae0edde7c6c481308ff407e065eec2828c
3
+ size 5259810
model.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import Linear, Dropout, Module, Sequential
3
+ from torch_geometric.nn import GINEConv, global_mean_pool
4
+ from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
5
+
6
+ class HybridGNN(Module):
7
+ def __init__(self, gnn_dim, rdkit_dim, hidden_dim, dropout_rate=0.2, activation='ReLU'):
8
+ super().__init__()
9
+ act_map = {'Swish': torch.nn.SiLU(), 'ReLU': torch.nn.ReLU()}
10
+ act_fn = act_map[activation]
11
+ self.gnn_dim = gnn_dim
12
+ self.rdkit_dim = rdkit_dim
13
+
14
+ self.atom_encoder = AtomEncoder(emb_dim=gnn_dim)
15
+ self.bond_encoder = BondEncoder(emb_dim=gnn_dim)
16
+
17
+ self.conv1 = GINEConv(Sequential(Linear(gnn_dim, gnn_dim), act_fn, Linear(gnn_dim, gnn_dim)))
18
+ self.conv2 = GINEConv(Sequential(Linear(gnn_dim, gnn_dim), act_fn, Linear(gnn_dim, gnn_dim)))
19
+ self.pool = global_mean_pool
20
+
21
+ self.mlp = Sequential(Linear(gnn_dim + rdkit_dim, hidden_dim), act_fn,
22
+ Dropout(dropout_rate),
23
+ Linear(hidden_dim, hidden_dim // 2), act_fn,
24
+ Dropout(dropout_rate),
25
+ Linear(hidden_dim // 2, 1))
26
+
27
+ def forward(self, data):
28
+ x = self.atom_encoder(data.x)
29
+ edge_attr = self.bond_encoder(data.edge_attr)
30
+
31
+ x = self.conv1(x, data.edge_index, edge_attr)
32
+ x = self.conv2(x, data.edge_index, edge_attr)
33
+ x = self.pool(x, data.batch)
34
+
35
+ rdkit_feats = getattr(data, 'rdkit_feats', None)
36
+ if rdkit_feats is not None:
37
+ if x.shape[0] != rdkit_feats.shape[0]:
38
+ raise ValueError(f"Shape mismatch: GNN output ({x.shape}) vs rdkit_feats ({rdkit_feats.shape})")
39
+ x = torch.cat([x, rdkit_feats], dim=1)
40
+ else:
41
+ raise ValueError("RDKit features not found in the data object.")
42
+
43
+ return self.mlp(x)
44
+
45
+ def load_model(rdkit_dim: int, path: str = "best_hybridgnn.pt", device: str = "cpu"):
46
+ model = HybridGNN(gnn_dim=512, rdkit_dim=rdkit_dim, hidden_dim=256, dropout_rate=0.29, activation='Swish')
47
+ model.load_state_dict(torch.load(path, map_location=device))
48
+ model.to(device)
49
+ model.eval()
50
+ return model
predictions.db ADDED
Binary file (24.6 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.45.0
2
+ rdkit-pypi==2022.9.5
3
+ pandas==2.2.3
4
+ numpy==1.26.4
5
+ torch==2.2.0
6
+ torch-geometric==2.5.2
7
+ ogb==1.3.6
8
+ pillow==10.3.0
utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from rdkit import Chem
4
+ from rdkit.Chem import Descriptors
5
+ from torch_geometric.data import Data
6
+ from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
7
+ from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
8
+ from ogb.lsc import PCQM4Mv2Evaluator
9
+ from ogb.utils import smiles2graph
10
+ from torch_geometric.loader import DataLoader
11
+
12
+ def compute_rdkit_features(smiles):
13
+ mol = Chem.MolFromSmiles(smiles)
14
+ if mol is None:
15
+ raise ValueError("Invalid SMILES")
16
+ return [
17
+ Descriptors.MolWt(mol),
18
+ Descriptors.NumRotatableBonds(mol),
19
+ Descriptors.TPSA(mol),
20
+ Descriptors.NumHAcceptors(mol),
21
+ Descriptors.NumHDonors(mol),
22
+ Descriptors.RingCount(mol)
23
+ ]
24
+
25
+ def smiles_to_data(smiles_list, device="cpu"):
26
+ graph_list = []
27
+ rdkit_list = []
28
+
29
+ for smi in smiles_list:
30
+ try:
31
+ graph = smiles2graph(smi)
32
+ rdkit_feats = compute_rdkit_features(smi)
33
+
34
+ data = Data(
35
+ x=torch.tensor(graph['node_feat'], dtype=torch.long),
36
+ edge_index=torch.tensor(graph['edge_index'], dtype=torch.long),
37
+ edge_attr=torch.tensor(graph['edge_feat'], dtype=torch.long),
38
+ rdkit_feats=torch.tensor(rdkit_feats, dtype=torch.float32).unsqueeze(0),
39
+ num_nodes=graph['num_nodes']
40
+ )
41
+ graph_list.append(data)
42
+ except Exception as e:
43
+ print(f"Error with SMILES '{smi}': {e}")
44
+ continue
45
+
46
+ return graph_list
47
+