import streamlit as st import torch import pandas as pd import numpy as np from transformers import BertTokenizer, BertModel from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report, confusion_matrix import requests import py3Dmol from Bio import SeqIO import io from Bio.SeqUtils.ProtParam import ProteinAnalysis import plotly.express as px import plotly.graph_objects as go from collections import Counter import matplotlib.pyplot as plt import seaborn as sns # import shap st.set_page_config( page_title="Parkinson's Protein Classifier", page_icon="🧬", layout="wide" ) # Load ProtBERT Model @st.cache_resource def load_protbert(): tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) model = BertModel.from_pretrained("Rostlab/prot_bert") model.eval() return tokenizer, model # Embedding Function def get_protbert_embedding(sequence, tokenizer, model): sequence = sequence.replace(" ", "") sequence = ' '.join(list(sequence)) tokens = tokenizer(sequence, return_tensors='pt') with torch.no_grad(): outputs = model(**tokens) embedding = torch.mean(outputs.last_hidden_state, dim=1) return embedding.squeeze().numpy() # Protein Analysis Function def analyze_protein(sequence): sequence = sequence.upper().replace(" ", "").replace("\n", "") if not all(residue in "ACDEFGHIKLMNPQRSTVWY" for residue in sequence): return "Invalid amino acid sequence!", None analysis = ProteinAnalysis(sequence) length = len(sequence) mw = analysis.molecular_weight() aromaticity = analysis.aromaticity() instability = analysis.instability_index() gravy = analysis.gravy() aa_counts = analysis.count_amino_acids() aa_percent = {k: v/length*100 for k, v in aa_counts.items()} # Secondary structure sec_struct = analysis.secondary_structure_fraction() # Isoelectric point pI = analysis.isoelectric_point() # Flexibility flexibility = analysis.flexibility() results = { 'basic': { 'Length': length, 'Molecular Weight (Da)': mw, 'Aromaticity': aromaticity, 'Instability Index': instability, 'GRAVY (Hydrophobicity)': gravy, 'Isoelectric Point (pI)': pI }, 'aa_composition': aa_percent, 'secondary_structure': { 'Helix': sec_struct[0], 'Turn': sec_struct[1], 'Sheet': sec_struct[2] }, 'flexibility': flexibility } parkinsons_analysis = { 'risk_factors': [], 'notes': [] } if length != 140: parkinsons_analysis['risk_factors'].append(f"Sequence length ({length}) deviates from wild-type (140)") if mw > 14660 or mw < 14400: parkinsons_analysis['risk_factors'].append(f"Molecular weight ({mw:.2f} Da) differs from wild-type (14.46 kDa)") if aromaticity > 0.05: parkinsons_analysis['risk_factors'].append("High aromaticity (potential aggregation risk)") if instability > 45: parkinsons_analysis['risk_factors'].append(f"High instability index ({instability:.2f}) suggests toxic form") if gravy > -0.3: parkinsons_analysis['risk_factors'].append(f"Hydrophobicity (GRAVY: {gravy:.3f}) suggests aggregation-prone variant") key_positions = { 53: 'A53T (known pathogenic)', 30: 'E46K (known pathogenic)', 83: 'E83Q (known pathogenic)' } high_risk_aas = { 'C': "Cysteine residues can promote aggregation", 'G': "Glycine substitutions often pathogenic", 'P': "Proline substitutions can disrupt structure" } for aa, risk in high_risk_aas.items(): if aa_counts.get(aa, 0) > 0: parkinsons_analysis['notes'].append(f"{risk} ({aa_counts.get(aa, 0)} {aa} residues)") return results, parkinsons_analysis def get_sample_data(): data = { 'sequence': [ "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVTTVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # A53T "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVVNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # Random (non-pathogenic) "MDVFMKGLSKAKEGVVAAAIKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # Random (non-pathogenic) "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDVEPEA", "MDVFMKGLSGAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # K10G "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGGVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #F94G "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEACEMPSEEGYQDYEPEA", #Y125C "MDVFMKGLSKHKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #A11C "MDVFMKGLSKAKEGVVAASEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #A19S "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLTVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Y39T "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #M5A "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGCQDYEPEA", #Y133C "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDGLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Q99G "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDCEPEA", #Y136C "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLCVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Y39C "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDAPVDPDNEAYEMPSEEGYQDYEPEA", #M116A "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTGEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #K45G "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVGKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #K96G "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDGPVDPDNEAYEMPSEEGYQDYEPEA", #M116G "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAPEMPSEEGYQDYEPEA", #Y125P "GDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #M1G "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGSVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #F94S "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGVQDYEPEA", #Y133V "MDVFMKGLSKAKEGVVAAAEKTKGGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Q24G "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEGGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #E105G "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDTEPEA", #Y136T "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKGGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #E35G "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGTQDYEPEA", #Y133T "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTGEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #K60G "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #F4G "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVFGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #E83F "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTKVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #N65K "MDVFMKGLSKSKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #A11S "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTEVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA" #N65E ], 'label': [1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], 'mutation': ['A53T', 'None', 'None', 'Unknown', 'K10G', 'F94G', 'Y125C', 'A11C', 'A19S', 'Y39T', 'M5A', 'Y133C', 'Q99G', 'Y136C', 'Y39C', 'M116A', 'K45G', 'K96G', 'M116G', 'Y125P', 'M1G', 'F94S', 'Y133V', 'Q24G', 'E105G', 'Y136T', 'E35G', 'Y133T', 'K60G', 'F4G', 'E83F', 'N65K', 'A11S', 'N65E'] } return pd.DataFrame(data) def train_classifier(X, y): X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) clf = RandomForestClassifier(n_estimators=100, random_state=42) clf.fit(X_train, y_train) return clf, X_test, y_test # Main App def main(): st.title("🧬 Parkinson's Disease Protein Sequence Classifier") st.markdown(""" This app uses ProtBERT to generate protein sequence embeddings and a Random Forest classifier to predict whether a protein sequence is associated with Parkinson's disease. """) st.sidebar.header("About") st.sidebar.info(""" This tool uses: - ProtBERT for protein sequence embeddings - Random Forest for classification - Sample dataset of known variants - 3D structure prediction via ESMFold API """) with st.spinner("Loading ProtBERT model..."): tokenizer, model = load_protbert() if 'classifier' not in st.session_state: st.session_state.classifier = None st.session_state.X_test = None st.session_state.y_test = None st.session_state.training_data = None tab1, tab2, tab3, tab4 = st.tabs(["Train Model", "Evaluate Model", "Predict New Sequence", "Data Exploration"]) with tab1: st.header("Train Classification Model") if st.button("Train Model with Sample Data"): with st.spinner("Training in progress..."): df = get_sample_data() embeddings = [] progress_bar = st.progress(0) status_text = st.empty() for i, seq in enumerate(df['sequence']): try: status_text.text(f"Processing sequence {i+1}/{len(df['sequence'])}...") progress_bar.progress((i+1)/len(df['sequence'])) emb = get_protbert_embedding(seq, tokenizer, model) embeddings.append(emb) except Exception as e: st.warning(f"Error with sequence {i+1}: {str(e)}") embeddings.append(np.zeros(1024)) X = np.array(embeddings) y = df['label'].values clf, X_test, y_test = train_classifier(X, y) st.session_state.classifier = clf st.session_state.X_test = X_test st.session_state.y_test = y_test st.session_state.training_data = df st.success("Model trained successfully!") st.subheader("Sample Training Data") st.dataframe(df) st.subheader("Class Distribution") class_counts = df['label'].value_counts() fig = px.pie(values=class_counts, names=class_counts.index.map({0: 'Non-Parkinson', 1: 'Parkinson'})) st.plotly_chart(fig, use_container_width=True) with tab2: st.header("Evaluate Model Performance") if st.session_state.classifier is not None: clf = st.session_state.classifier X_test = st.session_state.X_test y_test = st.session_state.y_test y_pred = clf.predict(X_test) y_proba = clf.predict_proba(X_test)[:, 1] st.subheader("Classification Report") report = classification_report(y_test, y_pred, output_dict=True) st.dataframe(pd.DataFrame(report).transpose()) st.subheader("Confusion Matrix") cm = confusion_matrix(y_test, y_pred) fig = px.imshow(cm, labels=dict(x="Predicted", y="Actual", color="Count"), x=['Non-Parkinson', 'Parkinson'], y=['Non-Parkinson', 'Parkinson'], text_auto=True) st.plotly_chart(fig, use_container_width=True) st.subheader("Feature Importance") try: importances = clf.feature_importances_ top_n = 20 indices = np.argsort(importances)[-top_n:] fig = go.Figure() fig.add_trace(go.Bar( y=[f"Feature {i}" for i in indices], x=importances[indices], orientation='h' )) fig.update_layout(title=f"Top {top_n} Important Features", xaxis_title="Importance Score") st.plotly_chart(fig, use_container_width=True) except Exception as e: st.warning(f"Could not display feature importance: {str(e)}") else: st.warning("Please train the model first using the 'Train Model' tab.") with tab3: def fetch_structure(sequence): url = "https://api.esmatlas.com/foldSequence/v1/pdb/" headers = {"Content-Type": "text/plain"} try: response = requests.post(url, data=sequence, headers=headers, timeout=30) if response.status_code == 200: return response.text else: raise Exception(f"API returned status code {response.status_code}") except Exception as e: raise Exception(f"Failed to fetch structure: {str(e)}") def display_structure(pdb_data, color="chain", show_sidechains=True, show_mainchains=False): view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js') view.addModel(pdb_data, 'pdb') if color == "rainbow": view.setStyle({'cartoon': {'color': 'spectrum'}}) elif color == "chain": view.setStyle({'cartoon': {'color': 'chain'}}) elif color == "residue": view.setStyle({'cartoon': {'colorscheme': 'residue'}}) else: view.setStyle({'cartoon': {'color': 'white'}}) if show_sidechains: view.addStyle({'and': [{'atom': ['C', 'O', 'N'], 'invert': True}]}, {'stick': {'colorscheme': "WhiteCarbon", 'radius': 0.3}}) if show_mainchains: view.addStyle({'atom': ['C', 'O', 'N', 'CA']}, {'stick': {'colorscheme': "WhiteCarbon", 'radius': 0.3}}) view.zoomTo() return view st.header("Predict New Sequence") col1, col2 = st.columns(2) with col1: uploaded_file = st.file_uploader("Upload a FASTA file:", type=["fasta", "fa"]) seq_input = st.text_area( "Or enter protein sequence manually:", value="MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVTTVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", height=200 ) if uploaded_file is not None: try: fasta_content = uploaded_file.read().decode("utf-8") fasta_io = io.StringIO(fasta_content) record = next(SeqIO.parse(fasta_io, "fasta")) seq_input = str(record.seq) st.success(f"Sequence loaded from FASTA file: {record.id}") except Exception as e: st.error(f"Error reading FASTA file: {e}") with col2: if st.button("Analyze Sequence"): if not seq_input.strip(): st.error("Please enter a protein sequence.") else: with st.spinner("Analyzing sequence..."): try: analysis_results, parkinsons_analysis = analyze_protein(seq_input) if isinstance(analysis_results, str): st.error(analysis_results) else: st.subheader("Basic Properties") st.table(pd.DataFrame.from_dict(analysis_results['basic'], orient='index')) st.subheader("Amino Acid Composition") aa_df = pd.DataFrame.from_dict(analysis_results['aa_composition'], orient='index', columns=['Percentage']) st.bar_chart(aa_df) st.subheader("Secondary Structure") ss_df = pd.DataFrame.from_dict(analysis_results['secondary_structure'], orient='index', columns=['Fraction']) st.bar_chart(ss_df) st.subheader("Parkinson's Risk Analysis") if parkinsons_analysis['risk_factors']: st.warning("⚠️ Potential Parkinson's risk factors detected:") for factor in parkinsons_analysis['risk_factors']: st.write(f"- {factor}") else: st.success("No obvious Parkinson's risk factors detected") if parkinsons_analysis['notes']: st.info("Additional notes:") for note in parkinsons_analysis['notes']: st.write(f"- {note}") except Exception as e: st.error(f"Error analyzing sequence: {str(e)}") if st.button("Predict Parkinson's Association"): if st.session_state.classifier is None: st.error("Please train the model first using the 'Train Model' tab.") elif not seq_input.strip(): st.error("Please enter a protein sequence.") else: with st.spinner("Generating embedding and making prediction..."): try: new_emb = get_protbert_embedding(seq_input, tokenizer, model).reshape(1, -1) prediction = st.session_state.classifier.predict(new_emb) proba = st.session_state.classifier.predict_proba(new_emb) st.subheader("Prediction Result") col1, col2 = st.columns(2) with col1: if prediction[0] == 1: st.error("**Prediction: Parkinson-related protein**") else: st.success("**Prediction: Not Parkinson-related**") st.write(f"Confidence: {max(proba[0])*100:.2f}%") proba_df = pd.DataFrame({ "Class": ["Not Parkinson-related", "Parkinson-related"], "Probability": proba[0] }) fig = px.bar(proba_df, x='Class', y='Probability', color='Class', color_discrete_map={ "Not Parkinson-related": "green", "Parkinson-related": "red" }) st.plotly_chart(fig, use_container_width=True) # with col2: # # Show SHAP values if available # try: # explainer = shap.TreeExplainer(st.session_state.classifier) # shap_values = explainer.shap_values(new_emb) # fig, ax = plt.subplots() # shap.summary_plot(shap_values, new_emb, # feature_names=[f"Feature {i}" for i in range(new_emb.shape[1])], # plot_type="bar", # show=False) # st.pyplot(fig) # plt.close() # except Exception as e: # st.warning(f"Could not generate SHAP explanation: {str(e)}") st.subheader("3D Protein Structure Prediction") try: pdb_data = fetch_structure(seq_input) col1, col2 = st.columns(2) with col1: st.write("**Cartoon Representation**") view = display_structure(pdb_data, color="chain") st.components.v1.html(view._make_html(), height=500) with col2: st.write("**Residue Coloring**") view = display_structure(pdb_data, color="residue", show_sidechains=True) st.components.v1.html(view._make_html(), height=500) st.download_button( label="Download PDB File", data=pdb_data, file_name="predicted_structure.pdb", mime="chemical/x-pdb" ) except Exception as e: st.error(f"Could not fetch protein structure: {str(e)}") except Exception as e: st.error(f"Error processing sequence: {str(e)}") with tab4: st.header("Data Exploration") if st.session_state.training_data is not None: df = st.session_state.training_data st.subheader("Training Data Overview") st.dataframe(df) st.subheader("Mutation Analysis") mutation_counts = df['mutation'].value_counts().reset_index() mutation_counts.columns = ['Mutation', 'Count'] fig = px.bar(mutation_counts, x='Mutation', y='Count') st.plotly_chart(fig, use_container_width=True) st.subheader("Label Distribution by Mutation") fig = px.histogram(df, x='mutation', color='label', barmode='group', color_discrete_map={0: 'green', 1: 'red'}) st.plotly_chart(fig, use_container_width=True) st.subheader("Sequence Length Distribution") df['length'] = df['sequence'].apply(len) fig = px.histogram(df, x='length', color='label', color_discrete_map={0: 'green', 1: 'red'}) st.plotly_chart(fig, use_container_width=True) else: st.warning("Please train the model first to explore the data.") if __name__ == "__main__": main()