Spaces:
Sleeping
Sleeping
Copy attention tensor to CPU for streamlit caching
Browse filesStreamlit tries to cache it using numpy (I think), and
this does not work because the tensor is on the GPU.
Interestingly it only causes issues for ZymCTRL not for
TAPE-BERT.
- hexviz/attention.py +2 -1
hexviz/attention.py
CHANGED
@@ -89,7 +89,8 @@ def get_attention(
|
|
89 |
else:
|
90 |
raise ValueError(f"Model {model_type} not supported")
|
91 |
|
92 |
-
|
|
|
93 |
|
94 |
def unidirectional_avg_filtered(attention, layer, head, threshold):
|
95 |
num_layers, num_heads, seq_len, _ = attention.shape
|
|
|
89 |
else:
|
90 |
raise ValueError(f"Model {model_type} not supported")
|
91 |
|
92 |
+
# Transfer to CPU to avoid issues with streamlit caching
|
93 |
+
return attentions.cpu()
|
94 |
|
95 |
def unidirectional_avg_filtered(attention, layer, head, threshold):
|
96 |
num_layers, num_heads, seq_len, _ = attention.shape
|