aksell commited on
Commit
13af878
·
1 Parent(s): 5ddca75

Allow plotting single row or column or a single head

Browse files
Files changed (1) hide show
  1. hexviz/plot.py +1 -1
hexviz/plot.py CHANGED
@@ -10,7 +10,7 @@ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[in
10
 
11
  x_size = num_heads * 2
12
  y_size = num_layers * 2
13
- fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size))
14
  for i in range(num_layers):
15
  for j in range(num_heads):
16
  axes[i, j].imshow(tensor[i, j].detach().numpy(), cmap='viridis', aspect='equal')
 
10
 
11
  x_size = num_heads * 2
12
  y_size = num_layers * 2
13
+ fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size), squeeze=False)
14
  for i in range(num_layers):
15
  for j in range(num_heads):
16
  axes[i, j].imshow(tensor[i, j].detach().numpy(), cmap='viridis', aspect='equal')