vesm-variants / app.py
vasilisNt's picture
Upload 5 files
2734263 verified
import streamlit as st
st.set_page_config(layout="wide")
import pandas as pd
import numpy as np
from zipfile import ZipFile
import plotly.express as px
import plotly.graph_objs as go
LLR_FILE = 'UniProtKB_human_VESM_llrs.zip'
df = pd.read_csv('UniProtKB_id_names.csv', index_col=0)
if 'shuffled_df' not in st.session_state:
st.session_state.shuffled_df = df.sample(frac=1)
df = st.session_state.shuffled_df
clinvar = pd.read_csv('clinvar_0325.csv.gz',index_col=0)
f = np.load("logreg_params.npz")
coef, intercept = f["coef"].item(), f["intercept"].item()
def load_LLR(uniprot_id):
'''Loads the LLRs for a given uniprot id. Returns a 20xL dataframe.
Rows are indexed by AA change,
(AAorder=['K','R','H','E','D','N','Q','T','S','C','G','A','V','L','I','M','P','Y','F','W'])
Columns indexed by WT_AA+position e.g., "G 12".
Usage: load_LLR('P01116') or load_LLR('P01116-2')
'''
with ZipFile(LLR_FILE) as myzip:
data = myzip.open(myzip.namelist()[0] + 'LLRs/' + uniprot_id + '.csv')
LLR = pd.read_csv(data, index_col=0)
if sigmoid:
p = 1/(1 + np.exp(-(LLR.values.ravel()*coef + intercept)))
LLR = pd.DataFrame(p.reshape(LLR.shape), index=LLR.index, columns=LLR.columns).round(4)
return LLR
def meltLLR(LLR, gene_prefix=None, ignore_pos=False):
vars = LLR.melt(ignore_index=False)
vars['variant'] = [''.join(i.split(' ')) + j for i, j in zip(vars['variable'], vars.index)]
vars['score'] = vars['value']
vars = vars.set_index('variant')
if not ignore_pos:
vars['pos'] = [int(i[1:-1]) for i in vars.index]
del vars['variable'], vars['value']
if gene_prefix is not None:
vars.index = gene_prefix + '_' + vars.index
return vars
def plot_interactive(uniprot_id, show_clinvar=False):
primaryLLR = load_LLR(uniprot_id)
template = 'plotly_white'
zmax=1.09 if sigmoid else 0
zmin=0 if sigmoid else -22
cmap='rdbu_r' if sigmoid else 'Viridis_r'
color = 'score' if sigmoid else 'LLR'
fig = px.imshow(
primaryLLR.values,
x=primaryLLR.columns,
y=primaryLLR.index,
color_continuous_scale=cmap,
zmax=zmax,
zmin=zmin,
labels=dict(y="Amino acid change", x="Protein sequence", color=color),
template=template,
title=selection
)
fig.update_xaxes(tickangle=-90,range=[0,99], rangeslider=dict(visible=True), dtick=1)
fig.update_yaxes(dtick=1)
fig.update_layout(
plot_bgcolor='rgba(0, 0, 0, 0)',
paper_bgcolor='rgba(0, 0, 0, 0)',
font={'family': 'Arial', 'size': 11},
hoverlabel=dict(font=dict(family='Arial', size=14))
)
fig.update_traces(
hovertemplate="<br>".join(["<b>%{x} %{y}</b> (%{z:.2f})"]) + '<extra></extra>'
)
if show_clinvar:
iso_clinvar = clinvar[clinvar.protein == uniprot_id]
iso_clinvar = iso_clinvar[iso_clinvar.GoldStars > 1]
b_mut = set(iso_clinvar[iso_clinvar.clinvar_label == 0.0].variant.values)
p_mut = set(iso_clinvar[iso_clinvar.clinvar_label == 1.0].variant.values)
hwt_x, hwt_y, cust = [], [], []
phwt_x, phwt_y, pcust = [], [], []
for i in primaryLLR.columns:
for j in list(primaryLLR.index):
mut = i[0] + i[2:] + j
if mut in b_mut:
hwt_x.append(i)
hwt_y.append(j)
cust.append(primaryLLR.loc[j, i])
elif mut in p_mut:
phwt_x.append(i)
phwt_y.append(j)
pcust.append(primaryLLR.loc[j, i])
# draw pathogenic
fig.add_trace(go.Scatter(
x=phwt_x, y=phwt_y, customdata=pcust,
mode='markers',
marker=dict(size=8, color='red'),
showlegend=False,
hoverlabel=dict(bgcolor='crimson', font_color='black'),
hovertemplate="<b>%{x} %{y}</b> (%{customdata:.2f})<extra></extra>"
))
# draw benign
fig.add_trace(go.Scatter(
x=hwt_x, y=hwt_y, customdata=cust,
mode='markers',
marker=dict(size=8, color='white'),
showlegend=False,
hoverlabel=dict(bgcolor='white', font_color='black'),
hovertemplate="<b>%{x} %{y}</b> (%{customdata:.2f})<extra></extra>"
))
fig.update_layout(
hovermode='closest',
hoverdistance=10
)
return fig
idx = df.index.get_loc('P32245') if 'P32245' in df.index else 0
selection = st.selectbox("uniprot_id:", df, index=idx)
uid = df[df.txt == selection].index.values[0]
col1, col2 = st.columns(2)
with col1:
sigmoid = st.checkbox(
"Calibrated VESM predictions (0: benign, 1: pathogenic)",
value=False
)
with col2:
show_clinvar = st.checkbox(
"Show ClinVar annotations (red: pathogenic, white: benign)",
value=False
)
fig = plot_interactive(uid, show_clinvar=show_clinvar)
fig.update_layout(width=800, height=600, autosize=False)
st.plotly_chart(fig, use_container_width=True)
st.download_button(
label="📥 Download as CSV",
data=meltLLR(load_LLR(uid)).to_csv(),
file_name=f"{selection}.csv",
mime='text/csv'
)
st.markdown("---")
st.markdown("""
- Bulk download precomputed scores at [VESM Effect Scores](https://huggingface.co/datasets/ntranoslab/vesm_scores) for all UniProt, hg19, and hg38 variants.
- Use VESM locally: Access the source code and installation instructions on [GitHub](https://github.com/ntranoslab/vesm).
""")