Spaces:
Sleeping
Sleeping
Allow plotting single row or column or a single head
Browse files- hexviz/plot.py +1 -1
hexviz/plot.py
CHANGED
@@ -10,7 +10,7 @@ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[in
|
|
10 |
|
11 |
x_size = num_heads * 2
|
12 |
y_size = num_layers * 2
|
13 |
-
fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size))
|
14 |
for i in range(num_layers):
|
15 |
for j in range(num_heads):
|
16 |
axes[i, j].imshow(tensor[i, j].detach().numpy(), cmap='viridis', aspect='equal')
|
|
|
10 |
|
11 |
x_size = num_heads * 2
|
12 |
y_size = num_layers * 2
|
13 |
+
fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size), squeeze=False)
|
14 |
for i in range(num_layers):
|
15 |
for j in range(num_heads):
|
16 |
axes[i, j].imshow(tensor[i, j].detach().numpy(), cmap='viridis', aspect='equal')
|