Spaces:
Sleeping
Sleeping
Allow selecting heads to plot
Browse files
hexviz/pages/Identify_interesting_Heads.py
CHANGED
@@ -32,11 +32,18 @@ slice_end = st.sidebar.number_input(f"Section end(1-{l})",value=50, min_value=1,
|
|
32 |
truncated_sequence = sequence[slice_start-1:slice_end]
|
33 |
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
st.markdown(f"Each tile is a heatmap of attention for a section of {pdb_id}(A) from residue {slice_start} to {slice_end}. Adjust the section length and starting point in the sidebar.")
|
36 |
|
37 |
# TODO: Decide if you should get attention for the full sequence or just the truncated sequence
|
38 |
# Attention values will change depending on what we do.
|
39 |
attention = get_attention(sequence=truncated_sequence, model_type=selected_model.name)
|
40 |
|
41 |
-
fig = plot_tiled_heatmap(attention,
|
42 |
st.pyplot(fig)
|
|
|
32 |
truncated_sequence = sequence[slice_start-1:slice_end]
|
33 |
|
34 |
|
35 |
+
layer_range = st.sidebar.slider("Heads", min_value=1, max_value=selected_model.layers, value=(1, selected_model.layers), step=1)
|
36 |
+
head_range = st.sidebar.slider("Layers", min_value=1, max_value=selected_model.heads, value=(1, selected_model.heads), step=1)
|
37 |
+
step_size = st.sidebar.number_input("Step size", value=2, min_value=1, max_value=selected_model.layers)
|
38 |
+
layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
|
39 |
+
head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
|
40 |
+
|
41 |
+
|
42 |
st.markdown(f"Each tile is a heatmap of attention for a section of {pdb_id}(A) from residue {slice_start} to {slice_end}. Adjust the section length and starting point in the sidebar.")
|
43 |
|
44 |
# TODO: Decide if you should get attention for the full sequence or just the truncated sequence
|
45 |
# Attention values will change depending on what we do.
|
46 |
attention = get_attention(sequence=truncated_sequence, model_type=selected_model.name)
|
47 |
|
48 |
+
fig = plot_tiled_heatmap(attention, layer_sequence=layer_sequence, head_sequence=head_sequence)
|
49 |
st.pyplot(fig)
|
hexviz/plot.py
CHANGED
@@ -1,10 +1,12 @@
|
|
|
|
|
|
1 |
import matplotlib.pyplot as plt
|
2 |
|
3 |
|
4 |
-
def plot_tiled_heatmap(tensor,
|
5 |
-
tensor = tensor[
|
6 |
-
num_layers =
|
7 |
-
num_heads =
|
8 |
fig, axes = plt.subplots(num_layers, num_heads, figsize=(12, 12))
|
9 |
for i in range(num_layers):
|
10 |
for j in range(num_heads):
|
@@ -13,11 +15,11 @@ def plot_tiled_heatmap(tensor, layer_count=12, head_count=12):
|
|
13 |
|
14 |
# Enumerate the axes
|
15 |
if i == 0:
|
16 |
-
axes[i, j].set_title(f'Head {j
|
17 |
|
18 |
# Add layer labels on the right Y-axis
|
19 |
for i in range(num_layers):
|
20 |
-
fig.text(0.98, (num_layers - i - 1) / num_layers + 0.025, f'Layer {i
|
21 |
|
22 |
plt.subplots_adjust(wspace=0.1, hspace=0.1)
|
23 |
-
return fig
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
import matplotlib.pyplot as plt
|
4 |
|
5 |
|
6 |
+
def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[int]):
|
7 |
+
tensor = tensor[layer_sequence, :][:, head_sequence, :, :] # Slice the tensor according to the provided sequences and sequence_count
|
8 |
+
num_layers = len(layer_sequence)
|
9 |
+
num_heads = len(head_sequence)
|
10 |
fig, axes = plt.subplots(num_layers, num_heads, figsize=(12, 12))
|
11 |
for i in range(num_layers):
|
12 |
for j in range(num_heads):
|
|
|
15 |
|
16 |
# Enumerate the axes
|
17 |
if i == 0:
|
18 |
+
axes[i, j].set_title(f'Head {head_sequence[j] + 1}', fontsize=10, y=1.05)
|
19 |
|
20 |
# Add layer labels on the right Y-axis
|
21 |
for i in range(num_layers):
|
22 |
+
fig.text(0.98, (num_layers - i - 1) / num_layers + 0.025, f'Layer {layer_sequence[i]+1}', fontsize=10, rotation=0, ha='right', va='center')
|
23 |
|
24 |
plt.subplots_adjust(wspace=0.1, hspace=0.1)
|
25 |
+
return fig
|