ShivamKum4r commited on
Commit
e6f8bfb
Β·
verified Β·
1 Parent(s): 03ba321

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +114 -148
src/streamlit_app.py CHANGED
@@ -1,3 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # ------------------- Imports -------------------
2
  import streamlit as st
3
  import numpy as np
@@ -9,8 +24,6 @@ from sklearn.metrics import accuracy_score, roc_auc_score
9
  from rdkit import Chem
10
  from rdkit.Chem import rdMolDescriptors
11
  from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
12
- from rdkit import RDLogger
13
- RDLogger.DisableLog('rdApp.*') # Silence all RDKit warnings
14
  from torch_geometric.data import Data
15
  from torch_geometric.nn import GCNConv, global_mean_pool
16
  from torch_geometric.loader import DataLoader
@@ -18,9 +31,9 @@ import plotly.express as px
18
  from rdkit.Chem import Draw
19
  from torch_geometric.data import Batch
20
  from rdkit.Chem import Descriptors
21
- import time
22
 
23
 
 
24
 
25
 
26
  # ------------------- Models -------------------
@@ -60,88 +73,42 @@ class RichGCNModel(nn.Module):
60
  st.set_page_config(layout="wide", page_title="Drug Toxicity Predictor")
61
  st.title("πŸ§ͺ Drug Toxicity Prediction Dashboard")
62
 
63
- # ------------------- Cache: Load Models and Threshold -------------------
64
- @st.cache_resource
65
- def load_fp_model():
66
- model = ToxicityNet()
67
- model.load_state_dict(torch.load("tox_model.pt", map_location="cpu"))
68
- model.eval()
69
- return model
70
-
71
- @st.cache_resource
72
- def load_gcn_model():
73
- model = RichGCNModel()
74
- model.load_state_dict(torch.load("gcn_model.pt", map_location="cpu"))
75
- model.eval()
76
- return model
77
-
78
- @st.cache_data
79
- def load_threshold():
80
- try:
81
- return float(np.load("gcn_best_threshold.npy"))
82
- except:
83
- return 0.5
84
-
85
- # ------------------- Cache: Load Dataset -------------------
86
- @st.cache_data
87
- def load_data():
88
- df = pd.read_csv("tox21.csv")[['smiles', 'SR-HSE']].dropna()
89
- df = df[df['SR-HSE'].isin([0, 1])]
90
- df = df[df['smiles'].apply(is_valid_graph)].reset_index(drop=True)
91
- return df
92
-
93
- # ------------------- Utility Functions -------------------
94
- fp_model = load_fp_model()
95
- gcn_model = load_gcn_model()
96
- best_threshold = load_threshold()
97
- fp_gen = GetMorganGenerator(radius=2, fpSize=1024)
98
 
99
  # Load Fingerprint Model
100
- msg_fp = st.empty()
101
- with msg_fp.container():
102
- with st.spinner("πŸ“¦ Loading fingerprint model..."):
103
- time.sleep(6)
104
- try:
105
- fp_model.load_state_dict(torch.load("tox_model.pt", map_location=torch.device("cpu")))
106
- fp_model.eval()
107
- fp_loaded = True
108
- st.success("βœ… Fingerprint model loaded.")
109
- except Exception as e:
110
- st.warning(f"⚠️ Fingerprint model not loaded: {e}")
111
- time.sleep(1)
112
- msg_fp.empty()
113
 
114
  # Load GCN Model
115
- msg_gcn = st.empty()
116
- with msg_gcn.container():
117
- with st.spinner("πŸ“¦ Loading GCN model..."):
118
- time.sleep(2)
119
- try:
120
- gcn_model.load_state_dict(torch.load("gcn_model.pt", map_location=torch.device("cpu")))
121
- gcn_model.eval()
122
- gcn_loaded = True
123
- st.success("βœ… GCN model loaded.")
124
- except Exception as e:
125
- st.warning(f"⚠️ GCN model not loaded: {e}")
126
- time.sleep(1)
127
- msg_gcn.empty()
128
 
129
  # Load Best Threshold
130
- msg_threshold = st.empty()
131
- with msg_threshold.container():
132
- with st.spinner("πŸ“Š Loading best threshold..."):
133
- time.sleep(2)
134
- try:
135
- best_threshold = float(np.load("gcn_best_threshold.npy"))
136
- except Exception as e:
137
- best_threshold = 0.5
138
- st.warning(f"⚠️ Using default threshold (0.5) for GCN model. Reason: {e}")
139
- st.success("βœ… All models loaded. Dashboard is ready!")
140
- time.sleep(2)
141
- msg_threshold.empty()
142
 
