aksell commited on
Commit
8799d0b
·
1 Parent(s): 28de159

Allow selecting heads to plot

Browse files
hexviz/pages/Identify_interesting_Heads.py CHANGED
@@ -32,11 +32,18 @@ slice_end = st.sidebar.number_input(f"Section end(1-{l})",value=50, min_value=1,
32
  truncated_sequence = sequence[slice_start-1:slice_end]
33
 
34
 
 
 
 
 
 
 
 
35
  st.markdown(f"Each tile is a heatmap of attention for a section of {pdb_id}(A) from residue {slice_start} to {slice_end}. Adjust the section length and starting point in the sidebar.")
36
 
37
  # TODO: Decide if you should get attention for the full sequence or just the truncated sequence
38
  # Attention values will change depending on what we do.
39
  attention = get_attention(sequence=truncated_sequence, model_type=selected_model.name)
40
 
41
- fig = plot_tiled_heatmap(attention, layer_count=selected_model.layers, head_count=selected_model.heads)
42
  st.pyplot(fig)
 
32
  truncated_sequence = sequence[slice_start-1:slice_end]
33
 
34
 
35
+ layer_range = st.sidebar.slider("Heads", min_value=1, max_value=selected_model.layers, value=(1, selected_model.layers), step=1)
36
+ head_range = st.sidebar.slider("Layers", min_value=1, max_value=selected_model.heads, value=(1, selected_model.heads), step=1)
37
+ step_size = st.sidebar.number_input("Step size", value=2, min_value=1, max_value=selected_model.layers)
38
+ layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
39
+ head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
40
+
41
+
42
  st.markdown(f"Each tile is a heatmap of attention for a section of {pdb_id}(A) from residue {slice_start} to {slice_end}. Adjust the section length and starting point in the sidebar.")
43
 
44
  # TODO: Decide if you should get attention for the full sequence or just the truncated sequence
45
  # Attention values will change depending on what we do.
46
  attention = get_attention(sequence=truncated_sequence, model_type=selected_model.name)
47
 
48
+ fig = plot_tiled_heatmap(attention, layer_sequence=layer_sequence, head_sequence=head_sequence)
49
  st.pyplot(fig)
hexviz/plot.py CHANGED
@@ -1,10 +1,12 @@
 
 
1
  import matplotlib.pyplot as plt
2
 
3
 
4
- def plot_tiled_heatmap(tensor, layer_count=12, head_count=12):
5
- tensor = tensor[:layer_count:2, :head_count:2, :, :] # Slice the tensor according to the provided arguments
6
- num_layers = layer_count // 2
7
- num_heads = head_count // 2
8
  fig, axes = plt.subplots(num_layers, num_heads, figsize=(12, 12))
9
  for i in range(num_layers):
10
  for j in range(num_heads):
@@ -13,11 +15,11 @@ def plot_tiled_heatmap(tensor, layer_count=12, head_count=12):
13
 
14
  # Enumerate the axes
15
  if i == 0:
16
- axes[i, j].set_title(f'Head {j * 2 + 1}', fontsize=10, y=1.05)
17
 
18
  # Add layer labels on the right Y-axis
19
  for i in range(num_layers):
20
- fig.text(0.98, (num_layers - i - 1) / num_layers + 0.025, f'Layer {i * 2 + 1}', fontsize=10, rotation=0, ha='right', va='center')
21
 
22
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
23
- return fig
 
1
+ from typing import List
2
+
3
  import matplotlib.pyplot as plt
4
 
5
 
6
+ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[int]):
7
+ tensor = tensor[layer_sequence, :][:, head_sequence, :, :] # Slice the tensor according to the provided sequences and sequence_count
8
+ num_layers = len(layer_sequence)
9
+ num_heads = len(head_sequence)
10
  fig, axes = plt.subplots(num_layers, num_heads, figsize=(12, 12))
11
  for i in range(num_layers):
12
  for j in range(num_heads):
 
15
 
16
  # Enumerate the axes
17
  if i == 0:
18
+ axes[i, j].set_title(f'Head {head_sequence[j] + 1}', fontsize=10, y=1.05)
19
 
20
  # Add layer labels on the right Y-axis
21
  for i in range(num_layers):
22
+ fig.text(0.98, (num_layers - i - 1) / num_layers + 0.025, f'Layer {layer_sequence[i]+1}', fontsize=10, rotation=0, ha='right', va='center')
23
 
24
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
25
+ return fig