aksell commited on
Commit
42ad7ef
·
1 Parent(s): 599e725

Label residues with highest sum of attention not highest pairs

Browse files
hexviz/attention.py CHANGED
@@ -107,7 +107,7 @@ def unidirectional_avg_filtered(attention, layer, head, threshold):
107
  return unidirectional_avg_for_head
108
 
109
  @st.cache
110
- def get_attention_pairs(pdb_code: str, layer: int, head: int, chain_ids: Optional[str] = None ,threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT):
111
  structure = get_structure(pdb_code=pdb_code)
112
 
113
  if chain_ids:
@@ -120,12 +120,26 @@ def get_attention_pairs(pdb_code: str, layer: int, head: int, chain_ids: Optiona
120
  sequence = get_sequence(chain)
121
  attention = get_attention(sequence=sequence, model_type=model_type)
122
  attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold)
 
 
 
123
  for attn_value, res_1, res_2 in attention_unidirectional:
124
  try:
125
  coord_1 = chain[res_1]["CA"].coord.tolist()
126
  coord_2 = chain[res_2]["CA"].coord.tolist()
127
  except KeyError:
128
  continue
 
129
  attention_pairs.append((attn_value, coord_1, coord_2, chain.id, res_1, res_2))
 
 
 
 
130
 
131
- return attention_pairs
 
 
 
 
 
 
 
107
  return unidirectional_avg_for_head
108
 
109
  @st.cache
110
+ def get_attention_pairs(pdb_code: str, layer: int, head: int, chain_ids: Optional[str] = None ,threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT, top_n: int = 2):
111
  structure = get_structure(pdb_code=pdb_code)
112
 
113
  if chain_ids:
 
120
  sequence = get_sequence(chain)
121
  attention = get_attention(sequence=sequence, model_type=model_type)
122
  attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold)
123
+
124
+ # Store sum of attention in to a resiue (from the unidirectional attention)
125
+ residue_attention = {}
126
  for attn_value, res_1, res_2 in attention_unidirectional:
127
  try:
128
  coord_1 = chain[res_1]["CA"].coord.tolist()
129
  coord_2 = chain[res_2]["CA"].coord.tolist()
130
  except KeyError:
131
  continue
132
+
133
  attention_pairs.append((attn_value, coord_1, coord_2, chain.id, res_1, res_2))
134
+ residue_attention[res_1] = residue_attention.get(res_1, 0) + attn_value
135
+ residue_attention[res_2] = residue_attention.get(res_2, 0) + attn_value
136
+
137
+ top_n_residues = sorted(residue_attention.items(), key=lambda x: x[1], reverse=True)[:top_n]
138
 
139
+ top_residues = []
140
+ for res, attn_sum in top_n_residues:
141
+ coord = chain[res]["CA"].coord.tolist()
142
+ top_residues.append((attn_sum, coord, chain.id, res))
143
+
144
+ return attention_pairs, top_residues
145
+
hexviz/🧬Attention_Visualization.py CHANGED
@@ -31,8 +31,8 @@ st.sidebar.markdown(
31
  ---
32
  """)
33
  min_attn = st.sidebar.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
34
- n_pairs = st.sidebar.number_input("Num attention pairs labeled", value=2, min_value=1, max_value=100)
35
- label_highest = st.sidebar.checkbox("Label highest attention pairs", value=True)
36
  # TODO add avg or max attention as params
37
 
38
 
@@ -64,10 +64,9 @@ if selected_model.name == ModelType.ZymCTRL:
64
  if ec_class and selected_model.name == ModelType.ZymCTRL:
65
  ec_class = st.sidebar.text_input("Enzyme classification number fetched from PDB", ec_class)
66
 
67
- attention_pairs = get_attention_pairs(pdb_id, chain_ids=selected_chains, layer=layer, head=head, threshold=min_attn, model_type=selected_model.name)
68
 
69
  sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
70
- top_n = sorted_by_attention[:n_pairs]
71
 
72
  def get_3dview(pdb):
73
  xyzview = py3Dmol.view(query=f"pdb:{pdb}")
@@ -93,11 +92,10 @@ def get_3dview(pdb):
93
  {"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
94
 
95
  if label_highest:
96
- for _, _, _, chain, a, b in top_n:
97
- xyzview.addResLabels({"chain": chain,"resi": a},
98
- {"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
99
- xyzview.addResLabels({"chain": chain,"resi": b},
100
- {"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
101
  return xyzview
102
 
103
  xyzview = get_3dview(pdb_id)
@@ -106,21 +104,20 @@ showmol(xyzview, height=500, width=800)
106
  st.markdown(f"""
107
  Visualize attention weights from protein language models on protein structures.
108
  Currently attention weights for PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id}) from layer: {layer_one}, head: {head_one} above {min_attn} from {selected_model.name.value}
