Spaces:
Sleeping
Sleeping
Use dynamic offset for row labels, makes them nice ✨
Browse files
hexviz/pages/Identify_interesting_Heads.py
CHANGED
@@ -35,7 +35,7 @@ truncated_sequence = sequence[slice_start-1:slice_end]
|
|
35 |
|
36 |
head_range = st.sidebar.slider("Heads to plot", min_value=1, max_value=selected_model.heads, value=(1, selected_model.heads), step=1)
|
37 |
layer_range = st.sidebar.slider("Layers to plot", min_value=1, max_value=selected_model.layers, value=(1, selected_model.layers), step=1)
|
38 |
-
step_size = st.sidebar.number_input("Optional step size to skip heads", value=2, min_value=1, max_value=selected_model.layers)
|
39 |
layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
|
40 |
head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
|
41 |
|
|
|
35 |
|
36 |
head_range = st.sidebar.slider("Heads to plot", min_value=1, max_value=selected_model.heads, value=(1, selected_model.heads), step=1)
|
37 |
layer_range = st.sidebar.slider("Layers to plot", min_value=1, max_value=selected_model.layers, value=(1, selected_model.layers), step=1)
|
38 |
+
step_size = st.sidebar.number_input("Optional step size to skip heads and layers", value=2, min_value=1, max_value=selected_model.layers)
|
39 |
layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
|
40 |
head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
|
41 |
|
hexviz/plot.py
CHANGED
@@ -20,9 +20,12 @@ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[in
|
|
20 |
if i == 0:
|
21 |
axes[i, j].set_title(f'Head {head_sequence[j] + 1}', fontsize=10, y=1.05)
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
26 |
|
27 |
plt.subplots_adjust(wspace=0.1, hspace=0.1)
|
28 |
return fig
|
|
|
20 |
if i == 0:
|
21 |
axes[i, j].set_title(f'Head {head_sequence[j] + 1}', fontsize=10, y=1.05)
|
22 |
|
23 |
+
# Calculate the row label offset based on the number of columns
|
24 |
+
offset = 0.02 + (12 - num_heads) * 0.0015
|
25 |
+
for i, ax_row in enumerate(axes):
|
26 |
+
row_label = f"{layer_sequence[i]+1}"
|
27 |
+
row_pos = ax_row[num_heads-1].get_position()
|
28 |
+
fig.text(row_pos.x1+offset, (row_pos.y1+row_pos.y0)/2, row_label, va='center')
|
29 |
|
30 |
plt.subplots_adjust(wspace=0.1, hspace=0.1)
|
31 |
return fig
|