MooseML commited on
Commit
0ef141b
Β·
1 Parent(s): 0714060

store the UploadedFile in st.session_state

Browse files
Files changed (1) hide show
  1. app.py +106 -88
app.py CHANGED
@@ -1,149 +1,167 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import torch
4
  import sqlite3
5
  from datetime import datetime
 
 
 
 
 
6
  from rdkit import Chem
7
  from rdkit.Chem import Draw
8
- import os, pathlib
 
9
  from model import load_model
10
  from utils import smiles_to_data
11
- from torch_geometric.loader import DataLoader
12
 
13
- # Config
14
- DEVICE = "cpu"
15
- RDKIT_DIM = 6
16
- MODEL_PATH = "best_hybridgnn.pt"
17
- MAX_DISPLAY = 10
18
 
19
- # Load Model
20
- model = load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)
 
 
 
 
21
 
22
- # SQLite Setup
23
- DB_DIR = os.getenv("DB_DIR", "/tmp")
24
- pathlib.Path(DB_DIR).mkdir(parents=True, exist_ok=True)
25
 
26
  @st.cache_resource
27
  def init_db():
28
- db_file = os.path.join(DB_DIR, "predictions.db")
29
- conn = sqlite3.connect(db_file, check_same_thread=False)
30
  c = conn.cursor()
31
- c.execute("""
 
32
  CREATE TABLE IF NOT EXISTS predictions (
33
  id INTEGER PRIMARY KEY AUTOINCREMENT,
34
  smiles TEXT,
35
  prediction REAL,
36
  timestamp TEXT
37
  )
38
- """)
 
39
  conn.commit()
40
  return conn
41
 
42
- conn = init_db()
43
  cursor = conn.cursor()
44
 
45
- # Streamlit UI
46
  st.title("HOMO-LUMO Gap Predictor")
47
- st.markdown("""
48
- This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph Neural Network (GNN).
49
-
50
- **Instructions:**
51
- - Enter a **single SMILES** string or **comma-separated list** in the box below.
52
- - Or **upload a CSV file** containing a single column of SMILES strings.
53
- - **Note**: If you've uploaded a CSV and want to switch to typing SMILES, please click the β€œX” next to the uploaded file to clear it.
54
- - SMILES format should look like: `CC(=O)Oc1ccccc1C(=O)O` (for aspirin).
55
- - The app will display predictions and molecule images (up to 10 shown at once).
56
- """)
 
57
 
58
  smiles_list = []
59
 
60
  with st.form("smiles_or_csv"):
61
  smiles_text = st.text_area(
62
- "SMILES (comma or line-separated)",
63
- placeholder="C1=CC=CC=C1\nCC(=O)Oc1ccccc1C(=O)O",
64
  height=120,
65
  )
66
- csv_file = st.file_uploader(
67
- "…or upload a one-column CSV",
68
- type=["csv"],
69
- )
70
  run = st.form_submit_button("Run Prediction")
71
 
 
72
  if run:
73
- if csv_file is not None:
 
 
 
74
  try:
75
- csv_file.seek(0)
76
- df = pd.read_csv(csv_file, comment="#")
77
 
78
  if df.shape[1] == 1:
79
  smiles_col = df.iloc[:, 0]
80
  elif "smiles" in [c.lower() for c in df.columns]:
81
- smiles_col = df[[c for c in df.columns if c.lower() == "smiles"][0]]
 
 
82
  else:
83
  st.error(
84
- "CSV must have a single column **or** a column named 'SMILES'. "
85
- f"Found columns: {', '.join(df.columns)}"
86
  )
87
  smiles_col = None
88
 
89
  if smiles_col is not None:
90
  smiles_list = smiles_col.dropna().astype(str).tolist()
91
- st.success(f"{len(smiles_list)} SMILES loaded from CSV.")
92
  except Exception as e:
93
  st.error(f"CSV read error: {e}")
94
 
 
95
  elif smiles_text.strip():
96
  raw = smiles_text.replace("\n", ",")
97
  smiles_list = [s.strip() for s in raw.split(",") if s.strip()]
98
- st.success(f"{len(smiles_list)} SMILES parsed from textbox.")
99
  else:
100
- st.warning("Please paste SMILES or upload a CSV before pressing *Run*.")
101
 
