Yashwanth11187's picture
Update app.py
81b3d47 verified
import streamlit as st
import torch
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertModel
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import requests
import py3Dmol
from Bio import SeqIO
import io
from Bio.SeqUtils.ProtParam import ProteinAnalysis
import plotly.express as px
import plotly.graph_objects as go
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
# import shap
st.set_page_config(
page_title="Parkinson's Protein Classifier",
page_icon="🧬",
layout="wide"
)
# Load ProtBERT Model
@st.cache_resource
def load_protbert():
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
model = BertModel.from_pretrained("Rostlab/prot_bert")
model.eval()
return tokenizer, model
# Embedding Function
def get_protbert_embedding(sequence, tokenizer, model):
sequence = sequence.replace(" ", "")
sequence = ' '.join(list(sequence))
tokens = tokenizer(sequence, return_tensors='pt')
with torch.no_grad():
outputs = model(**tokens)
embedding = torch.mean(outputs.last_hidden_state, dim=1)
return embedding.squeeze().numpy()
# Protein Analysis Function
def analyze_protein(sequence):
sequence = sequence.upper().replace(" ", "").replace("\n", "")
if not all(residue in "ACDEFGHIKLMNPQRSTVWY" for residue in sequence):
return "Invalid amino acid sequence!", None
analysis = ProteinAnalysis(sequence)
length = len(sequence)
mw = analysis.molecular_weight()
aromaticity = analysis.aromaticity()
instability = analysis.instability_index()
gravy = analysis.gravy()
aa_counts = analysis.count_amino_acids()
aa_percent = {k: v/length*100 for k, v in aa_counts.items()}
# Secondary structure
sec_struct = analysis.secondary_structure_fraction()
# Isoelectric point
pI = analysis.isoelectric_point()
# Flexibility
flexibility = analysis.flexibility()
results = {
'basic': {
'Length': length,
'Molecular Weight (Da)': mw,
'Aromaticity': aromaticity,
'Instability Index': instability,
'GRAVY (Hydrophobicity)': gravy,
'Isoelectric Point (pI)': pI
},
'aa_composition': aa_percent,
'secondary_structure': {
'Helix': sec_struct[0],
'Turn': sec_struct[1],
'Sheet': sec_struct[2]
},
'flexibility': flexibility
}
parkinsons_analysis = {
'risk_factors': [],
'notes': []
}
if length != 140:
parkinsons_analysis['risk_factors'].append(f"Sequence length ({length}) deviates from wild-type (140)")
if mw > 14660 or mw < 14400:
parkinsons_analysis['risk_factors'].append(f"Molecular weight ({mw:.2f} Da) differs from wild-type (14.46 kDa)")
if aromaticity > 0.05:
parkinsons_analysis['risk_factors'].append("High aromaticity (potential aggregation risk)")
if instability > 45:
parkinsons_analysis['risk_factors'].append(f"High instability index ({instability:.2f}) suggests toxic form")
if gravy > -0.3:
parkinsons_analysis['risk_factors'].append(f"Hydrophobicity (GRAVY: {gravy:.3f}) suggests aggregation-prone variant")
key_positions = {
53: 'A53T (known pathogenic)',
30: 'E46K (known pathogenic)',
83: 'E83Q (known pathogenic)'
}
high_risk_aas = {
'C': "Cysteine residues can promote aggregation",
'G': "Glycine substitutions often pathogenic",
'P': "Proline substitutions can disrupt structure"
}
for aa, risk in high_risk_aas.items():
if aa_counts.get(aa, 0) > 0:
parkinsons_analysis['notes'].append(f"{risk} ({aa_counts.get(aa, 0)} {aa} residues)")
return results, parkinsons_analysis
def get_sample_data():
data = {
'sequence': [
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVTTVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # A53T
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVVNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # Random (non-pathogenic)
"MDVFMKGLSKAKEGVVAAAIKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # Random (non-pathogenic)
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDVEPEA",
"MDVFMKGLSGAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # K10G
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGGVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #F94G
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEACEMPSEEGYQDYEPEA", #Y125C
"MDVFMKGLSKHKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #A11C
"MDVFMKGLSKAKEGVVAASEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #A19S
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLTVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Y39T
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #M5A
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGCQDYEPEA", #Y133C
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDGLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Q99G
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDCEPEA", #Y136C
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLCVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Y39C
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDAPVDPDNEAYEMPSEEGYQDYEPEA", #M116A
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTGEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #K45G
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVGKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #K96G
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDGPVDPDNEAYEMPSEEGYQDYEPEA", #M116G
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAPEMPSEEGYQDYEPEA", #Y125P
"GDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #M1G
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGSVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #F94S
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGVQDYEPEA", #Y133V
"MDVFMKGLSKAKEGVVAAAEKTKGGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Q24G
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEGGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #E105G
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDTEPEA", #Y136T
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKGGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #E35G
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGTQDYEPEA", #Y133T
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTGEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #K60G
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #F4G
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVFGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #E83F
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTKVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #N65K
"MDVFMKGLSKSKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #A11S
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTEVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA" #N65E
],
'label': [1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
'mutation': ['A53T', 'None', 'None', 'Unknown', 'K10G', 'F94G', 'Y125C', 'A11C', 'A19S', 'Y39T',
'M5A', 'Y133C', 'Q99G', 'Y136C', 'Y39C', 'M116A', 'K45G', 'K96G', 'M116G', 'Y125P',
'M1G', 'F94S', 'Y133V', 'Q24G', 'E105G', 'Y136T', 'E35G', 'Y133T', 'K60G', 'F4G',
'E83F', 'N65K', 'A11S', 'N65E']
}
return pd.DataFrame(data)
def train_classifier(X, y):
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
return clf, X_test, y_test
# Main App
def main():
st.title("🧬 Parkinson's Disease Protein Sequence Classifier")
st.markdown("""
This app uses ProtBERT to generate protein sequence embeddings and a Random Forest classifier
to predict whether a protein sequence is associated with Parkinson's disease.
""")
st.sidebar.header("About")
st.sidebar.info("""
This tool uses:
- ProtBERT for protein sequence embeddings
- Random Forest for classification
- Sample dataset of known variants
- 3D structure prediction via ESMFold API
""")
with st.spinner("Loading ProtBERT model..."):
tokenizer, model = load_protbert()
if 'classifier' not in st.session_state:
st.session_state.classifier = None
st.session_state.X_test = None
st.session_state.y_test = None
st.session_state.training_data = None
tab1, tab2, tab3, tab4 = st.tabs(["Train Model", "Evaluate Model", "Predict New Sequence", "Data Exploration"])
with tab1:
st.header("Train Classification Model")
if st.button("Train Model with Sample Data"):
with st.spinner("Training in progress..."):
df = get_sample_data()
embeddings = []
progress_bar = st.progress(0)
status_text = st.empty()
for i, seq in enumerate(df['sequence']):
try:
status_text.text(f"Processing sequence {i+1}/{len(df['sequence'])}...")
progress_bar.progress((i+1)/len(df['sequence']))
emb = get_protbert_embedding(seq, tokenizer, model)
embeddings.append(emb)
except Exception as e:
st.warning(f"Error with sequence {i+1}: {str(e)}")
embeddings.append(np.zeros(1024))
X = np.array(embeddings)
y = df['label'].values
clf, X_test, y_test = train_classifier(X, y)
st.session_state.classifier = clf
st.session_state.X_test = X_test
st.session_state.y_test = y_test
st.session_state.training_data = df
st.success("Model trained successfully!")
st.subheader("Sample Training Data")
st.dataframe(df)
st.subheader("Class Distribution")
class_counts = df['label'].value_counts()
fig = px.pie(values=class_counts, names=class_counts.index.map({0: 'Non-Parkinson', 1: 'Parkinson'}))
st.plotly_chart(fig, use_container_width=True)
with tab2:
st.header("Evaluate Model Performance")
if st.session_state.classifier is not None:
clf = st.session_state.classifier
X_test = st.session_state.X_test
y_test = st.session_state.y_test
y_pred = clf.predict(X_test)
y_proba = clf.predict_proba(X_test)[:, 1]
st.subheader("Classification Report")
report = classification_report(y_test, y_pred, output_dict=True)
st.dataframe(pd.DataFrame(report).transpose())
st.subheader("Confusion Matrix")
cm = confusion_matrix(y_test, y_pred)
fig = px.imshow(cm,
labels=dict(x="Predicted", y="Actual", color="Count"),
x=['Non-Parkinson', 'Parkinson'],
y=['Non-Parkinson', 'Parkinson'],
text_auto=True)
st.plotly_chart(fig, use_container_width=True)
st.subheader("Feature Importance")
try:
importances = clf.feature_importances_
top_n = 20
indices = np.argsort(importances)[-top_n:]
fig = go.Figure()
fig.add_trace(go.Bar(
y=[f"Feature {i}" for i in indices],
x=importances[indices],
orientation='h'
))
fig.update_layout(title=f"Top {top_n} Important Features",
xaxis_title="Importance Score")
st.plotly_chart(fig, use_container_width=True)
except Exception as e:
st.warning(f"Could not display feature importance: {str(e)}")
else:
st.warning("Please train the model first using the 'Train Model' tab.")
with tab3:
def fetch_structure(sequence):
url = "https://api.esmatlas.com/foldSequence/v1/pdb/"
headers = {"Content-Type": "text/plain"}
try:
response = requests.post(url, data=sequence, headers=headers, timeout=30)
if response.status_code == 200:
return response.text
else:
raise Exception(f"API returned status code {response.status_code}")
except Exception as e:
raise Exception(f"Failed to fetch structure: {str(e)}")
def display_structure(pdb_data, color="chain", show_sidechains=True, show_mainchains=False):
view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
view.addModel(pdb_data, 'pdb')
if color == "rainbow":
view.setStyle({'cartoon': {'color': 'spectrum'}})
elif color == "chain":
view.setStyle({'cartoon': {'color': 'chain'}})
elif color == "residue":
view.setStyle({'cartoon': {'colorscheme': 'residue'}})
else:
view.setStyle({'cartoon': {'color': 'white'}})
if show_sidechains:
view.addStyle({'and': [{'atom': ['C', 'O', 'N'], 'invert': True}]},
{'stick': {'colorscheme': "WhiteCarbon", 'radius': 0.3}})
if show_mainchains:
view.addStyle({'atom': ['C', 'O', 'N', 'CA']},
{'stick': {'colorscheme': "WhiteCarbon", 'radius': 0.3}})
view.zoomTo()
return view
st.header("Predict New Sequence")
col1, col2 = st.columns(2)
with col1:
uploaded_file = st.file_uploader("Upload a FASTA file:", type=["fasta", "fa"])
seq_input = st.text_area(
"Or enter protein sequence manually:",
value="MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVTTVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA",
height=200
)
if uploaded_file is not None:
try:
fasta_content = uploaded_file.read().decode("utf-8")
fasta_io = io.StringIO(fasta_content)
record = next(SeqIO.parse(fasta_io, "fasta"))
seq_input = str(record.seq)
st.success(f"Sequence loaded from FASTA file: {record.id}")
except Exception as e:
st.error(f"Error reading FASTA file: {e}")
with col2:
if st.button("Analyze Sequence"):
if not seq_input.strip():
st.error("Please enter a protein sequence.")
else:
with st.spinner("Analyzing sequence..."):
try:
analysis_results, parkinsons_analysis = analyze_protein(seq_input)
if isinstance(analysis_results, str):
st.error(analysis_results)
else:
st.subheader("Basic Properties")
st.table(pd.DataFrame.from_dict(analysis_results['basic'], orient='index'))
st.subheader("Amino Acid Composition")
aa_df = pd.DataFrame.from_dict(analysis_results['aa_composition'], orient='index', columns=['Percentage'])
st.bar_chart(aa_df)
st.subheader("Secondary Structure")
ss_df = pd.DataFrame.from_dict(analysis_results['secondary_structure'], orient='index', columns=['Fraction'])
st.bar_chart(ss_df)
st.subheader("Parkinson's Risk Analysis")
if parkinsons_analysis['risk_factors']:
st.warning("⚠️ Potential Parkinson's risk factors detected:")
for factor in parkinsons_analysis['risk_factors']:
st.write(f"- {factor}")
else:
st.success("No obvious Parkinson's risk factors detected")
if parkinsons_analysis['notes']:
st.info("Additional notes:")
for note in parkinsons_analysis['notes']:
st.write(f"- {note}")
except Exception as e:
st.error(f"Error analyzing sequence: {str(e)}")
if st.button("Predict Parkinson's Association"):
if st.session_state.classifier is None:
st.error("Please train the model first using the 'Train Model' tab.")
elif not seq_input.strip():
st.error("Please enter a protein sequence.")
else:
with st.spinner("Generating embedding and making prediction..."):
try:
new_emb = get_protbert_embedding(seq_input, tokenizer, model).reshape(1, -1)
prediction = st.session_state.classifier.predict(new_emb)
proba = st.session_state.classifier.predict_proba(new_emb)
st.subheader("Prediction Result")
col1, col2 = st.columns(2)
with col1:
if prediction[0] == 1:
st.error("**Prediction: Parkinson-related protein**")
else:
st.success("**Prediction: Not Parkinson-related**")
st.write(f"Confidence: {max(proba[0])*100:.2f}%")
proba_df = pd.DataFrame({
"Class": ["Not Parkinson-related", "Parkinson-related"],
"Probability": proba[0]
})
fig = px.bar(proba_df, x='Class', y='Probability',
color='Class',
color_discrete_map={
"Not Parkinson-related": "green",
"Parkinson-related": "red"
})
st.plotly_chart(fig, use_container_width=True)
# with col2:
# # Show SHAP values if available
# try:
# explainer = shap.TreeExplainer(st.session_state.classifier)
# shap_values = explainer.shap_values(new_emb)
# fig, ax = plt.subplots()
# shap.summary_plot(shap_values, new_emb,
# feature_names=[f"Feature {i}" for i in range(new_emb.shape[1])],
# plot_type="bar",
# show=False)
# st.pyplot(fig)
# plt.close()
# except Exception as e:
# st.warning(f"Could not generate SHAP explanation: {str(e)}")
st.subheader("3D Protein Structure Prediction")
try:
pdb_data = fetch_structure(seq_input)
col1, col2 = st.columns(2)
with col1:
st.write("**Cartoon Representation**")
view = display_structure(pdb_data, color="chain")
st.components.v1.html(view._make_html(), height=500)
with col2:
st.write("**Residue Coloring**")
view = display_structure(pdb_data, color="residue", show_sidechains=True)
st.components.v1.html(view._make_html(), height=500)
st.download_button(
label="Download PDB File",
data=pdb_data,
file_name="predicted_structure.pdb",
mime="chemical/x-pdb"
)
except Exception as e:
st.error(f"Could not fetch protein structure: {str(e)}")
except Exception as e:
st.error(f"Error processing sequence: {str(e)}")
with tab4:
st.header("Data Exploration")
if st.session_state.training_data is not None:
df = st.session_state.training_data
st.subheader("Training Data Overview")
st.dataframe(df)
st.subheader("Mutation Analysis")
mutation_counts = df['mutation'].value_counts().reset_index()
mutation_counts.columns = ['Mutation', 'Count']
fig = px.bar(mutation_counts, x='Mutation', y='Count')
st.plotly_chart(fig, use_container_width=True)
st.subheader("Label Distribution by Mutation")
fig = px.histogram(df, x='mutation', color='label',
barmode='group',
color_discrete_map={0: 'green', 1: 'red'})
st.plotly_chart(fig, use_container_width=True)
st.subheader("Sequence Length Distribution")
df['length'] = df['sequence'].apply(len)
fig = px.histogram(df, x='length', color='label',
color_discrete_map={0: 'green', 1: 'red'})
st.plotly_chart(fig, use_container_width=True)
else:
st.warning("Please train the model first to explore the data.")
if __name__ == "__main__":
main()