# Safe monkey patch to fix Streamlit reloader crash due to torch.classes bug import types import torch try: # Check if torch.classes exists and fix the __path__ attribute if needed if hasattr(torch, 'classes') and not hasattr(torch.classes, "__path__"): torch.classes.__path__ = types.SimpleNamespace(_path=[]) except Exception: pass # Safe fallback if torch.classes doesn't exist or can't be patched # ------------------- Imports ------------------- import streamlit as st import numpy as np import pandas as pd import torch import time import torch.nn as nn import torch.nn.functional as F from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve from sklearn.decomposition import PCA from sklearn.cluster import KMeans from sklearn.manifold import TSNE from sklearn.metrics.pairwise import cosine_similarity # Optional SHAP import for interpretability try: import shap import matplotlib.pyplot as plt SHAP_AVAILABLE = True except ImportError: SHAP_AVAILABLE = False from rdkit import Chem from rdkit.Chem import rdMolDescriptors from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator from torch_geometric.data import Data from torch_geometric.nn import GCNConv, global_mean_pool from torch_geometric.loader import DataLoader import plotly.express as px from rdkit.Chem import Draw from torch_geometric.data import Batch from rdkit.Chem import Descriptors import time import gzip import io # ------------------- Models ------------------- class ToxicityNet(nn.Module): def __init__(self): super().__init__() self.model = nn.Sequential( nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 128), nn.ReLU(), nn.Linear(128, 1) ) def forward(self, x): return self.model(x) class RichGCNModel(nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(10, 64) self.bn1 = nn.BatchNorm1d(64) self.conv2 = GCNConv(64, 128) self.bn2 = nn.BatchNorm1d(128) self.dropout = nn.Dropout(0.2) self.fc1 = nn.Linear(128, 64) self.fc2 = nn.Linear(64, 1) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.bn1(self.conv1(x, edge_index))) x = F.relu(self.bn2(self.conv2(x, edge_index))) x = global_mean_pool(x, batch) x = self.dropout(x) x = F.relu(self.fc1(x)) return self.fc2(x) # ------------------- UI Setup ------------------- st.set_page_config(layout="wide", page_title="Drug Toxicity Predictor") st.title("๐งช Drug Toxicity Prediction Dashboard") # Performance info st.info("๐ **Fast Loading**: This app uses optimized sampling for quick startup. Full dataset processing happens on-demand.") # ------------------- Load Models with Spinner ------------------- # ------------------- Load Models with Temporary Messages ------------------- fp_model = ToxicityNet() gcn_model = RichGCNModel() fp_loaded = gcn_loaded = False # Load Fingerprint Model try: fp_model.load_state_dict(torch.load("tox_model.pt", map_location=torch.device("cpu"))) fp_model.eval() fp_loaded = True except Exception as e: st.warning(f"โ ๏ธ Fingerprint model not loaded: {e}") # Load GCN Model try: # Load the state dict state_dict = torch.load("gcn_model_improved.pt", map_location=torch.device("cpu")) # Handle DataParallel wrapper - fix nested module prefixes new_state_dict = {} for key, value in state_dict.items(): # Handle cases like "bn1.module.weight" -> "bn1.weight" if ".module." in key: new_key = key.replace(".module.", ".") new_state_dict[new_key] = value # Handle cases like "module.conv1.weight" -> "conv1.weight" elif key.startswith("module."): new_key = key[7:] # Remove "module." prefix new_state_dict[new_key] = value else: new_state_dict[key] = value gcn_model.load_state_dict(new_state_dict) gcn_model.eval() gcn_loaded = True except Exception as e: st.warning(f"โ ๏ธ GCN model not loaded: {e}") # Load Best Threshold try: best_threshold = float(np.load("gcn_best_threshold.npy")) except Exception as e: best_threshold = 0.5 st.warning(f"โ ๏ธ Using default threshold (0.5) for GCN model. Reason: {e}") # ------------------- Enhanced SMILES Processing ------------------- def preprocess_smiles(smiles): """Preprocess SMILES to handle complex cases""" if not smiles or not isinstance(smiles, str): return None, "Empty or invalid SMILES" # Remove whitespace smiles = smiles.strip() # Handle multi-component molecules (salts, etc.) if '.' in smiles: components = smiles.split('.') # Take the largest component (usually the main molecule) main_component = max(components, key=len) return main_component, f"Multi-component molecule detected. Using largest component: {main_component}" return smiles, None def safe_mol_from_smiles(smiles): """Safely create molecule from SMILES with better error handling""" try: processed_smiles, warning = preprocess_smiles(smiles) if processed_smiles is None: return None, warning mol = Chem.MolFromSmiles(processed_smiles) if mol is None: # Try to sanitize try: mol = Chem.MolFromSmiles(processed_smiles, sanitize=False) if mol is not None: Chem.SanitizeMol(mol) except: pass return mol, warning except Exception as e: return None, f"Error processing SMILES: {str(e)}" # ------------------- Utility Functions ------------------- fp_gen = GetMorganGenerator(radius=2, fpSize=1024) def get_molecule_info(mol): return { "Formula": Chem.rdMolDescriptors.CalcMolFormula(mol), "Weight": round(Descriptors.MolWt(mol), 2), "Atoms": mol.GetNumAtoms(), "Bonds": mol.GetNumBonds() } def predict_gcn(smiles): mol, warning = safe_mol_from_smiles(smiles) if mol is None: return None, None graph = smiles_to_graph(smiles) if graph is None: return None, None batch = Batch.from_data_list([graph]) with torch.no_grad(): out = gcn_model(batch) prob = torch.sigmoid(out).item() return ("Toxic" if prob > best_threshold else "Non-toxic"), prob def atom_feats(atom): return [ atom.GetAtomicNum(), atom.GetDegree(), atom.GetFormalCharge(), atom.GetNumExplicitHs(), atom.GetNumImplicitHs(), atom.GetIsAromatic(), atom.GetMass(), int(atom.IsInRing()), int(atom.GetChiralTag()), int(atom.GetHybridization()) ] def smiles_to_graph(smiles, label=None): mol, warning = safe_mol_from_smiles(smiles) if mol is None or mol.GetNumAtoms() == 0: return None atoms = [atom_feats(a) for a in mol.GetAtoms()] if not atoms: return None # No atoms present edges = [] for b in mol.GetBonds(): i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() edges += [[i, j], [j, i]] # Handle molecules with no bonds (e.g. single atom) if len(edges) == 0: edges = [[0, 0]] edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() x = torch.tensor(atoms, dtype=torch.float) batch = torch.zeros(x.size(0), dtype=torch.long) data = Data(x=x, edge_index=edge_index, batch=batch) if label is not None: data.y = torch.tensor([label], dtype=torch.float) return data # def predict_gcn(smiles): # graph = smiles_to_graph(smiles) # if graph is None or graph.x.size(0) == 0: # return None, None # batch = Batch.from_data_list([graph]) # with torch.no_grad(): # out = gcn_model(batch) # raw = out.item() # prob = torch.sigmoid(out).item() # print(f"Raw logit: {raw:.4f}, Prob: {prob:.4f}") # return ("Toxic" if prob > best_threshold else "Non-toxic"), prob # ------------------- Load Dataset (Enhanced Full Dataset Loading) ------------------- @st.cache_data(show_spinner="Loading full dataset...") def load_full_dataset(): """Load and preprocess the complete Tox21 dataset""" df = pd.read_csv("tox21.csv")[['smiles', 'SR-HSE']].dropna() df = df[df['SR-HSE'].isin([0, 1])].reset_index(drop=True) st.info(f"๐ Loaded {len(df)} molecules from Tox21 dataset") # Validate all SMILES for both models def validate_smiles(smi): mol = Chem.MolFromSmiles(smi) if mol is None: return False, False # Check fingerprint compatibility fp_valid = True try: fp_gen.GetFingerprint(mol) except: fp_valid = False # Check GCN compatibility gcn_valid = True try: graph = smiles_to_graph(smi) if graph is None: gcn_valid = False # Check supported atoms SUPPORTED_ATOMS = {1, 6, 7, 8, 9, 16, 17, 35, 53} if not all(atom.GetAtomicNum() in SUPPORTED_ATOMS for atom in mol.GetAtoms()): gcn_valid = False except: gcn_valid = False return fp_valid, gcn_valid # Process validation with progress valid_results = [] progress_bar = st.progress(0) status_text = st.empty() for i, smi in enumerate(df['smiles']): fp_valid, gcn_valid = validate_smiles(smi) valid_results.append((fp_valid, gcn_valid)) if i % 100 == 0: progress_bar.progress((i + 1) / len(df)) status_text.text(f"Validated {i + 1}/{len(df)} molecules...") progress_bar.empty() status_text.empty() # Add validation columns df['fp_valid'] = [r[0] for r in valid_results] df['gcn_valid'] = [r[1] for r in valid_results] # Statistics fp_valid_count = df['fp_valid'].sum() gcn_valid_count = df['gcn_valid'].sum() both_valid_count = (df['fp_valid'] & df['gcn_valid']).sum() st.success(f""" โ **Dataset Validation Complete:** - Fingerprint Model: {fp_valid_count}/{len(df)} valid molecules ({fp_valid_count/len(df)*100:.1f}%) - GCN Model: {gcn_valid_count}/{len(df)} valid molecules ({gcn_valid_count/len(df)*100:.1f}%) - Both Models: {both_valid_count}/{len(df)} valid molecules ({both_valid_count/len(df)*100:.1f}%) """) return df @st.cache_data(show_spinner="Loading dataset...") def load_dataset(): df = pd.read_csv("tox21.csv")[['smiles', 'SR-HSE']].dropna() df = df[df['SR-HSE'].isin([0, 1])].reset_index(drop=True) # โ Filter invalid or unprocessable SMILES (but only for a sample to speed up) def is_valid_graph(smi): mol = Chem.MolFromSmiles(smi) return mol is not None and smiles_to_graph(smi) is not None # Only validate a subset for faster loading sample_size = min(1000, len(df)) df_sample = df.sample(n=sample_size, random_state=42) valid_indices = df_sample[df_sample['smiles'].apply(is_valid_graph)].index df_valid = df.loc[valid_indices].reset_index(drop=True) return df_valid # Load dataset lazily (sample for quick startup) df = load_dataset() # Full dataset will be loaded on-demand df_full = None @st.cache_data(show_spinner="Creating graph dataset...") def create_graph_dataset_cached(smiles_list, labels): data_list = [] for smi, label in zip(smiles_list, labels): data = smiles_to_graph(smi, label) if data: data_list.append(data) return data_list # Only create graph dataset when needed for evaluation def get_graph_data(): return create_graph_dataset_cached(df['smiles'].tolist(), df['SR-HSE'].tolist()) # ------------------- Plot Function ------------------- def plot_distribution(df_to_plot, model_type, input_prob=None): col = 'fp_prob' if model_type == 'fp' else 'gcn_prob' df_plot = df_to_plot[df_to_plot[col].notna()].copy() df_plot["Label"] = df_plot["SR-HSE"].map({0: "Non-toxic", 1: "Toxic"}) fig = px.histogram(df_plot, x=col, color="Label", nbins=30, barmode="overlay", color_discrete_map={"Non-toxic": "green", "Toxic": "red"}, title=f"{model_type.upper()} Model - Sample Distribution (n={len(df_plot)})") if input_prob: fig.add_vline(x=input_prob, line_dash="dash", line_color="yellow", annotation_text="Your Input") return fig # ------------------- Enhanced Prediction Functions ------------------- @st.cache_data(show_spinner="Generating full dataset predictions...") def get_full_predictions(model_type='fp', max_molecules=None): """Generate predictions for the full dataset""" global df_full if df_full is None: df_full = load_full_dataset() if model_type == 'fp': valid_df = df_full[df_full['fp_valid']].copy() if not fp_loaded: st.error("Fingerprint model not loaded") return None else: valid_df = df_full[df_full['gcn_valid']].copy() if not gcn_loaded: st.error("GCN model not loaded") return None if max_molecules: valid_df = valid_df.head(max_molecules) predictions = [] probabilities = [] progress_bar = st.progress(0) status_text = st.empty() for i, (_, row) in enumerate(valid_df.iterrows()): try: if model_type == 'fp': pred, prob = predict_fp(row['smiles']) else: pred, prob = predict_gcn(row['smiles']) predictions.append(pred) probabilities.append(prob) except Exception as e: predictions.append("Error") probabilities.append(None) if i % 50 == 0: progress_bar.progress((i + 1) / len(valid_df)) status_text.text(f"Processed {i + 1}/{len(valid_df)} molecules...") progress_bar.empty() status_text.empty() valid_df[f'{model_type}_prediction'] = predictions valid_df[f'{model_type}_probability'] = probabilities return valid_df @st.cache_data(show_spinner="Evaluating model performance...") def comprehensive_evaluation(model_type='fp'): """Comprehensive evaluation on full dataset""" predictions_df = get_full_predictions(model_type) if predictions_df is None: return None # Remove error predictions clean_df = predictions_df[ (predictions_df[f'{model_type}_prediction'] != "Error") & (predictions_df[f'{model_type}_probability'].notna()) ].copy() # Convert predictions to binary if model_type == 'fp': threshold = 0.5 else: threshold = best_threshold y_true = clean_df['SR-HSE'].values y_prob = clean_df[f'{model_type}_probability'].values y_pred = (y_prob > threshold).astype(int) # Calculate metrics accuracy = accuracy_score(y_true, y_pred) precision = precision_score(y_true, y_pred) recall = recall_score(y_true, y_pred) f1 = f1_score(y_true, y_pred) roc_auc = roc_auc_score(y_true, y_prob) # Confusion matrix cm = confusion_matrix(y_true, y_pred) results = { 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'roc_auc': roc_auc, 'confusion_matrix': cm, 'predictions_df': clean_df, 'threshold': threshold, 'n_samples': len(clean_df) } return results def create_comprehensive_plots(results, model_type): """Create comprehensive visualization plots""" df = results['predictions_df'] prob_col = f'{model_type}_probability' # 1. Probability Distribution fig_dist = px.histogram( df, x=prob_col, color=df['SR-HSE'].map({0: "Non-toxic", 1: "Toxic"}), nbins=50, barmode="overlay", color_discrete_map={"Non-toxic": "green", "Toxic": "red"}, title=f"{model_type.upper()} Model - Full Dataset Probability Distribution (n={len(df)})" ) fig_dist.add_vline(x=results['threshold'], line_dash="dash", line_color="black", annotation_text=f"Threshold ({results['threshold']:.3f})") # 2. ROC Curve fpr, tpr, _ = roc_curve(df['SR-HSE'], df[prob_col]) fig_roc = px.line( x=fpr, y=tpr, title=f"{model_type.upper()} Model - ROC Curve (AUC: {results['roc_auc']:.3f})", labels={'x': 'False Positive Rate', 'y': 'True Positive Rate'} ) fig_roc.add_scatter(x=[0, 1], y=[0, 1], mode='lines', line=dict(dash='dash', color='gray'), name='Random Classifier', showlegend=False) # 3. Confusion Matrix cm = results['confusion_matrix'] fig_cm = px.imshow( cm, text_auto=True, aspect="auto", title="Confusion Matrix", labels=dict(x="Predicted", y="Actual"), x=['Non-toxic', 'Toxic'], y=['Non-toxic', 'Toxic'] ) return fig_dist, fig_roc, fig_cm def show_detailed_statistics(results, model_type): """Show detailed performance statistics""" col1, col2, col3 = st.columns(3) with col1: st.metric("Accuracy", f"{results['accuracy']:.3f}") st.metric("Precision", f"{results['precision']:.3f}") with col2: st.metric("Recall", f"{results['recall']:.3f}") st.metric("F1-Score", f"{results['f1']:.3f}") with col3: st.metric("ROC-AUC", f"{results['roc_auc']:.3f}") st.metric("Samples", f"{results['n_samples']:,}") # Confusion Matrix Details cm = results['confusion_matrix'] tn, fp, fn, tp = cm.ravel() st.subheader("Detailed Classification Results") col1, col2 = st.columns(2) with col1: st.write("**True Negatives (Correct Non-toxic):**", tn) st.write("**False Positives (Incorrect Toxic):**", fp) with col2: st.write("**False Negatives (Missed Toxic):**", fn) st.write("**True Positives (Correct Toxic):**", tp) # Calculate additional metrics specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 st.write(f"**Specificity (True Negative Rate):** {specificity:.3f}") st.write(f"**Sensitivity (True Positive Rate):** {sensitivity:.3f}") # ------------------- Advanced Analytics ------------------- # ------------------- Model Performance Insights Functions ------------------- def generate_single_model_insights(model_type, sample_size): """Generate performance insights for a single model""" model_name = "Fingerprint" if model_type == 'fp' else "GCN" threshold = 0.5 if model_type == 'fp' else best_threshold # Sample data for analysis sample_data = df.sample(n=min(sample_size, len(df))) predictions = [] probabilities = [] confidences = [] valid_count = 0 for _, row in sample_data.iterrows(): try: if model_type == 'fp': pred, prob = predict_fp(row['smiles']) else: pred, prob = predict_gcn(row['smiles']) if pred and prob is not None: predictions.append(pred) probabilities.append(prob) confidences.append(max(prob, 1-prob) if threshold == 0.5 else max(prob - threshold, threshold - prob) + threshold) valid_count += 1 except: predictions.append(None) probabilities.append(None) confidences.append(None) if valid_count == 0: st.error(f"โ No valid predictions generated for {model_name} model") return # Create analysis dataframe analysis_df = sample_data.copy() analysis_df['Predicted'] = predictions analysis_df['Probability'] = probabilities analysis_df['Confidence'] = confidences analysis_df['Toxicity_Label'] = analysis_df['SR-HSE'].map({0: 'Non-Toxic', 1: 'Toxic'}) # Filter valid predictions valid_df = analysis_df[analysis_df['Predicted'].notna()].copy() if len(valid_df) == 0: st.error(f"โ No valid predictions for analysis") return st.success(f"โ **{model_name} Model Insights** - Analyzed {len(valid_df)} valid predictions") # Performance metrics col1, col2, col3, col4 = st.columns(4) # Calculate basic metrics y_true = valid_df['SR-HSE'].values y_pred_binary = (valid_df['Probability'] > threshold).astype(int) y_prob = valid_df['Probability'].values accuracy = accuracy_score(y_true, y_pred_binary) precision = precision_score(y_true, y_pred_binary) recall = recall_score(y_true, y_pred_binary) roc_auc = roc_auc_score(y_true, y_prob) with col1: st.metric("Accuracy", f"{accuracy:.3f}") with col2: st.metric("Precision", f"{precision:.3f}") with col3: st.metric("Recall", f"{recall:.3f}") with col4: st.metric("ROC-AUC", f"{roc_auc:.3f}") # Visualization col1, col2 = st.columns(2) with col1: # Confidence distribution fig_conf = px.histogram( valid_df, x='Confidence', color='Predicted', title=f"{model_name} Model - Confidence Distribution", nbins=20, color_discrete_map={'Toxic': 'red', 'Non-toxic': 'green'} ) st.plotly_chart(fig_conf, use_container_width=True) with col2: # Probability distribution by true label fig_prob = px.histogram( valid_df, x='Probability', color='Toxicity_Label', title=f"{model_name} Model - Probability by True Label", nbins=20, color_discrete_map={'Toxic': 'red', 'Non-Toxic': 'green'} ) fig_prob.add_vline(x=threshold, line_dash="dash", annotation_text=f"Threshold ({threshold})") st.plotly_chart(fig_prob, use_container_width=True) # Confidence-based accuracy analysis valid_df['Correct'] = (valid_df['Predicted'] == valid_df['Toxicity_Label']) # Binned confidence analysis valid_df['Confidence_Bin'] = pd.cut(valid_df['Confidence'], bins=5, labels=['Very Low', 'Low', 'Medium', 'High', 'Very High']) conf_analysis = valid_df.groupby('Confidence_Bin').agg({ 'Correct': ['mean', 'count'], 'Confidence': 'mean' }).round(3) st.subheader(f"๐ฏ {model_name} Model Reliability Analysis") # High/Low confidence metrics high_conf_mask = valid_df['Confidence'] > 0.8 low_conf_mask = valid_df['Confidence'] < 0.6 high_conf_accuracy = valid_df[high_conf_mask]['Correct'].mean() if high_conf_mask.any() else 0 low_conf_accuracy = valid_df[low_conf_mask]['Correct'].mean() if low_conf_mask.any() else 0 avg_confidence = valid_df['Confidence'].mean() col1, col2, col3 = st.columns(3) with col1: st.metric("High Confidence Accuracy (>0.8)", f"{high_conf_accuracy:.1%}", f"{high_conf_mask.sum()} samples") with col2: st.metric("Low Confidence Accuracy (<0.6)", f"{low_conf_accuracy:.1%}", f"{low_conf_mask.sum()} samples") with col3: st.metric("Average Confidence", f"{avg_confidence:.3f}") # Model-specific insights st.markdown(f""" **๐ก {model_name} Model Insights:** **Strengths:** - {"High precision for confident predictions" if high_conf_accuracy > 0.9 else "Moderate precision for confident predictions"} - {"Good calibration between confidence and accuracy" if abs(high_conf_accuracy - low_conf_accuracy) > 0.2 else "Confidence and accuracy correlation needs improvement"} **Recommendations:** - {"Review low-confidence predictions manually" if low_conf_mask.sum() > 10 else "Most predictions are high confidence"} - {"Model shows good reliability" if accuracy > 0.8 else "Consider model improvement or ensemble methods"} - {"Threshold optimization may improve performance" if abs(precision - recall) > 0.1 else "Balanced precision-recall performance"} **Model Characteristics:** - **Decision Threshold:** {threshold} - **Feature Space:** {"1024-dimensional molecular fingerprints" if model_type == 'fp' else "128-dimensional graph neural network features"} - **Architecture:** {"Feed-forward neural network with dropout" if model_type == 'fp' else "Graph convolutional network with batch normalization"} """) def generate_dual_model_insights(sample_size): """Generate comparative insights between both models""" st.success("โ **Dual Model Comparison Analysis**") # Sample data for analysis sample_data = df.sample(n=min(sample_size, len(df))) fp_predictions = [] fp_probabilities = [] gcn_predictions = [] gcn_probabilities = [] valid_indices = [] for idx, (_, row) in enumerate(sample_data.iterrows()): try: fp_pred, fp_prob = predict_fp(row['smiles']) gcn_pred, gcn_prob = predict_gcn(row['smiles']) if fp_pred and fp_prob is not None and gcn_pred and gcn_prob is not None: fp_predictions.append(fp_pred) fp_probabilities.append(fp_prob) gcn_predictions.append(gcn_pred) gcn_probabilities.append(gcn_prob) valid_indices.append(idx) except: continue if len(valid_indices) == 0: st.error("โ No valid predictions from both models for comparison") return # Create comparison dataframe valid_sample = sample_data.iloc[valid_indices].copy() valid_sample['FP_Prediction'] = fp_predictions valid_sample['FP_Probability'] = fp_probabilities valid_sample['GCN_Prediction'] = gcn_predictions valid_sample['GCN_Probability'] = gcn_probabilities valid_sample['Toxicity_Label'] = valid_sample['SR-HSE'].map({0: 'Non-Toxic', 1: 'Toxic'}) st.info(f"๐ Comparing {len(valid_sample)} molecules across both models") # Model agreement analysis agreement = (valid_sample['FP_Prediction'] == valid_sample['GCN_Prediction']).mean() # Calculate individual model accuracies fp_accuracy = accuracy_score(valid_sample['SR-HSE'], (valid_sample['FP_Probability'] > 0.5).astype(int)) gcn_accuracy = accuracy_score(valid_sample['SR-HSE'], (valid_sample['GCN_Probability'] > best_threshold).astype(int)) # Correlation between probabilities prob_correlation = np.corrcoef(valid_sample['FP_Probability'], valid_sample['GCN_Probability'])[0, 1] # Display comparison metrics col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Model Agreement", f"{agreement:.1%}") with col2: st.metric("FP Accuracy", f"{fp_accuracy:.3f}") with col3: st.metric("GCN Accuracy", f"{gcn_accuracy:.3f}") with col4: st.metric("Probability Correlation", f"{prob_correlation:.3f}") # Visualizations col1, col2 = st.columns(2) with col1: # Probability correlation scatter plot fig_corr = px.scatter( valid_sample, x='FP_Probability', y='GCN_Probability', color='Toxicity_Label', title="Model Probability Correlation", labels={'x': 'Fingerprint Probability', 'y': 'GCN Probability'}, color_discrete_map={'Toxic': 'red', 'Non-Toxic': 'green'} ) fig_corr.add_scatter(x=[0, 1], y=[0, 1], mode='lines', line=dict(dash='dash', color='gray'), name='Perfect Correlation', showlegend=False) st.plotly_chart(fig_corr, use_container_width=True) with col2: # Agreement by prediction confidence valid_sample['Agreement'] = valid_sample['FP_Prediction'] == valid_sample['GCN_Prediction'] valid_sample['Avg_Confidence'] = (valid_sample['FP_Probability'].apply(lambda x: max(x, 1-x)) + valid_sample['GCN_Probability'].apply(lambda x: max(x, 1-x))) / 2 conf_bins = pd.cut(valid_sample['Avg_Confidence'], bins=5, labels=['Very Low', 'Low', 'Medium', 'High', 'Very High']) agreement_by_conf = valid_sample.groupby(conf_bins)['Agreement'].mean() fig_agreement = px.bar( x=agreement_by_conf.index.astype(str), y=agreement_by_conf.values, title="Model Agreement by Confidence Level", labels={'x': 'Average Confidence Bin', 'y': 'Agreement Rate'} ) st.plotly_chart(fig_agreement, use_container_width=True) # Disagreement analysis disagreement_cases = valid_sample[~valid_sample['Agreement']].copy() st.subheader("๐ Model Disagreement Analysis") if len(disagreement_cases) > 0: st.write(f"**Found {len(disagreement_cases)} disagreement cases ({len(disagreement_cases)/len(valid_sample)*100:.1f}% of predictions)**") # Show some example disagreements st.write("**Example Disagreement Cases:**") display_cols = ['smiles', 'FP_Prediction', 'FP_Probability', 'GCN_Prediction', 'GCN_Probability', 'Toxicity_Label'] st.dataframe(disagreement_cases[display_cols].head(5)) # Disagreement patterns fp_toxic_gcn_nontoxic = ((disagreement_cases['FP_Prediction'] == 'Toxic') & (disagreement_cases['GCN_Prediction'] == 'Non-toxic')).sum() gcn_toxic_fp_nontoxic = ((disagreement_cases['GCN_Prediction'] == 'Toxic') & (disagreement_cases['FP_Prediction'] == 'Non-toxic')).sum() st.write(f"**Disagreement Patterns:**") st.write(f"- FP says Toxic, GCN says Non-toxic: {fp_toxic_gcn_nontoxic} cases") st.write(f"- GCN says Toxic, FP says Non-toxic: {gcn_toxic_fp_nontoxic} cases") else: st.success("๐ Perfect agreement between models on this sample!") # Ensemble performance st.subheader("๐ค Ensemble Analysis") # Simple ensemble: average probabilities ensemble_prob = (valid_sample['FP_Probability'] + valid_sample['GCN_Probability']) / 2 ensemble_pred = (ensemble_prob > 0.5).astype(int) ensemble_accuracy = accuracy_score(valid_sample['SR-HSE'], ensemble_pred) # Voting ensemble fp_votes = (valid_sample['FP_Probability'] > 0.5).astype(int) gcn_votes = (valid_sample['GCN_Probability'] > best_threshold).astype(int) voting_pred = (fp_votes + gcn_votes >= 1).astype(int) # At least one model predicts toxic voting_accuracy = accuracy_score(valid_sample['SR-HSE'], voting_pred) col1, col2, col3 = st.columns(3) with col1: st.metric("Average Ensemble Accuracy", f"{ensemble_accuracy:.3f}") with col2: st.metric("Voting Ensemble Accuracy", f"{voting_accuracy:.3f}") with col3: best_single = max(fp_accuracy, gcn_accuracy) improvement = max(ensemble_accuracy, voting_accuracy) - best_single st.metric("Best Improvement", f"+{improvement:.3f}" if improvement > 0 else f"{improvement:.3f}") st.markdown(""" **๐ก Dual Model Insights:** **Model Complementarity:** - **Fingerprint Model**: Excels at chemical similarity patterns, traditional QSAR relationships - **GCN Model**: Captures structural topology, graph-based molecular patterns - **Together**: Provide comprehensive molecular understanding **Ensemble Recommendations:** - Use **voting ensemble** for high-stakes decisions (conservative approach) - Use **average ensemble** for balanced sensitivity-specificity - Manual review recommended for disagreement cases **Best Practices:** - High agreement cases: Trust the prediction - Disagreement cases: Consider additional validation - Low confidence from both: Seek experimental validation """) @st.cache_data(show_spinner="Generating SHAP explanations...") def get_shap_explanations(smiles_list, model_type='fp', num_samples=100): """Generate SHAP explanations for model interpretability""" if not SHAP_AVAILABLE: return None, "SHAP library not available. Install with: pip install shap" try: if model_type == 'fp': # Create a sample of fingerprints for SHAP analysis sample_smiles = smiles_list[:num_samples] fingerprints = [] for smiles in sample_smiles: mol, _ = safe_mol_from_smiles(smiles) if mol is not None: fp = fp_gen.GetFingerprint(mol) fingerprints.append(np.array(fp)) if len(fingerprints) == 0: return None, "No valid molecules for SHAP analysis" X = np.array(fingerprints) # Create SHAP explainer for fingerprint model def model_predict(X): with torch.no_grad(): preds = [] for fp in X: tensor = torch.tensor(fp.reshape(1, -1)).float() pred = torch.sigmoid(fp_model(tensor)).item() preds.append(pred) return np.array(preds) # Use a subset for background background = X[:min(10, len(X))] explainer = shap.Explainer(model_predict, background) shap_values = explainer(X[:min(20, len(X))]) return shap_values, None elif model_type == 'gcn': # Create a sample of GCN features for SHAP analysis if not gcn_loaded: return None, "GCN model not loaded" sample_smiles = smiles_list[:num_samples] gcn_features, valid_smiles = extract_gcn_features(sample_smiles) if len(gcn_features) == 0: return None, "No valid molecules for GCN SHAP analysis" X = np.array(gcn_features) # Create SHAP explainer for GCN model using extracted features def gcn_model_predict(X): with torch.no_grad(): preds = [] for features in X: # Convert features back to a format the model can use # We'll use the extracted features directly for explanation # This represents the contribution of the learned features tensor = torch.tensor(features.reshape(1, -1)).float() # Apply the final layers of GCN model (fc1 -> fc2) x = torch.relu(gcn_model.fc1(tensor)) pred = torch.sigmoid(gcn_model.fc2(x)).item() preds.append(pred) return np.array(preds) # Use a subset for background background = X[:min(10, len(X))] explainer = shap.Explainer(gcn_model_predict, background) shap_values = explainer(X[:min(20, len(X))]) return shap_values, None except Exception as e: return None, f"SHAP analysis failed: {str(e)}" def extract_gcn_features(smiles_list): """Extract features from GCN model's hidden layers""" features = [] valid_smiles = [] for smiles in smiles_list: try: mol, _ = safe_mol_from_smiles(smiles) if mol is None: continue # Check if molecule is supported by GCN SUPPORTED_ATOMS = {1, 6, 7, 8, 9, 16, 17, 35, 53} if not all(atom.GetAtomicNum() in SUPPORTED_ATOMS for atom in mol.GetAtoms()): continue graph = smiles_to_graph(smiles) if graph is None: continue batch = Batch.from_data_list([graph]) # Extract features from GCN model's hidden layers with torch.no_grad(): # Get the second-to-last layer features (before final classification) x = batch.x edge_index = batch.edge_index batch_idx = batch.batch # Forward through GCN layers x = gcn_model.conv1(x, edge_index) x = torch.relu(x) x = gcn_model.conv2(x, edge_index) x = torch.relu(x) # Global pooling to get graph-level features graph_features = global_mean_pool(x, batch_idx) features.append(graph_features.cpu().numpy().flatten()) valid_smiles.append(smiles) except Exception: continue return features, valid_smiles @st.cache_data(show_spinner="Performing PCA analysis...") def perform_pca_analysis(df_sample, model_type='fp'): """Perform Principal Component Analysis on molecular features""" try: features = [] labels = [] smiles_list = [] if model_type == 'fp': # Use molecular fingerprints for _, row in df_sample.iterrows(): mol, _ = safe_mol_from_smiles(row['smiles']) if mol is not None: fp = fp_gen.GetFingerprint(mol) features.append(np.array(fp)) labels.append(row['SR-HSE']) smiles_list.append(row['smiles']) elif model_type == 'gcn': # Use GCN learned features if not gcn_loaded: return None, "GCN model not loaded" sample_smiles = df_sample['smiles'].tolist() gcn_features, valid_smiles = extract_gcn_features(sample_smiles) # Map back to labels smiles_to_label = dict(zip(df_sample['smiles'], df_sample['SR-HSE'])) for i, smiles in enumerate(valid_smiles): features.append(gcn_features[i]) labels.append(smiles_to_label[smiles]) smiles_list.append(smiles) if len(features) < 10: return None, f"Insufficient valid molecules for PCA (got {len(features)}, need at least 10)" X = np.array(features) # Perform PCA pca = PCA(n_components=2) X_pca = pca.fit_transform(X) # Create DataFrame for plotting pca_df = pd.DataFrame({ 'PC1': X_pca[:, 0], 'PC2': X_pca[:, 1], 'Label': ['Toxic' if l == 1 else 'Non-toxic' for l in labels], 'SMILES': smiles_list }) explained_variance = pca.explained_variance_ratio_ return { 'pca_df': pca_df, 'explained_variance': explained_variance, 'pca_model': pca, 'feature_type': 'Fingerprints' if model_type == 'fp' else 'GCN Features' }, None except Exception as e: return None, f"PCA analysis failed: {str(e)}" @st.cache_data(show_spinner="Performing clustering analysis...") def perform_clustering_analysis(df_sample, n_clusters=3, model_type='fp'): """Perform K-means clustering on molecular features""" try: features = [] labels = [] smiles_list = [] if model_type == 'fp': # Use molecular fingerprints for _, row in df_sample.iterrows(): mol, _ = safe_mol_from_smiles(row['smiles']) if mol is not None: fp = fp_gen.GetFingerprint(mol) features.append(np.array(fp)) labels.append(row['SR-HSE']) smiles_list.append(row['smiles']) elif model_type == 'gcn': # Use GCN learned features if not gcn_loaded: return None, "GCN model not loaded" sample_smiles = df_sample['smiles'].tolist() gcn_features, valid_smiles = extract_gcn_features(sample_smiles) # Map back to labels smiles_to_label = dict(zip(df_sample['smiles'], df_sample['SR-HSE'])) for i, smiles in enumerate(valid_smiles): features.append(gcn_features[i]) labels.append(smiles_to_label[smiles]) smiles_list.append(smiles) if len(features) < n_clusters: return None, f"Need at least {n_clusters} valid molecules for clustering (got {len(features)})" X = np.array(features) # Perform K-means clustering kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) cluster_labels = kmeans.fit_predict(X) # Perform t-SNE for visualization tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(X)-1)) X_tsne = tsne.fit_transform(X) # Create DataFrame for plotting cluster_df = pd.DataFrame({ 'tSNE1': X_tsne[:, 0], 'tSNE2': X_tsne[:, 1], 'Cluster': [f'Cluster {i}' for i in cluster_labels], 'Toxicity': ['Toxic' if l == 1 else 'Non-toxic' for l in labels], 'SMILES': smiles_list }) return { 'cluster_df': cluster_df, 'kmeans_model': kmeans, 'feature_type': 'Fingerprints' if model_type == 'fp' else 'GCN Features' }, None except Exception as e: return None, f"Clustering analysis failed: {str(e)}" @st.cache_data(show_spinner="Finding similar molecules...") def find_similar_molecules(query_smiles, df_sample, top_k=5, model_type='fp'): """Find molecules similar to the query molecule based on feature similarity""" try: if model_type == 'fp': # Use fingerprint similarity query_mol, _ = safe_mol_from_smiles(query_smiles) if query_mol is None: return None, "Invalid query SMILES" query_fp = np.array(fp_gen.GetFingerprint(query_mol)).reshape(1, -1) similarities = [] valid_molecules = [] for _, row in df_sample.iterrows(): mol, _ = safe_mol_from_smiles(row['smiles']) if mol is not None: fp = np.array(fp_gen.GetFingerprint(mol)).reshape(1, -1) similarity = cosine_similarity(query_fp, fp)[0][0] similarities.append(similarity) valid_molecules.append(row) elif model_type == 'gcn': # Use GCN feature similarity if not gcn_loaded: return None, "GCN model not loaded" # Get query features query_features, query_valid = extract_gcn_features([query_smiles]) if len(query_features) == 0: return None, "Could not extract GCN features for query molecule" query_feature = np.array(query_features[0]).reshape(1, -1) # Get features for all molecules in sample sample_smiles = df_sample['smiles'].tolist() sample_features, valid_smiles = extract_gcn_features(sample_smiles) if len(sample_features) == 0: return None, "No valid molecules found for GCN similarity search" # Map back to original data smiles_to_data = dict(zip(df_sample['smiles'], df_sample.to_dict('records'))) similarities = [] valid_molecules = [] for i, smiles in enumerate(valid_smiles): feature = np.array(sample_features[i]).reshape(1, -1) similarity = cosine_similarity(query_feature, feature)[0][0] similarities.append(similarity) valid_molecules.append(smiles_to_data[smiles]) if len(similarities) == 0: return None, "No valid molecules for comparison" # Get top similar molecules similarity_df = pd.DataFrame(valid_molecules) similarity_df['Similarity'] = similarities similarity_df = similarity_df.sort_values('Similarity', ascending=False) return similarity_df.head(top_k), None except Exception as e: return None, f"Similarity search failed: {str(e)}" def uncertainty_quantification(smiles, model_type='fp', num_samples=100): """Estimate prediction uncertainty using Monte Carlo dropout""" try: mol, _ = safe_mol_from_smiles(smiles) if mol is None: return None, None, None, "Invalid SMILES" if model_type == 'fp': fp = fp_gen.GetFingerprint(mol) fp_array = np.array(fp).reshape(1, -1) tensor = torch.tensor(fp_array).float() # Enable training mode for Monte Carlo dropout fp_model.train() predictions = [] for _ in range(num_samples): with torch.no_grad(): logits = fp_model(tensor) prob = torch.sigmoid(logits).item() predictions.append(prob) fp_model.eval() # Back to evaluation mode mean_pred = np.mean(predictions) std_pred = np.std(predictions) confidence_interval = ( np.percentile(predictions, 2.5), np.percentile(predictions, 97.5) ) return mean_pred, std_pred, confidence_interval, None elif model_type == 'gcn': # Check if molecule is supported by GCN SUPPORTED_ATOMS = {1, 6, 7, 8, 9, 16, 17, 35, 53} if not all(atom.GetAtomicNum() in SUPPORTED_ATOMS for atom in mol.GetAtoms()): return None, None, None, "Molecule contains unsupported atoms for GCN model" graph = smiles_to_graph(smiles) if graph is None: return None, None, None, "Could not convert SMILES to graph" batch = Batch.from_data_list([graph]) # Enable training mode for Monte Carlo dropout gcn_model.train() predictions = [] for _ in range(num_samples): with torch.no_grad(): logits = gcn_model(batch) prob = torch.sigmoid(logits).item() predictions.append(prob) gcn_model.eval() # Back to evaluation mode mean_pred = np.mean(predictions) std_pred = np.std(predictions) confidence_interval = ( np.percentile(predictions, 2.5), np.percentile(predictions, 97.5) ) return mean_pred, std_pred, confidence_interval, None except Exception as e: return None, None, None, f"Uncertainty quantification failed: {str(e)}" @st.cache_data(show_spinner="Processing batch predictions...") def process_batch_predictions(uploaded_df, model_type='fp'): """Process batch predictions from uploaded CSV with optimizations for large datasets""" try: results = [] total_count = len(uploaded_df) # Create progress tracking progress_container = st.container() progress_bar = progress_container.progress(0) status_text = progress_container.empty() # Process in chunks for better memory management and progress updates chunk_size = min(100, max(10, total_count // 20)) # Adaptive chunk size start_time = time.time() for i, (_, row) in enumerate(uploaded_df.iterrows()): smiles = row.get('smiles', row.get('SMILES', '')) if model_type == 'fp': pred, prob = predict_fp(smiles) # Add uncertainty quantification mean_pred, std_pred, ci, _ = uncertainty_quantification(smiles, 'fp') results.append({ 'SMILES': smiles, 'Prediction': pred, 'Probability': prob, 'Mean_Probability': mean_pred, 'Std_Probability': std_pred, 'CI_Lower': ci[0] if ci else None, 'CI_Upper': ci[1] if ci else None }) else: # GCN model pred, prob = predict_gcn(smiles) # Add uncertainty quantification for GCN mean_pred, std_pred, ci, _ = uncertainty_quantification(smiles, 'gcn') results.append({ 'SMILES': smiles, 'Prediction': pred, 'Probability': prob, 'Mean_Probability': mean_pred, 'Std_Probability': std_pred, 'CI_Lower': ci[0] if ci else None, 'CI_Upper': ci[1] if ci else None }) # Update progress every chunk_size molecules if (i + 1) % chunk_size == 0 or i == total_count - 1: progress = (i + 1) / total_count progress_bar.progress(progress) # Calculate ETA elapsed_time = time.time() - start_time if i > 0: avg_time_per_molecule = elapsed_time / (i + 1) remaining_molecules = total_count - (i + 1) eta_seconds = avg_time_per_molecule * remaining_molecules eta_minutes = eta_seconds / 60 status_text.text(f"Processed {i+1:,}/{total_count:,} molecules ({progress:.1%}) - ETA: {eta_minutes:.1f} min") else: status_text.text(f"Processed {i+1:,}/{total_count:,} molecules ({progress:.1%})") # Clean up progress indicators progress_bar.empty() status_text.empty() return pd.DataFrame(results), None except Exception as e: return None, f"Batch processing failed: {str(e)}" # ------------------- Prediction Cache ------------------- @st.cache_data(show_spinner="Generating predictions...") def predict_fp(smiles): try: mol, warning = safe_mol_from_smiles(smiles) if mol is None: return f"Invalid SMILES: {warning}", 0.0 fp = fp_gen.GetFingerprint(mol) fp_array = np.array(fp).reshape(1, -1) with torch.no_grad(): logits = fp_model(torch.tensor(fp_array).float()) prob = torch.sigmoid(logits).item() return ("Toxic" if prob > 0.5 else "Non-toxic"), prob except Exception as e: return f"Error: {str(e)}", 0.0 @st.cache_data(show_spinner="Generating sample predictions...") def get_sample_predictions(model_type='fp', sample_size=100): """Generate predictions for a small sample of the dataset for visualization""" sample_df = df.sample(n=min(sample_size, len(df)), random_state=42) preds = [] for smi in sample_df['smiles']: try: p = predict_fp(smi)[1] if model_type == 'fp' else predict_gcn(smi)[1] preds.append(p) except: preds.append(None) sample_df = sample_df.copy() sample_df[f'{model_type}_prob'] = preds return sample_df # Only generate sample predictions for visualization (much faster) sample_df_fp = get_sample_predictions('fp', 100) if fp_loaded else None sample_df_gcn = get_sample_predictions('gcn', 100) if gcn_loaded else None # ------------------- Evaluation Function ------------------- @st.cache_data(show_spinner="Evaluating model on test set...") def evaluate_gcn_test_set(model): graph_data = get_graph_data() test_loader = DataLoader(graph_data, batch_size=32) model.eval() all_preds, all_labels = [], [] with torch.no_grad(): for batch in test_loader: batch = batch.to("cpu") # Ensure on CPU out = model(batch) probs = torch.sigmoid(out) all_preds.extend(probs.cpu().numpy()) all_labels.extend(batch.y.cpu().numpy()) acc = accuracy_score(all_labels, (np.array(all_preds) > 0.5).astype(int)) roc = roc_auc_score(all_labels, all_preds) df_eval = pd.DataFrame({ "Predicted Probability": all_preds, "Label": ["Non-toxic" if i == 0 else "Toxic" for i in all_labels] }) fig = px.histogram(df_eval, x="Predicted Probability", color="Label", nbins=30, barmode="overlay", color_discrete_map={"Non-toxic": "green", "Toxic": "red"}, title="GCN Test Set - Probability Distribution") fig.update_layout(bargap=0.1) st.success(f"โ Accuracy: `{acc:.4f}`, ROC-AUC: `{roc:.4f}`") st.plotly_chart(fig, use_container_width=True) # ------------------- Tabs ------------------- tab1, tab2, tab3, tab4, tab5 = st.tabs([ "๐ฌ Fingerprint Model", "๐งฌ GCN Model", "โ๏ธ Model Comparison", "๐ง AI Interpretability", "๐ Advanced Analytics" ]) with tab1: st.subheader("Fingerprint-based Prediction") # Add full dataset evaluation section with st.expander("๐ Full Dataset Performance Analysis"): if st.button("๐ Run Full Dataset Evaluation (Fingerprint)", key="fp_eval"): with st.spinner("Evaluating fingerprint model on full dataset..."): fp_results = comprehensive_evaluation('fp') if fp_results: show_detailed_statistics(fp_results, 'fp') # Create and show plots fig_dist, fig_roc, fig_cm = create_comprehensive_plots(fp_results, 'fp') st.plotly_chart(fig_dist, use_container_width=True) col1, col2 = st.columns(2) with col1: st.plotly_chart(fig_roc, use_container_width=True) with col2: st.plotly_chart(fig_cm, use_container_width=True) # Download predictions csv_data = fp_results['predictions_df'].to_csv(index=False) st.download_button( "๐ฅ Download Full Predictions", csv_data, "fingerprint_full_predictions.csv", "text/csv" ) with st.form("fp_form"): smiles_fp = st.text_input("Enter SMILES", "CCO") show_debug_fp = st.checkbox("๐ Show Debug Info (raw score/logit)", key="fp_debug") predict_btn = st.form_submit_button("๐ Predict") if predict_btn: mol, warning = safe_mol_from_smiles(smiles_fp) if warning: st.info(f"โน๏ธ {warning}") if mol: fp = fp_gen.GetFingerprint(mol) arr = np.array(fp).reshape(1, -1) tensor = torch.tensor(arr).float() with torch.no_grad(): output = fp_model(tensor) prob = torch.sigmoid(output).item() raw_score = output.item() label = "Toxic" if prob > 0.5 else "Non-toxic" color = "red" if label == "Toxic" else "green" st.markdown(f"
{prob:.3f}
{prob:.3f}