102
- # Run Inference
103
  if smiles_list:
104
- with st.spinner("Processing molecules..."):
105
  data_list = smiles_to_data(smiles_list, device=DEVICE)
106
 
107
- valid_pairs = [(smi, data) for smi, data in zip(smiles_list, data_list) if data is not None]
108
-
109
- if not valid_pairs:
110
- st.warning("No valid molecules found")
111
- else:
112
- valid_smiles, valid_data = zip(*valid_pairs)
113
- loader = DataLoader(valid_data, batch_size=64)
114
- predictions = []
115
-
116
- for batch in loader:
117
- batch = batch.to(DEVICE)
118
- with torch.no_grad():
119
- pred = model(batch).view(-1).cpu().numpy()
120
- predictions.extend(pred.tolist())
121
-
122
- st.subheader(f"Predictions (showing up to {MAX_DISPLAY} molecules):")
123
-
124
- for i, (smi, pred) in enumerate(zip(valid_smiles, predictions)):
125
- if i >= MAX_DISPLAY:
126
- st.info(f"...only showing the first {MAX_DISPLAY} molecules")
127
- break
128
 
129
- mol = Chem.MolFromSmiles(smi)
130
- if mol:
131
- st.image(Draw.MolToImage(mol, size=(250, 250)))
132
- st.write(f"**SMILES**: `{smi}`")
133
- st.write(f"**Predicted HOMO-LUMO Gap**: `{pred:.4f} eV`")
134
-
135
- cursor.execute("INSERT INTO predictions (smiles, prediction, timestamp) VALUES (?, ?, ?)",
136
- (smi, pred, str(datetime.now())))
137
- conn.commit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- result_df = pd.DataFrame({
 
 
140
  "SMILES": valid_smiles,
141
- "Predicted HOMO-LUMO Gap (eV)": [round(p, 4) for p in predictions]
142
- })
143
-
144
- st.download_button(
145
- label="Download Predictions as CSV",
146
- data=result_df.to_csv(index=False).encode('utf-8'),
147
- file_name="homolumo_predictions.csv",
148
- mime="text/csv"
149
- )
 
1
+ import os
2
+ import pathlib
 
3
  import sqlite3
4
  from datetime import datetime
5
+ from io import StringIO
6
+
7
+ import pandas as pd
8
+ import streamlit as st
9
+ import torch
10
  from rdkit import Chem
11
  from rdkit.Chem import Draw
12
+ from torch_geometric.loader import DataLoader
13
+
14
  from model import load_model
15
  from utils import smiles_to_data
 
16
 
17
+ # ───────────────────────── Configuration ─────────────────────────
18
+ DEVICE = "cpu"
19
+ RDKIT_DIM = 6
20
+ MODEL_PATH = "best_hybridgnn.pt"
21
+ MAX_DISPLAY = 10
22
 
23
+ # ─────────────────────── Cached model & database ─────────────────
24
+ @st.cache_resource
25
+ def get_model():
26
+ return load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)
27
+
28
+ model = get_model()
29
 
30
+ DB_DIR = pathlib.Path(os.getenv("DB_DIR", "/tmp"))
31
+ DB_DIR.mkdir(parents=True, exist_ok=True)
 
32
 
33
  @st.cache_resource
34
  def init_db():
35
+ conn = sqlite3.connect(DB_DIR / "predictions.db", check_same_thread=False)
 
36
  c = conn.cursor()
37
+ c.execute(
38
+ """
39
  CREATE TABLE IF NOT EXISTS predictions (
40
  id INTEGER PRIMARY KEY AUTOINCREMENT,
41
  smiles TEXT,
42
  prediction REAL,
43
  timestamp TEXT
44
  )
45
+ """
46
+ )
47
  conn.commit()
48
  return conn
49
 
50
+ conn = init_db()
51
  cursor = conn.cursor()
52
 
53
+ # UI header
54
  st.title("HOMO-LUMO Gap Predictor")
55
+ st.markdown(
56
+ """
57
+ Paste SMILES **or** upload a one-column CSV, then click **Run Prediction**.
58
+ The app draws each molecule and shows the predicted HOMO-LUMO gap (eV).
59
+ """
60
+ )
61
+
62
+ # Input widgets
63
+ csv_file = st.file_uploader("Upload CSV (one SMILES column)", type=["csv"])
64
+ if csv_file is not None:
65
+ st.session_state["uploaded_csv"] = csv_file # persist across reruns
66
 
