# Safe monkey patch to fix Streamlit reloader crash due to torch.classes bug import types import torch try: # Check if torch.classes exists and fix the __path__ attribute if needed if hasattr(torch, 'classes') and not hasattr(torch.classes, "__path__"): torch.classes.__path__ = types.SimpleNamespace(_path=[]) except Exception: pass # Safe fallback if torch.classes doesn't exist or can't be patched # ------------------- Imports ------------------- import streamlit as st import numpy as np import pandas as pd import torch import time import torch.nn as nn import torch.nn.functional as F from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve from sklearn.decomposition import PCA from sklearn.cluster import KMeans from sklearn.manifold import TSNE from sklearn.metrics.pairwise import cosine_similarity # Optional SHAP import for interpretability try: import shap import matplotlib.pyplot as plt SHAP_AVAILABLE = True except ImportError: SHAP_AVAILABLE = False from rdkit import Chem from rdkit.Chem import rdMolDescriptors from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator from torch_geometric.data import Data from torch_geometric.nn import GCNConv, global_mean_pool from torch_geometric.loader import DataLoader import plotly.express as px from rdkit.Chem import Draw from torch_geometric.data import Batch from rdkit.Chem import Descriptors import time import gzip import io # ------------------- Models ------------------- class ToxicityNet(nn.Module): def __init__(self): super().__init__() self.model = nn.Sequential( nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 128), nn.ReLU(), nn.Linear(128, 1) ) def forward(self, x): return self.model(x) class RichGCNModel(nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(10, 64) self.bn1 = nn.BatchNorm1d(64) self.conv2 = GCNConv(64, 128) self.bn2 = nn.BatchNorm1d(128) self.dropout = nn.Dropout(0.2) self.fc1 = nn.Linear(128, 64) self.fc2 = nn.Linear(64, 1) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.bn1(self.conv1(x, edge_index))) x = F.relu(self.bn2(self.conv2(x, edge_index))) x = global_mean_pool(x, batch) x = self.dropout(x) x = F.relu(self.fc1(x)) return self.fc2(x) # ------------------- UI Setup ------------------- st.set_page_config(layout="wide", page_title="Drug Toxicity Predictor") st.title("๐Ÿงช Drug Toxicity Prediction Dashboard") # Performance info st.info("๐Ÿš€ **Fast Loading**: This app uses optimized sampling for quick startup. Full dataset processing happens on-demand.") # ------------------- Load Models with Spinner ------------------- # ------------------- Load Models with Temporary Messages ------------------- fp_model = ToxicityNet() gcn_model = RichGCNModel() fp_loaded = gcn_loaded = False # Load Fingerprint Model try: fp_model.load_state_dict(torch.load("tox_model.pt", map_location=torch.device("cpu"))) fp_model.eval() fp_loaded = True except Exception as e: st.warning(f"โš ๏ธ Fingerprint model not loaded: {e}") # Load GCN Model try: # Load the state dict state_dict = torch.load("gcn_model_improved.pt", map_location=torch.device("cpu")) # Handle DataParallel wrapper - fix nested module prefixes new_state_dict = {} for key, value in state_dict.items(): # Handle cases like "bn1.module.weight" -> "bn1.weight" if ".module." in key: new_key = key.replace(".module.", ".") new_state_dict[new_key] = value # Handle cases like "module.conv1.weight" -> "conv1.weight" elif key.startswith("module."): new_key = key[7:] # Remove "module." prefix new_state_dict[new_key] = value else: new_state_dict[key] = value gcn_model.load_state_dict(new_state_dict) gcn_model.eval() gcn_loaded = True except Exception as e: st.warning(f"โš ๏ธ GCN model not loaded: {e}") # Load Best Threshold try: best_threshold = float(np.load("gcn_best_threshold.npy")) except Exception as e: best_threshold = 0.5 st.warning(f"โš ๏ธ Using default threshold (0.5) for GCN model. Reason: {e}") # ------------------- Enhanced SMILES Processing ------------------- def preprocess_smiles(smiles): """Preprocess SMILES to handle complex cases""" if not smiles or not isinstance(smiles, str): return None, "Empty or invalid SMILES" # Remove whitespace smiles = smiles.strip() # Handle multi-component molecules (salts, etc.) if '.' in smiles: components = smiles.split('.') # Take the largest component (usually the main molecule) main_component = max(components, key=len) return main_component, f"Multi-component molecule detected. Using largest component: {main_component}" return smiles, None def safe_mol_from_smiles(smiles): """Safely create molecule from SMILES with better error handling""" try: processed_smiles, warning = preprocess_smiles(smiles) if processed_smiles is None: return None, warning mol = Chem.MolFromSmiles(processed_smiles) if mol is None: # Try to sanitize try: mol = Chem.MolFromSmiles(processed_smiles, sanitize=False) if mol is not None: Chem.SanitizeMol(mol) except: pass return mol, warning except Exception as e: return None, f"Error processing SMILES: {str(e)}" # ------------------- Utility Functions ------------------- fp_gen = GetMorganGenerator(radius=2, fpSize=1024) def get_molecule_info(mol): return { "Formula": Chem.rdMolDescriptors.CalcMolFormula(mol), "Weight": round(Descriptors.MolWt(mol), 2), "Atoms": mol.GetNumAtoms(), "Bonds": mol.GetNumBonds() } def predict_gcn(smiles): mol, warning = safe_mol_from_smiles(smiles) if mol is None: return None, None graph = smiles_to_graph(smiles) if graph is None: return None, None batch = Batch.from_data_list([graph]) with torch.no_grad(): out = gcn_model(batch) prob = torch.sigmoid(out).item() return ("Toxic" if prob > best_threshold else "Non-toxic"), prob def atom_feats(atom): return [ atom.GetAtomicNum(), atom.GetDegree(), atom.GetFormalCharge(), atom.GetNumExplicitHs(), atom.GetNumImplicitHs(), atom.GetIsAromatic(), atom.GetMass(), int(atom.IsInRing()), int(atom.GetChiralTag()), int(atom.GetHybridization()) ] def smiles_to_graph(smiles, label=None): mol, warning = safe_mol_from_smiles(smiles) if mol is None or mol.GetNumAtoms() == 0: return None atoms = [atom_feats(a) for a in mol.GetAtoms()] if not atoms: return None # No atoms present edges = [] for b in mol.GetBonds(): i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() edges += [[i, j], [j, i]] # Handle molecules with no bonds (e.g. single atom) if len(edges) == 0: edges = [[0, 0]] edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() x = torch.tensor(atoms, dtype=torch.float) batch = torch.zeros(x.size(0), dtype=torch.long) data = Data(x=x, edge_index=edge_index, batch=batch) if label is not None: data.y = torch.tensor([label], dtype=torch.float) return data # def predict_gcn(smiles): # graph = smiles_to_graph(smiles) # if graph is None or graph.x.size(0) == 0: # return None, None # batch = Batch.from_data_list([graph]) # with torch.no_grad(): # out = gcn_model(batch) # raw = out.item() # prob = torch.sigmoid(out).item() # print(f"Raw logit: {raw:.4f}, Prob: {prob:.4f}") # return ("Toxic" if prob > best_threshold else "Non-toxic"), prob # ------------------- Load Dataset (Enhanced Full Dataset Loading) ------------------- @st.cache_data(show_spinner="Loading full dataset...") def load_full_dataset(): """Load and preprocess the complete Tox21 dataset""" df = pd.read_csv("tox21.csv")[['smiles', 'SR-HSE']].dropna() df = df[df['SR-HSE'].isin([0, 1])].reset_index(drop=True) st.info(f"๐Ÿ“Š Loaded {len(df)} molecules from Tox21 dataset") # Validate all SMILES for both models def validate_smiles(smi): mol = Chem.MolFromSmiles(smi) if mol is None: return False, False # Check fingerprint compatibility fp_valid = True try: fp_gen.GetFingerprint(mol) except: fp_valid = False # Check GCN compatibility gcn_valid = True try: graph = smiles_to_graph(smi) if graph is None: gcn_valid = False # Check supported atoms SUPPORTED_ATOMS = {1, 6, 7, 8, 9, 16, 17, 35, 53} if not all(atom.GetAtomicNum() in SUPPORTED_ATOMS for atom in mol.GetAtoms()): gcn_valid = False except: gcn_valid = False return fp_valid, gcn_valid # Process validation with progress valid_results = [] progress_bar = st.progress(0) status_text = st.empty() for i, smi in enumerate(df['smiles']): fp_valid, gcn_valid = validate_smiles(smi) valid_results.append((fp_valid, gcn_valid)) if i % 100 == 0: progress_bar.progress((i + 1) / len(df)) status_text.text(f"Validated {i + 1}/{len(df)} molecules...") progress_bar.empty() status_text.empty() # Add validation columns df['fp_valid'] = [r[0] for r in valid_results] df['gcn_valid'] = [r[1] for r in valid_results] # Statistics fp_valid_count = df['fp_valid'].sum() gcn_valid_count = df['gcn_valid'].sum() both_valid_count = (df['fp_valid'] & df['gcn_valid']).sum() st.success(f""" โœ… **Dataset Validation Complete:** - Fingerprint Model: {fp_valid_count}/{len(df)} valid molecules ({fp_valid_count/len(df)*100:.1f}%) - GCN Model: {gcn_valid_count}/{len(df)} valid molecules ({gcn_valid_count/len(df)*100:.1f}%) - Both Models: {both_valid_count}/{len(df)} valid molecules ({both_valid_count/len(df)*100:.1f}%) """) return df @st.cache_data(show_spinner="Loading dataset...") def load_dataset(): df = pd.read_csv("tox21.csv")[['smiles', 'SR-HSE']].dropna() df = df[df['SR-HSE'].isin([0, 1])].reset_index(drop=True) # โœ… Filter invalid or unprocessable SMILES (but only for a sample to speed up) def is_valid_graph(smi): mol = Chem.MolFromSmiles(smi) return mol is not None and smiles_to_graph(smi) is not None # Only validate a subset for faster loading sample_size = min(1000, len(df)) df_sample = df.sample(n=sample_size, random_state=42) valid_indices = df_sample[df_sample['smiles'].apply(is_valid_graph)].index df_valid = df.loc[valid_indices].reset_index(drop=True) return df_valid # Load dataset lazily (sample for quick startup) df = load_dataset() # Full dataset will be loaded on-demand df_full = None @st.cache_data(show_spinner="Creating graph dataset...") def create_graph_dataset_cached(smiles_list, labels): data_list = [] for smi, label in zip(smiles_list, labels): data = smiles_to_graph(smi, label) if data: data_list.append(data) return data_list # Only create graph dataset when needed for evaluation def get_graph_data(): return create_graph_dataset_cached(df['smiles'].tolist(), df['SR-HSE'].tolist()) # ------------------- Plot Function ------------------- def plot_distribution(df_to_plot, model_type, input_prob=None): col = 'fp_prob' if model_type == 'fp' else 'gcn_prob' df_plot = df_to_plot[df_to_plot[col].notna()].copy() df_plot["Label"] = df_plot["SR-HSE"].map({0: "Non-toxic", 1: "Toxic"}) fig = px.histogram(df_plot, x=col, color="Label", nbins=30, barmode="overlay", color_discrete_map={"Non-toxic": "green", "Toxic": "red"}, title=f"{model_type.upper()} Model - Sample Distribution (n={len(df_plot)})") if input_prob: fig.add_vline(x=input_prob, line_dash="dash", line_color="yellow", annotation_text="Your Input") return fig # ------------------- Enhanced Prediction Functions ------------------- @st.cache_data(show_spinner="Generating full dataset predictions...") def get_full_predictions(model_type='fp', max_molecules=None): """Generate predictions for the full dataset""" global df_full if df_full is None: df_full = load_full_dataset() if model_type == 'fp': valid_df = df_full[df_full['fp_valid']].copy() if not fp_loaded: st.error("Fingerprint model not loaded") return None else: valid_df = df_full[df_full['gcn_valid']].copy() if not gcn_loaded: st.error("GCN model not loaded") return None if max_molecules: valid_df = valid_df.head(max_molecules) predictions = [] probabilities = [] progress_bar = st.progress(0) status_text = st.empty() for i, (_, row) in enumerate(valid_df.iterrows()): try: if model_type == 'fp': pred, prob = predict_fp(row['smiles']) else: pred, prob = predict_gcn(row['smiles']) predictions.append(pred) probabilities.append(prob) except Exception as e: predictions.append("Error") probabilities.append(None) if i % 50 == 0: progress_bar.progress((i + 1) / len(valid_df)) status_text.text(f"Processed {i + 1}/{len(valid_df)} molecules...") progress_bar.empty() status_text.empty() valid_df[f'{model_type}_prediction'] = predictions valid_df[f'{model_type}_probability'] = probabilities return valid_df @st.cache_data(show_spinner="Evaluating model performance...") def comprehensive_evaluation(model_type='fp'): """Comprehensive evaluation on full dataset""" predictions_df = get_full_predictions(model_type) if predictions_df is None: return None # Remove error predictions clean_df = predictions_df[ (predictions_df[f'{model_type}_prediction'] != "Error") & (predictions_df[f'{model_type}_probability'].notna()) ].copy() # Convert predictions to binary if model_type == 'fp': threshold = 0.5 else: threshold = best_threshold y_true = clean_df['SR-HSE'].values y_prob = clean_df[f'{model_type}_probability'].values y_pred = (y_prob > threshold).astype(int) # Calculate metrics accuracy = accuracy_score(y_true, y_pred) precision = precision_score(y_true, y_pred) recall = recall_score(y_true, y_pred) f1 = f1_score(y_true, y_pred) roc_auc = roc_auc_score(y_true, y_prob) # Confusion matrix cm = confusion_matrix(y_true, y_pred) results = { 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'roc_auc': roc_auc, 'confusion_matrix': cm, 'predictions_df': clean_df, 'threshold': threshold, 'n_samples': len(clean_df) } return results def create_comprehensive_plots(results, model_type): """Create comprehensive visualization plots""" df = results['predictions_df'] prob_col = f'{model_type}_probability' # 1. Probability Distribution fig_dist = px.histogram( df, x=prob_col, color=df['SR-HSE'].map({0: "Non-toxic", 1: "Toxic"}), nbins=50, barmode="overlay", color_discrete_map={"Non-toxic": "green", "Toxic": "red"}, title=f"{model_type.upper()} Model - Full Dataset Probability Distribution (n={len(df)})" ) fig_dist.add_vline(x=results['threshold'], line_dash="dash", line_color="black", annotation_text=f"Threshold ({results['threshold']:.3f})") # 2. ROC Curve fpr, tpr, _ = roc_curve(df['SR-HSE'], df[prob_col]) fig_roc = px.line( x=fpr, y=tpr, title=f"{model_type.upper()} Model - ROC Curve (AUC: {results['roc_auc']:.3f})", labels={'x': 'False Positive Rate', 'y': 'True Positive Rate'} ) fig_roc.add_scatter(x=[0, 1], y=[0, 1], mode='lines', line=dict(dash='dash', color='gray'), name='Random Classifier', showlegend=False) # 3. Confusion Matrix cm = results['confusion_matrix'] fig_cm = px.imshow( cm, text_auto=True, aspect="auto", title="Confusion Matrix", labels=dict(x="Predicted", y="Actual"), x=['Non-toxic', 'Toxic'], y=['Non-toxic', 'Toxic'] ) return fig_dist, fig_roc, fig_cm def show_detailed_statistics(results, model_type): """Show detailed performance statistics""" col1, col2, col3 = st.columns(3) with col1: st.metric("Accuracy", f"{results['accuracy']:.3f}") st.metric("Precision", f"{results['precision']:.3f}") with col2: st.metric("Recall", f"{results['recall']:.3f}") st.metric("F1-Score", f"{results['f1']:.3f}") with col3: st.metric("ROC-AUC", f"{results['roc_auc']:.3f}") st.metric("Samples", f"{results['n_samples']:,}") # Confusion Matrix Details cm = results['confusion_matrix'] tn, fp, fn, tp = cm.ravel() st.subheader("Detailed Classification Results") col1, col2 = st.columns(2) with col1: st.write("**True Negatives (Correct Non-toxic):**", tn) st.write("**False Positives (Incorrect Toxic):**", fp) with col2: st.write("**False Negatives (Missed Toxic):**", fn) st.write("**True Positives (Correct Toxic):**", tp) # Calculate additional metrics specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 st.write(f"**Specificity (True Negative Rate):** {specificity:.3f}") st.write(f"**Sensitivity (True Positive Rate):** {sensitivity:.3f}") # ------------------- Advanced Analytics ------------------- # ------------------- Model Performance Insights Functions ------------------- def generate_single_model_insights(model_type, sample_size): """Generate performance insights for a single model""" model_name = "Fingerprint" if model_type == 'fp' else "GCN" threshold = 0.5 if model_type == 'fp' else best_threshold # Sample data for analysis sample_data = df.sample(n=min(sample_size, len(df))) predictions = [] probabilities = [] confidences = [] valid_count = 0 for _, row in sample_data.iterrows(): try: if model_type == 'fp': pred, prob = predict_fp(row['smiles']) else: pred, prob = predict_gcn(row['smiles']) if pred and prob is not None: predictions.append(pred) probabilities.append(prob) confidences.append(max(prob, 1-prob) if threshold == 0.5 else max(prob - threshold, threshold - prob) + threshold) valid_count += 1 except: predictions.append(None) probabilities.append(None) confidences.append(None) if valid_count == 0: st.error(f"โŒ No valid predictions generated for {model_name} model") return # Create analysis dataframe analysis_df = sample_data.copy() analysis_df['Predicted'] = predictions analysis_df['Probability'] = probabilities analysis_df['Confidence'] = confidences analysis_df['Toxicity_Label'] = analysis_df['SR-HSE'].map({0: 'Non-Toxic', 1: 'Toxic'}) # Filter valid predictions valid_df = analysis_df[analysis_df['Predicted'].notna()].copy() if len(valid_df) == 0: st.error(f"โŒ No valid predictions for analysis") return st.success(f"โœ… **{model_name} Model Insights** - Analyzed {len(valid_df)} valid predictions") # Performance metrics col1, col2, col3, col4 = st.columns(4) # Calculate basic metrics y_true = valid_df['SR-HSE'].values y_pred_binary = (valid_df['Probability'] > threshold).astype(int) y_prob = valid_df['Probability'].values accuracy = accuracy_score(y_true, y_pred_binary) precision = precision_score(y_true, y_pred_binary) recall = recall_score(y_true, y_pred_binary) roc_auc = roc_auc_score(y_true, y_prob) with col1: st.metric("Accuracy", f"{accuracy:.3f}") with col2: st.metric("Precision", f"{precision:.3f}") with col3: st.metric("Recall", f"{recall:.3f}") with col4: st.metric("ROC-AUC", f"{roc_auc:.3f}") # Visualization col1, col2 = st.columns(2) with col1: # Confidence distribution fig_conf = px.histogram( valid_df, x='Confidence', color='Predicted', title=f"{model_name} Model - Confidence Distribution", nbins=20, color_discrete_map={'Toxic': 'red', 'Non-toxic': 'green'} ) st.plotly_chart(fig_conf, use_container_width=True) with col2: # Probability distribution by true label fig_prob = px.histogram( valid_df, x='Probability', color='Toxicity_Label', title=f"{model_name} Model - Probability by True Label", nbins=20, color_discrete_map={'Toxic': 'red', 'Non-Toxic': 'green'} ) fig_prob.add_vline(x=threshold, line_dash="dash", annotation_text=f"Threshold ({threshold})") st.plotly_chart(fig_prob, use_container_width=True) # Confidence-based accuracy analysis valid_df['Correct'] = (valid_df['Predicted'] == valid_df['Toxicity_Label']) # Binned confidence analysis valid_df['Confidence_Bin'] = pd.cut(valid_df['Confidence'], bins=5, labels=['Very Low', 'Low', 'Medium', 'High', 'Very High']) conf_analysis = valid_df.groupby('Confidence_Bin').agg({ 'Correct': ['mean', 'count'], 'Confidence': 'mean' }).round(3) st.subheader(f"๐ŸŽฏ {model_name} Model Reliability Analysis") # High/Low confidence metrics high_conf_mask = valid_df['Confidence'] > 0.8 low_conf_mask = valid_df['Confidence'] < 0.6 high_conf_accuracy = valid_df[high_conf_mask]['Correct'].mean() if high_conf_mask.any() else 0 low_conf_accuracy = valid_df[low_conf_mask]['Correct'].mean() if low_conf_mask.any() else 0 avg_confidence = valid_df['Confidence'].mean() col1, col2, col3 = st.columns(3) with col1: st.metric("High Confidence Accuracy (>0.8)", f"{high_conf_accuracy:.1%}", f"{high_conf_mask.sum()} samples") with col2: st.metric("Low Confidence Accuracy (<0.6)", f"{low_conf_accuracy:.1%}", f"{low_conf_mask.sum()} samples") with col3: st.metric("Average Confidence", f"{avg_confidence:.3f}") # Model-specific insights st.markdown(f""" **๐Ÿ’ก {model_name} Model Insights:** **Strengths:** - {"High precision for confident predictions" if high_conf_accuracy > 0.9 else "Moderate precision for confident predictions"} - {"Good calibration between confidence and accuracy" if abs(high_conf_accuracy - low_conf_accuracy) > 0.2 else "Confidence and accuracy correlation needs improvement"} **Recommendations:** - {"Review low-confidence predictions manually" if low_conf_mask.sum() > 10 else "Most predictions are high confidence"} - {"Model shows good reliability" if accuracy > 0.8 else "Consider model improvement or ensemble methods"} - {"Threshold optimization may improve performance" if abs(precision - recall) > 0.1 else "Balanced precision-recall performance"} **Model Characteristics:** - **Decision Threshold:** {threshold} - **Feature Space:** {"1024-dimensional molecular fingerprints" if model_type == 'fp' else "128-dimensional graph neural network features"} - **Architecture:** {"Feed-forward neural network with dropout" if model_type == 'fp' else "Graph convolutional network with batch normalization"} """) def generate_dual_model_insights(sample_size): """Generate comparative insights between both models""" st.success("โœ… **Dual Model Comparison Analysis**") # Sample data for analysis sample_data = df.sample(n=min(sample_size, len(df))) fp_predictions = [] fp_probabilities = [] gcn_predictions = [] gcn_probabilities = [] valid_indices = [] for idx, (_, row) in enumerate(sample_data.iterrows()): try: fp_pred, fp_prob = predict_fp(row['smiles']) gcn_pred, gcn_prob = predict_gcn(row['smiles']) if fp_pred and fp_prob is not None and gcn_pred and gcn_prob is not None: fp_predictions.append(fp_pred) fp_probabilities.append(fp_prob) gcn_predictions.append(gcn_pred) gcn_probabilities.append(gcn_prob) valid_indices.append(idx) except: continue if len(valid_indices) == 0: st.error("โŒ No valid predictions from both models for comparison") return # Create comparison dataframe valid_sample = sample_data.iloc[valid_indices].copy() valid_sample['FP_Prediction'] = fp_predictions valid_sample['FP_Probability'] = fp_probabilities valid_sample['GCN_Prediction'] = gcn_predictions valid_sample['GCN_Probability'] = gcn_probabilities valid_sample['Toxicity_Label'] = valid_sample['SR-HSE'].map({0: 'Non-Toxic', 1: 'Toxic'}) st.info(f"๐Ÿ“Š Comparing {len(valid_sample)} molecules across both models") # Model agreement analysis agreement = (valid_sample['FP_Prediction'] == valid_sample['GCN_Prediction']).mean() # Calculate individual model accuracies fp_accuracy = accuracy_score(valid_sample['SR-HSE'], (valid_sample['FP_Probability'] > 0.5).astype(int)) gcn_accuracy = accuracy_score(valid_sample['SR-HSE'], (valid_sample['GCN_Probability'] > best_threshold).astype(int)) # Correlation between probabilities prob_correlation = np.corrcoef(valid_sample['FP_Probability'], valid_sample['GCN_Probability'])[0, 1] # Display comparison metrics col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Model Agreement", f"{agreement:.1%}") with col2: st.metric("FP Accuracy", f"{fp_accuracy:.3f}") with col3: st.metric("GCN Accuracy", f"{gcn_accuracy:.3f}") with col4: st.metric("Probability Correlation", f"{prob_correlation:.3f}") # Visualizations col1, col2 = st.columns(2) with col1: # Probability correlation scatter plot fig_corr = px.scatter( valid_sample, x='FP_Probability', y='GCN_Probability', color='Toxicity_Label', title="Model Probability Correlation", labels={'x': 'Fingerprint Probability', 'y': 'GCN Probability'}, color_discrete_map={'Toxic': 'red', 'Non-Toxic': 'green'} ) fig_corr.add_scatter(x=[0, 1], y=[0, 1], mode='lines', line=dict(dash='dash', color='gray'), name='Perfect Correlation', showlegend=False) st.plotly_chart(fig_corr, use_container_width=True) with col2: # Agreement by prediction confidence valid_sample['Agreement'] = valid_sample['FP_Prediction'] == valid_sample['GCN_Prediction'] valid_sample['Avg_Confidence'] = (valid_sample['FP_Probability'].apply(lambda x: max(x, 1-x)) + valid_sample['GCN_Probability'].apply(lambda x: max(x, 1-x))) / 2 conf_bins = pd.cut(valid_sample['Avg_Confidence'], bins=5, labels=['Very Low', 'Low', 'Medium', 'High', 'Very High']) agreement_by_conf = valid_sample.groupby(conf_bins)['Agreement'].mean() fig_agreement = px.bar( x=agreement_by_conf.index.astype(str), y=agreement_by_conf.values, title="Model Agreement by Confidence Level", labels={'x': 'Average Confidence Bin', 'y': 'Agreement Rate'} ) st.plotly_chart(fig_agreement, use_container_width=True) # Disagreement analysis disagreement_cases = valid_sample[~valid_sample['Agreement']].copy() st.subheader("๐Ÿ” Model Disagreement Analysis") if len(disagreement_cases) > 0: st.write(f"**Found {len(disagreement_cases)} disagreement cases ({len(disagreement_cases)/len(valid_sample)*100:.1f}% of predictions)**") # Show some example disagreements st.write("**Example Disagreement Cases:**") display_cols = ['smiles', 'FP_Prediction', 'FP_Probability', 'GCN_Prediction', 'GCN_Probability', 'Toxicity_Label'] st.dataframe(disagreement_cases[display_cols].head(5)) # Disagreement patterns fp_toxic_gcn_nontoxic = ((disagreement_cases['FP_Prediction'] == 'Toxic') & (disagreement_cases['GCN_Prediction'] == 'Non-toxic')).sum() gcn_toxic_fp_nontoxic = ((disagreement_cases['GCN_Prediction'] == 'Toxic') & (disagreement_cases['FP_Prediction'] == 'Non-toxic')).sum() st.write(f"**Disagreement Patterns:**") st.write(f"- FP says Toxic, GCN says Non-toxic: {fp_toxic_gcn_nontoxic} cases") st.write(f"- GCN says Toxic, FP says Non-toxic: {gcn_toxic_fp_nontoxic} cases") else: st.success("๐ŸŽ‰ Perfect agreement between models on this sample!") # Ensemble performance st.subheader("๐Ÿค Ensemble Analysis") # Simple ensemble: average probabilities ensemble_prob = (valid_sample['FP_Probability'] + valid_sample['GCN_Probability']) / 2 ensemble_pred = (ensemble_prob > 0.5).astype(int) ensemble_accuracy = accuracy_score(valid_sample['SR-HSE'], ensemble_pred) # Voting ensemble fp_votes = (valid_sample['FP_Probability'] > 0.5).astype(int) gcn_votes = (valid_sample['GCN_Probability'] > best_threshold).astype(int) voting_pred = (fp_votes + gcn_votes >= 1).astype(int) # At least one model predicts toxic voting_accuracy = accuracy_score(valid_sample['SR-HSE'], voting_pred) col1, col2, col3 = st.columns(3) with col1: st.metric("Average Ensemble Accuracy", f"{ensemble_accuracy:.3f}") with col2: st.metric("Voting Ensemble Accuracy", f"{voting_accuracy:.3f}") with col3: best_single = max(fp_accuracy, gcn_accuracy) improvement = max(ensemble_accuracy, voting_accuracy) - best_single st.metric("Best Improvement", f"+{improvement:.3f}" if improvement > 0 else f"{improvement:.3f}") st.markdown(""" **๐Ÿ’ก Dual Model Insights:** **Model Complementarity:** - **Fingerprint Model**: Excels at chemical similarity patterns, traditional QSAR relationships - **GCN Model**: Captures structural topology, graph-based molecular patterns - **Together**: Provide comprehensive molecular understanding **Ensemble Recommendations:** - Use **voting ensemble** for high-stakes decisions (conservative approach) - Use **average ensemble** for balanced sensitivity-specificity - Manual review recommended for disagreement cases **Best Practices:** - High agreement cases: Trust the prediction - Disagreement cases: Consider additional validation - Low confidence from both: Seek experimental validation """) @st.cache_data(show_spinner="Generating SHAP explanations...") def get_shap_explanations(smiles_list, model_type='fp', num_samples=100): """Generate SHAP explanations for model interpretability""" if not SHAP_AVAILABLE: return None, "SHAP library not available. Install with: pip install shap" try: if model_type == 'fp': # Create a sample of fingerprints for SHAP analysis sample_smiles = smiles_list[:num_samples] fingerprints = [] for smiles in sample_smiles: mol, _ = safe_mol_from_smiles(smiles) if mol is not None: fp = fp_gen.GetFingerprint(mol) fingerprints.append(np.array(fp)) if len(fingerprints) == 0: return None, "No valid molecules for SHAP analysis" X = np.array(fingerprints) # Create SHAP explainer for fingerprint model def model_predict(X): with torch.no_grad(): preds = [] for fp in X: tensor = torch.tensor(fp.reshape(1, -1)).float() pred = torch.sigmoid(fp_model(tensor)).item() preds.append(pred) return np.array(preds) # Use a subset for background background = X[:min(10, len(X))] explainer = shap.Explainer(model_predict, background) shap_values = explainer(X[:min(20, len(X))]) return shap_values, None elif model_type == 'gcn': # Create a sample of GCN features for SHAP analysis if not gcn_loaded: return None, "GCN model not loaded" sample_smiles = smiles_list[:num_samples] gcn_features, valid_smiles = extract_gcn_features(sample_smiles) if len(gcn_features) == 0: return None, "No valid molecules for GCN SHAP analysis" X = np.array(gcn_features) # Create SHAP explainer for GCN model using extracted features def gcn_model_predict(X): with torch.no_grad(): preds = [] for features in X: # Convert features back to a format the model can use # We'll use the extracted features directly for explanation # This represents the contribution of the learned features tensor = torch.tensor(features.reshape(1, -1)).float() # Apply the final layers of GCN model (fc1 -> fc2) x = torch.relu(gcn_model.fc1(tensor)) pred = torch.sigmoid(gcn_model.fc2(x)).item() preds.append(pred) return np.array(preds) # Use a subset for background background = X[:min(10, len(X))] explainer = shap.Explainer(gcn_model_predict, background) shap_values = explainer(X[:min(20, len(X))]) return shap_values, None except Exception as e: return None, f"SHAP analysis failed: {str(e)}" def extract_gcn_features(smiles_list): """Extract features from GCN model's hidden layers""" features = [] valid_smiles = [] for smiles in smiles_list: try: mol, _ = safe_mol_from_smiles(smiles) if mol is None: continue # Check if molecule is supported by GCN SUPPORTED_ATOMS = {1, 6, 7, 8, 9, 16, 17, 35, 53} if not all(atom.GetAtomicNum() in SUPPORTED_ATOMS for atom in mol.GetAtoms()): continue graph = smiles_to_graph(smiles) if graph is None: continue batch = Batch.from_data_list([graph]) # Extract features from GCN model's hidden layers with torch.no_grad(): # Get the second-to-last layer features (before final classification) x = batch.x edge_index = batch.edge_index batch_idx = batch.batch # Forward through GCN layers x = gcn_model.conv1(x, edge_index) x = torch.relu(x) x = gcn_model.conv2(x, edge_index) x = torch.relu(x) # Global pooling to get graph-level features graph_features = global_mean_pool(x, batch_idx) features.append(graph_features.cpu().numpy().flatten()) valid_smiles.append(smiles) except Exception: continue return features, valid_smiles @st.cache_data(show_spinner="Performing PCA analysis...") def perform_pca_analysis(df_sample, model_type='fp'): """Perform Principal Component Analysis on molecular features""" try: features = [] labels = [] smiles_list = [] if model_type == 'fp': # Use molecular fingerprints for _, row in df_sample.iterrows(): mol, _ = safe_mol_from_smiles(row['smiles']) if mol is not None: fp = fp_gen.GetFingerprint(mol) features.append(np.array(fp)) labels.append(row['SR-HSE']) smiles_list.append(row['smiles']) elif model_type == 'gcn': # Use GCN learned features if not gcn_loaded: return None, "GCN model not loaded" sample_smiles = df_sample['smiles'].tolist() gcn_features, valid_smiles = extract_gcn_features(sample_smiles) # Map back to labels smiles_to_label = dict(zip(df_sample['smiles'], df_sample['SR-HSE'])) for i, smiles in enumerate(valid_smiles): features.append(gcn_features[i]) labels.append(smiles_to_label[smiles]) smiles_list.append(smiles) if len(features) < 10: return None, f"Insufficient valid molecules for PCA (got {len(features)}, need at least 10)" X = np.array(features) # Perform PCA pca = PCA(n_components=2) X_pca = pca.fit_transform(X) # Create DataFrame for plotting pca_df = pd.DataFrame({ 'PC1': X_pca[:, 0], 'PC2': X_pca[:, 1], 'Label': ['Toxic' if l == 1 else 'Non-toxic' for l in labels], 'SMILES': smiles_list }) explained_variance = pca.explained_variance_ratio_ return { 'pca_df': pca_df, 'explained_variance': explained_variance, 'pca_model': pca, 'feature_type': 'Fingerprints' if model_type == 'fp' else 'GCN Features' }, None except Exception as e: return None, f"PCA analysis failed: {str(e)}" @st.cache_data(show_spinner="Performing clustering analysis...") def perform_clustering_analysis(df_sample, n_clusters=3, model_type='fp'): """Perform K-means clustering on molecular features""" try: features = [] labels = [] smiles_list = [] if model_type == 'fp': # Use molecular fingerprints for _, row in df_sample.iterrows(): mol, _ = safe_mol_from_smiles(row['smiles']) if mol is not None: fp = fp_gen.GetFingerprint(mol) features.append(np.array(fp)) labels.append(row['SR-HSE']) smiles_list.append(row['smiles']) elif model_type == 'gcn': # Use GCN learned features if not gcn_loaded: return None, "GCN model not loaded" sample_smiles = df_sample['smiles'].tolist() gcn_features, valid_smiles = extract_gcn_features(sample_smiles) # Map back to labels smiles_to_label = dict(zip(df_sample['smiles'], df_sample['SR-HSE'])) for i, smiles in enumerate(valid_smiles): features.append(gcn_features[i]) labels.append(smiles_to_label[smiles]) smiles_list.append(smiles) if len(features) < n_clusters: return None, f"Need at least {n_clusters} valid molecules for clustering (got {len(features)})" X = np.array(features) # Perform K-means clustering kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) cluster_labels = kmeans.fit_predict(X) # Perform t-SNE for visualization tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(X)-1)) X_tsne = tsne.fit_transform(X) # Create DataFrame for plotting cluster_df = pd.DataFrame({ 'tSNE1': X_tsne[:, 0], 'tSNE2': X_tsne[:, 1], 'Cluster': [f'Cluster {i}' for i in cluster_labels], 'Toxicity': ['Toxic' if l == 1 else 'Non-toxic' for l in labels], 'SMILES': smiles_list }) return { 'cluster_df': cluster_df, 'kmeans_model': kmeans, 'feature_type': 'Fingerprints' if model_type == 'fp' else 'GCN Features' }, None except Exception as e: return None, f"Clustering analysis failed: {str(e)}" @st.cache_data(show_spinner="Finding similar molecules...") def find_similar_molecules(query_smiles, df_sample, top_k=5, model_type='fp'): """Find molecules similar to the query molecule based on feature similarity""" try: if model_type == 'fp': # Use fingerprint similarity query_mol, _ = safe_mol_from_smiles(query_smiles) if query_mol is None: return None, "Invalid query SMILES" query_fp = np.array(fp_gen.GetFingerprint(query_mol)).reshape(1, -1) similarities = [] valid_molecules = [] for _, row in df_sample.iterrows(): mol, _ = safe_mol_from_smiles(row['smiles']) if mol is not None: fp = np.array(fp_gen.GetFingerprint(mol)).reshape(1, -1) similarity = cosine_similarity(query_fp, fp)[0][0] similarities.append(similarity) valid_molecules.append(row) elif model_type == 'gcn': # Use GCN feature similarity if not gcn_loaded: return None, "GCN model not loaded" # Get query features query_features, query_valid = extract_gcn_features([query_smiles]) if len(query_features) == 0: return None, "Could not extract GCN features for query molecule" query_feature = np.array(query_features[0]).reshape(1, -1) # Get features for all molecules in sample sample_smiles = df_sample['smiles'].tolist() sample_features, valid_smiles = extract_gcn_features(sample_smiles) if len(sample_features) == 0: return None, "No valid molecules found for GCN similarity search" # Map back to original data smiles_to_data = dict(zip(df_sample['smiles'], df_sample.to_dict('records'))) similarities = [] valid_molecules = [] for i, smiles in enumerate(valid_smiles): feature = np.array(sample_features[i]).reshape(1, -1) similarity = cosine_similarity(query_feature, feature)[0][0] similarities.append(similarity) valid_molecules.append(smiles_to_data[smiles]) if len(similarities) == 0: return None, "No valid molecules for comparison" # Get top similar molecules similarity_df = pd.DataFrame(valid_molecules) similarity_df['Similarity'] = similarities similarity_df = similarity_df.sort_values('Similarity', ascending=False) return similarity_df.head(top_k), None except Exception as e: return None, f"Similarity search failed: {str(e)}" def uncertainty_quantification(smiles, model_type='fp', num_samples=100): """Estimate prediction uncertainty using Monte Carlo dropout""" try: mol, _ = safe_mol_from_smiles(smiles) if mol is None: return None, None, None, "Invalid SMILES" if model_type == 'fp': fp = fp_gen.GetFingerprint(mol) fp_array = np.array(fp).reshape(1, -1) tensor = torch.tensor(fp_array).float() # Enable training mode for Monte Carlo dropout fp_model.train() predictions = [] for _ in range(num_samples): with torch.no_grad(): logits = fp_model(tensor) prob = torch.sigmoid(logits).item() predictions.append(prob) fp_model.eval() # Back to evaluation mode mean_pred = np.mean(predictions) std_pred = np.std(predictions) confidence_interval = ( np.percentile(predictions, 2.5), np.percentile(predictions, 97.5) ) return mean_pred, std_pred, confidence_interval, None elif model_type == 'gcn': # Check if molecule is supported by GCN SUPPORTED_ATOMS = {1, 6, 7, 8, 9, 16, 17, 35, 53} if not all(atom.GetAtomicNum() in SUPPORTED_ATOMS for atom in mol.GetAtoms()): return None, None, None, "Molecule contains unsupported atoms for GCN model" graph = smiles_to_graph(smiles) if graph is None: return None, None, None, "Could not convert SMILES to graph" batch = Batch.from_data_list([graph]) # Enable training mode for Monte Carlo dropout gcn_model.train() predictions = [] for _ in range(num_samples): with torch.no_grad(): logits = gcn_model(batch) prob = torch.sigmoid(logits).item() predictions.append(prob) gcn_model.eval() # Back to evaluation mode mean_pred = np.mean(predictions) std_pred = np.std(predictions) confidence_interval = ( np.percentile(predictions, 2.5), np.percentile(predictions, 97.5) ) return mean_pred, std_pred, confidence_interval, None except Exception as e: return None, None, None, f"Uncertainty quantification failed: {str(e)}" @st.cache_data(show_spinner="Processing batch predictions...") def process_batch_predictions(uploaded_df, model_type='fp'): """Process batch predictions from uploaded CSV with optimizations for large datasets""" try: results = [] total_count = len(uploaded_df) # Create progress tracking progress_container = st.container() progress_bar = progress_container.progress(0) status_text = progress_container.empty() # Process in chunks for better memory management and progress updates chunk_size = min(100, max(10, total_count // 20)) # Adaptive chunk size start_time = time.time() for i, (_, row) in enumerate(uploaded_df.iterrows()): smiles = row.get('smiles', row.get('SMILES', '')) if model_type == 'fp': pred, prob = predict_fp(smiles) # Add uncertainty quantification mean_pred, std_pred, ci, _ = uncertainty_quantification(smiles, 'fp') results.append({ 'SMILES': smiles, 'Prediction': pred, 'Probability': prob, 'Mean_Probability': mean_pred, 'Std_Probability': std_pred, 'CI_Lower': ci[0] if ci else None, 'CI_Upper': ci[1] if ci else None }) else: # GCN model pred, prob = predict_gcn(smiles) # Add uncertainty quantification for GCN mean_pred, std_pred, ci, _ = uncertainty_quantification(smiles, 'gcn') results.append({ 'SMILES': smiles, 'Prediction': pred, 'Probability': prob, 'Mean_Probability': mean_pred, 'Std_Probability': std_pred, 'CI_Lower': ci[0] if ci else None, 'CI_Upper': ci[1] if ci else None }) # Update progress every chunk_size molecules if (i + 1) % chunk_size == 0 or i == total_count - 1: progress = (i + 1) / total_count progress_bar.progress(progress) # Calculate ETA elapsed_time = time.time() - start_time if i > 0: avg_time_per_molecule = elapsed_time / (i + 1) remaining_molecules = total_count - (i + 1) eta_seconds = avg_time_per_molecule * remaining_molecules eta_minutes = eta_seconds / 60 status_text.text(f"Processed {i+1:,}/{total_count:,} molecules ({progress:.1%}) - ETA: {eta_minutes:.1f} min") else: status_text.text(f"Processed {i+1:,}/{total_count:,} molecules ({progress:.1%})") # Clean up progress indicators progress_bar.empty() status_text.empty() return pd.DataFrame(results), None except Exception as e: return None, f"Batch processing failed: {str(e)}" # ------------------- Prediction Cache ------------------- @st.cache_data(show_spinner="Generating predictions...") def predict_fp(smiles): try: mol, warning = safe_mol_from_smiles(smiles) if mol is None: return f"Invalid SMILES: {warning}", 0.0 fp = fp_gen.GetFingerprint(mol) fp_array = np.array(fp).reshape(1, -1) with torch.no_grad(): logits = fp_model(torch.tensor(fp_array).float()) prob = torch.sigmoid(logits).item() return ("Toxic" if prob > 0.5 else "Non-toxic"), prob except Exception as e: return f"Error: {str(e)}", 0.0 @st.cache_data(show_spinner="Generating sample predictions...") def get_sample_predictions(model_type='fp', sample_size=100): """Generate predictions for a small sample of the dataset for visualization""" sample_df = df.sample(n=min(sample_size, len(df)), random_state=42) preds = [] for smi in sample_df['smiles']: try: p = predict_fp(smi)[1] if model_type == 'fp' else predict_gcn(smi)[1] preds.append(p) except: preds.append(None) sample_df = sample_df.copy() sample_df[f'{model_type}_prob'] = preds return sample_df # Only generate sample predictions for visualization (much faster) sample_df_fp = get_sample_predictions('fp', 100) if fp_loaded else None sample_df_gcn = get_sample_predictions('gcn', 100) if gcn_loaded else None # ------------------- Evaluation Function ------------------- @st.cache_data(show_spinner="Evaluating model on test set...") def evaluate_gcn_test_set(model): graph_data = get_graph_data() test_loader = DataLoader(graph_data, batch_size=32) model.eval() all_preds, all_labels = [], [] with torch.no_grad(): for batch in test_loader: batch = batch.to("cpu") # Ensure on CPU out = model(batch) probs = torch.sigmoid(out) all_preds.extend(probs.cpu().numpy()) all_labels.extend(batch.y.cpu().numpy()) acc = accuracy_score(all_labels, (np.array(all_preds) > 0.5).astype(int)) roc = roc_auc_score(all_labels, all_preds) df_eval = pd.DataFrame({ "Predicted Probability": all_preds, "Label": ["Non-toxic" if i == 0 else "Toxic" for i in all_labels] }) fig = px.histogram(df_eval, x="Predicted Probability", color="Label", nbins=30, barmode="overlay", color_discrete_map={"Non-toxic": "green", "Toxic": "red"}, title="GCN Test Set - Probability Distribution") fig.update_layout(bargap=0.1) st.success(f"โœ… Accuracy: `{acc:.4f}`, ROC-AUC: `{roc:.4f}`") st.plotly_chart(fig, use_container_width=True) # ------------------- Tabs ------------------- tab1, tab2, tab3, tab4, tab5 = st.tabs([ "๐Ÿ”ฌ Fingerprint Model", "๐Ÿงฌ GCN Model", "โš–๏ธ Model Comparison", "๐Ÿง  AI Interpretability", "๐Ÿ“Š Advanced Analytics" ]) with tab1: st.subheader("Fingerprint-based Prediction") # Add full dataset evaluation section with st.expander("๐Ÿ“Š Full Dataset Performance Analysis"): if st.button("๐Ÿš€ Run Full Dataset Evaluation (Fingerprint)", key="fp_eval"): with st.spinner("Evaluating fingerprint model on full dataset..."): fp_results = comprehensive_evaluation('fp') if fp_results: show_detailed_statistics(fp_results, 'fp') # Create and show plots fig_dist, fig_roc, fig_cm = create_comprehensive_plots(fp_results, 'fp') st.plotly_chart(fig_dist, use_container_width=True) col1, col2 = st.columns(2) with col1: st.plotly_chart(fig_roc, use_container_width=True) with col2: st.plotly_chart(fig_cm, use_container_width=True) # Download predictions csv_data = fp_results['predictions_df'].to_csv(index=False) st.download_button( "๐Ÿ“ฅ Download Full Predictions", csv_data, "fingerprint_full_predictions.csv", "text/csv" ) with st.form("fp_form"): smiles_fp = st.text_input("Enter SMILES", "CCO") show_debug_fp = st.checkbox("๐Ÿž Show Debug Info (raw score/logit)", key="fp_debug") predict_btn = st.form_submit_button("๐Ÿ” Predict") if predict_btn: mol, warning = safe_mol_from_smiles(smiles_fp) if warning: st.info(f"โ„น๏ธ {warning}") if mol: fp = fp_gen.GetFingerprint(mol) arr = np.array(fp).reshape(1, -1) tensor = torch.tensor(arr).float() with torch.no_grad(): output = fp_model(tensor) prob = torch.sigmoid(output).item() raw_score = output.item() label = "Toxic" if prob > 0.5 else "Non-toxic" color = "red" if label == "Toxic" else "green" st.markdown(f"

