Update src/streamlit_app.py
Browse files- src/streamlit_app.py +12 -15
src/streamlit_app.py
CHANGED
@@ -301,26 +301,23 @@ tab1, tab2 = st.tabs(["π¬ Fingerprint Model", "𧬠GCN Model"])
|
|
301 |
with tab1:
|
302 |
st.subheader("Fingerprint-based Prediction")
|
303 |
|
304 |
-
# β
SMILES validation function
|
305 |
-
def is_valid_smiles(smiles):
|
306 |
-
mol = Chem.MolFromSmiles(smiles)
|
307 |
-
return mol is not None
|
308 |
-
|
309 |
with st.form("fp_form"):
|
310 |
smiles_fp = st.text_input("Enter SMILES", "CCO")
|
311 |
show_debug_fp = st.checkbox("π Show Debug Info (raw score/logit)", key="fp_debug")
|
312 |
predict_btn = st.form_submit_button("π Predict")
|
313 |
|
314 |
if predict_btn:
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
mol = Chem.MolFromSmiles(smiles_fp)
|
321 |
fp = fp_gen.GetFingerprint(mol)
|
322 |
arr = np.array(fp).reshape(1, -1)
|
323 |
-
tensor = torch.tensor(arr).float()
|
|
|
|
|
324 |
with torch.no_grad():
|
325 |
output = fp_model(tensor)
|
326 |
prob = torch.sigmoid(output).item()
|
@@ -344,6 +341,10 @@ with tab1:
|
|
344 |
|
345 |
st.plotly_chart(plot_distribution(df, 'fp', prob), use_container_width=True)
|
346 |
|
|
|
|
|
|
|
|
|
347 |
with st.expander("π Example SMILES to Try"):
|
348 |
st.markdown("""
|
349 |
- `CCO` (Ethanol)
|
@@ -374,10 +375,6 @@ with tab2:
|
|
374 |
|
375 |
SUPPORTED_ATOMS = {1, 6, 7, 8, 9, 16, 17, 35, 53} # H, C, N, O, F, S, Cl, Br, I
|
376 |
|
377 |
-
# β
Molecule validation functions
|
378 |
-
def is_valid_smiles(smiles):
|
379 |
-
mol = Chem.MolFromSmiles(smiles)
|
380 |
-
return mol is not None
|
381 |
|
382 |
def is_supported(mol):
|
383 |
return all(atom.GetAtomicNum() in SUPPORTED_ATOMS for atom in mol.GetAtoms())
|
|
|
301 |
with tab1:
|
302 |
st.subheader("Fingerprint-based Prediction")
|
303 |
|
|
|
|
|
|
|
|
|
|
|
304 |
with st.form("fp_form"):
|
305 |
smiles_fp = st.text_input("Enter SMILES", "CCO")
|
306 |
show_debug_fp = st.checkbox("π Show Debug Info (raw score/logit)", key="fp_debug")
|
307 |
predict_btn = st.form_submit_button("π Predict")
|
308 |
|
309 |
if predict_btn:
|
310 |
+
with st.spinner("Predicting..."):
|
311 |
+
if not is_valid_smiles(smiles_fp):
|
312 |
+
st.error("β Invalid SMILES input. Please check your string.")
|
313 |
+
else:
|
314 |
+
try:
|
315 |
mol = Chem.MolFromSmiles(smiles_fp)
|
316 |
fp = fp_gen.GetFingerprint(mol)
|
317 |
arr = np.array(fp).reshape(1, -1)
|
318 |
+
tensor = torch.tensor(arr).float().to("cpu")
|
319 |
+
fp_model.to("cpu") # Ensure model is on CPU
|
320 |
+
|
321 |
with torch.no_grad():
|
322 |
output = fp_model(tensor)
|
323 |
prob = torch.sigmoid(output).item()
|
|
|
341 |
|
342 |
st.plotly_chart(plot_distribution(df, 'fp', prob), use_container_width=True)
|
343 |
|
344 |
+
except Exception as e:
|
345 |
+
st.error(f"Prediction error: {str(e)}")
|
346 |
+
|
347 |
+
|
348 |
with st.expander("π Example SMILES to Try"):
|
349 |
st.markdown("""
|
350 |
- `CCO` (Ethanol)
|
|
|
375 |
|
376 |
SUPPORTED_ATOMS = {1, 6, 7, 8, 9, 16, 17, 35, 53} # H, C, N, O, F, S, Cl, Br, I
|
377 |
|
|
|
|
|
|
|
|
|
378 |
|
379 |
def is_supported(mol):
|
380 |
return all(atom.GetAtomicNum() in SUPPORTED_ATOMS for atom in mol.GetAtoms())
|