109
- are visualized as red bars. The highest {n_pairs} attention pairs are labeled.
110
  Visualize attention weights on protein structures for the protein language models TAPE-BERT and ZymCTRL.
111
  Pick a PDB ID, layer and head to visualize attention.
112
  """, unsafe_allow_html=True)
113
 
114
  chain_dict = {f"{chain.id}": chain for chain in list(structure.get_chains())}
115
  data = []
116
- for att_weight, _ , _ , chain, first, second in top_n:
117
- res1 = chain_dict[chain][first]
118
- res2 = chain_dict[chain][second]
119
- el = (att_weight, f"{res1.resname:3}{res1.id[1]:0>3} - {res2.resname:3}{res2.id[1]:0>3} ({chain})")
120
  data.append(el)
121
 
122
- df = pd.DataFrame(data, columns=['Avg attention', 'Residue pair'])
123
- st.markdown(f"The {n_pairs} residue pairs with the highest average attention weights are labeled in the visualization and listed below:")
124
  st.table(df)
125
 
126
  st.markdown("""Clik in to the [Identify Interesting heads](#Identify-Interesting-heads) page to get an overview of attention
 
31
  ---
32
  """)
33
  min_attn = st.sidebar.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
34
+ n_highest_resis = st.sidebar.number_input("Num highest attention resis to label", value=2, min_value=1, max_value=100)
35
+ label_highest = st.sidebar.checkbox("Label highest attention residues", value=True)
36
  # TODO add avg or max attention as params
37
 
38
 
 
64
  if ec_class and selected_model.name == ModelType.ZymCTRL:
65
  ec_class = st.sidebar.text_input("Enzyme classification number fetched from PDB", ec_class)
66
 
67
+ attention_pairs, top_residues = get_attention_pairs(pdb_id, chain_ids=selected_chains, layer=layer, head=head, threshold=min_attn, model_type=selected_model.name, top_n=n_highest_resis)
68
 
69
  sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
 
70
 
71
  def get_3dview(pdb):
72
  xyzview = py3Dmol.view(query=f"pdb:{pdb}")
 
92
  {"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
93
 
94
  if label_highest:
95
+ for _, _, chain, res in top_residues:
96
+ xyzview.addResLabels({"chain": chain, "resi": res},
97
+ {"backgroundColor": "lightgray", "fontColor": "black", "backgroundOpacity": 0.5})
98
+
 
99
  return xyzview
100
 
101
  xyzview = get_3dview(pdb_id)
 
104
  st.markdown(f"""
105
  Visualize attention weights from protein language models on protein structures.
106
  Currently attention weights for PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id}) from layer: {layer_one}, head: {head_one} above {min_attn} from {selected_model.name.value}
107
+ are visualized as red bars. The {n_highest_resis} residues with the highest sum of attention are labeled.
108
  Visualize attention weights on protein structures for the protein language models TAPE-BERT and ZymCTRL.
109
  Pick a PDB ID, layer and head to visualize attention.
110
  """, unsafe_allow_html=True)
111
 
112
  chain_dict = {f"{chain.id}": chain for chain in list(structure.get_chains())}
113
  data = []
114
+ for att_weight, _ , chain, resi in top_residues:
115
+ res = chain_dict[chain][resi]
116
+ el = (att_weight, f"{res.resname:3}{res.id[1]:0>3}")
 
117
  data.append(el)
118
 
119
+ df = pd.DataFrame(data, columns=['Total attention (disregarding direction)', 'Residue'])
120
+ st.markdown(f"The {n_highest_resis} residues with the highest attention sum are labeled in the visualization and listed below:")
121
  st.table(df)
122
 
123
  st.markdown("""Clik in to the [Identify Interesting heads](#Identify-Interesting-heads) page to get an overview of attention