ShivamKum4r commited on
Commit
6eb1141
Β·
verified Β·
1 Parent(s): 8174384

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +456 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,459 @@
1
- import altair as alt
 
2
  import numpy as np
3
  import pandas as pd
4
- import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ # ------------------- Imports -------------------
2
+ import streamlit as st
3
  import numpy as np
4
  import pandas as pd
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from sklearn.metrics import accuracy_score, roc_auc_score
9
+ from rdkit import Chem
10
+ from rdkit.Chem import rdMolDescriptors
11
+ from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
12
+ from torch_geometric.data import Data
13
+ from torch_geometric.nn import GCNConv, global_mean_pool
14
+ from torch_geometric.loader import DataLoader
15
+ import plotly.express as px
16
+ from rdkit.Chem import Draw
17
+ from torch_geometric.data import Batch
18
+ from rdkit.Chem import Descriptors
19
+
20
+
21
+ import time
22
+
23
+
24
+ # ------------------- Models -------------------
25
+ class ToxicityNet(nn.Module):
26
+ def __init__(self):
27
+ super().__init__()
28
+ self.model = nn.Sequential(
29
+ nn.Linear(1024, 512), nn.ReLU(),
30
+ nn.Dropout(0.3), nn.Linear(512, 128),
31
+ nn.ReLU(), nn.Linear(128, 1)
32
+ )
33
+
34
+ def forward(self, x):
35
+ return self.model(x)
36
+
37
+ class RichGCNModel(nn.Module):
38
+ def __init__(self):
39
+ super().__init__()
40
+ self.conv1 = GCNConv(10, 64)
41
+ self.bn1 = nn.BatchNorm1d(64)
42
+ self.conv2 = GCNConv(64, 128)
43
+ self.bn2 = nn.BatchNorm1d(128)
44
+ self.dropout = nn.Dropout(0.2)
45
+ self.fc1 = nn.Linear(128, 64)
46
+ self.fc2 = nn.Linear(64, 1)
47
+
48
+ def forward(self, data):
49
+ x, edge_index, batch = data.x, data.edge_index, data.batch
50
+ x = F.relu(self.bn1(self.conv1(x, edge_index)))
51
+ x = F.relu(self.bn2(self.conv2(x, edge_index)))
52
+ x = global_mean_pool(x, batch)
53
+ x = self.dropout(x)
54
+ x = F.relu(self.fc1(x))
55
+ return self.fc2(x)
56
+
57
+ # ------------------- UI Setup -------------------
58
+ st.set_page_config(layout="wide", page_title="Drug Toxicity Predictor")
59
+ st.title("πŸ§ͺ Drug Toxicity Prediction Dashboard")
60
+
61
+ # ------------------- Load Models with Spinner -------------------
62
+ # ------------------- Load Models with Temporary Messages -------------------
63
+ fp_model = ToxicityNet()
64
+ gcn_model = RichGCNModel()
65
+ fp_loaded = gcn_loaded = False
66
+
67
+ # Load Fingerprint Model
68
+ msg_fp = st.empty()
69
+ with msg_fp.container():
70
+ with st.spinner("πŸ“¦ Loading fingerprint model..."):
71
+ time.sleep(6)
72
+ try:
73
+ fp_model.load_state_dict(torch.load("tox_model.pt", map_location=torch.device("cpu")))
74
+ fp_model.eval()
75
+ fp_loaded = True
76
+ st.success("βœ… Fingerprint model loaded.")
77
+ except Exception as e:
78
+ st.warning(f"⚠️ Fingerprint model not loaded: {e}")
79
+ time.sleep(1)
80
+ msg_fp.empty()
81
+
82
+ # Load GCN Model
83
+ msg_gcn = st.empty()
84
+ with msg_gcn.container():
85
+ with st.spinner("πŸ“¦ Loading GCN model..."):
86
+ time.sleep(2)
87
+ try:
88
+ gcn_model.load_state_dict(torch.load("gcn_model.pt", map_location=torch.device("cpu")))
89
+ gcn_model.eval()
90
+ gcn_loaded = True
91
+ st.success("βœ… GCN model loaded.")
92
+ except Exception as e:
93
+ st.warning(f"⚠️ GCN model not loaded: {e}")
94
+ time.sleep(1)
95
+ msg_gcn.empty()
96
+
97
+ # Load Best Threshold
98
+ msg_threshold = st.empty()
99
+ with msg_threshold.container():
100
+ with st.spinner("πŸ“Š Loading best threshold..."):
101
+ time.sleep(2)
102
+ try:
103
+ best_threshold = float(np.load("gcn_best_threshold.npy"))
104
+ except Exception as e:
105
+ best_threshold = 0.5
106
+ st.warning(f"⚠️ Using default threshold (0.5) for GCN model. Reason: {e}")
107
+ st.success("βœ… All models loaded. Dashboard is ready!")
108
+ time.sleep(2)
109
+ msg_threshold.empty()
110
+
111
+
112
+
113
+
114
+ # ------------------- Utility Functions -------------------
115
+ fp_gen = GetMorganGenerator(radius=2, fpSize=1024)
116
+
117
+ def get_molecule_info(mol):
118
+ return {
119
+ "Formula": Chem.rdMolDescriptors.CalcMolFormula(mol),
120
+ "Weight": round(Descriptors.MolWt(mol), 2),
121
+ "Atoms": mol.GetNumAtoms(),
122
+ "Bonds": mol.GetNumBonds()
123
+ }
124
+
125
+
126
+
127
+ def predict_gcn(smiles):
128
+ graph = smiles_to_graph(smiles)
129
+ if graph is None:
130
+ return None, None
131
+ batch = Batch.from_data_list([graph])
132
+ with torch.no_grad():
133
+ out = gcn_model(batch)
134
+ prob = torch.sigmoid(out).item()
135
+ return ("Toxic" if prob > best_threshold else "Non-toxic"), prob
136
+
137
+
138
+ def atom_feats(atom):
139
+ return [
140
+ atom.GetAtomicNum(),
141
+ atom.GetDegree(),
142
+ atom.GetFormalCharge(),
143
+ atom.GetNumExplicitHs(),
144
+ atom.GetNumImplicitHs(),
145
+ atom.GetIsAromatic(),
146
+ atom.GetMass(),
147
+ int(atom.IsInRing()),
148
+ int(atom.GetChiralTag()),
149
+ int(atom.GetHybridization())
150
+ ]
151
+
152
+ def smiles_to_graph(smiles, label=None):
153
+ mol = Chem.MolFromSmiles(smiles)
154
+ if mol is None or mol.GetNumAtoms() == 0:
155
+ return None
156
+
157
+ atoms = [atom_feats(a) for a in mol.GetAtoms()]
158
+ if not atoms:
159
+ return None # No atoms present
160
+
161
+ edges = []
162
+ for b in mol.GetBonds():
163
+ i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
164
+ edges += [[i, j], [j, i]]
165
+
166
+ # Handle molecules with no bonds (e.g. single atom)
167
+ if len(edges) == 0:
168
+ edges = [[0, 0]]
169
+
170
+ edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
171
+ x = torch.tensor(atoms, dtype=torch.float)
172
+ batch = torch.zeros(x.size(0), dtype=torch.long)
173
+
174
+ data = Data(x=x, edge_index=edge_index, batch=batch)
175
+ if label is not None:
176
+ data.y = torch.tensor([label], dtype=torch.float)
177
+ return data
178
+
179
+
180
+ # def predict_gcn(smiles):
181
+ # graph = smiles_to_graph(smiles)
182
+ # if graph is None or graph.x.size(0) == 0:
183
+ # return None, None
184
+ # batch = Batch.from_data_list([graph])
185
+ # with torch.no_grad():
186
+ # out = gcn_model(batch)
187
+ # raw = out.item()
188
+ # prob = torch.sigmoid(out).item()
189
+ # print(f"Raw logit: {raw:.4f}, Prob: {prob:.4f}")
190
+ # return ("Toxic" if prob > best_threshold else "Non-toxic"), prob
191
+
192
+
193
+
194
+ # ------------------- Load Dataset -------------------
195
+ # df = pd.read_csv("tox21.csv")[['smiles', 'SR-HSE']].dropna()
196
+ # df = df[df['SR-HSE'].isin([0, 1])]
197
+
198
+ # # 🧼 Filter out invalid SMILES
199
+ # df['mol'] = df['smiles'].apply(Chem.MolFromSmiles)
200
+ # df = df[df['mol'].notna()].reset_index(drop=True)
201
+
202
+ df = pd.read_csv("tox21.csv")[['smiles', 'SR-HSE']].dropna()
203
+ df = df[df['SR-HSE'].isin([0, 1])].reset_index(drop=True)
204
+
205
+ # βœ… Filter invalid or unprocessable SMILES
206
+ def is_valid_graph(smi):
207
+ mol = Chem.MolFromSmiles(smi)
208
+ return mol is not None and smiles_to_graph(smi) is not None
209
+
210
+ df = df[df['smiles'].apply(is_valid_graph)].reset_index(drop=True)
211
+
212
+
213
+
214
+
215
+ def create_graph_dataset(smiles_list, labels):
216
+ data_list = []
217
+ for smi, label in zip(smiles_list, labels):
218
+ data = smiles_to_graph(smi, label)
219
+ if data:
220
+ data_list.append(data)
221
+ return data_list
222
+
223
+ graph_data = create_graph_dataset(df['smiles'], df['SR-HSE'])
224
+ test_loader = DataLoader(graph_data, batch_size=32)
225
+
226
+ # ------------------- Plot Function -------------------
227
+ def plot_distribution(df, model_type, input_prob=None):
228
+ col = 'fp_prob' if model_type == 'fp' else 'gcn_prob'
229
+ df_plot = df[df[col].notna()].copy()
230
+ df_plot["Label"] = df_plot["SR-HSE"].map({0: "Non-toxic", 1: "Toxic"})
231
+ fig = px.histogram(df_plot, x=col, color="Label", nbins=30, barmode="overlay",
232
+ color_discrete_map={"Non-toxic": "green", "Toxic": "red"},
233
+ title=f"{model_type.upper()} Model - Test Set Distribution")
234
+ if input_prob:
235
+ fig.add_vline(x=input_prob, line_dash="dash", line_color="yellow", annotation_text="Your Input")
236
+ return fig
237
+
238
+ # ------------------- Prediction Cache -------------------
239
+ @st.cache_data(show_spinner="Generating predictions...")
240
+
241
+ def predict_fp(smiles):
242
+ try:
243
+ mol = Chem.MolFromSmiles(smiles)
244
+ if mol is None:
245
+ return "Invalid SMILES", 0.0
246
+ fp = fp_gen.GetFingerprint(mol)
247
+ fp_array = np.array(fp).reshape(1, -1)
248
+ with torch.no_grad():
249
+ logits = fp_model(torch.tensor(fp_array).float())
250
+ prob = torch.sigmoid(logits).item()
251
+ return ("Toxic" if prob > 0.5 else "Non-toxic"), prob
252
+ except Exception as e:
253
+ return f"Error: {str(e)}", 0.0
254
+
255
+ def get_predictions(model_type='fp'):
256
+ preds = []
257
+ for smi in df['smiles']:
258
+ try:
259
+ p = predict_fp(smi)[1] if model_type == 'fp' else predict_gcn(smi)[1]
260
+ preds.append(p)
261
+ except:
262
+ preds.append(None)
263
+ return preds
264
+
265
+ df['fp_prob'] = get_predictions('fp') if fp_loaded else None
266
+ df['gcn_prob'] = get_predictions('gcn') if gcn_loaded else None
267
+
268
+ # ------------------- Evaluation Function -------------------
269
+ def evaluate_gcn_test_set(model, test_loader):
270
+ model.eval()
271
+ all_preds, all_labels = [], []
272
+ with torch.no_grad():
273
+ for batch in test_loader:
274
+ batch = batch.to("cpu") # Ensure on CPU
275
+ out = model(batch)
276
+ probs = torch.sigmoid(out)
277
+ all_preds.extend(probs.cpu().numpy())
278
+ all_labels.extend(batch.y.cpu().numpy())
279
+ acc = accuracy_score(all_labels, (np.array(all_preds) > 0.5).astype(int))
280
+ roc = roc_auc_score(all_labels, all_preds)
281
+
282
+ df_eval = pd.DataFrame({
283
+ "Predicted Probability": all_preds,
284
+ "Label": ["Non-toxic" if i == 0 else "Toxic" for i in all_labels]
285
+ })
286
+
287
+ fig = px.histogram(df_eval, x="Predicted Probability", color="Label",
288
+ nbins=30, barmode="overlay",
289
+ color_discrete_map={"Non-toxic": "green", "Toxic": "red"},
290
+ title="GCN Test Set - Probability Distribution")
291
+ fig.update_layout(bargap=0.1)
292
+
293
+ st.success(f"βœ… Accuracy: `{acc:.4f}`, ROC-AUC: `{roc:.4f}`")
294
+ st.plotly_chart(fig, use_container_width=True)
295
+
296
+ # ------------------- Tabs -------------------
297
+ tab1, tab2 = st.tabs(["πŸ”¬ Fingerprint Model", "🧬 GCN Model"])
298
+
299
+ with tab1:
300
+ st.subheader("Fingerprint-based Prediction")
301
+ with st.form("fp_form"):
302
+ smiles_fp = st.text_input("Enter SMILES", "CCO")
303
+ show_debug_fp = st.checkbox("🐞 Show Debug Info (raw score/logit)", key="fp_debug")
304
+ predict_btn = st.form_submit_button("πŸ” Predict")
305
+
306
+ if predict_btn:
307
+ with st.spinner("Predicting..."):
308
+ mol = Chem.MolFromSmiles(smiles_fp)
309
+ if mol:
310
+ fp = fp_gen.GetFingerprint(mol)
311
+ arr = np.array(fp).reshape(1, -1)
312
+ tensor = torch.tensor(arr).float()
313
+ with torch.no_grad():
314
+ output = fp_model(tensor)
315
+ prob = torch.sigmoid(output).item()
316
+ raw_score = output.item()
317
+ label = "Toxic" if prob > 0.5 else "Non-toxic"
318
+ color = "red" if label == "Toxic" else "green"
319
+
320
+ st.markdown(f"<h4>🧾 Prediction: <span style='color:{color}'>{label}</span> β€” <code>{prob:.3f}</code></h4>", unsafe_allow_html=True)
321
+
322
+ if show_debug_fp:
323
+ st.code(f"πŸ“‰ Raw Logit: {raw_score:.4f}", language='text')
324
+ st.markdown("#### Fingerprint Vector (First 20 bits)")
325
+ st.code(str(arr[0][:20]) + " ...", language="text")
326
+
327
+ st.image(Draw.MolToImage(mol), caption="Molecular Structure", width=250)
328
+
329
+ info = get_molecule_info(mol)
330
+ st.markdown("### Molecule Info:")
331
+ for k, v in info.items():
332
+ st.markdown(f"**{k}:** {v}")
333
+
334
+ st.plotly_chart(plot_distribution(df, 'fp', prob), use_container_width=True)
335
+ else:
336
+ st.error("❌ Invalid SMILES input. Please check your string.")
337
+
338
+ with st.expander("πŸ“Œ Example SMILES to Try"):
339
+ st.markdown("""
340
+ - `CCO` (Ethanol)
341
+ - `CC(=O)O` (Acetic Acid)
342
+ - `c1ccccc1` (Benzene)
343
+ - `CCN(CC)CC` (Triethylamine)
344
+ - `C1=CC=CN=C1` (Pyridine)
345
+ """)
346
+
347
+ with st.expander("πŸ§ͺ Top 5 Toxic Predictions from Test Set (Fingerprint Model)"):
348
+ if 'fp_prob' in df:
349
+ top_toxic_fp = df[df['fp_prob'] > 0.5].sort_values('fp_prob', ascending=False)
350
+
351
+ def is_valid_fp(smi):
352
+ return Chem.MolFromSmiles(smi) is not None
353
+
354
+ top_toxic_fp = top_toxic_fp[top_toxic_fp['smiles'].apply(is_valid_fp)].head(5)
355
+
356
+ if not top_toxic_fp.empty:
357
+ st.table(top_toxic_fp[['smiles', 'fp_prob']].rename(columns={'fp_prob': 'Predicted Probability'}))
358
+ else:
359
+ st.info("No valid top fingerprint predictions available.")
360
+ else:
361
+ st.info("Fingerprint model predictions not available.")
362
+
363
+
364
+ with tab2:
365
+ st.subheader("Graph Neural Network Prediction")
366
+
367
+ SUPPORTED_ATOMS = {1, 6, 7, 8, 9, 16, 17, 35, 53} # H, C, N, O, F, S, Cl, Br, I
368
+
369
+ def is_supported(mol):
370
+ return all(atom.GetAtomicNum() in SUPPORTED_ATOMS for atom in mol.GetAtoms())
371
+
372
+ with st.form("gcn_form"):
373
+ smiles_gcn = st.text_input("Enter SMILES", "c1ccccc1", key="gcn_smiles")
374
+ show_debug = st.checkbox("🐞 Show Debug Info (raw score/logit)")
375
+ gcn_btn = st.form_submit_button("πŸ” Predict")
376
+
377
+ if gcn_btn:
378
+ with st.spinner("Predicting..."):
379
+ mol = Chem.MolFromSmiles(smiles_gcn)
380
+
381
+ if mol is None:
382
+ st.error("❌ Invalid SMILES: could not parse molecule.")
383
+ elif not is_supported(mol):
384
+ st.error("⚠️ This molecule contains unsupported atoms (e.g. Sn, P, etc.). GCN model only supports common organic elements.")
385
+ else:
386
+ graph = smiles_to_graph(smiles_gcn)
387
+ if graph is None:
388
+ st.error("❌ SMILES is valid but could not be converted to graph. Possibly malformed structure.")
389
+ else:
390
+ batch = Batch.from_data_list([graph])
391
+ with torch.no_grad():
392
+ out = gcn_model(batch)
393
+ prob = torch.sigmoid(out).item()
394
+ raw_score = out.item()
395
+ label = "Toxic" if prob > best_threshold else "Non-toxic"
396
+ color = "red" if label == "Toxic" else "green"
397
+
398
+ st.markdown(f"<h4>🧾 GCN Prediction: <span style='color:{color}'>{label}</span> β€” <code>{prob:.3f}</code></h4>", unsafe_allow_html=True)
399
+
400
+ if show_debug:
401
+ st.code(f"πŸ“‰ Raw Logit: {raw_score:.4f}", language='text')
402
+
403
+ st.image(Draw.MolToImage(mol), caption="Molecular Structure", width=250)
404
+
405
+ def get_molecule_info(mol):
406
+ return {
407
+ "Molecular Weight": round(Chem.Descriptors.MolWt(mol), 2),
408
+ "LogP": round(Chem.Crippen.MolLogP(mol), 2),
409
+ "Num H-Bond Donors": Chem.Lipinski.NumHDonors(mol),
410
+ "Num H-Bond Acceptors": Chem.Lipinski.NumHAcceptors(mol),
411
+ "TPSA": round(Chem.rdMolDescriptors.CalcTPSA(mol), 2),
412
+ "Num Rotatable Bonds": Chem.Lipinski.NumRotatableBonds(mol)
413
+ }
414
+
415
+ info = get_molecule_info(mol)
416
+ st.markdown("### Molecule Info:")
417
+ for k, v in info.items():
418
+ st.markdown(f"**{k}:** {v}")
419
+
420
+ st.plotly_chart(plot_distribution(df, 'gcn', prob), use_container_width=True)
421
+
422
+ with st.expander("πŸ“Œ Example SMILES to Try"):
423
+ st.markdown("""
424
+ - `c1ccccc1` (Benzene)
425
+ - `C1=CC=CC=C1O` (Phenol)
426
+ - `CC(=O)OC1=CC=CC=C1C(=O)O` (Aspirin)
427
+ - `NCC(O)=O` (Glycine)
428
+ - `C1CCC(CC1)NC(=O)C2=CC=CC=C2` (Cyclohexylbenzamide)
429
+ """)
430
+
431
+ with st.expander("πŸ“₯ Download GCN Model Predictions"):
432
+ if 'gcn_prob' in df:
433
+ def is_valid_gcn(smi):
434
+ mol = Chem.MolFromSmiles(smi)
435
+ return mol is not None and is_supported(mol) and smiles_to_graph(smi) is not None
436
+
437
+ df_valid = df[df['smiles'].apply(is_valid_gcn)].copy()
438
+ csv_gcn = df_valid[['smiles', 'gcn_prob', 'SR-HSE']].dropna().to_csv(index=False)
439
+ st.download_button("Download CSV", csv_gcn, "gcn_predictions.csv", "text/csv")
440
+ else:
441
+ st.info("Predictions not available yet.")
442
+
443
+ with st.expander("πŸ§ͺ Top 5 Toxic Predictions from Test Set"):
444
+ if 'gcn_prob' in df:
445
+ def is_valid_gcn(smi):
446
+ mol = Chem.MolFromSmiles(smi)
447
+ return mol is not None and is_supported(mol) and smiles_to_graph(smi) is not None
448
+
449
+ top_toxic = df[df['gcn_prob'] > best_threshold].copy()
450
+ top_toxic = top_toxic[top_toxic['smiles'].apply(is_valid_gcn)]
451
+ top_toxic = top_toxic.sort_values('gcn_prob', ascending=False).head(5)
452
+
453
+ if not top_toxic.empty:
454
+ st.table(top_toxic[['smiles', 'gcn_prob']].rename(columns={'gcn_prob': 'Predicted Probability'}))
455
+ else:
456
+ st.info("No valid top predictions available.")
457
+ else:
458
+ st.info("GCN model predictions not available.")
459