aksell commited on
Commit
0fa194b
·
1 Parent(s): 9d9c196

Fix subplot sizing and bug

Browse files
hexviz/Attention_Visualization.py CHANGED
@@ -110,7 +110,6 @@ def get_3dview(pdb):
110
  {"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
111
  return xyzview
112
 
113
-
114
  xyzview = get_3dview(pdb_id)
115
  showmol(xyzview, height=500, width=800)
116
  st.markdown(f'PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id})', unsafe_allow_html=True)
 
110
  {"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
111
  return xyzview
112
 
 
113
  xyzview = get_3dview(pdb_id)
114
  showmol(xyzview, height=500, width=800)
115
  st.markdown(f'PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id})', unsafe_allow_html=True)
hexviz/pages/Identify_interesting_Heads.py CHANGED
@@ -27,14 +27,15 @@ chains = list(structure.get_chains())
27
 
28
  sequence = get_sequence(chains[0])
29
  l = len(sequence)
30
- slice_start= st.sidebar.number_input(f"Section start(1-{l})",value=1, min_value=1, max_value=l)
31
- slice_end = st.sidebar.number_input(f"Section end(1-{l})",value=50, min_value=1, max_value=l)
 
 
32
  truncated_sequence = sequence[slice_start-1:slice_end]
33
 
34
-
35
- layer_range = st.sidebar.slider("Heads", min_value=1, max_value=selected_model.layers, value=(1, selected_model.layers), step=1)
36
- head_range = st.sidebar.slider("Layers", min_value=1, max_value=selected_model.heads, value=(1, selected_model.heads), step=1)
37
- step_size = st.sidebar.number_input("Step size", value=2, min_value=1, max_value=selected_model.layers)
38
  layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
39
  head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
40
 
 
27
 
28
  sequence = get_sequence(chains[0])
29
  l = len(sequence)
30
+ st.sidebar.markdown("Sequence segment to plot")
31
+ slice_start, slice_end = st.sidebar.slider("Sequence", min_value=1, max_value=l, value=(1, 50), step=1)
32
+ # slice_start= st.sidebar.number_input(f"Section start(1-{l})",value=1, min_value=1, max_value=l)
33
+ # slice_end = st.sidebar.number_input(f"Section end(1-{l})",value=50, min_value=1, max_value=l)
34
  truncated_sequence = sequence[slice_start-1:slice_end]
35
 
36
+ head_range = st.sidebar.slider("Heads to plot", min_value=1, max_value=selected_model.heads, value=(1, selected_model.heads), step=1)
37
+ layer_range = st.sidebar.slider("Layers to plot", min_value=1, max_value=selected_model.layers, value=(1, selected_model.layers), step=1)
38
+ step_size = st.sidebar.number_input("Optional step size to skip heads", value=2, min_value=1, max_value=selected_model.layers)
 
39
  layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
40
  head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
41
 
hexviz/plot.py CHANGED
@@ -7,10 +7,13 @@ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[in
7
  tensor = tensor[layer_sequence, :][:, head_sequence, :, :] # Slice the tensor according to the provided sequences and sequence_count
8
  num_layers = len(layer_sequence)
9
  num_heads = len(head_sequence)
10
- fig, axes = plt.subplots(num_layers, num_heads, figsize=(12, 12))
 
 
 
11
  for i in range(num_layers):
12
  for j in range(num_heads):
13
- axes[i, j].imshow(tensor[i, j].detach().numpy(), cmap='viridis', aspect='auto')
14
  axes[i, j].axis('off')
15
 
16
  # Enumerate the axes
 
7
  tensor = tensor[layer_sequence, :][:, head_sequence, :, :] # Slice the tensor according to the provided sequences and sequence_count
8
  num_layers = len(layer_sequence)
9
  num_heads = len(head_sequence)
10
+
11
+ x_size = num_heads * 2
12
+ y_size = num_layers * 2
13
+ fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size))
14
  for i in range(num_layers):
15
  for j in range(num_heads):
16
+ axes[i, j].imshow(tensor[i, j].detach().numpy(), cmap='viridis', aspect='equal')
17
  axes[i, j].axis('off')
18
 
19
  # Enumerate the axes