๐Ÿงพ Prediction: {label} โ€” {prob:.3f}

", unsafe_allow_html=True) if show_debug_fp: st.code(f"๐Ÿ“‰ Raw Logit: {raw_score:.4f}", language='text') st.markdown("#### Fingerprint Vector (First 20 bits)") st.code(str(arr[0][:20]) + " ...", language="text") st.image(Draw.MolToImage(mol), caption="Molecular Structure", width=250) info = get_molecule_info(mol) st.markdown("### Molecule Info:") for k, v in info.items(): st.markdown(f"**{k}:** {v}") st.plotly_chart(plot_distribution(sample_df_fp if sample_df_fp is not None else df, 'fp', prob), use_container_width=True) else: st.error("โŒ Invalid SMILES input. Please check your string.") with st.expander("๐Ÿ“Œ Example SMILES to Try"): st.markdown(""" - `CCO` (Ethanol) - `CC(=O)O` (Acetic Acid) - `c1ccccc1` (Benzene) - `CCN(CC)CC` (Triethylamine) - `C1=CC=CN=C1` (Pyridine) **Note:** For complex salts/ionic compounds (containing `.` or charges), the app will automatically extract and analyze the largest component. """) with st.expander("โš—๏ธ Complex SMILES Examples"): st.markdown(""" **Ionic Liquids & Salts:** - `CCCCCCn1ccnc1` (Hexylimidazole - main component only) - `c1ccc(cc1)S(=O)(=O)O` (Benzenesulfonic acid) - `CC[N+](C)(C)C` (Tetramethylammonium cation) **Multi-component molecules:** The app extracts the largest molecular component for prediction. """) with st.expander("๐Ÿงช Top 5 Toxic Predictions from Test Set (Fingerprint Model)"): if sample_df_fp is not None and 'fp_prob' in sample_df_fp: top_toxic_fp = sample_df_fp[sample_df_fp['fp_prob'] > 0.5].sort_values('fp_prob', ascending=False) def is_valid_fp(smi): return Chem.MolFromSmiles(smi) is not None top_toxic_fp = top_toxic_fp[top_toxic_fp['smiles'].apply(is_valid_fp)].head(5) if not top_toxic_fp.empty: st.table(top_toxic_fp[['smiles', 'fp_prob']].rename(columns={'fp_prob': 'Predicted Probability'})) else: st.info("No valid top fingerprint predictions available in sample.") else: st.info("Fingerprint model predictions not available.") with tab2: st.subheader("Graph Neural Network Prediction") SUPPORTED_ATOMS = {1, 6, 7, 8, 9, 16, 17, 35, 53} # H, C, N, O, F, S, Cl, Br, I def is_supported(mol): return all(atom.GetAtomicNum() in SUPPORTED_ATOMS for atom in mol.GetAtoms()) # Add full dataset evaluation section with st.expander("๐Ÿ“Š Full Dataset Performance Analysis"): if st.button("๐Ÿš€ Run Full Dataset Evaluation (GCN)", key="gcn_eval"): with st.spinner("Evaluating GCN model on full dataset..."): gcn_results = comprehensive_evaluation('gcn') if gcn_results: show_detailed_statistics(gcn_results, 'gcn') # Create and show plots fig_dist, fig_roc, fig_cm = create_comprehensive_plots(gcn_results, 'gcn') st.plotly_chart(fig_dist, use_container_width=True) col1, col2 = st.columns(2) with col1: st.plotly_chart(fig_roc, use_container_width=True) with col2: st.plotly_chart(fig_cm, use_container_width=True) # Download predictions csv_data = gcn_results['predictions_df'].to_csv(index=False) st.download_button( "๐Ÿ“ฅ Download Full Predictions", csv_data, "gcn_full_predictions.csv", "text/csv" ) with st.form("gcn_form"): smiles_gcn = st.text_input("Enter SMILES", "c1ccccc1", key="gcn_smiles") show_debug = st.checkbox("๐Ÿž Show Debug Info (raw score/logit)") gcn_btn = st.form_submit_button("๐Ÿ” Predict") if gcn_btn: mol, warning = safe_mol_from_smiles(smiles_gcn) if warning: st.info(f"โ„น๏ธ {warning}") if mol is None: st.error("โŒ Invalid SMILES: could not parse molecule.") elif not is_supported(mol): st.error("โš ๏ธ This molecule contains unsupported atoms (e.g. Sn, P, etc.). GCN model only supports common organic elements.") else: graph = smiles_to_graph(smiles_gcn) if graph is None: st.error("โŒ SMILES is valid but could not be converted to graph. Possibly malformed structure.") else: batch = Batch.from_data_list([graph]) with torch.no_grad(): out = gcn_model(batch) prob = torch.sigmoid(out).item() raw_score = out.item() label = "Toxic" if prob > best_threshold else "Non-toxic" color = "red" if label == "Toxic" else "green" st.markdown(f"

