Spaces:
Sleeping
Sleeping
import streamlit as st | |
import re | |
import numpy as np | |
import pandas as pd | |
import pickle | |
import sklearn | |
import catboost | |
import shap | |
from shap_plots import shap_summary_plot | |
from dynamic_shap_plots import matplotlib_to_plotly, summary_plot_plotly_fig | |
import plotly.tools as tls | |
from dash import dcc | |
import matplotlib.pyplot as plt | |
import plotly.graph_objs as go | |
try: | |
import matplotlib.pyplot as pl | |
from matplotlib.colors import LinearSegmentedColormap | |
from matplotlib.ticker import MaxNLocator | |
except ImportError: | |
pass | |
st.set_option('deprecation.showPyplotGlobalUse', False) | |
seed = 0 | |
annotations = pd.read_csv("all_genes_merged_ml_data.csv") | |
annotations.fillna(0, inplace=True) | |
annotations = annotations.set_index("Gene") | |
model_path = "best_model_fitted.pkl" | |
with open(model_path, 'rb') as file: | |
catboost_model = pickle.load(file) | |
probabilities = catboost_model.predict_proba(annotations) | |
prob_df = pd.DataFrame(probabilities, index=annotations.index, columns=['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely']) | |
df_total = pd.concat([prob_df, annotations], axis=1) | |
# Create tabs for navigation | |
with st.sidebar: | |
st.sidebar.title("Navigation") | |
tab = st.sidebar.radio("Go to", ("Gene Prioritisation", "Interactive SHAP Plot", "Supervised SHAP Clustering")) | |
st.title('Blood Pressure Gene Prioritisation Post-GWAS') | |
st.markdown("""A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure.""") | |
# Define a function to collect genes from input | |
collect_genes = lambda x: [str(i) for i in re.split(",|,\s+|\s+", x) if i != ""] | |
input_gene_list = st.text_input("Input a list of multiple HGNC genes (enter comma separated):") | |
gene_list = collect_genes(input_gene_list) | |
explainer = shap.TreeExplainer(catboost_model) | |
def convert_df(df): | |
return df.to_csv(index=False).encode('utf-8') | |
probability_columns = ['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely'] | |
features_list = [column for column in df_total.columns if column not in probability_columns] | |
features = df_total[features_list] | |
# Page 1: Gene Prioritisation | |
if tab == "Gene Prioritisation": | |
if len(gene_list) > 1: | |
df = df_total[df_total.index.isin(gene_list)] | |
df['Gene'] = df.index | |
df.reset_index(drop=True, inplace=True) | |
required_columns = ['Gene'] + probability_columns + [column for column in df.columns if column not in probability_columns and column != 'Gene'] | |
df = df[required_columns] | |
st.dataframe(df) | |
output = df[['Gene'] + probability_columns] | |
csv = convert_df(output) | |
st.download_button("Download Gene Prioritisation", csv, "bp_gene_prioritisation.csv", "text/csv", key='download-csv') | |
df_shap = df.drop(columns=probability_columns + ['Gene']) | |
shap_values = explainer.shap_values(df_shap) | |
col1, col2 = st.columns(2) | |
class_names = ["Most likely", "Probable", "Least likely"] | |
with col1: | |
st.subheader("Global SHAP Summary Plot") | |
shap.summary_plot(shap_values, df_shap, plot_type="bar", class_names=class_names) | |
st.pyplot(bbox_inches='tight', clear_figure=True) | |
with col2: | |
st.subheader(f"{class_names[0]} Gene Prediction") | |
shap.summary_plot(shap_values[0], df_shap) | |
st.pyplot(bbox_inches='tight', clear_figure=True) | |
col3, col4 = st.columns(2) | |
with col3: | |
st.subheader(f"{class_names[1]} Gene Prediction") | |
shap.summary_plot(shap_values[1], df_shap) | |
st.pyplot(bbox_inches='tight', clear_figure=True) | |
with col4: | |
st.subheader(f"{class_names[2]} Gene Prediction") | |
shap.summary_plot(shap_values[2], df_shap) | |
st.pyplot(bbox_inches='tight', clear_figure=True) | |
else: | |
pass | |
input_gene = st.text_input("Input an individual HGNC gene:") | |
if input_gene: | |
df2 = df_total[df_total.index == input_gene] | |
class_names = ["Most likely", "Probable", "Least likely"] | |
if not df2.empty: | |
df2['Gene'] = df2.index | |
df2.reset_index(drop=True, inplace=True) | |
required_columns = ['Gene'] + probability_columns + [col for col in df2.columns if col not in probability_columns and col != 'Gene'] | |
df2 = df2[required_columns] | |
st.dataframe(df2) | |
if ' ' in input_gene or ',' in input_gene: | |
st.write('Input Error: Please input only a single HGNC gene name with no white spaces or commas.') | |
else: | |
df2_shap = df_total.loc[[input_gene], [col for col in df_total.columns if col not in probability_columns + ['Gene']]] | |
print(df2_shap.columns) | |
shap_values = explainer.shap_values(df2_shap) | |
shap.getjs() | |
for i in range(3): | |
st.subheader(f"Force Plot for {class_names[i]} Prediction") | |
force_plot = shap.force_plot( | |
explainer.expected_value[i], | |
shap_values[i], | |
df2_shap, | |
matplotlib=True, | |
show=False | |
) | |
st.pyplot(fig=force_plot) | |
else: | |
st.write("Gene not found in the dataset.") | |
else: | |
pass | |
st.markdown(""" | |
### Total Gene Prioritisation Results for All Genes: | |
""") | |
df_total_output = df_total | |
df_total_output['Gene'] = df_total_output.index | |
#df_total_output.reset_index(drop=True, inplace=True) | |
st.dataframe(df_total_output) | |
csv = convert_df(df_total_output) | |
st.download_button("Download Gene Prioritisation", csv, "all_genes_bp_prioritisation.csv", "text/csv", key='download-all-csv') | |
# Page 2: Interactive SHAP Plot | |
elif tab == "Interactive SHAP Plot": | |
st.title("Interactive SHAP Plot") | |
if len(gene_list) > 1: | |
df = df_total[df_total.index.isin(gene_list)] | |
df['Gene'] = df.index | |
df.reset_index(drop=True, inplace=True) | |
required_columns = ['Gene'] + probability_columns + [column for column in df.columns if column not in probability_columns and column != 'Gene'] | |
df = df[required_columns] | |
st.dataframe(df) | |
output = df[['Gene'] + probability_columns] | |
csv = convert_df(output) | |
st.download_button("Download Gene Prioritisation", csv, "bp_gene_prioritisation.csv", "text/csv", key='download-csv') | |
df_shap = df.drop(columns=probability_columns + ['Gene']) | |
shap_values = explainer.shap_values(df_shap) | |
# Use shap's summary_plot function for interactivity | |
# summary_plot = shap.summary_plot(shap_values[0], df_shap, plot_type='interactive', max_display=10) | |
summary_plot = summary_plot_plotly_fig(df_shap, shap_values[0], max_display=10) | |
st.pyplot(summary_plot) | |
st.caption("SHAP Summary Plot of All Input Genes") | |
# Page 3: Supervised SHAP Clustering | |
elif tab == "Supervised SHAP Clustering": | |
st.title("Supervised SHAP Clustering") | |
# Add your code here to implement supervised SHAP clustering | |