aksell commited on
Commit
5ddca75
·
1 Parent(s): 0fa194b

Use dynamic offset for row labels, makes them nice ✨

Browse files
hexviz/pages/Identify_interesting_Heads.py CHANGED
@@ -35,7 +35,7 @@ truncated_sequence = sequence[slice_start-1:slice_end]
35
 
36
  head_range = st.sidebar.slider("Heads to plot", min_value=1, max_value=selected_model.heads, value=(1, selected_model.heads), step=1)
37
  layer_range = st.sidebar.slider("Layers to plot", min_value=1, max_value=selected_model.layers, value=(1, selected_model.layers), step=1)
38
- step_size = st.sidebar.number_input("Optional step size to skip heads", value=2, min_value=1, max_value=selected_model.layers)
39
  layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
40
  head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
41
 
 
35
 
36
  head_range = st.sidebar.slider("Heads to plot", min_value=1, max_value=selected_model.heads, value=(1, selected_model.heads), step=1)
37
  layer_range = st.sidebar.slider("Layers to plot", min_value=1, max_value=selected_model.layers, value=(1, selected_model.layers), step=1)
38
+ step_size = st.sidebar.number_input("Optional step size to skip heads and layers", value=2, min_value=1, max_value=selected_model.layers)
39
  layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
40
  head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
41
 
hexviz/plot.py CHANGED
@@ -20,9 +20,12 @@ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[in
20
  if i == 0:
21
  axes[i, j].set_title(f'Head {head_sequence[j] + 1}', fontsize=10, y=1.05)
22
 
23
- # Add layer labels on the right Y-axis
24
- for i in range(num_layers):
25
- fig.text(0.98, (num_layers - i - 1) / num_layers + 0.025, f'Layer {layer_sequence[i]+1}', fontsize=10, rotation=0, ha='right', va='center')
 
 
 
26
 
27
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
28
  return fig
 
20
  if i == 0:
21
  axes[i, j].set_title(f'Head {head_sequence[j] + 1}', fontsize=10, y=1.05)
22
 
23
+ # Calculate the row label offset based on the number of columns
24
+ offset = 0.02 + (12 - num_heads) * 0.0015
25
+ for i, ax_row in enumerate(axes):
26
+ row_label = f"{layer_sequence[i]+1}"
27
+ row_pos = ax_row[num_heads-1].get_position()
28
+ fig.text(row_pos.x1+offset, (row_pos.y1+row_pos.y0)/2, row_label, va='center')
29
 
30
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
31
  return fig