Spaces:
Build error
Build error
import streamlit as st | |
import torch | |
import pandas as pd | |
import numpy as np | |
from transformers import BertTokenizer, BertModel | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import classification_report, confusion_matrix | |
import requests | |
import py3Dmol | |
from Bio import SeqIO | |
import io | |
from Bio.SeqUtils.ProtParam import ProteinAnalysis | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from collections import Counter | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
# import shap | |
st.set_page_config( | |
page_title="Parkinson's Protein Classifier", | |
page_icon="🧬", | |
layout="wide" | |
) | |
# Load ProtBERT Model | |
def load_protbert(): | |
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) | |
model = BertModel.from_pretrained("Rostlab/prot_bert") | |
model.eval() | |
return tokenizer, model | |
# Embedding Function | |
def get_protbert_embedding(sequence, tokenizer, model): | |
sequence = sequence.replace(" ", "") | |
sequence = ' '.join(list(sequence)) | |
tokens = tokenizer(sequence, return_tensors='pt') | |
with torch.no_grad(): | |
outputs = model(**tokens) | |
embedding = torch.mean(outputs.last_hidden_state, dim=1) | |
return embedding.squeeze().numpy() | |
# Protein Analysis Function | |
def analyze_protein(sequence): | |
sequence = sequence.upper().replace(" ", "").replace("\n", "") | |
if not all(residue in "ACDEFGHIKLMNPQRSTVWY" for residue in sequence): | |
return "Invalid amino acid sequence!", None | |
analysis = ProteinAnalysis(sequence) | |
length = len(sequence) | |
mw = analysis.molecular_weight() | |
aromaticity = analysis.aromaticity() | |
instability = analysis.instability_index() | |
gravy = analysis.gravy() | |
aa_counts = analysis.count_amino_acids() | |
aa_percent = {k: v/length*100 for k, v in aa_counts.items()} | |
# Secondary structure | |
sec_struct = analysis.secondary_structure_fraction() | |
# Isoelectric point | |
pI = analysis.isoelectric_point() | |
# Flexibility | |
flexibility = analysis.flexibility() | |
results = { | |
'basic': { | |
'Length': length, | |
'Molecular Weight (Da)': mw, | |
'Aromaticity': aromaticity, | |
'Instability Index': instability, | |
'GRAVY (Hydrophobicity)': gravy, | |
'Isoelectric Point (pI)': pI | |
}, | |
'aa_composition': aa_percent, | |
'secondary_structure': { | |
'Helix': sec_struct[0], | |
'Turn': sec_struct[1], | |
'Sheet': sec_struct[2] | |
}, | |
'flexibility': flexibility | |
} | |
parkinsons_analysis = { | |
'risk_factors': [], | |
'notes': [] | |
} | |
if length != 140: | |
parkinsons_analysis['risk_factors'].append(f"Sequence length ({length}) deviates from wild-type (140)") | |
if mw > 14660 or mw < 14400: | |
parkinsons_analysis['risk_factors'].append(f"Molecular weight ({mw:.2f} Da) differs from wild-type (14.46 kDa)") | |
if aromaticity > 0.05: | |
parkinsons_analysis['risk_factors'].append("High aromaticity (potential aggregation risk)") | |
if instability > 45: | |
parkinsons_analysis['risk_factors'].append(f"High instability index ({instability:.2f}) suggests toxic form") | |
if gravy > -0.3: | |
parkinsons_analysis['risk_factors'].append(f"Hydrophobicity (GRAVY: {gravy:.3f}) suggests aggregation-prone variant") | |
key_positions = { | |
53: 'A53T (known pathogenic)', | |
30: 'E46K (known pathogenic)', | |
83: 'E83Q (known pathogenic)' | |
} | |
high_risk_aas = { | |
'C': "Cysteine residues can promote aggregation", | |
'G': "Glycine substitutions often pathogenic", | |
'P': "Proline substitutions can disrupt structure" | |
} | |
for aa, risk in high_risk_aas.items(): | |
if aa_counts.get(aa, 0) > 0: | |
parkinsons_analysis['notes'].append(f"{risk} ({aa_counts.get(aa, 0)} {aa} residues)") | |
return results, parkinsons_analysis | |
def get_sample_data(): | |
data = { | |
'sequence': [ | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVTTVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # A53T | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVVNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # Random (non-pathogenic) | |
"MDVFMKGLSKAKEGVVAAAIKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # Random (non-pathogenic) | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDVEPEA", | |
"MDVFMKGLSGAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # K10G | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGGVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #F94G | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEACEMPSEEGYQDYEPEA", #Y125C | |
"MDVFMKGLSKHKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #A11C | |
"MDVFMKGLSKAKEGVVAASEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #A19S | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLTVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Y39T | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #M5A | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGCQDYEPEA", #Y133C | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDGLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Q99G | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDCEPEA", #Y136C | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLCVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Y39C | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDAPVDPDNEAYEMPSEEGYQDYEPEA", #M116A | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTGEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #K45G | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVGKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #K96G | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDGPVDPDNEAYEMPSEEGYQDYEPEA", #M116G | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAPEMPSEEGYQDYEPEA", #Y125P | |
"GDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #M1G | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGSVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #F94S | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGVQDYEPEA", #Y133V | |
"MDVFMKGLSKAKEGVVAAAEKTKGGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Q24G | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEGGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #E105G | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDTEPEA", #Y136T | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKGGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #E35G | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGTQDYEPEA", #Y133T | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTGEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #K60G | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #F4G | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVFGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #E83F | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTKVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #N65K | |
"MDVFMKGLSKSKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #A11S | |
"MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTEVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA" #N65E | |
], | |
'label': [1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], | |
'mutation': ['A53T', 'None', 'None', 'Unknown', 'K10G', 'F94G', 'Y125C', 'A11C', 'A19S', 'Y39T', | |
'M5A', 'Y133C', 'Q99G', 'Y136C', 'Y39C', 'M116A', 'K45G', 'K96G', 'M116G', 'Y125P', | |
'M1G', 'F94S', 'Y133V', 'Q24G', 'E105G', 'Y136T', 'E35G', 'Y133T', 'K60G', 'F4G', | |
'E83F', 'N65K', 'A11S', 'N65E'] | |
} | |
return pd.DataFrame(data) | |
def train_classifier(X, y): | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
clf = RandomForestClassifier(n_estimators=100, random_state=42) | |
clf.fit(X_train, y_train) | |
return clf, X_test, y_test | |
# Main App | |
def main(): | |
st.title("🧬 Parkinson's Disease Protein Sequence Classifier") | |
st.markdown(""" | |
This app uses ProtBERT to generate protein sequence embeddings and a Random Forest classifier | |
to predict whether a protein sequence is associated with Parkinson's disease. | |
""") | |
st.sidebar.header("About") | |
st.sidebar.info(""" | |
This tool uses: | |
- ProtBERT for protein sequence embeddings | |
- Random Forest for classification | |
- Sample dataset of known variants | |
- 3D structure prediction via ESMFold API | |
""") | |
with st.spinner("Loading ProtBERT model..."): | |
tokenizer, model = load_protbert() | |
if 'classifier' not in st.session_state: | |
st.session_state.classifier = None | |
st.session_state.X_test = None | |
st.session_state.y_test = None | |
st.session_state.training_data = None | |
tab1, tab2, tab3, tab4 = st.tabs(["Train Model", "Evaluate Model", "Predict New Sequence", "Data Exploration"]) | |
with tab1: | |
st.header("Train Classification Model") | |
if st.button("Train Model with Sample Data"): | |
with st.spinner("Training in progress..."): | |
df = get_sample_data() | |
embeddings = [] | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
for i, seq in enumerate(df['sequence']): | |
try: | |
status_text.text(f"Processing sequence {i+1}/{len(df['sequence'])}...") | |
progress_bar.progress((i+1)/len(df['sequence'])) | |
emb = get_protbert_embedding(seq, tokenizer, model) | |
embeddings.append(emb) | |
except Exception as e: | |
st.warning(f"Error with sequence {i+1}: {str(e)}") | |
embeddings.append(np.zeros(1024)) | |
X = np.array(embeddings) | |
y = df['label'].values | |
clf, X_test, y_test = train_classifier(X, y) | |
st.session_state.classifier = clf | |
st.session_state.X_test = X_test | |
st.session_state.y_test = y_test | |
st.session_state.training_data = df | |
st.success("Model trained successfully!") | |
st.subheader("Sample Training Data") | |
st.dataframe(df) | |
st.subheader("Class Distribution") | |
class_counts = df['label'].value_counts() | |
fig = px.pie(values=class_counts, names=class_counts.index.map({0: 'Non-Parkinson', 1: 'Parkinson'})) | |
st.plotly_chart(fig, use_container_width=True) | |
with tab2: | |
st.header("Evaluate Model Performance") | |
if st.session_state.classifier is not None: | |
clf = st.session_state.classifier | |
X_test = st.session_state.X_test | |
y_test = st.session_state.y_test | |
y_pred = clf.predict(X_test) | |
y_proba = clf.predict_proba(X_test)[:, 1] | |
st.subheader("Classification Report") | |
report = classification_report(y_test, y_pred, output_dict=True) | |
st.dataframe(pd.DataFrame(report).transpose()) | |
st.subheader("Confusion Matrix") | |
cm = confusion_matrix(y_test, y_pred) | |
fig = px.imshow(cm, | |
labels=dict(x="Predicted", y="Actual", color="Count"), | |
x=['Non-Parkinson', 'Parkinson'], | |
y=['Non-Parkinson', 'Parkinson'], | |
text_auto=True) | |
st.plotly_chart(fig, use_container_width=True) | |
st.subheader("Feature Importance") | |
try: | |
importances = clf.feature_importances_ | |
top_n = 20 | |
indices = np.argsort(importances)[-top_n:] | |
fig = go.Figure() | |
fig.add_trace(go.Bar( | |
y=[f"Feature {i}" for i in indices], | |
x=importances[indices], | |
orientation='h' | |
)) | |
fig.update_layout(title=f"Top {top_n} Important Features", | |
xaxis_title="Importance Score") | |
st.plotly_chart(fig, use_container_width=True) | |
except Exception as e: | |
st.warning(f"Could not display feature importance: {str(e)}") | |
else: | |
st.warning("Please train the model first using the 'Train Model' tab.") | |
with tab3: | |
def fetch_structure(sequence): | |
url = "https://api.esmatlas.com/foldSequence/v1/pdb/" | |
headers = {"Content-Type": "text/plain"} | |
try: | |
response = requests.post(url, data=sequence, headers=headers, timeout=30) | |
if response.status_code == 200: | |
return response.text | |
else: | |
raise Exception(f"API returned status code {response.status_code}") | |
except Exception as e: | |
raise Exception(f"Failed to fetch structure: {str(e)}") | |
def display_structure(pdb_data, color="chain", show_sidechains=True, show_mainchains=False): | |
view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js') | |
view.addModel(pdb_data, 'pdb') | |
if color == "rainbow": | |
view.setStyle({'cartoon': {'color': 'spectrum'}}) | |
elif color == "chain": | |
view.setStyle({'cartoon': {'color': 'chain'}}) | |
elif color == "residue": | |
view.setStyle({'cartoon': {'colorscheme': 'residue'}}) | |
else: | |
view.setStyle({'cartoon': {'color': 'white'}}) | |
if show_sidechains: | |
view.addStyle({'and': [{'atom': ['C', 'O', 'N'], 'invert': True}]}, | |
{'stick': {'colorscheme': "WhiteCarbon", 'radius': 0.3}}) | |
if show_mainchains: | |
view.addStyle({'atom': ['C', 'O', 'N', 'CA']}, | |
{'stick': {'colorscheme': "WhiteCarbon", 'radius': 0.3}}) | |
view.zoomTo() | |
return view | |
st.header("Predict New Sequence") | |
col1, col2 = st.columns(2) | |
with col1: | |
uploaded_file = st.file_uploader("Upload a FASTA file:", type=["fasta", "fa"]) | |
seq_input = st.text_area( | |
"Or enter protein sequence manually:", | |
value="MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVTTVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", | |
height=200 | |
) | |
if uploaded_file is not None: | |
try: | |
fasta_content = uploaded_file.read().decode("utf-8") | |
fasta_io = io.StringIO(fasta_content) | |
record = next(SeqIO.parse(fasta_io, "fasta")) | |
seq_input = str(record.seq) | |
st.success(f"Sequence loaded from FASTA file: {record.id}") | |
except Exception as e: | |
st.error(f"Error reading FASTA file: {e}") | |
with col2: | |
if st.button("Analyze Sequence"): | |
if not seq_input.strip(): | |
st.error("Please enter a protein sequence.") | |
else: | |
with st.spinner("Analyzing sequence..."): | |
try: | |
analysis_results, parkinsons_analysis = analyze_protein(seq_input) | |
if isinstance(analysis_results, str): | |
st.error(analysis_results) | |
else: | |
st.subheader("Basic Properties") | |
st.table(pd.DataFrame.from_dict(analysis_results['basic'], orient='index')) | |
st.subheader("Amino Acid Composition") | |
aa_df = pd.DataFrame.from_dict(analysis_results['aa_composition'], orient='index', columns=['Percentage']) | |
st.bar_chart(aa_df) | |
st.subheader("Secondary Structure") | |
ss_df = pd.DataFrame.from_dict(analysis_results['secondary_structure'], orient='index', columns=['Fraction']) | |
st.bar_chart(ss_df) | |
st.subheader("Parkinson's Risk Analysis") | |
if parkinsons_analysis['risk_factors']: | |
st.warning("⚠️ Potential Parkinson's risk factors detected:") | |
for factor in parkinsons_analysis['risk_factors']: | |
st.write(f"- {factor}") | |
else: | |
st.success("No obvious Parkinson's risk factors detected") | |
if parkinsons_analysis['notes']: | |
st.info("Additional notes:") | |
for note in parkinsons_analysis['notes']: | |
st.write(f"- {note}") | |
except Exception as e: | |
st.error(f"Error analyzing sequence: {str(e)}") | |
if st.button("Predict Parkinson's Association"): | |
if st.session_state.classifier is None: | |
st.error("Please train the model first using the 'Train Model' tab.") | |
elif not seq_input.strip(): | |
st.error("Please enter a protein sequence.") | |
else: | |
with st.spinner("Generating embedding and making prediction..."): | |
try: | |
new_emb = get_protbert_embedding(seq_input, tokenizer, model).reshape(1, -1) | |
prediction = st.session_state.classifier.predict(new_emb) | |
proba = st.session_state.classifier.predict_proba(new_emb) | |
st.subheader("Prediction Result") | |
col1, col2 = st.columns(2) | |
with col1: | |
if prediction[0] == 1: | |
st.error("**Prediction: Parkinson-related protein**") | |
else: | |
st.success("**Prediction: Not Parkinson-related**") | |
st.write(f"Confidence: {max(proba[0])*100:.2f}%") | |
proba_df = pd.DataFrame({ | |
"Class": ["Not Parkinson-related", "Parkinson-related"], | |
"Probability": proba[0] | |
}) | |
fig = px.bar(proba_df, x='Class', y='Probability', | |
color='Class', | |
color_discrete_map={ | |
"Not Parkinson-related": "green", | |
"Parkinson-related": "red" | |
}) | |
st.plotly_chart(fig, use_container_width=True) | |
# with col2: | |
# # Show SHAP values if available | |
# try: | |
# explainer = shap.TreeExplainer(st.session_state.classifier) | |
# shap_values = explainer.shap_values(new_emb) | |
# fig, ax = plt.subplots() | |
# shap.summary_plot(shap_values, new_emb, | |
# feature_names=[f"Feature {i}" for i in range(new_emb.shape[1])], | |
# plot_type="bar", | |
# show=False) | |
# st.pyplot(fig) | |
# plt.close() | |
# except Exception as e: | |
# st.warning(f"Could not generate SHAP explanation: {str(e)}") | |
st.subheader("3D Protein Structure Prediction") | |
try: | |
pdb_data = fetch_structure(seq_input) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("**Cartoon Representation**") | |
view = display_structure(pdb_data, color="chain") | |
st.components.v1.html(view._make_html(), height=500) | |
with col2: | |
st.write("**Residue Coloring**") | |
view = display_structure(pdb_data, color="residue", show_sidechains=True) | |
st.components.v1.html(view._make_html(), height=500) | |
st.download_button( | |
label="Download PDB File", | |
data=pdb_data, | |
file_name="predicted_structure.pdb", | |
mime="chemical/x-pdb" | |
) | |
except Exception as e: | |
st.error(f"Could not fetch protein structure: {str(e)}") | |
except Exception as e: | |
st.error(f"Error processing sequence: {str(e)}") | |
with tab4: | |
st.header("Data Exploration") | |
if st.session_state.training_data is not None: | |
df = st.session_state.training_data | |
st.subheader("Training Data Overview") | |
st.dataframe(df) | |
st.subheader("Mutation Analysis") | |
mutation_counts = df['mutation'].value_counts().reset_index() | |
mutation_counts.columns = ['Mutation', 'Count'] | |
fig = px.bar(mutation_counts, x='Mutation', y='Count') | |
st.plotly_chart(fig, use_container_width=True) | |
st.subheader("Label Distribution by Mutation") | |
fig = px.histogram(df, x='mutation', color='label', | |
barmode='group', | |
color_discrete_map={0: 'green', 1: 'red'}) | |
st.plotly_chart(fig, use_container_width=True) | |
st.subheader("Sequence Length Distribution") | |
df['length'] = df['sequence'].apply(len) | |
fig = px.histogram(df, x='length', color='label', | |
color_discrete_map={0: 'green', 1: 'red'}) | |
st.plotly_chart(fig, use_container_width=True) | |
else: | |
st.warning("Please train the model first to explore the data.") | |
if __name__ == "__main__": | |
main() |