67
  smiles_list = []
68
 
69
  with st.form("smiles_or_csv"):
70
  smiles_text = st.text_area(
71
+ "…or paste SMILES (comma or newline separated)",
72
+ placeholder="CC(=O)Oc1ccccc1C(=O)O",
73
  height=120,
74
  )
 
 
 
 
75
  run = st.form_submit_button("Run Prediction")
76
 
77
+ # Parse input after button
78
  if run:
79
+ csv_obj = st.session_state.get("uploaded_csv", None)
80
+
81
+ # CSV branch
82
+ if csv_obj is not None:
83
  try:
84
+ csv_obj.seek(0)
85
+ df = pd.read_csv(StringIO(csv_obj.getvalue().decode("utf-8")), comment="#")
86
 
87
  if df.shape[1] == 1:
88
  smiles_col = df.iloc[:, 0]
89
  elif "smiles" in [c.lower() for c in df.columns]:
90
+ smiles_col = df[
91
+ [c for c in df.columns if c.lower() == "smiles"][0]
92
+ ]
93
  else:
94
  st.error(
95
+ "CSV must have one column **or** a column named 'SMILES'"
96
+ f"Found: {', '.join(df.columns)}"
97
  )
98
  smiles_col = None
99
 
100
  if smiles_col is not None:
101
  smiles_list = smiles_col.dropna().astype(str).tolist()
102
+ st.success(f"{len(smiles_list)} SMILES loaded from CSV")
103
  except Exception as e:
104
  st.error(f"CSV read error: {e}")
105
 
106
+ # Textarea branch
107
  elif smiles_text.strip():
108
  raw = smiles_text.replace("\n", ",")
109
  smiles_list = [s.strip() for s in raw.split(",") if s.strip()]
110
+ st.success(f"{len(smiles_list)} SMILES parsed from textbox")
111
  else:
112
+ st.warning("Paste SMILES or upload a CSV before pressing **Run**")
113
 
114
+ # Inference
115
  if smiles_list:
116
+ with st.spinner("Running model…"):
117
  data_list = smiles_to_data(smiles_list, device=DEVICE)
118
 
119
+ valid_pairs = [
120
+ (smi, data)
121
+ for smi, data in zip(smiles_list, data_list)
122
+ if data is not None
123
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ if not valid_pairs:
126
+ st.warning("No valid molecules found")
127
+ else:
128
+ valid_smiles, valid_data = zip(*valid_pairs)
129
+ loader = DataLoader(valid_data, batch_size=64)
130
+ preds = []
131
+
132
+ for batch in loader:
133
+ batch = batch.to(DEVICE)
134
+ with torch.no_grad():
135
+ preds.extend(model(batch).view(-1).cpu().numpy().tolist())
136
+
137
+ # Display results
138
+ st.subheader(f"Predictions (showing up to {MAX_DISPLAY})")
139
+ for i, (smi, pred) in enumerate(zip(valid_smiles, preds)):
140
+ if i >= MAX_DISPLAY:
141
+ st.info(f"…only first {MAX_DISPLAY} molecules shown")
142
+ break
143
+ mol = Chem.MolFromSmiles(smi)
144
+ if mol:
145
+ st.image(Draw.MolToImage(mol, size=(250, 250)))
146
+ st.write(f"**SMILES:** `{smi}`")
147
+ st.write(f"**Predicted Gap:** `{pred:.4f} eV`")
148
+
149
+ cursor.execute(
150
+ "INSERT INTO predictions (smiles, prediction, timestamp) VALUES (?, ?, ?)",
151
+ (smi, float(pred), datetime.now().isoformat())
152
+ )
153
+ conn.commit()
154
 
155
+ # Download results
156
+ res_df = pd.DataFrame(
157
+ {
158
  "SMILES": valid_smiles,
159
+ "Predicted HOMO-LUMO Gap (eV)": [round(p, 4) for p in preds],
160
+ }
161
+ )
162
+ st.download_button(
163
+ "Download results as CSV",
164
+ res_df.to_csv(index=False).encode("utf-8"),
165
+ "homolumo_predictions.csv",
166
+ "text/csv",
167
+ )