aksell commited on
Commit
9d9c196
·
1 Parent(s): 8799d0b

Copy attention tensor to CPU for streamlit caching

Browse files

Streamlit tries to cache it using numpy (I think), and
this does not work because the tensor is on the GPU.
Interestingly it only causes issues for ZymCTRL not for
TAPE-BERT.

Files changed (1) hide show
  1. hexviz/attention.py +2 -1
hexviz/attention.py CHANGED
@@ -89,7 +89,8 @@ def get_attention(
89
  else:
90
  raise ValueError(f"Model {model_type} not supported")
91
 
92
- return attentions
 
93
 
94
  def unidirectional_avg_filtered(attention, layer, head, threshold):
95
  num_layers, num_heads, seq_len, _ = attention.shape
 
89
  else:
90
  raise ValueError(f"Model {model_type} not supported")
91
 
92
+ # Transfer to CPU to avoid issues with streamlit caching
93
+ return attentions.cpu()
94
 
95
  def unidirectional_avg_filtered(attention, layer, head, threshold):
96
  num_layers, num_heads, seq_len, _ = attention.shape