Spaces:
Sleeping
Sleeping
Fix subplot sizing and bug
Browse files
hexviz/Attention_Visualization.py
CHANGED
@@ -110,7 +110,6 @@ def get_3dview(pdb):
|
|
110 |
{"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
|
111 |
return xyzview
|
112 |
|
113 |
-
|
114 |
xyzview = get_3dview(pdb_id)
|
115 |
showmol(xyzview, height=500, width=800)
|
116 |
st.markdown(f'PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id})', unsafe_allow_html=True)
|
|
|
110 |
{"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
|
111 |
return xyzview
|
112 |
|
|
|
113 |
xyzview = get_3dview(pdb_id)
|
114 |
showmol(xyzview, height=500, width=800)
|
115 |
st.markdown(f'PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id})', unsafe_allow_html=True)
|
hexviz/pages/Identify_interesting_Heads.py
CHANGED
@@ -27,14 +27,15 @@ chains = list(structure.get_chains())
|
|
27 |
|
28 |
sequence = get_sequence(chains[0])
|
29 |
l = len(sequence)
|
30 |
-
|
31 |
-
slice_end = st.sidebar.
|
|
|
|
|
32 |
truncated_sequence = sequence[slice_start-1:slice_end]
|
33 |
|
34 |
-
|
35 |
-
layer_range = st.sidebar.slider("
|
36 |
-
|
37 |
-
step_size = st.sidebar.number_input("Step size", value=2, min_value=1, max_value=selected_model.layers)
|
38 |
layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
|
39 |
head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
|
40 |
|
|
|
27 |
|
28 |
sequence = get_sequence(chains[0])
|
29 |
l = len(sequence)
|
30 |
+
st.sidebar.markdown("Sequence segment to plot")
|
31 |
+
slice_start, slice_end = st.sidebar.slider("Sequence", min_value=1, max_value=l, value=(1, 50), step=1)
|
32 |
+
# slice_start= st.sidebar.number_input(f"Section start(1-{l})",value=1, min_value=1, max_value=l)
|
33 |
+
# slice_end = st.sidebar.number_input(f"Section end(1-{l})",value=50, min_value=1, max_value=l)
|
34 |
truncated_sequence = sequence[slice_start-1:slice_end]
|
35 |
|
36 |
+
head_range = st.sidebar.slider("Heads to plot", min_value=1, max_value=selected_model.heads, value=(1, selected_model.heads), step=1)
|
37 |
+
layer_range = st.sidebar.slider("Layers to plot", min_value=1, max_value=selected_model.layers, value=(1, selected_model.layers), step=1)
|
38 |
+
step_size = st.sidebar.number_input("Optional step size to skip heads", value=2, min_value=1, max_value=selected_model.layers)
|
|
|
39 |
layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
|
40 |
head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
|
41 |
|
hexviz/plot.py
CHANGED
@@ -7,10 +7,13 @@ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[in
|
|
7 |
tensor = tensor[layer_sequence, :][:, head_sequence, :, :] # Slice the tensor according to the provided sequences and sequence_count
|
8 |
num_layers = len(layer_sequence)
|
9 |
num_heads = len(head_sequence)
|
10 |
-
|
|
|
|
|
|
|
11 |
for i in range(num_layers):
|
12 |
for j in range(num_heads):
|
13 |
-
axes[i, j].imshow(tensor[i, j].detach().numpy(), cmap='viridis', aspect='
|
14 |
axes[i, j].axis('off')
|
15 |
|
16 |
# Enumerate the axes
|
|
|
7 |
tensor = tensor[layer_sequence, :][:, head_sequence, :, :] # Slice the tensor according to the provided sequences and sequence_count
|
8 |
num_layers = len(layer_sequence)
|
9 |
num_heads = len(head_sequence)
|
10 |
+
|
11 |
+
x_size = num_heads * 2
|
12 |
+
y_size = num_layers * 2
|
13 |
+
fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size))
|
14 |
for i in range(num_layers):
|
15 |
for j in range(num_heads):
|
16 |
+
axes[i, j].imshow(tensor[i, j].detach().numpy(), cmap='viridis', aspect='equal')
|
17 |
axes[i, j].axis('off')
|
18 |
|
19 |
# Enumerate the axes
|