ShivamKum4r commited on
Commit
227ad73
Β·
verified Β·
1 Parent(s): db6c0b7

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +100 -91
src/streamlit_app.py CHANGED
@@ -16,9 +16,9 @@ import plotly.express as px
16
  from rdkit.Chem import Draw
17
  from torch_geometric.data import Batch
18
  from rdkit.Chem import Descriptors
19
-
20
-
21
  import time
 
 
22
 
23
 
24
  # ------------------- Models -------------------
@@ -114,6 +114,16 @@ with msg_threshold.container():
114
  # ------------------- Utility Functions -------------------
115
  fp_gen = GetMorganGenerator(radius=2, fpSize=1024)
116
 
 
 
 
 
 
 
 
 
 
 
117
  def get_molecule_info(mol):
118
  return {
119
  "Formula": Chem.rdMolDescriptors.CalcMolFormula(mol),
@@ -122,8 +132,6 @@ def get_molecule_info(mol):
122
  "Bonds": mol.GetNumBonds()
123
  }
124
 
125
-
126
-
127
  def predict_gcn(smiles):
128
  graph = smiles_to_graph(smiles)
129
  if graph is None:
@@ -177,6 +185,8 @@ def smiles_to_graph(smiles, label=None):
177
  return data
178
 
179
 
 
 
180
  # def predict_gcn(smiles):
181
  # graph = smiles_to_graph(smiles)
182
  # if graph is None or graph.x.size(0) == 0:
@@ -199,26 +209,19 @@ def smiles_to_graph(smiles, label=None):
199
  # df['mol'] = df['smiles'].apply(Chem.MolFromSmiles)
200
  # df = df[df['mol'].notna()].reset_index(drop=True)
201
 
202
- df = pd.read_csv("tox21.csv")[['smiles', 'SR-HSE']].dropna()
203
  df = df[df['SR-HSE'].isin([0, 1])].reset_index(drop=True)
204
 
205
  # βœ… Filter invalid or unprocessable SMILES
206
  def is_valid_graph(smi):
207
- mol = Chem.MolFromSmiles(smi)
208
- return mol is not None and smiles_to_graph(smi) is not None
209
 
210
  df = df[df['smiles'].apply(is_valid_graph)].reset_index(drop=True)
211
 
212
 
213
 
214
-
215
  def create_graph_dataset(smiles_list, labels):
216
- data_list = []
217
- for smi, label in zip(smiles_list, labels):
218
- data = smiles_to_graph(smi, label)
219
- if data:
220
- data_list.append(data)
221
- return data_list
222
 
223
  graph_data = create_graph_dataset(df['smiles'], df['SR-HSE'])
224
  test_loader = DataLoader(graph_data, batch_size=32)
@@ -237,16 +240,15 @@ def plot_distribution(df, model_type, input_prob=None):
237
 
238
  # ------------------- Prediction Cache -------------------
239
  @st.cache_data(show_spinner="Generating predictions...")
240
-
241
  def predict_fp(smiles):
242
  try:
243
- mol = Chem.MolFromSmiles(smiles)
244
- if mol is None:
245
  return "Invalid SMILES", 0.0
 
246
  fp = fp_gen.GetFingerprint(mol)
247
- fp_array = np.array(fp).reshape(1, -1)
248
  with torch.no_grad():
249
- logits = fp_model(torch.tensor(fp_array).float())
250
  prob = torch.sigmoid(logits).item()
251
  return ("Toxic" if prob > 0.5 else "Non-toxic"), prob
252
  except Exception as e:
@@ -298,6 +300,7 @@ tab1, tab2 = st.tabs(["πŸ”¬ Fingerprint Model", "🧬 GCN Model"])
298
 
299
  with tab1:
300
  st.subheader("Fingerprint-based Prediction")
 
301
  with st.form("fp_form"):
302
  smiles_fp = st.text_input("Enter SMILES", "CCO")
303
  show_debug_fp = st.checkbox("🐞 Show Debug Info (raw score/logit)", key="fp_debug")
