File size: 5,593 Bytes
2734263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import streamlit as st
st.set_page_config(layout="wide")

import pandas as pd
import numpy as np
from zipfile import ZipFile

import plotly.express as px
import plotly.graph_objs as go

LLR_FILE = 'UniProtKB_human_VESM_llrs.zip'

df = pd.read_csv('UniProtKB_id_names.csv', index_col=0)
if 'shuffled_df' not in st.session_state:
    st.session_state.shuffled_df = df.sample(frac=1)
df = st.session_state.shuffled_df
clinvar = pd.read_csv('clinvar_0325.csv.gz',index_col=0)

f = np.load("logreg_params.npz")
coef, intercept = f["coef"].item(), f["intercept"].item()

def load_LLR(uniprot_id):
    '''Loads the LLRs for a given uniprot id. Returns a 20xL dataframe. 
    Rows are indexed by AA change, 
    (AAorder=['K','R','H','E','D','N','Q','T','S','C','G','A','V','L','I','M','P','Y','F','W'])
    Columns indexed by WT_AA+position e.g., "G 12".
    Usage: load_LLR('P01116') or load_LLR('P01116-2')
    '''
    with ZipFile(LLR_FILE) as myzip:
        data = myzip.open(myzip.namelist()[0] + 'LLRs/' + uniprot_id + '.csv')
    LLR = pd.read_csv(data, index_col=0)
    if sigmoid:
        p = 1/(1 + np.exp(-(LLR.values.ravel()*coef + intercept)))
        LLR = pd.DataFrame(p.reshape(LLR.shape), index=LLR.index, columns=LLR.columns).round(4)
    return LLR

def meltLLR(LLR, gene_prefix=None, ignore_pos=False):
    vars = LLR.melt(ignore_index=False)
    vars['variant'] = [''.join(i.split(' ')) + j for i, j in zip(vars['variable'], vars.index)]
    vars['score'] = vars['value']
    vars = vars.set_index('variant')
    if not ignore_pos:
        vars['pos'] = [int(i[1:-1]) for i in vars.index]
    del vars['variable'], vars['value']
    if gene_prefix is not None:
        vars.index = gene_prefix + '_' + vars.index
    return vars
    
    
def plot_interactive(uniprot_id, show_clinvar=False):
    primaryLLR = load_LLR(uniprot_id)
    template = 'plotly_white'
    zmax=1.09 if sigmoid else 0
    zmin=0 if sigmoid else -22
    cmap='rdbu_r' if sigmoid else 'Viridis_r'
    color = 'score' if sigmoid else 'LLR' 
    fig = px.imshow(
        primaryLLR.values, 
        x=primaryLLR.columns, 
        y=primaryLLR.index, 
        color_continuous_scale=cmap, 
        zmax=zmax, 
        zmin=zmin,
        labels=dict(y="Amino acid change", x="Protein sequence", color=color),
        template=template,
        title=selection
    )

    fig.update_xaxes(tickangle=-90,range=[0,99], rangeslider=dict(visible=True), dtick=1)
    fig.update_yaxes(dtick=1)
    fig.update_layout(
        plot_bgcolor='rgba(0, 0, 0, 0)',
        paper_bgcolor='rgba(0, 0, 0, 0)',
        font={'family': 'Arial', 'size': 11},
        hoverlabel=dict(font=dict(family='Arial', size=14))
    )

    fig.update_traces(
        hovertemplate="<br>".join(["<b>%{x} %{y}</b> (%{z:.2f})"]) + '<extra></extra>'
    )

    if show_clinvar:
        iso_clinvar = clinvar[clinvar.protein == uniprot_id]
        iso_clinvar = iso_clinvar[iso_clinvar.GoldStars > 1]
        b_mut = set(iso_clinvar[iso_clinvar.clinvar_label == 0.0].variant.values)
        p_mut = set(iso_clinvar[iso_clinvar.clinvar_label == 1.0].variant.values)

        hwt_x, hwt_y, cust = [], [], []
        phwt_x, phwt_y, pcust = [], [], []

        for i in primaryLLR.columns:
            for j in list(primaryLLR.index):
                mut = i[0] + i[2:] + j
                if mut in b_mut:
                    hwt_x.append(i)
                    hwt_y.append(j)
                    cust.append(primaryLLR.loc[j, i])
                elif mut in p_mut:
                    phwt_x.append(i)
                    phwt_y.append(j)
                    pcust.append(primaryLLR.loc[j, i])

        # draw pathogenic
        fig.add_trace(go.Scatter(
            x=phwt_x, y=phwt_y, customdata=pcust,
            mode='markers',
            marker=dict(size=8, color='red'),
            showlegend=False,
            hoverlabel=dict(bgcolor='crimson', font_color='black'),
            hovertemplate="<b>%{x} %{y}</b> (%{customdata:.2f})<extra></extra>"
        ))
        # draw benign
        fig.add_trace(go.Scatter(
            x=hwt_x, y=hwt_y, customdata=cust,
            mode='markers',
            marker=dict(size=8, color='white'),
            showlegend=False,
            hoverlabel=dict(bgcolor='white', font_color='black'),
            hovertemplate="<b>%{x} %{y}</b> (%{customdata:.2f})<extra></extra>"
        ))

        fig.update_layout(
            hovermode='closest',
            hoverdistance=10
        )

    return fig


idx = df.index.get_loc('P32245') if 'P32245' in df.index else 0
selection = st.selectbox("uniprot_id:", df, index=idx)
uid = df[df.txt == selection].index.values[0]

col1, col2 = st.columns(2)
with col1:
    sigmoid = st.checkbox(
        "Calibrated VESM predictions (0: benign, 1: pathogenic)",
        value=False
    )
with col2:
    show_clinvar = st.checkbox(
        "Show ClinVar annotations (red: pathogenic, white: benign)",
        value=False
    )

fig = plot_interactive(uid, show_clinvar=show_clinvar)
fig.update_layout(width=800, height=600, autosize=False)
st.plotly_chart(fig, use_container_width=True)

st.download_button(
    label="📥 Download as CSV",
    data=meltLLR(load_LLR(uid)).to_csv(),
    file_name=f"{selection}.csv",
    mime='text/csv'
)
st.markdown("---")

st.markdown("""
- Bulk download precomputed scores at [VESM Effect Scores](https://huggingface.co/datasets/ntranoslab/vesm_scores) for all UniProt, hg19, and hg38 variants.
- Use VESM locally: Access the source code and installation instructions on [GitHub](https://github.com/ntranoslab/vesm).
""")