Update src/streamlit_app.py
Browse files- src/streamlit_app.py +100 -91
src/streamlit_app.py
CHANGED
@@ -16,9 +16,9 @@ import plotly.express as px
|
|
16 |
from rdkit.Chem import Draw
|
17 |
from torch_geometric.data import Batch
|
18 |
from rdkit.Chem import Descriptors
|
19 |
-
|
20 |
-
|
21 |
import time
|
|
|
|
|
22 |
|
23 |
|
24 |
# ------------------- Models -------------------
|
@@ -114,6 +114,16 @@ with msg_threshold.container():
|
|
114 |
# ------------------- Utility Functions -------------------
|
115 |
fp_gen = GetMorganGenerator(radius=2, fpSize=1024)
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
def get_molecule_info(mol):
|
118 |
return {
|
119 |
"Formula": Chem.rdMolDescriptors.CalcMolFormula(mol),
|
@@ -122,8 +132,6 @@ def get_molecule_info(mol):
|
|
122 |
"Bonds": mol.GetNumBonds()
|
123 |
}
|
124 |
|
125 |
-
|
126 |
-
|
127 |
def predict_gcn(smiles):
|
128 |
graph = smiles_to_graph(smiles)
|
129 |
if graph is None:
|
@@ -177,6 +185,8 @@ def smiles_to_graph(smiles, label=None):
|
|
177 |
return data
|
178 |
|
179 |
|
|
|
|
|
180 |
# def predict_gcn(smiles):
|
181 |
# graph = smiles_to_graph(smiles)
|
182 |
# if graph is None or graph.x.size(0) == 0:
|
@@ -199,26 +209,19 @@ def smiles_to_graph(smiles, label=None):
|
|
199 |
# df['mol'] = df['smiles'].apply(Chem.MolFromSmiles)
|
200 |
# df = df[df['mol'].notna()].reset_index(drop=True)
|
201 |
|
202 |
-
df = pd.read_csv("tox21.csv")[['smiles', 'SR-HSE']].dropna()
|
203 |
df = df[df['SR-HSE'].isin([0, 1])].reset_index(drop=True)
|
204 |
|
205 |
# β
Filter invalid or unprocessable SMILES
|
206 |
def is_valid_graph(smi):
|
207 |
-
|
208 |
-
return mol is not None and smiles_to_graph(smi) is not None
|
209 |
|
210 |
df = df[df['smiles'].apply(is_valid_graph)].reset_index(drop=True)
|
211 |
|
212 |
|
213 |
|
214 |
-
|
215 |
def create_graph_dataset(smiles_list, labels):
|
216 |
-
|
217 |
-
for smi, label in zip(smiles_list, labels):
|
218 |
-
data = smiles_to_graph(smi, label)
|
219 |
-
if data:
|
220 |
-
data_list.append(data)
|
221 |
-
return data_list
|
222 |
|
223 |
graph_data = create_graph_dataset(df['smiles'], df['SR-HSE'])
|
224 |
test_loader = DataLoader(graph_data, batch_size=32)
|
@@ -237,16 +240,15 @@ def plot_distribution(df, model_type, input_prob=None):
|
|
237 |
|
238 |
# ------------------- Prediction Cache -------------------
|
239 |
@st.cache_data(show_spinner="Generating predictions...")
|
240 |
-
|
241 |
def predict_fp(smiles):
|
242 |
try:
|
243 |
-
|
244 |
-
if mol is None:
|
245 |
return "Invalid SMILES", 0.0
|
|
|
246 |
fp = fp_gen.GetFingerprint(mol)
|
247 |
-
|
248 |
with torch.no_grad():
|
249 |
-
logits = fp_model(torch.tensor(
|
250 |
prob = torch.sigmoid(logits).item()
|
251 |
return ("Toxic" if prob > 0.5 else "Non-toxic"), prob
|
252 |
except Exception as e:
|
@@ -298,6 +300,7 @@ tab1, tab2 = st.tabs(["π¬ Fingerprint Model", "𧬠GCN Model"])
|
|
298 |
|
299 |
with tab1:
|
300 |
st.subheader("Fingerprint-based Prediction")
|
|
|
301 |
with st.form("fp_form"):
|
302 |
smiles_fp = st.text_input("Enter SMILES", "CCO")
|
303 |
show_debug_fp = st.checkbox("π Show Debug Info (raw score/logit)", key="fp_debug")
|
@@ -305,35 +308,42 @@ with tab1:
|
|
305 |
|
306 |
if predict_btn:
|
307 |
with st.spinner("Predicting..."):
|
308 |
-
|
309 |
-
if mol:
|
310 |
-
fp = fp_gen.GetFingerprint(mol)
|
311 |
-
arr = np.array(fp).reshape(1, -1)
|
312 |
-
tensor = torch.tensor(arr).float()
|
313 |
-
with torch.no_grad():
|
314 |
-
output = fp_model(tensor)
|
315 |
-
prob = torch.sigmoid(output).item()
|
316 |
-
raw_score = output.item()
|
317 |
-
label = "Toxic" if prob > 0.5 else "Non-toxic"
|
318 |
-
color = "red" if label == "Toxic" else "green"
|
319 |
-
|
320 |
-
st.markdown(f"<h4>π§Ύ Prediction: <span style='color:{color}'>{label}</span> β <code>{prob:.3f}</code></h4>", unsafe_allow_html=True)
|
321 |
-
|
322 |
-
if show_debug_fp:
|
323 |
-
st.code(f"π Raw Logit: {raw_score:.4f}", language='text')
|
324 |
-
st.markdown("#### Fingerprint Vector (First 20 bits)")
|
325 |
-
st.code(str(arr[0][:20]) + " ...", language="text")
|
326 |
-
|
327 |
-
st.image(Draw.MolToImage(mol), caption="Molecular Structure", width=250)
|
328 |
-
|
329 |
-
info = get_molecule_info(mol)
|
330 |
-
st.markdown("### Molecule Info:")
|
331 |
-
for k, v in info.items():
|
332 |
-
st.markdown(f"**{k}:** {v}")
|
333 |
-
|
334 |
-
st.plotly_chart(plot_distribution(df, 'fp', prob), use_container_width=True)
|
335 |
-
else:
|
336 |
st.error("β Invalid SMILES input. Please check your string.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
|
338 |
with st.expander("π Example SMILES to Try"):
|
339 |
st.markdown("""
|
@@ -360,12 +370,12 @@ with tab1:
|
|
360 |
else:
|
361 |
st.info("Fingerprint model predictions not available.")
|
362 |
|
363 |
-
|
364 |
with tab2:
|
365 |
st.subheader("Graph Neural Network Prediction")
|
366 |
|
367 |
SUPPORTED_ATOMS = {1, 6, 7, 8, 9, 16, 17, 35, 53} # H, C, N, O, F, S, Cl, Br, I
|
368 |
|
|
|
369 |
def is_supported(mol):
|
370 |
return all(atom.GetAtomicNum() in SUPPORTED_ATOMS for atom in mol.GetAtoms())
|
371 |
|
@@ -376,48 +386,48 @@ with tab2:
|
|
376 |
|
377 |
if gcn_btn:
|
378 |
with st.spinner("Predicting..."):
|
379 |
-
|
380 |
-
|
381 |
-
if mol is None:
|
382 |
st.error("β Invalid SMILES: could not parse molecule.")
|
383 |
-
elif not is_supported(mol):
|
384 |
-
st.error("β οΈ This molecule contains unsupported atoms (e.g. Sn, P, etc.). GCN model only supports common organic elements.")
|
385 |
else:
|
386 |
-
|
387 |
-
if
|
388 |
-
st.error("
|
389 |
else:
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
|
|
|
|
|
|
|
|
421 |
|
422 |
with st.expander("π Example SMILES to Try"):
|
423 |
st.markdown("""
|
@@ -441,19 +451,18 @@ with tab2:
|
|
441 |
st.info("Predictions not available yet.")
|
442 |
|
443 |
with st.expander("π§ͺ Top 5 Toxic Predictions from Test Set"):
|
444 |
-
if 'gcn_prob' in df:
|
445 |
def is_valid_gcn(smi):
|
446 |
mol = Chem.MolFromSmiles(smi)
|
447 |
return mol is not None and is_supported(mol) and smiles_to_graph(smi) is not None
|
448 |
|
449 |
top_toxic = df[df['gcn_prob'] > best_threshold].copy()
|
450 |
top_toxic = top_toxic[top_toxic['smiles'].apply(is_valid_gcn)]
|
451 |
-
top_toxic = top_toxic.sort_values('gcn_prob', ascending=False).head(5)
|
452 |
|
453 |
-
if not top_toxic.empty:
|
|
|
454 |
st.table(top_toxic[['smiles', 'gcn_prob']].rename(columns={'gcn_prob': 'Predicted Probability'}))
|
455 |
else:
|
456 |
st.info("No valid top predictions available.")
|
457 |
else:
|
458 |
st.info("GCN model predictions not available.")
|
459 |
-
|
|
|
16 |
from rdkit.Chem import Draw
|
17 |
from torch_geometric.data import Batch
|
18 |
from rdkit.Chem import Descriptors
|
|
|
|
|
19 |
import time
|
20 |
+
from rdkit import RDLogger
|
21 |
+
RDLogger.DisableLog('rdApp.*')
|
22 |
|
23 |
|
24 |
# ------------------- Models -------------------
|
|
|
114 |
# ------------------- Utility Functions -------------------
|
115 |
fp_gen = GetMorganGenerator(radius=2, fpSize=1024)
|
116 |
|
117 |
+
def is_valid_smiles(smiles):
|
118 |
+
try:
|
119 |
+
mol = Chem.MolFromSmiles(smiles)
|
120 |
+
if mol is None or mol.GetNumAtoms() == 0:
|
121 |
+
return False
|
122 |
+
Chem.SanitizeMol(mol) # Force check for chemical correctness
|
123 |
+
return True
|
124 |
+
except:
|
125 |
+
return False
|
126 |
+
|
127 |
def get_molecule_info(mol):
|
128 |
return {
|
129 |
"Formula": Chem.rdMolDescriptors.CalcMolFormula(mol),
|
|
|
132 |
"Bonds": mol.GetNumBonds()
|
133 |
}
|
134 |
|
|
|
|
|
135 |
def predict_gcn(smiles):
|
136 |
graph = smiles_to_graph(smiles)
|
137 |
if graph is None:
|
|
|
185 |
return data
|
186 |
|
187 |
|
188 |
+
|
189 |
+
|
190 |
# def predict_gcn(smiles):
|
191 |
# graph = smiles_to_graph(smiles)
|
192 |
# if graph is None or graph.x.size(0) == 0:
|
|
|
209 |
# df['mol'] = df['smiles'].apply(Chem.MolFromSmiles)
|
210 |
# df = df[df['mol'].notna()].reset_index(drop=True)
|
211 |
|
212 |
+
df = pd.read_csv("tox21.csv")[['smiles', 'SR-HSE']].dropna() #src changes
|
213 |
df = df[df['SR-HSE'].isin([0, 1])].reset_index(drop=True)
|
214 |
|
215 |
# β
Filter invalid or unprocessable SMILES
|
216 |
def is_valid_graph(smi):
|
217 |
+
return is_valid_smiles(smi) and smiles_to_graph(smi) is not None
|
|
|
218 |
|
219 |
df = df[df['smiles'].apply(is_valid_graph)].reset_index(drop=True)
|
220 |
|
221 |
|
222 |
|
|
|
223 |
def create_graph_dataset(smiles_list, labels):
|
224 |
+
return [smiles_to_graph(smi, label) for smi, label in zip(smiles_list, labels) if smiles_to_graph(smi, label)]
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
graph_data = create_graph_dataset(df['smiles'], df['SR-HSE'])
|
227 |
test_loader = DataLoader(graph_data, batch_size=32)
|
|
|
240 |
|
241 |
# ------------------- Prediction Cache -------------------
|
242 |
@st.cache_data(show_spinner="Generating predictions...")
|
|
|
243 |
def predict_fp(smiles):
|
244 |
try:
|
245 |
+
if not is_valid_smiles(smiles):
|
|
|
246 |
return "Invalid SMILES", 0.0
|
247 |
+
mol = Chem.MolFromSmiles(smiles)
|
248 |
fp = fp_gen.GetFingerprint(mol)
|
249 |
+
arr = np.array(fp).reshape(1, -1)
|
250 |
with torch.no_grad():
|
251 |
+
logits = fp_model(torch.tensor(arr).float())
|
252 |
prob = torch.sigmoid(logits).item()
|
253 |
return ("Toxic" if prob > 0.5 else "Non-toxic"), prob
|
254 |
except Exception as e:
|
|
|
300 |
|
301 |
with tab1:
|
302 |
st.subheader("Fingerprint-based Prediction")
|
303 |
+
|
304 |
with st.form("fp_form"):
|
305 |
smiles_fp = st.text_input("Enter SMILES", "CCO")
|
306 |
show_debug_fp = st.checkbox("π Show Debug Info (raw score/logit)", key="fp_debug")
|
|
|
308 |
|
309 |
if predict_btn:
|
310 |
with st.spinner("Predicting..."):
|
311 |
+
if not is_valid_smiles(smiles_fp):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
st.error("β Invalid SMILES input. Please check your string.")
|
313 |
+
else:
|
314 |
+
try:
|
315 |
+
mol = Chem.MolFromSmiles(smiles_fp)
|
316 |
+
fp = fp_gen.GetFingerprint(mol)
|
317 |
+
arr = np.array(fp).reshape(1, -1)
|
318 |
+
tensor = torch.tensor(arr).float().to("cpu")
|
319 |
+
fp_model.to("cpu") # Ensure model is on CPU
|
320 |
+
|
321 |
+
with torch.no_grad():
|
322 |
+
output = fp_model(tensor)
|
323 |
+
prob = torch.sigmoid(output).item()
|
324 |
+
raw_score = output.item()
|
325 |
+
label = "Toxic" if prob > 0.5 else "Non-toxic"
|
326 |
+
color = "red" if label == "Toxic" else "green"
|
327 |
+
|
328 |
+
st.markdown(f"<h4>π§Ύ Prediction: <span style='color:{color}'>{label}</span> β <code>{prob:.3f}</code></h4>", unsafe_allow_html=True)
|
329 |
+
|
330 |
+
if show_debug_fp:
|
331 |
+
st.code(f"π Raw Logit: {raw_score:.4f}", language='text')
|
332 |
+
st.markdown("#### Fingerprint Vector (First 20 bits)")
|
333 |
+
st.code(str(arr[0][:20]) + " ...", language="text")
|
334 |
+
|
335 |
+
st.image(Draw.MolToImage(mol), caption="Molecular Structure", width=250)
|
336 |
+
|
337 |
+
info = get_molecule_info(mol)
|
338 |
+
st.markdown("### Molecule Info:")
|
339 |
+
for k, v in info.items():
|
340 |
+
st.markdown(f"**{k}:** {v}")
|
341 |
+
|
342 |
+
st.plotly_chart(plot_distribution(df, 'fp', prob), use_container_width=True)
|
343 |
+
|
344 |
+
except Exception as e:
|
345 |
+
st.error(f"Prediction error: {str(e)}")
|
346 |
+
|
347 |
|
348 |
with st.expander("π Example SMILES to Try"):
|
349 |
st.markdown("""
|
|
|
370 |
else:
|
371 |
st.info("Fingerprint model predictions not available.")
|
372 |
|
|
|
373 |
with tab2:
|
374 |
st.subheader("Graph Neural Network Prediction")
|
375 |
|
376 |
SUPPORTED_ATOMS = {1, 6, 7, 8, 9, 16, 17, 35, 53} # H, C, N, O, F, S, Cl, Br, I
|
377 |
|
378 |
+
|
379 |
def is_supported(mol):
|
380 |
return all(atom.GetAtomicNum() in SUPPORTED_ATOMS for atom in mol.GetAtoms())
|
381 |
|
|
|
386 |
|
387 |
if gcn_btn:
|
388 |
with st.spinner("Predicting..."):
|
389 |
+
if not is_valid_smiles(smiles_gcn):
|
|
|
|
|
390 |
st.error("β Invalid SMILES: could not parse molecule.")
|
|
|
|
|
391 |
else:
|
392 |
+
mol = Chem.MolFromSmiles(smiles_gcn)
|
393 |
+
if not is_supported(mol):
|
394 |
+
st.error("β οΈ This molecule contains unsupported atoms (e.g. Sn, P, etc.). GCN model only supports common organic elements.")
|
395 |
else:
|
396 |
+
graph = smiles_to_graph(smiles_gcn)
|
397 |
+
if graph is None:
|
398 |
+
st.error("β SMILES is valid but could not be converted to graph. Possibly malformed structure.")
|
399 |
+
else:
|
400 |
+
batch = Batch.from_data_list([graph])
|
401 |
+
with torch.no_grad():
|
402 |
+
out = gcn_model(batch)
|
403 |
+
prob = torch.sigmoid(out).item()
|
404 |
+
raw_score = out.item()
|
405 |
+
label = "Toxic" if prob > best_threshold else "Non-toxic"
|
406 |
+
color = "red" if label == "Toxic" else "green"
|
407 |
+
|
408 |
+
st.markdown(f"<h4>π§Ύ GCN Prediction: <span style='color:{color}'>{label}</span> β <code>{prob:.3f}</code></h4>", unsafe_allow_html=True)
|
409 |
+
|
410 |
+
if show_debug:
|
411 |
+
st.code(f"π Raw Logit: {raw_score:.4f}", language='text')
|
412 |
+
|
413 |
+
st.image(Draw.MolToImage(mol), caption="Molecular Structure", width=250)
|
414 |
+
|
415 |
+
def get_molecule_info(mol):
|
416 |
+
return {
|
417 |
+
"Molecular Weight": round(Chem.Descriptors.MolWt(mol), 2),
|
418 |
+
"LogP": round(Chem.Crippen.MolLogP(mol), 2),
|
419 |
+
"Num H-Bond Donors": Chem.Lipinski.NumHDonors(mol),
|
420 |
+
"Num H-Bond Acceptors": Chem.Lipinski.NumHAcceptors(mol),
|
421 |
+
"TPSA": round(Chem.rdMolDescriptors.CalcTPSA(mol), 2),
|
422 |
+
"Num Rotatable Bonds": Chem.Lipinski.NumRotatableBonds(mol)
|
423 |
+
}
|
424 |
+
|
425 |
+
info = get_molecule_info(mol)
|
426 |
+
st.markdown("### Molecule Info:")
|
427 |
+
for k, v in info.items():
|
428 |
+
st.markdown(f"**{k}:** {v}")
|
429 |
+
|
430 |
+
st.plotly_chart(plot_distribution(df, 'gcn', prob), use_container_width=True)
|
431 |
|
432 |
with st.expander("π Example SMILES to Try"):
|
433 |
st.markdown("""
|
|
|
451 |
st.info("Predictions not available yet.")
|
452 |
|
453 |
with st.expander("π§ͺ Top 5 Toxic Predictions from Test Set"):
|
454 |
+
if 'gcn_prob' in df.columns:
|
455 |
def is_valid_gcn(smi):
|
456 |
mol = Chem.MolFromSmiles(smi)
|
457 |
return mol is not None and is_supported(mol) and smiles_to_graph(smi) is not None
|
458 |
|
459 |
top_toxic = df[df['gcn_prob'] > best_threshold].copy()
|
460 |
top_toxic = top_toxic[top_toxic['smiles'].apply(is_valid_gcn)]
|
|
|
461 |
|
462 |
+
if not top_toxic.empty and 'gcn_prob' in top_toxic.columns:
|
463 |
+
top_toxic = top_toxic.sort_values('gcn_prob', ascending=False).head(5)
|
464 |
st.table(top_toxic[['smiles', 'gcn_prob']].rename(columns={'gcn_prob': 'Predicted Probability'}))
|
465 |
else:
|
466 |
st.info("No valid top predictions available.")
|
467 |
else:
|
468 |
st.info("GCN model predictions not available.")
|
|