@@ -305,35 +308,42 @@ with tab1:
305
 
306
  if predict_btn:
307
  with st.spinner("Predicting..."):
308
- mol = Chem.MolFromSmiles(smiles_fp)
309
- if mol:
310
- fp = fp_gen.GetFingerprint(mol)
311
- arr = np.array(fp).reshape(1, -1)
312
- tensor = torch.tensor(arr).float()
313
- with torch.no_grad():
314
- output = fp_model(tensor)
315
- prob = torch.sigmoid(output).item()
316
- raw_score = output.item()
317
- label = "Toxic" if prob > 0.5 else "Non-toxic"
318
- color = "red" if label == "Toxic" else "green"
319
-
320
- st.markdown(f"<h4>🧾 Prediction: <span style='color:{color}'>{label}</span> β€” <code>{prob:.3f}</code></h4>", unsafe_allow_html=True)
321
-
322
- if show_debug_fp:
323
- st.code(f"πŸ“‰ Raw Logit: {raw_score:.4f}", language='text')
324
- st.markdown("#### Fingerprint Vector (First 20 bits)")
325
- st.code(str(arr[0][:20]) + " ...", language="text")
326
-
327
- st.image(Draw.MolToImage(mol), caption="Molecular Structure", width=250)
328
-
329
- info = get_molecule_info(mol)
330
- st.markdown("### Molecule Info:")
331
- for k, v in info.items():
332
- st.markdown(f"**{k}:** {v}")
333
-
334
- st.plotly_chart(plot_distribution(df, 'fp', prob), use_container_width=True)
335
- else:
336
  st.error("❌ Invalid SMILES input. Please check your string.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
  with st.expander("πŸ“Œ Example SMILES to Try"):
339
  st.markdown("""
@@ -360,12 +370,12 @@ with tab1:
360
  else:
361
  st.info("Fingerprint model predictions not available.")
362
 
363
-
364
  with tab2:
365
  st.subheader("Graph Neural Network Prediction")
366
 
367
  SUPPORTED_ATOMS = {1, 6, 7, 8, 9, 16, 17, 35, 53} # H, C, N, O, F, S, Cl, Br, I
368
 
 
369
  def is_supported(mol):
370
  return all(atom.GetAtomicNum() in SUPPORTED_ATOMS for atom in mol.GetAtoms())
371
 
@@ -376,48 +386,48 @@ with tab2:
376
 
377
  if gcn_btn:
378
  with st.spinner("Predicting..."):
379
- mol = Chem.MolFromSmiles(smiles_gcn)
380
-
381
- if mol is None:
382
  st.error("❌ Invalid SMILES: could not parse molecule.")
383
- elif not is_supported(mol):
384
- st.error("⚠️ This molecule contains unsupported atoms (e.g. Sn, P, etc.). GCN model only supports common organic elements.")
385
  else:
386
- graph = smiles_to_graph(smiles_gcn)
387
- if graph is None:
388
- st.error("❌ SMILES is valid but could not be converted to graph. Possibly malformed structure.")
389
  else:
390
- batch = Batch.from_data_list([graph])
391
- with torch.no_grad():
392
- out = gcn_model(batch)
393
- prob = torch.sigmoid(out).item()
394
- raw_score = out.item()
395
- label = "Toxic" if prob > best_threshold else "Non-toxic"
396
- color = "red" if label == "Toxic" else "green"
397
-
398
- st.markdown(f"<h4>🧾 GCN Prediction: <span style='color:{color}'>{label}</span> β€” <code>{prob:.3f}</code></h4>", unsafe_allow_html=True)
399
-
400
- if show_debug:
401
- st.code(f"πŸ“‰ Raw Logit: {raw_score:.4f}", language='text')
402
-
403
- st.image(Draw.MolToImage(mol), caption="Molecular Structure", width=250)
404
-
405
- def get_molecule_info(mol):
406
- return {
407
- "Molecular Weight": round(Chem.Descriptors.MolWt(mol), 2),
408
- "LogP": round(Chem.Crippen.MolLogP(mol), 2),
409
- "Num H-Bond Donors": Chem.Lipinski.NumHDonors(mol),
410
- "Num H-Bond Acceptors": Chem.Lipinski.NumHAcceptors(mol),
411
- "TPSA": round(Chem.rdMolDescriptors.CalcTPSA(mol), 2),
412
- "Num Rotatable Bonds": Chem.Lipinski.NumRotatableBonds(mol)
413
- }
414
-
415
- info = get_molecule_info(mol)
416
- st.markdown("### Molecule Info:")
417
- for k, v in info.items():
418
- st.markdown(f"**{k}:** {v}")
419
-
420
- st.plotly_chart(plot_distribution(df, 'gcn', prob), use_container_width=True)
 
 
 
 
421
 
422
  with st.expander("πŸ“Œ Example SMILES to Try"):
423
  st.markdown("""
@@ -441,19 +451,18 @@ with tab2:
441
  st.info("Predictions not available yet.")
442
 
443
  with st.expander("πŸ§ͺ Top 5 Toxic Predictions from Test Set"):
444
- if 'gcn_prob' in df:
445
  def is_valid_gcn(smi):
446
  mol = Chem.MolFromSmiles(smi)
447
  return mol is not None and is_supported(mol) and smiles_to_graph(smi) is not None
448
 
449
  top_toxic = df[df['gcn_prob'] > best_threshold].copy()
450
  top_toxic = top_toxic[top_toxic['smiles'].apply(is_valid_gcn)]
451
- top_toxic = top_toxic.sort_values('gcn_prob', ascending=False).head(5)
452
 
453
- if not top_toxic.empty:
 
454
  st.table(top_toxic[['smiles', 'gcn_prob']].rename(columns={'gcn_prob': 'Predicted Probability'}))
455
  else:
456
  st.info("No valid top predictions available.")
457
  else:
458
  st.info("GCN model predictions not available.")
459
-
 
16
  from rdkit.Chem import Draw
17
  from torch_geometric.data import Batch
18
  from rdkit.Chem import Descriptors
 
 
19
  import time
20
+ from rdkit import RDLogger
21
+ RDLogger.DisableLog('rdApp.*')
22
 
23
 
24
  # ------------------- Models -------------------
 
114
  # ------------------- Utility Functions -------------------
115
  fp_gen = GetMorganGenerator(radius=2, fpSize=1024)
116
 
117
+ def is_valid_smiles(smiles):
118
+ try:
119
+ mol = Chem.MolFromSmiles(smiles)
120
+ if mol is None or mol.GetNumAtoms() == 0:
121
+ return False
122
+ Chem.SanitizeMol(mol) # Force check for chemical correctness
123
+ return True
124
+ except:
125
+ return False
126
+
127
  def get_molecule_info(mol):
128
  return {
129
  "Formula": Chem.rdMolDescriptors.CalcMolFormula(mol),
 
132
  "Bonds": mol.GetNumBonds()
133
  }
134
 
 
 
135
  def predict_gcn(smiles):
136
  graph = smiles_to_graph(smiles)
137
  if graph is None:
 
185
  return data
186
 
187
 
188
+
189
+
190
  # def predict_gcn(smiles):
191
  # graph = smiles_to_graph(smiles)
192
  # if graph is None or graph.x.size(0) == 0:
 
209
  # df['mol'] = df['smiles'].apply(Chem.MolFromSmiles)
210
  # df = df[df['mol'].notna()].reset_index(drop=True)
211
 
212
+ df = pd.read_csv("tox21.csv")[['smiles', 'SR-HSE']].dropna() #src changes
213
  df = df[df['SR-HSE'].isin([0, 1])].reset_index(drop=True)
214
 
215
  # βœ… Filter invalid or unprocessable SMILES
216
  def is_valid_graph(smi):
217
+ return is_valid_smiles(smi) and smiles_to_graph(smi) is not None
 
218
 
219
  df = df[df['smiles'].apply(is_valid_graph)].reset_index(drop=True)
220
 
221
 
222
 
 
223
  def create_graph_dataset(smiles_list, labels):
224
+ return [smiles_to_graph(smi, label) for smi, label in zip(smiles_list, labels) if smiles_to_graph(smi, label)]
 
 
 
 
 
225
 
226
  graph_data = create_graph_dataset(df['smiles'], df['SR-HSE'])
227
  test_loader = DataLoader(graph_data, batch_size=32)
 
240
 
241
  # ------------------- Prediction Cache -------------------
242
  @st.cache_data(show_spinner="Generating predictions...")
 
243
  def predict_fp(smiles):
244
  try:
245
+ if not is_valid_smiles(smiles):
 
246
  return "Invalid SMILES", 0.0
247
+ mol = Chem.MolFromSmiles(smiles)
248
  fp = fp_gen.GetFingerprint(mol)
249
+ arr = np.array(fp).reshape(1, -1)
250
  with torch.no_grad():
251
+ logits = fp_model(torch.tensor(arr).float())
252
  prob = torch.sigmoid(logits).item()
253
  return ("Toxic" if prob > 0.5 else "Non-toxic"), prob
254
  except Exception as e:
 
300
 
301
  with tab1:
302
  st.subheader("Fingerprint-based Prediction")
303
+
304
  with st.form("fp_form"):
305
  smiles_fp = st.text_input("Enter SMILES", "CCO")
306
  show_debug_fp = st.checkbox("🐞 Show Debug Info (raw score/logit)", key="fp_debug")
 
308
 
309
  if predict_btn:
310
  with st.spinner("Predicting..."):
311
+ if not is_valid_smiles(smiles_fp):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  st.error("❌ Invalid SMILES input. Please check your string.")
313
+ else:
314
+ try:
315
+ mol = Chem.MolFromSmiles(smiles_fp)
316
+ fp = fp_gen.GetFingerprint(mol)
317
+ arr = np.array(fp).reshape(1, -1)
318
+ tensor = torch.tensor(arr).float().to("cpu")
319
+ fp_model.to("cpu") # Ensure model is on CPU
320
+
321
+ with torch.no_grad():
322
+ output = fp_model(tensor)
323
+ prob = torch.sigmoid(output).item()
324
+ raw_score = output.item()
325
+ label = "Toxic" if prob > 0.5 else "Non-toxic"
326
+ color = "red" if label == "Toxic" else "green"
327
+
328
+ st.markdown(f"<h4>🧾 Prediction: <span style='color:{color}'>{label}</span> β€” <code>{prob:.3f}</code></h4>", unsafe_allow_html=True)
329
+
330
+ if show_debug_fp:
331
+ st.code(f"πŸ“‰ Raw Logit: {raw_score:.4f}", language='text')
332
+ st.markdown("#### Fingerprint Vector (First 20 bits)")
333
+ st.code(str(arr[0][:20]) + " ...", language="text")
334
+
335
+ st.image(Draw.MolToImage(mol), caption="Molecular Structure", width=250)
336
+
337
+ info = get_molecule_info(mol)
338
+ st.markdown("### Molecule Info:")
339
+ for k, v in info.items():
340
+ st.markdown(f"**{k}:** {v}")
341
+
342
+ st.plotly_chart(plot_distribution(df, 'fp', prob), use_container_width=True)
343
+
344
+ except Exception as e:
345
+ st.error(f"Prediction error: {str(e)}")
346
+
347
 
348
  with st.expander("πŸ“Œ Example SMILES to Try"):
349
  st.markdown("""
 
370
  else:
371
  st.info("Fingerprint model predictions not available.")
372
 
 
373
  with tab2:
374
  st.subheader("Graph Neural Network Prediction")
375
 
376
  SUPPORTED_ATOMS = {1, 6, 7, 8, 9, 16, 17, 35, 53} # H, C, N, O, F, S, Cl, Br, I
377
 
378
+
379
  def is_supported(mol):
380
  return all(atom.GetAtomicNum() in SUPPORTED_ATOMS for atom in mol.GetAtoms())
381
 
 
386
 
387
  if gcn_btn:
388
  with st.spinner("Predicting..."):
389
+ if not is_valid_smiles(smiles_gcn):
 
 
390
  st.error("❌ Invalid SMILES: could not parse molecule.")
 
 
391
  else:
392
+ mol = Chem.MolFromSmiles(smiles_gcn)
393
+ if not is_supported(mol):
394
+ st.error("⚠️ This molecule contains unsupported atoms (e.g. Sn, P, etc.). GCN model only supports common organic elements.")
395
  else:
396
+ graph = smiles_to_graph(smiles_gcn)
397
+ if graph is None:
398
+ st.error("❌ SMILES is valid but could not be converted to graph. Possibly malformed structure.")
399
+ else:
400
+ batch = Batch.from_data_list([graph])
401
+ with torch.no_grad():
402
+ out = gcn_model(batch)
403
+ prob = torch.sigmoid(out).item()
404
+ raw_score = out.item()
405
+ label = "Toxic" if prob > best_threshold else "Non-toxic"
406
+ color = "red" if label == "Toxic" else "green"
407
+
408
+ st.markdown(f"<h4>🧾 GCN Prediction: <span style='color:{color}'>{label}</span> β€” <code>{prob:.3f}</code></h4>", unsafe_allow_html=True)
409
+
410
+ if show_debug:
411
+ st.code(f"πŸ“‰ Raw Logit: {raw_score:.4f}", language='text')
412
+
413
+ st.image(Draw.MolToImage(mol), caption="Molecular Structure", width=250)
414
+
415
+ def get_molecule_info(mol):
416
+ return {
417
+ "Molecular Weight": round(Chem.Descriptors.MolWt(mol), 2),
418
+ "LogP": round(Chem.Crippen.MolLogP(mol), 2),
419
+ "Num H-Bond Donors": Chem.Lipinski.NumHDonors(mol),
420
+ "Num H-Bond Acceptors": Chem.Lipinski.NumHAcceptors(mol),
421
+ "TPSA": round(Chem.rdMolDescriptors.CalcTPSA(mol), 2),
422
+ "Num Rotatable Bonds": Chem.Lipinski.NumRotatableBonds(mol)
423
+ }
424
+
425
+ info = get_molecule_info(mol)
426
+ st.markdown("### Molecule Info:")
427
+ for k, v in info.items():
428
+ st.markdown(f"**{k}:** {v}")
429
+
430
+ st.plotly_chart(plot_distribution(df, 'gcn', prob), use_container_width=True)
431
 
432
  with st.expander("πŸ“Œ Example SMILES to Try"):
433
  st.markdown("""
 
451
  st.info("Predictions not available yet.")
452
 
453
  with st.expander("πŸ§ͺ Top 5 Toxic Predictions from Test Set"):
454
+ if 'gcn_prob' in df.columns:
455
  def is_valid_gcn(smi):
456
  mol = Chem.MolFromSmiles(smi)
457
  return mol is not None and is_supported(mol) and smiles_to_graph(smi) is not None
458
 
459
  top_toxic = df[df['gcn_prob'] > best_threshold].copy()
460
  top_toxic = top_toxic[top_toxic['smiles'].apply(is_valid_gcn)]
 
461
 
462
+ if not top_toxic.empty and 'gcn_prob' in top_toxic.columns:
463
+ top_toxic = top_toxic.sort_values('gcn_prob', ascending=False).head(5)
464
  st.table(top_toxic[['smiles', 'gcn_prob']].rename(columns={'gcn_prob': 'Predicted Probability'}))
465
  else:
466
  st.info("No valid top predictions available.")
467
  else:
468
  st.info("GCN model predictions not available.")