143
 
144
 
 
 
 
 
145
  def get_molecule_info(mol):
146
  return {
147
  "Formula": Chem.rdMolDescriptors.CalcMolFormula(mol),
@@ -234,10 +201,12 @@ df = df[df['SR-HSE'].isin([0, 1])].reset_index(drop=True)
234
  def is_valid_graph(smi):
235
  mol = Chem.MolFromSmiles(smi)
236
  return mol is not None and smiles_to_graph(smi) is not None
 
237
  df = df[df['smiles'].apply(is_valid_graph)].reset_index(drop=True)
238
 
239
 
240
 
 
241
  def create_graph_dataset(smiles_list, labels):
242
  data_list = []
243
  for smi, label in zip(smiles_list, labels):
@@ -330,36 +299,35 @@ with tab1:
330
  predict_btn = st.form_submit_button("πŸ” Predict")
331
 
332
  if predict_btn:
333
- with st.spinner("Predicting..."):
334
- mol = Chem.MolFromSmiles(smiles_fp)
335
- if mol:
336
- fp = fp_gen.GetFingerprint(mol)
337
- arr = np.array(fp).reshape(1, -1)
338
- tensor = torch.tensor(arr).float()
339
- with torch.no_grad():
340
- output = fp_model(tensor)
341
- prob = torch.sigmoid(output).item()
342
- raw_score = output.item()
343
- label = "Toxic" if prob > 0.5 else "Non-toxic"
344
- color = "red" if label == "Toxic" else "green"
345
-
346
- st.markdown(f"<h4>🧾 Prediction: <span style='color:{color}'>{label}</span> β€” <code>{prob:.3f}</code></h4>", unsafe_allow_html=True)
347
-
348
- if show_debug_fp:
349
- st.code(f"πŸ“‰ Raw Logit: {raw_score:.4f}", language='text')
350
- st.markdown("#### Fingerprint Vector (First 20 bits)")
351
- st.code(str(arr[0][:20]) + " ...", language="text")
352
-
353
- st.image(Draw.MolToImage(mol), caption="Molecular Structure", width=250)
354
-
355
- info = get_molecule_info(mol)
356
- st.markdown("### Molecule Info:")
357
- for k, v in info.items():
358
- st.markdown(f"**{k}:** {v}")
359
-
360
- st.plotly_chart(plot_distribution(df, 'fp', prob), use_container_width=True)
361
- else:
362
- st.error("❌ Invalid SMILES input. Please check your string.")
363
 
364
  with st.expander("πŸ“Œ Example SMILES to Try"):
365
  st.markdown("""
@@ -401,49 +369,47 @@ with tab2:
401
  gcn_btn = st.form_submit_button("πŸ” Predict")
402
 
403
  if gcn_btn:
404
- with st.spinner("Predicting..."):
405
- mol = Chem.MolFromSmiles(smiles_gcn)
406
-
407
- if mol is None:
408
- st.error("❌ Invalid SMILES: could not parse molecule.")
409
- elif not is_supported(mol):
410
- st.error("⚠️ This molecule contains unsupported atoms (e.g. Sn, P, etc.). GCN model only supports common organic elements.")
 
 
411
  else:
412
- graph = smiles_to_graph(smiles_gcn)
413
- if graph is None:
414
- st.error("❌ SMILES is valid but could not be converted to graph. Possibly malformed structure.")
415
- else:
416
- batch = Batch.from_data_list([graph])
417
- with torch.no_grad():
418
- out = gcn_model(batch)
419
- prob = torch.sigmoid(out).item()
420
- raw_score = out.item()
421
- label = "Toxic" if prob > best_threshold else "Non-toxic"
422
- color = "red" if label == "Toxic" else "green"
423
-
424
- st.markdown(f"<h4>🧾 GCN Prediction: <span style='color:{color}'>{label}</span> β€” <code>{prob:.3f}</code></h4>", unsafe_allow_html=True)
425
-
426
- if show_debug:
427
- st.code(f"πŸ“‰ Raw Logit: {raw_score:.4f}", language='text')
428
-
429
- st.image(Draw.MolToImage(mol), caption="Molecular Structure", width=250)
430
-
431
- def get_molecule_info(mol):
432
- return {
433
- "Molecular Weight": round(Chem.Descriptors.MolWt(mol), 2),
434
- "LogP": round(Chem.Crippen.MolLogP(mol), 2),
435
- "Num H-Bond Donors": Chem.Lipinski.NumHDonors(mol),
436
- "Num H-Bond Acceptors": Chem.Lipinski.NumHAcceptors(mol),
437
- "TPSA": round(Chem.rdMolDescriptors.CalcTPSA(mol), 2),
438
- "Num Rotatable Bonds": Chem.Lipinski.NumRotatableBonds(mol)
439
- }
440
-
441
- info = get_molecule_info(mol)
442
- st.markdown("### Molecule Info:")
443
- for k, v in info.items():
444
- st.markdown(f"**{k}:** {v}")
445
-
446
- st.plotly_chart(plot_distribution(df, 'gcn', prob), use_container_width=True)
447
 
448
  with st.expander("πŸ“Œ Example SMILES to Try"):
449
  st.markdown("""
 
1
+ # Safe monkey patch to fix Streamlit reloader crash due to torch.classes bug
2
+ import types
3
+ import torch
4
+
5
+ try:
6
+ import torch.classes
7
+ if not hasattr(torch.classes, "__path__"):
8
+ torch.classes.__path__ = types.SimpleNamespace(_path=[])
9
+ except Exception:
10
+ pass # Safe fallback if torch.classes doesn't exist
11
+
12
+
13
+
14
+
15
+
16
  # ------------------- Imports -------------------
17
  import streamlit as st
18
  import numpy as np
 
24
  from rdkit import Chem
25
  from rdkit.Chem import rdMolDescriptors
26
  from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
 
 
27
  from torch_geometric.data import Data
28
  from torch_geometric.nn import GCNConv, global_mean_pool
29
  from torch_geometric.loader import DataLoader
 
31
  from rdkit.Chem import Draw
32
  from torch_geometric.data import Batch
33
  from rdkit.Chem import Descriptors
 
34
 
35
 
36
+ import time
37
 
38
 
39
  # ------------------- Models -------------------
 
73
  st.set_page_config(layout="wide", page_title="Drug Toxicity Predictor")
74
  st.title("πŸ§ͺ Drug Toxicity Prediction Dashboard")
75
 
76
+ # ------------------- Load Models with Spinner -------------------
77
+ # ------------------- Load Models with Temporary Messages -------------------
78
+ fp_model = ToxicityNet()
79
+ gcn_model = RichGCNModel()
80
+ fp_loaded = gcn_loaded = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # Load Fingerprint Model
83
+ try:
84
+ fp_model.load_state_dict(torch.load("tox_model.pt", map_location=torch.device("cpu")))
85
+ fp_model.eval()
86
+ fp_loaded = True
87
+ except Exception as e:
88
+ st.warning(f"⚠️ Fingerprint model not loaded: {e}")
 
 
 
 
 
 
 
89
 
90
  # Load GCN Model
91
+ try:
92
+ gcn_model.load_state_dict(torch.load("gcn_model.pt", map_location=torch.device("cpu")))
93
+ gcn_model.eval()
94
+ gcn_loaded = True
95
+ except Exception as e:
96
+ st.warning(f"⚠️ GCN model not loaded: {e}")
97
+
 
 
 
 
 
 
98
 
99
  # Load Best Threshold
100
+ try:
101
+ best_threshold = float(np.load("gcn_best_threshold.npy"))
102
+ except Exception as e:
103
+ best_threshold = 0.5
104
+ st.warning(f"⚠️ Using default threshold (0.5) for GCN model. Reason: {e}")
 
 
 
 
 
 
 
105
 
106
 
107
 
108
+
109
+ # ------------------- Utility Functions -------------------
110
+ fp_gen = GetMorganGenerator(radius=2, fpSize=1024)
111
+
112
  def get_molecule_info(mol):
113
  return {
114
  "Formula": Chem.rdMolDescriptors.CalcMolFormula(mol),
 
201
  def is_valid_graph(smi):
202
  mol = Chem.MolFromSmiles(smi)
203
  return mol is not None and smiles_to_graph(smi) is not None
204
+
205
  df = df[df['smiles'].apply(is_valid_graph)].reset_index(drop=True)
206
 
207
 
208
 
209
+
210
  def create_graph_dataset(smiles_list, labels):
211
  data_list = []
212
  for smi, label in zip(smiles_list, labels):
 
299
  predict_btn = st.form_submit_button("πŸ” Predict")
300
 
301
  if predict_btn:
302
+ mol = Chem.MolFromSmiles(smiles_fp)
303
+ if mol:
304
+ fp = fp_gen.GetFingerprint(mol)
305
+ arr = np.array(fp).reshape(1, -1)
306
+ tensor = torch.tensor(arr).float()
307
+ with torch.no_grad():
308
+ output = fp_model(tensor)
309
+ prob = torch.sigmoid(output).item()
310
+ raw_score = output.item()
311
+ label = "Toxic" if prob > 0.5 else "Non-toxic"
312
+ color = "red" if label == "Toxic" else "green"
313
+
314
+ st.markdown(f"<h4>🧾 Prediction: <span style='color:{color}'>{label}</span> β€” <code>{prob:.3f}</code></h4>", unsafe_allow_html=True)
315
+
316
+ if show_debug_fp:
317
+ st.code(f"πŸ“‰ Raw Logit: {raw_score:.4f}", language='text')
318
+ st.markdown("#### Fingerprint Vector (First 20 bits)")
319
+ st.code(str(arr[0][:20]) + " ...", language="text")
320
+
321
+ st.image(Draw.MolToImage(mol), caption="Molecular Structure", width=250)
322
+
323
+ info = get_molecule_info(mol)
324
+ st.markdown("### Molecule Info:")
325
+ for k, v in info.items():
326
+ st.markdown(f"**{k}:** {v}")
327
+
328
+ st.plotly_chart(plot_distribution(df, 'fp', prob), use_container_width=True)
329
+ else:
330
+ st.error("❌ Invalid SMILES input. Please check your string.")
 
331
 
332
  with st.expander("πŸ“Œ Example SMILES to Try"):
333
  st.markdown("""
 
369
  gcn_btn = st.form_submit_button("πŸ” Predict")
370
 
371
  if gcn_btn:
372
+ mol = Chem.MolFromSmiles(smiles_gcn)
373
+ if mol is None:
374
+ st.error("❌ Invalid SMILES: could not parse molecule.")
375
+ elif not is_supported(mol):
376
+ st.error("⚠️ This molecule contains unsupported atoms (e.g. Sn, P, etc.). GCN model only supports common organic elements.")
377
+ else:
378
+ graph = smiles_to_graph(smiles_gcn)
379
+ if graph is None:
380
+ st.error("❌ SMILES is valid but could not be converted to graph. Possibly malformed structure.")
381
  else:
382
+ batch = Batch.from_data_list([graph])
383
+ with torch.no_grad():
384
+ out = gcn_model(batch)
385
+ prob = torch.sigmoid(out).item()
386
+ raw_score = out.item()
387
+ label = "Toxic" if prob > best_threshold else "Non-toxic"
388
+ color = "red" if label == "Toxic" else "green"
389
+
390
+ st.markdown(f"<h4>🧾 GCN Prediction: <span style='color:{color}'>{label}</span> β€” <code>{prob:.3f}</code></h4>", unsafe_allow_html=True)
391
+
392
+ if show_debug:
393
+ st.code(f"πŸ“‰ Raw Logit: {raw_score:.4f}", language='text')
394
+
395
+ st.image(Draw.MolToImage(mol), caption="Molecular Structure", width=250)
396
+
397
+ def get_molecule_info(mol):
398
+ return {
399
+ "Molecular Weight": round(Chem.Descriptors.MolWt(mol), 2),
400
+ "LogP": round(Chem.Crippen.MolLogP(mol), 2),
401
+ "Num H-Bond Donors": Chem.Lipinski.NumHDonors(mol),
402
+ "Num H-Bond Acceptors": Chem.Lipinski.NumHAcceptors(mol),
403
+ "TPSA": round(Chem.rdMolDescriptors.CalcTPSA(mol), 2),
404
+ "Num Rotatable Bonds": Chem.Lipinski.NumRotatableBonds(mol)
405
+ }
406
+
407
+ info = get_molecule_info(mol)
408
+ st.markdown("### Molecule Info:")
409
+ for k, v in info.items():
410
+ st.markdown(f"**{k}:** {v}")
411
+
412
+ st.plotly_chart(plot_distribution(df, 'gcn', prob), use_container_width=True)
 
 
 
 
413
 
414
  with st.expander("πŸ“Œ Example SMILES to Try"):
415
  st.markdown("""