aksell commited on
Commit
51f3451
·
1 Parent(s): a9afebb

Add text and slice selection to head selection view

Browse files
hexviz/pages/Identify_interesting_Heads.py CHANGED
@@ -5,6 +5,7 @@ from hexviz.models import Model, ModelType
5
  from hexviz.plot import plot_tiled_heatmap
6
 
7
  st.set_page_config(layout="wide")
 
8
 
9
 
10
  models = [
@@ -20,16 +21,22 @@ pdb_id = st.sidebar.text_input(
20
  value="1AKE",
21
  )
22
 
 
23
  structure = get_structure(pdb_id)
24
  chains = list(structure.get_chains())
25
 
26
  sequence = get_sequence(chains[0])
27
  l = len(sequence)
28
- n_residues = st.sidebar.number_input(f"Residue count (1-{l})",value=50, min_value=1, max_value=l)
29
- truncated_sequence = sequence[:n_residues]
 
 
30
 
 
 
 
 
31
  attention = get_attention(sequence=truncated_sequence, model_type=selected_model.name)
32
 
33
- st.subheader("Find interesting heads and layers")
34
  fig = plot_tiled_heatmap(attention, layer_count=selected_model.layers, head_count=selected_model.heads)
35
  st.pyplot(fig)
 
5
  from hexviz.plot import plot_tiled_heatmap
6
 
7
  st.set_page_config(layout="wide")
8
+ st.subheader("Find interesting heads and layers")
9
 
10
 
11
  models = [
 
21
  value="1AKE",
22
  )
23
 
24
+
25
  structure = get_structure(pdb_id)
26
  chains = list(structure.get_chains())
27
 
28
  sequence = get_sequence(chains[0])
29
  l = len(sequence)
30
+ slice_start= st.sidebar.number_input(f"Slice start(1-{l})",value=1, min_value=1, max_value=l)
31
+ slice_end = st.sidebar.number_input(f"Slice end(1-{l})",value=50, min_value=1, max_value=l)
32
+ truncated_sequence = sequence[slice_start-1:slice_end]
33
+
34
 
35
+ st.text(f"Each tile is a heatmap of attention for the slice {slice_start}:{slice_end} of {pdb_id} (chain A) for one attention head.")
36
+
37
+ # TODO: Decide if you should get attention for the full sequence or just the truncated sequence
38
+ # Attention values will change depending on what we do.
39
  attention = get_attention(sequence=truncated_sequence, model_type=selected_model.name)
40
 
 
41
  fig = plot_tiled_heatmap(attention, layer_count=selected_model.layers, head_count=selected_model.heads)
42
  st.pyplot(fig)