๐Ÿงพ GCN Prediction: {label} โ€” {prob:.3f}

", unsafe_allow_html=True) if show_debug: st.code(f"๐Ÿ“‰ Raw Logit: {raw_score:.4f}", language='text') st.image(Draw.MolToImage(mol), caption="Molecular Structure", width=250) def get_molecule_info(mol): return { "Molecular Weight": round(Chem.Descriptors.MolWt(mol), 2), "LogP": round(Chem.Crippen.MolLogP(mol), 2), "Num H-Bond Donors": Chem.Lipinski.NumHDonors(mol), "Num H-Bond Acceptors": Chem.Lipinski.NumHAcceptors(mol), "TPSA": round(Chem.rdMolDescriptors.CalcTPSA(mol), 2), "Num Rotatable Bonds": Chem.Lipinski.NumRotatableBonds(mol) } info = get_molecule_info(mol) st.markdown("### Molecule Info:") for k, v in info.items(): st.markdown(f"**{k}:** {v}") st.plotly_chart(plot_distribution(sample_df_gcn if sample_df_gcn is not None else df, 'gcn', prob), use_container_width=True) with st.expander("๐Ÿ“Œ Example SMILES to Try"): st.markdown(""" - `c1ccccc1` (Benzene) - `C1=CC=CC=C1O` (Phenol) - `CC(=O)OC1=CC=CC=C1C(=O)O` (Aspirin) - `NCC(O)=O` (Glycine) - `C1CCC(CC1)NC(=O)C2=CC=CC=C2` (Cyclohexylbenzamide) """) with st.expander("๐Ÿ“ฅ Download GCN Model Predictions"): if sample_df_gcn is not None and 'gcn_prob' in sample_df_gcn: def is_valid_gcn(smi): mol = Chem.MolFromSmiles(smi) return mol is not None and is_supported(mol) and smiles_to_graph(smi) is not None df_valid = sample_df_gcn[sample_df_gcn['smiles'].apply(is_valid_gcn)].copy() csv_gcn = df_valid[['smiles', 'gcn_prob', 'SR-HSE']].dropna().to_csv(index=False) st.download_button("Download Sample CSV", csv_gcn, "gcn_sample_predictions.csv", "text/csv") else: st.info("Predictions not available yet.") with st.expander("๐Ÿงช Top 5 Toxic Predictions from Test Set"): if sample_df_gcn is not None and 'gcn_prob' in sample_df_gcn: def is_valid_gcn(smi): mol = Chem.MolFromSmiles(smi) return mol is not None and is_supported(mol) and smiles_to_graph(smi) is not None top_toxic = sample_df_gcn[sample_df_gcn['gcn_prob'] > best_threshold].copy() top_toxic = top_toxic[top_toxic['smiles'].apply(is_valid_gcn)] top_toxic = top_toxic.sort_values('gcn_prob', ascending=False).head(5) if not top_toxic.empty: st.table(top_toxic[['smiles', 'gcn_prob']].rename(columns={'gcn_prob': 'Predicted Probability'})) else: st.info("No valid top predictions available in sample.") else: st.info("GCN model predictions not available.") with tab3: st.subheader("Model Comparison Analysis") if st.button("๐Ÿ” Compare Models on Full Dataset"): with st.spinner("Running comprehensive model comparison..."): # Get predictions from both models fp_results = comprehensive_evaluation('fp') gcn_results = comprehensive_evaluation('gcn') if fp_results and gcn_results: st.success("โœ… Both models evaluated successfully!") # Comparison metrics table comparison_data = { 'Metric': ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'ROC-AUC', 'Samples'], 'Fingerprint': [ f"{fp_results['accuracy']:.3f}", f"{fp_results['precision']:.3f}", f"{fp_results['recall']:.3f}", f"{fp_results['f1']:.3f}", f"{fp_results['roc_auc']:.3f}", f"{fp_results['n_samples']:,}" ], 'GCN': [ f"{gcn_results['accuracy']:.3f}", f"{gcn_results['precision']:.3f}", f"{gcn_results['recall']:.3f}", f"{gcn_results['f1']:.3f}", f"{gcn_results['roc_auc']:.3f}", f"{gcn_results['n_samples']:,}" ] } comparison_df = pd.DataFrame(comparison_data) st.table(comparison_df) # Side-by-side probability distributions col1, col2 = st.columns(2) with col1: fig_fp, _, _ = create_comprehensive_plots(fp_results, 'fp') st.plotly_chart(fig_fp, use_container_width=True) with col2: fig_gcn, _, _ = create_comprehensive_plots(gcn_results, 'gcn') st.plotly_chart(fig_gcn, use_container_width=True) # Agreement analysis st.subheader("Model Agreement Analysis") # Find common molecules fp_df = fp_results['predictions_df'][['smiles', 'fp_probability']].set_index('smiles') gcn_df = gcn_results['predictions_df'][['smiles', 'gcn_probability']].set_index('smiles') common_df = fp_df.join(gcn_df, how='inner').reset_index() if len(common_df) > 0: # Calculate agreement fp_pred = (common_df['fp_probability'] > 0.5).astype(int) gcn_pred = (common_df['gcn_probability'] > best_threshold).astype(int) agreement = (fp_pred == gcn_pred).mean() st.metric("Model Agreement", f"{agreement:.1%}") # Scatter plot of probabilities fig_scatter = px.scatter( common_df, x='fp_probability', y='gcn_probability', title=f"Model Probability Correlation (n={len(common_df)})", labels={'x': 'Fingerprint Probability', 'y': 'GCN Probability'} ) fig_scatter.add_scatter(x=[0, 1], y=[0, 1], mode='lines', line=dict(dash='dash', color='gray'), name='Perfect Agreement', showlegend=False) st.plotly_chart(fig_scatter, use_container_width=True) # Download comparison data csv_comparison = common_df.to_csv(index=False) st.download_button( "๐Ÿ“ฅ Download Model Comparison", csv_comparison, "model_comparison.csv", "text/csv" ) else: st.error("โŒ Could not evaluate one or both models") with tab4: st.subheader("๐Ÿง  AI Interpretability & Explainability") # SHAP Analysis Section with st.expander("๐ŸŽฏ SHAP Analysis - Model Explanations"): st.markdown(""" **SHAP (SHapley Additive exPlanations)** reveals which molecular features contribute most to toxicity predictions. """) if not SHAP_AVAILABLE: st.warning("๏ฟฝ SHAP library not installed. Install with: `pip install shap` for full interpretability features.") st.info("๐Ÿ’ก Using simplified feature importance analysis instead.") if st.button("๐Ÿš€ Generate Feature Importance Analysis", key="shap_btn"): if fp_loaded and df is not None: with st.spinner("Generating feature importance analysis..."): try: # Feature importance analysis using fingerprint activation patterns sample_molecules = df.sample(n=min(50, len(df))) feature_importance = [] for _, row in sample_molecules.iterrows(): mol, _ = safe_mol_from_smiles(row['smiles']) if mol is not None: fp = np.array(fp_gen.GetFingerprint(mol)) pred, prob = predict_fp(row['smiles']) # Simple feature importance based on fingerprint activation importance = fp * prob # Weight by prediction feature_importance.append(importance) if feature_importance: avg_importance = np.mean(feature_importance, axis=0) top_features = np.argsort(avg_importance)[-20:][::-1] # Create feature importance plot fig_importance = px.bar( x=top_features, y=avg_importance[top_features], title="Top 20 Most Important Fingerprint Features", labels={'x': 'Fingerprint Bit Index', 'y': 'Average Importance'} ) st.plotly_chart(fig_importance, use_container_width=True) st.success("โœ… Feature importance analysis completed!") st.info("๐Ÿ’ก Higher values indicate features that strongly influence toxicity predictions.") if SHAP_AVAILABLE: st.info("๐Ÿ”ฌ For more detailed SHAP explanations, the full SHAP library is available!") except Exception as e: st.error(f"โŒ Analysis failed: {str(e)}") else: st.warning("โš ๏ธ Fingerprint model not loaded or no data available") # Enhanced SHAP Analysis for Both Models with st.expander("๐Ÿงฌ Advanced SHAP Analysis - Dual Model Support"): st.markdown(""" **Enhanced SHAP Analysis** supports both Fingerprint and GCN models for comprehensive interpretability. """) # Model selection for enhanced SHAP analysis col1, col2 = st.columns(2) with col1: enhanced_shap_model = st.selectbox( "Select model for enhanced SHAP analysis", ["Fingerprint", "GCN"], key="enhanced_shap_model" ) with col2: enhanced_num_samples = st.slider("Number of samples", 10, 100, 50, key="enhanced_shap_samples") if st.button("๐Ÿš€ Generate Enhanced SHAP Analysis", key="enhanced_shap_btn"): model_key = 'fp' if enhanced_shap_model == 'Fingerprint' else 'gcn' if (model_key == 'fp' and fp_loaded) or (model_key == 'gcn' and gcn_loaded): if df is not None and SHAP_AVAILABLE: with st.spinner(f"Generating enhanced SHAP analysis for {enhanced_shap_model} model..."): smiles_list = df['smiles'].tolist() shap_values, error = get_shap_explanations(smiles_list, model_key, enhanced_num_samples) if error: st.error(f"โŒ {error}") elif shap_values is not None: st.success(f"โœ… Enhanced SHAP analysis completed for {enhanced_shap_model} model!") # Display comprehensive SHAP metrics mean_shap = np.mean(np.abs(shap_values.values)) max_shap = np.max(np.abs(shap_values.values)) std_shap = np.std(shap_values.values) col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Samples Analyzed", len(shap_values.values)) with col2: st.metric("Mean |SHAP|", f"{mean_shap:.4f}") with col3: st.metric("Max |SHAP|", f"{max_shap:.4f}") with col4: st.metric("SHAP Std Dev", f"{std_shap:.4f}") # Create detailed feature importance visualization mean_importance = np.mean(np.abs(shap_values.values), axis=0) top_n = min(25, len(mean_importance)) top_features = np.argsort(mean_importance)[-top_n:][::-1] if model_key == 'fp': feature_names = [f"FP Bit {i}" for i in top_features] title = f"Top {top_n} Most Important Fingerprint Features (Enhanced SHAP)" explanation = """ **๐Ÿ”ฌ Enhanced Fingerprint SHAP Analysis:** - Morgan fingerprint bits ranked by SHAP importance - Each bit represents specific molecular substructures - Positive SHAP values increase toxicity probability - Negative SHAP values decrease toxicity probability """ else: feature_names = [f"GCN Feat {i}" for i in top_features] title = f"Top {top_n} Most Important GCN Features (Enhanced SHAP)" explanation = """ **๐Ÿงฌ Enhanced GCN SHAP Analysis:** - Learned graph features from GCN hidden layers - Captures complex molecular topology patterns - Features represent graph convolution outputs - Shows importance of molecular graph structure """ # Create enhanced visualization fig_enhanced = px.bar( x=feature_names, y=mean_importance[top_features], title=title, labels={'x': 'Feature', 'y': 'Mean |SHAP Value|'}, color=mean_importance[top_features], color_continuous_scale='Viridis' ) fig_enhanced.update_xaxes(tickangle=45) fig_enhanced.update_layout(height=500) st.plotly_chart(fig_enhanced, use_container_width=True) st.markdown(explanation) # Feature distribution analysis st.subheader("๐Ÿ“ˆ SHAP Value Distribution Analysis") # Create histogram of SHAP values all_shap_values = shap_values.values.flatten() fig_dist = px.histogram( x=all_shap_values, nbins=50, title=f"Distribution of SHAP Values - {enhanced_shap_model} Model", labels={'x': 'SHAP Value', 'y': 'Frequency'} ) st.plotly_chart(fig_dist, use_container_width=True) # Summary statistics st.markdown(f""" **๐Ÿ“Š SHAP Value Statistics:** - **Total Features:** {len(mean_importance)} - **Non-zero Features:** {np.sum(mean_importance > 1e-6)} - **Feature Sparsity:** {(1 - np.sum(mean_importance > 1e-6) / len(mean_importance)) * 100:.1f}% - **Model Type:** {enhanced_shap_model} """) else: st.warning("No SHAP values generated. Check if molecules are valid for the selected model.") elif not SHAP_AVAILABLE: st.error("โŒ SHAP library not available. Please install with: pip install shap") else: st.warning("โš ๏ธ No dataset loaded for SHAP analysis") else: model_name = "Fingerprint" if model_key == 'fp' else "GCN" st.warning(f"โš ๏ธ {model_name} model not loaded") # Uncertainty Quantification Section with st.expander("๐ŸŽฒ Uncertainty Quantification"): st.markdown(""" **Uncertainty Quantification** provides confidence intervals for predictions using Monte Carlo dropout. This analysis estimates prediction uncertainty by running multiple forward passes with dropout enabled, providing insights into model confidence and reliability. """) with st.form("uncertainty_form"): col1, col2 = st.columns(2) with col1: unc_smiles = st.text_input("Enter SMILES for uncertainty analysis", "CCO") unc_model = st.selectbox("Select model", ["Fingerprint", "GCN"], key="unc_model") with col2: unc_samples = st.slider("Number of Monte Carlo samples", 50, 500, 100) st.info("๐Ÿ’ก More samples = more accurate uncertainty estimates but slower computation") unc_btn = st.form_submit_button("๐Ÿ” Analyze Uncertainty") if unc_btn: model_key = 'fp' if unc_model == 'Fingerprint' else 'gcn' model_loaded = fp_loaded if model_key == 'fp' else gcn_loaded if model_loaded: with st.spinner(f"Analyzing uncertainty with {unc_model} model..."): mean_pred, std_pred, ci, error = uncertainty_quantification( unc_smiles, model_key, unc_samples ) if error: st.error(f"โŒ {error}") elif mean_pred is not None: # Display results col1, col2, col3 = st.columns(3) with col1: st.metric("Mean Prediction", f"{mean_pred:.3f}") st.metric("Std Deviation", f"{std_pred:.4f}") with col2: st.metric("95% CI Lower", f"{ci[0]:.3f}") st.metric("95% CI Upper", f"{ci[1]:.3f}") with col3: confidence_width = ci[1] - ci[0] st.metric("CI Width", f"{confidence_width:.4f}") certainty = max(0, 1 - confidence_width) st.metric("Model Certainty", f"{certainty:.1%}") # Interpretation st.subheader("๐Ÿง  Uncertainty Interpretation") if std_pred < 0.05: uncertainty_level = "๐ŸŸข Very Low" interpretation = "High confidence prediction" elif std_pred < 0.1: uncertainty_level = "๐ŸŸก Low" interpretation = "Moderate confidence prediction" elif std_pred < 0.2: uncertainty_level = "๐ŸŸ  Moderate" interpretation = "Some uncertainty in prediction" else: uncertainty_level = "๐Ÿ”ด High" interpretation = "High uncertainty - use caution" st.write(f"**Uncertainty Level:** {uncertainty_level}") st.write(f"**Interpretation:** {interpretation}") # Show molecule structure mol, _ = safe_mol_from_smiles(unc_smiles) if mol: st.image(Draw.MolToImage(mol), caption="Analyzed Molecule", width=250) # Additional insights threshold = 0.5 if model_key == 'fp' else best_threshold if ci[0] > threshold: st.success("โœ… Confident TOXIC prediction (entire CI above threshold)") elif ci[1] < threshold: st.success("โœ… Confident NON-TOXIC prediction (entire CI below threshold)") else: st.warning("โš ๏ธ Uncertain prediction (CI spans decision threshold)") # Visualize uncertainty uncertainty_data = pd.DataFrame({ 'Prediction': [mean_pred], 'Lower_CI': [ci[0]], 'Upper_CI': [ci[1]], 'Model': [unc_model] }) fig_unc = px.scatter( uncertainty_data, x='Prediction', y=[0], error_x=[std_pred], title=f"{unc_model} Model - Prediction Uncertainty Visualization", labels={'x': 'Predicted Probability'} ) fig_unc.add_vline(x=threshold, line_dash="dash", annotation_text=f"Decision Threshold ({threshold})") fig_unc.add_vrect(x0=ci[0], x1=ci[1], fillcolor="blue", opacity=0.2, annotation_text="95% CI") st.plotly_chart(fig_unc, use_container_width=True) else: st.error("โŒ Could not perform uncertainty analysis") else: model_name = "Fingerprint" if model_key == 'fp' else "GCN" st.warning(f"โš ๏ธ {model_name} model not loaded") # Batch Upload Section with st.expander("๐Ÿ“ Batch Prediction Upload"): st.markdown(""" **Batch Processing** allows you to upload a CSV file with multiple SMILES for analysis. Features include: - Support for both Fingerprint and GCN models - Uncertainty quantification with confidence intervals (both models) - Comprehensive statistics and downloadable results """) uploaded_file = st.file_uploader( "Choose a CSV file", type=["csv", "gz"], help="CSV should contain a 'smiles' or 'SMILES' column. Supports compressed (.gz) files." ) if uploaded_file is not None: try: # Handle different file types file_name = uploaded_file.name.lower() if file_name.endswith('.gz') or file_name.endswith('.csv.gz'): # Handle gzipped files import gzip import io # Read the compressed file compressed_data = uploaded_file.read() # Decompress decompressed_data = gzip.decompress(compressed_data) # Convert to string and create StringIO object text_data = decompressed_data.decode('utf-8') batch_df = pd.read_csv(io.StringIO(text_data)) st.info(f"๐Ÿ“ฆ Successfully decompressed file: {uploaded_file.name}") elif file_name.endswith('.csv'): # Handle regular CSV files batch_df = pd.read_csv(uploaded_file) else: # Try different encodings for problematic files try: # Reset file pointer uploaded_file.seek(0) batch_df = pd.read_csv(uploaded_file, encoding='latin-1') st.warning("โš ๏ธ Used Latin-1 encoding for file reading") except: uploaded_file.seek(0) batch_df = pd.read_csv(uploaded_file, encoding='cp1252') st.warning("โš ๏ธ Used CP1252 encoding for file reading") st.write("๐Ÿ“‹ **Preview of uploaded data:**") st.dataframe(batch_df.head()) # Check for SMILES column smiles_col = None for col in ['smiles', 'SMILES', 'Smiles']: if col in batch_df.columns: smiles_col = col break if smiles_col: st.success(f"โœ… Found SMILES column: '{smiles_col}'") # Show file statistics total_molecules = len(batch_df) st.info(f"๐Ÿ“Š Dataset contains {total_molecules:,} molecules") # Add sampling options for large datasets if total_molecules > 10000: st.warning(f"โš ๏ธ Large dataset detected ({total_molecules:,} molecules). Processing all molecules may take a very long time.") col1, col2 = st.columns(2) with col1: use_sampling = st.checkbox("๐ŸŽฏ Use random sampling for faster processing", value=True) with col2: if use_sampling: sample_size = st.number_input( "Sample size", min_value=100, max_value=min(50000, total_molecules), value=min(5000, total_molecules), step=100, help="Smaller samples process faster but may be less representative" ) if use_sampling: # Show estimated processing time estimated_minutes = sample_size / 100 # Rough estimate: ~100 molecules/minute st.info(f"โฑ๏ธ Estimated processing time: ~{estimated_minutes:.1f} minutes for {sample_size:,} molecules") else: use_sampling = False sample_size = total_molecules model_choice = st.selectbox("Select model for batch prediction", ["Fingerprint", "GCN"]) if st.button("๐Ÿš€ Process Batch Predictions"): model_type = 'fp' if model_choice == 'Fingerprint' else 'gcn' if (model_type == 'fp' and fp_loaded) or (model_type == 'gcn' and gcn_loaded): # Prepare dataset for processing if use_sampling and total_molecules > sample_size: # Random sampling processing_df = batch_df.sample(n=sample_size, random_state=42).reset_index(drop=True) st.info(f"๐ŸŽฏ Processing random sample of {sample_size:,} molecules from {total_molecules:,} total") else: processing_df = batch_df st.info(f"๐Ÿ”„ Processing all {len(processing_df):,} molecules") results_df, error = process_batch_predictions(processing_df, model_type) if error: st.error(f"โŒ {error}") else: st.success("โœ… Batch processing completed!") st.dataframe(results_df) # Download results csv_results = results_df.to_csv(index=False) st.download_button( "๐Ÿ“ฅ Download Results", csv_results, f"batch_predictions_{model_choice.lower()}.csv", "text/csv" ) # Summary statistics toxic_count = len(results_df[results_df['Prediction'] == 'Toxic']) total_count = len(results_df) st.markdown(f""" **๐Ÿ“Š Batch Summary:** - Total molecules: {total_count} - Predicted toxic: {toxic_count} ({toxic_count/total_count*100:.1f}%) - Predicted non-toxic: {total_count-toxic_count} ({(total_count-toxic_count)/total_count*100:.1f}%) """) else: st.error(f"โŒ {model_choice} model not loaded") else: st.error("โŒ No SMILES column found. Please ensure your CSV has a column named 'smiles' or 'SMILES'") st.info("๐Ÿ’ก Available columns: " + ", ".join(batch_df.columns.tolist())) except UnicodeDecodeError as e: st.error(f"โŒ File encoding error: {str(e)}") st.info(""" ๐Ÿ’ก **File Format Issues:** - Your file might be compressed (.gz) - we now support compressed files! - Try saving your file with UTF-8 encoding - Ensure it's a valid CSV file with proper formatting """) except pd.errors.EmptyDataError: st.error("โŒ The uploaded file is empty or contains no valid data") except pd.errors.ParserError as e: st.error(f"โŒ CSV parsing error: {str(e)}") st.info("๐Ÿ’ก Please check that your file is a valid CSV with proper delimiters") except Exception as e: st.error(f"โŒ Unexpected error reading file: {str(e)}") st.info(""" ๐Ÿ’ก **Troubleshooting:** - Ensure file is a valid CSV or compressed CSV (.csv.gz) - Check that the file contains a 'smiles' or 'SMILES' column - Verify the file is not corrupted - Try with a smaller sample file first """) with tab5: st.subheader("๐Ÿ“Š Advanced Analytics & Insights") # PCA Analysis Section with st.expander("๐Ÿ” Principal Component Analysis (PCA)"): st.markdown(""" **PCA** reduces high-dimensional molecular features to 2D for visualization and pattern discovery. """) # Model selection for PCA col1, col2 = st.columns(2) with col1: pca_model_type = st.selectbox( "Select feature type for PCA", ["Fingerprints", "GCN Features"], key="pca_model_select" ) with col2: st.info(f"**{pca_model_type}**: {'Morgan fingerprints (1024-dim)' if pca_model_type == 'Fingerprints' else 'GCN learned features (variable-dim)'}") if st.button("๐Ÿš€ Generate PCA Analysis", key="pca_btn"): if df is not None: model_type = 'fp' if pca_model_type == 'Fingerprints' else 'gcn' with st.spinner(f"Performing PCA analysis on {pca_model_type.lower()}..."): pca_result, error = perform_pca_analysis(df, model_type) if error: st.error(f"โŒ {error}") else: pca_df = pca_result['pca_df'] explained_var = pca_result['explained_variance'] feature_type = pca_result['feature_type'] col1, col2 = st.columns(2) with col1: # 2D PCA Plot fig_pca = px.scatter( pca_df, x='PC1', y='PC2', color='Label', title=f"PCA - {feature_type}", hover_data=['SMILES'], color_discrete_map={'Toxic': 'red', 'Non-toxic': 'blue'} ) st.plotly_chart(fig_pca, use_container_width=True) with col2: # Explained Variance Plot fig_var = px.bar( x=range(1, len(explained_var) + 1), y=explained_var, title="Explained Variance by Principal Component", labels={'x': 'Principal Component', 'y': 'Explained Variance Ratio'} ) st.plotly_chart(fig_var, use_container_width=True) st.markdown(f""" **๐Ÿ“Š PCA Results ({feature_type}):** - **PC1 explains {explained_var[0]*100:.1f}%** of variance - **PC2 explains {explained_var[1]*100:.1f}%** of variance - **Combined: {(explained_var[0] + explained_var[1])*100:.1f}%** of total variance """) if st.checkbox("Show detailed component loadings"): pca_model = pca_result['pca_model'] components = pca_model.components_ st.write("**Top features contributing to PC1:**") pc1_loadings = pd.DataFrame({ 'Feature': range(len(components[0])), 'Loading': components[0] }).sort_values('Loading', key=abs, ascending=False).head(10) st.dataframe(pc1_loadings) else: st.warning("โš ๏ธ No dataset loaded for PCA analysis") # Clustering Analysis Section with st.expander("๐ŸŽฏ Molecular Clustering Analysis"): st.markdown(""" **K-Means Clustering** groups similar molecules based on their molecular features. """) col1, col2, col3 = st.columns(3) with col1: n_clusters = st.slider("Number of clusters", 2, 10, 4) with col2: cluster_method = st.selectbox("Clustering method", ["K-Means", "Hierarchical"]) with col3: cluster_model_type = st.selectbox( "Feature type", ["Fingerprints", "GCN Features"], key="cluster_model_select" ) if st.button("๐Ÿš€ Perform Clustering Analysis", key="cluster_btn"): if df is not None: model_type = 'fp' if cluster_model_type == 'Fingerprints' else 'gcn' with st.spinner(f"Performing clustering analysis on {cluster_model_type.lower()}..."): cluster_result, error = perform_clustering_analysis(df, n_clusters, model_type) if error: st.error(f"โŒ {error}") else: cluster_df = cluster_result['cluster_df'] feature_type = cluster_result['feature_type'] col1, col2 = st.columns(2) with col1: # Cluster visualization with t-SNE fig_cluster = px.scatter( cluster_df, x='tSNE1', y='tSNE2', color='Cluster', symbol='Toxicity', title=f"{cluster_method} Clustering Results ({feature_type})", hover_data=['SMILES'], symbol_map={'Toxic': 'triangle-up', 'Non-toxic': 'circle'} ) st.plotly_chart(fig_cluster, use_container_width=True) with col2: # Cluster composition cluster_stats = cluster_df.groupby(['Cluster', 'Toxicity']).size().reset_index(name='Count') fig_composition = px.bar( cluster_stats, x='Cluster', y='Count', color='Toxicity', title="Cluster Composition by Toxicity", color_discrete_map={'Toxic': 'red', 'Non-toxic': 'blue'} ) st.plotly_chart(fig_composition, use_container_width=True) # Calculate silhouette score from sklearn.metrics import silhouette_score kmeans_model = cluster_result['kmeans_model'] # Get features for silhouette calculation if model_type == 'fp': features = [] for _, row in df.iterrows(): mol, _ = safe_mol_from_smiles(row['smiles']) if mol is not None: fp = np.array(fp_gen.GetFingerprint(mol)) features.append(fp) else: sample_smiles = df['smiles'].tolist() features, _ = extract_gcn_features(sample_smiles) if len(features) > n_clusters: X = np.array(features) silhouette = silhouette_score(X, kmeans_model.labels_) else: silhouette = 0.0 st.markdown(f""" **๐Ÿ“Š Clustering Results ({feature_type}):** - **Silhouette Score:** {silhouette:.3f} (higher is better, range: -1 to 1) - **Number of clusters:** {n_clusters} - **Method:** {cluster_method} """) # Cluster summary table cluster_summary = cluster_df.groupby('Cluster').agg({ 'Toxicity': lambda x: f"{(x=='Toxic').sum()}/{len(x)} toxic", 'SMILES': 'count' }).rename(columns={'SMILES': 'Total Count'}) st.write("**Cluster Summary:**") st.dataframe(cluster_summary) else: st.warning("โš ๏ธ No dataset loaded for clustering analysis") # Molecular Similarity Search with st.expander("๐Ÿ” Molecular Similarity Search"): st.markdown(""" **Similarity Search** finds molecules most similar to your query based on molecular features. """) col1, col2 = st.columns(2) with col1: query_smiles = st.text_input("Query SMILES", "CCO") with col2: similarity_model_type = st.selectbox( "Feature type", ["Fingerprints", "GCN Features"], key="similarity_model_select" ) col1, col2 = st.columns(2) with col1: n_similar = st.slider("Number of similar molecules to find", 5, 50, 10) with col2: similarity_threshold = st.slider("Minimum similarity threshold", 0.0, 1.0, 0.5) if st.button("๐Ÿ” Find Similar Molecules", key="similarity_search_btn"): if df is not None: model_type = 'fp' if similarity_model_type == 'Fingerprints' else 'gcn' with st.spinner(f"Finding similar molecules using {similarity_model_type.lower()}..."): similar_mols, error = find_similar_molecules( query_smiles, df, n_similar, similarity_threshold, model_type ) if error: st.error(f"โŒ {error}") elif similar_mols is not None and len(similar_mols) > 0: st.success(f"โœ… Found {len(similar_mols)} similar molecules using {similarity_model_type}!") # Convert to DataFrame for display similar_data = [] for mol in similar_mols: similar_data.append({ 'SMILES': mol['smiles'], 'Similarity': mol['similarity'], 'Toxicity': mol['toxicity'] }) similar_df = pd.DataFrame(similar_data) # Display results table st.dataframe(similar_df, use_container_width=True) # Similarity distribution fig_sim = px.histogram( similar_df, x='Similarity', color='Toxicity', title=f"Similarity Distribution ({similarity_model_type})", nbins=20, color_discrete_map={'Toxic': 'red', 'Non-toxic': 'blue'} ) st.plotly_chart(fig_sim, use_container_width=True) # Statistics toxic_similar = len(similar_df[similar_df['Toxicity'] == 'Toxic']) avg_similarity = similar_df['Similarity'].mean() st.markdown(f""" **๐Ÿ“Š Similarity Analysis ({similarity_model_type}):** - **Average similarity:** {avg_similarity:.3f} - **Toxic molecules found:** {toxic_similar}/{len(similar_df)} ({toxic_similar/len(similar_df)*100:.1f}%) - **Query molecule:** `{query_smiles}` - **Feature type:** {similarity_model_type} """) else: st.warning(f"No similar molecules found with similarity โ‰ฅ {similarity_threshold}") else: st.warning("โš ๏ธ No dataset loaded for similarity search") # Advanced Model Insights with st.expander("๐Ÿง  Model Performance Insights - Dual Model Analysis"): st.markdown(""" **Model Insights** provide deep analysis of model behavior and performance patterns for both AI models. **Available Models:** - ๐Ÿ”ฌ **Fingerprint Model**: Traditional molecular fingerprint-based neural network (1024 features) - ๐Ÿงฌ **GCN Model**: Graph Convolutional Network using molecular graph structure (128 learned features) """) # Model selection for insights col1, col2 = st.columns(2) with col1: insights_model = st.selectbox( "Select model for performance insights", ["Fingerprint", "GCN", "Both Models Comparison"], key="insights_model" ) with col2: insights_sample_size = st.slider("Sample size for analysis", 100, 500, 200, key="insights_sample") if st.button("๐Ÿš€ Generate Model Insights", key="insights_btn"): if df is not None: if insights_model == "Fingerprint" and fp_loaded: with st.spinner("Analyzing Fingerprint model performance patterns..."): generate_single_model_insights('fp', insights_sample_size) elif insights_model == "GCN" and gcn_loaded: with st.spinner("Analyzing GCN model performance patterns..."): generate_single_model_insights('gcn', insights_sample_size) elif insights_model == "Both Models Comparison": if fp_loaded and gcn_loaded: with st.spinner("Generating comprehensive dual-model analysis..."): generate_dual_model_insights(insights_sample_size) else: missing_models = [] if not fp_loaded: missing_models.append("Fingerprint") if not gcn_loaded: missing_models.append("GCN") st.error(f"โŒ Missing models: {', '.join(missing_models)}") else: model_name = insights_model if model_name == "Fingerprint" and not fp_loaded: st.error("โŒ Fingerprint model not loaded") elif model_name == "GCN" and not gcn_loaded: st.error("โŒ GCN model not loaded") else: st.warning("โš ๏ธ No dataset loaded for analysis") # Footer st.markdown("---") st.markdown("""
๐Ÿงฌ Drug Toxicity Predictor | Built with Streamlit, PyTorch & RDKit | ๐Ÿš€ Enhanced with Advanced ML Analytics
""", unsafe_allow_html=True)