Spaces:
Sleeping
Sleeping
Label residues with highest sum of attention not highest pairs
Browse files- hexviz/attention.py +16 -2
- hexviz/🧬Attention_Visualization.py +13 -16
hexviz/attention.py
CHANGED
@@ -107,7 +107,7 @@ def unidirectional_avg_filtered(attention, layer, head, threshold):
|
|
107 |
return unidirectional_avg_for_head
|
108 |
|
109 |
@st.cache
|
110 |
-
def get_attention_pairs(pdb_code: str, layer: int, head: int, chain_ids: Optional[str] = None ,threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT):
|
111 |
structure = get_structure(pdb_code=pdb_code)
|
112 |
|
113 |
if chain_ids:
|
@@ -120,12 +120,26 @@ def get_attention_pairs(pdb_code: str, layer: int, head: int, chain_ids: Optiona
|
|
120 |
sequence = get_sequence(chain)
|
121 |
attention = get_attention(sequence=sequence, model_type=model_type)
|
122 |
attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold)
|
|
|
|
|
|
|
123 |
for attn_value, res_1, res_2 in attention_unidirectional:
|
124 |
try:
|
125 |
coord_1 = chain[res_1]["CA"].coord.tolist()
|
126 |
coord_2 = chain[res_2]["CA"].coord.tolist()
|
127 |
except KeyError:
|
128 |
continue
|
|
|
129 |
attention_pairs.append((attn_value, coord_1, coord_2, chain.id, res_1, res_2))
|
|
|
|
|
|
|
|
|
130 |
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
return unidirectional_avg_for_head
|
108 |
|
109 |
@st.cache
|
110 |
+
def get_attention_pairs(pdb_code: str, layer: int, head: int, chain_ids: Optional[str] = None ,threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT, top_n: int = 2):
|
111 |
structure = get_structure(pdb_code=pdb_code)
|
112 |
|
113 |
if chain_ids:
|
|
|
120 |
sequence = get_sequence(chain)
|
121 |
attention = get_attention(sequence=sequence, model_type=model_type)
|
122 |
attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold)
|
123 |
+
|
124 |
+
# Store sum of attention in to a resiue (from the unidirectional attention)
|
125 |
+
residue_attention = {}
|
126 |
for attn_value, res_1, res_2 in attention_unidirectional:
|
127 |
try:
|
128 |
coord_1 = chain[res_1]["CA"].coord.tolist()
|
129 |
coord_2 = chain[res_2]["CA"].coord.tolist()
|
130 |
except KeyError:
|
131 |
continue
|
132 |
+
|
133 |
attention_pairs.append((attn_value, coord_1, coord_2, chain.id, res_1, res_2))
|
134 |
+
residue_attention[res_1] = residue_attention.get(res_1, 0) + attn_value
|
135 |
+
residue_attention[res_2] = residue_attention.get(res_2, 0) + attn_value
|
136 |
+
|
137 |
+
top_n_residues = sorted(residue_attention.items(), key=lambda x: x[1], reverse=True)[:top_n]
|
138 |
|
139 |
+
top_residues = []
|
140 |
+
for res, attn_sum in top_n_residues:
|
141 |
+
coord = chain[res]["CA"].coord.tolist()
|
142 |
+
top_residues.append((attn_sum, coord, chain.id, res))
|
143 |
+
|
144 |
+
return attention_pairs, top_residues
|
145 |
+
|
hexviz/🧬Attention_Visualization.py
CHANGED
@@ -31,8 +31,8 @@ st.sidebar.markdown(
|
|
31 |
---
|
32 |
""")
|
33 |
min_attn = st.sidebar.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
|
34 |
-
|
35 |
-
label_highest = st.sidebar.checkbox("Label highest attention
|
36 |
# TODO add avg or max attention as params
|
37 |
|
38 |
|
@@ -64,10 +64,9 @@ if selected_model.name == ModelType.ZymCTRL:
|
|
64 |
if ec_class and selected_model.name == ModelType.ZymCTRL:
|
65 |
ec_class = st.sidebar.text_input("Enzyme classification number fetched from PDB", ec_class)
|
66 |
|
67 |
-
attention_pairs = get_attention_pairs(pdb_id, chain_ids=selected_chains, layer=layer, head=head, threshold=min_attn, model_type=selected_model.name)
|
68 |
|
69 |
sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
|
70 |
-
top_n = sorted_by_attention[:n_pairs]
|
71 |
|
72 |
def get_3dview(pdb):
|
73 |
xyzview = py3Dmol.view(query=f"pdb:{pdb}")
|
@@ -93,11 +92,10 @@ def get_3dview(pdb):
|
|
93 |
{"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
|
94 |
|
95 |
if label_highest:
|
96 |
-
for _, _,
|
97 |
-
xyzview.addResLabels({"chain": chain,"resi":
|
98 |
-
{"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
|
99 |
-
|
100 |
-
{"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
|
101 |
return xyzview
|
102 |
|
103 |
xyzview = get_3dview(pdb_id)
|
@@ -106,21 +104,20 @@ showmol(xyzview, height=500, width=800)
|
|
106 |
st.markdown(f"""
|
107 |
Visualize attention weights from protein language models on protein structures.
|
108 |
Currently attention weights for PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id}) from layer: {layer_one}, head: {head_one} above {min_attn} from {selected_model.name.value}
|
109 |
-
are visualized as red bars. The
|
110 |
Visualize attention weights on protein structures for the protein language models TAPE-BERT and ZymCTRL.
|
111 |
Pick a PDB ID, layer and head to visualize attention.
|
112 |
""", unsafe_allow_html=True)
|
113 |
|
114 |
chain_dict = {f"{chain.id}": chain for chain in list(structure.get_chains())}
|
115 |
data = []
|
116 |
-
for att_weight, _ ,
|
117 |
-
|
118 |
-
|
119 |
-
el = (att_weight, f"{res1.resname:3}{res1.id[1]:0>3} - {res2.resname:3}{res2.id[1]:0>3} ({chain})")
|
120 |
data.append(el)
|
121 |
|
122 |
-
df = pd.DataFrame(data, columns=['
|
123 |
-
st.markdown(f"The {
|
124 |
st.table(df)
|
125 |
|
126 |
st.markdown("""Clik in to the [Identify Interesting heads](#Identify-Interesting-heads) page to get an overview of attention
|
|
|
31 |
---
|
32 |
""")
|
33 |
min_attn = st.sidebar.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
|
34 |
+
n_highest_resis = st.sidebar.number_input("Num highest attention resis to label", value=2, min_value=1, max_value=100)
|
35 |
+
label_highest = st.sidebar.checkbox("Label highest attention residues", value=True)
|
36 |
# TODO add avg or max attention as params
|
37 |
|
38 |
|
|
|
64 |
if ec_class and selected_model.name == ModelType.ZymCTRL:
|
65 |
ec_class = st.sidebar.text_input("Enzyme classification number fetched from PDB", ec_class)
|
66 |
|
67 |
+
attention_pairs, top_residues = get_attention_pairs(pdb_id, chain_ids=selected_chains, layer=layer, head=head, threshold=min_attn, model_type=selected_model.name, top_n=n_highest_resis)
|
68 |
|
69 |
sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
|
|
|
70 |
|
71 |
def get_3dview(pdb):
|
72 |
xyzview = py3Dmol.view(query=f"pdb:{pdb}")
|
|
|
92 |
{"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
|
93 |
|
94 |
if label_highest:
|
95 |
+
for _, _, chain, res in top_residues:
|
96 |
+
xyzview.addResLabels({"chain": chain, "resi": res},
|
97 |
+
{"backgroundColor": "lightgray", "fontColor": "black", "backgroundOpacity": 0.5})
|
98 |
+
|
|
|
99 |
return xyzview
|
100 |
|
101 |
xyzview = get_3dview(pdb_id)
|
|
|
104 |
st.markdown(f"""
|
105 |
Visualize attention weights from protein language models on protein structures.
|
106 |
Currently attention weights for PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id}) from layer: {layer_one}, head: {head_one} above {min_attn} from {selected_model.name.value}
|
107 |
+
are visualized as red bars. The {n_highest_resis} residues with the highest sum of attention are labeled.
|
108 |
Visualize attention weights on protein structures for the protein language models TAPE-BERT and ZymCTRL.
|
109 |
Pick a PDB ID, layer and head to visualize attention.
|
110 |
""", unsafe_allow_html=True)
|
111 |
|
112 |
chain_dict = {f"{chain.id}": chain for chain in list(structure.get_chains())}
|
113 |
data = []
|
114 |
+
for att_weight, _ , chain, resi in top_residues:
|
115 |
+
res = chain_dict[chain][resi]
|
116 |
+
el = (att_weight, f"{res.resname:3}{res.id[1]:0>3}")
|
|
|
117 |
data.append(el)
|
118 |
|
119 |
+
df = pd.DataFrame(data, columns=['Total attention (disregarding direction)', 'Residue'])
|
120 |
+
st.markdown(f"The {n_highest_resis} residues with the highest attention sum are labeled in the visualization and listed below:")
|
121 |
st.table(df)
|
122 |
|
123 |
st.markdown("""Clik in to the [Identify Interesting heads](#Identify-Interesting-heads) page to get an overview of attention
|