Spaces:
Sleeping
Sleeping
Remove unused debuging code path
Browse files- hexviz/attention.py +6 -21
hexviz/attention.py
CHANGED
@@ -103,27 +103,12 @@ def get_attention(
|
|
103 |
with torch.no_grad():
|
104 |
outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
|
105 |
attentions = outputs.attentions
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
attentions_symmetric = []
|
113 |
-
# Make symmetric attention matrix
|
114 |
-
for attention in reshaped:
|
115 |
-
x = torch.zeros(n_heads, n_residues, n_residues)
|
116 |
-
x[:,i,j] = attention
|
117 |
-
x[:,j,i] = attention
|
118 |
-
attentions_symmetric.append(x)
|
119 |
-
attentions = torch.stack([attention for attention in attentions_symmetric])
|
120 |
-
else:
|
121 |
-
# torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
|
122 |
-
attention_squeezed = [torch.squeeze(attention) for attention in attentions]
|
123 |
-
|
124 |
-
# ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
|
125 |
-
attention_stacked = torch.stack([attention for attention in attention_squeezed])
|
126 |
-
attentions = attention_stacked
|
127 |
|
128 |
elif model_type == ModelType.PROT_T5:
|
129 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
103 |
with torch.no_grad():
|
104 |
outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
|
105 |
attentions = outputs.attentions
|
106 |
+
|
107 |
+
# torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
|
108 |
+
attention_squeezed = [torch.squeeze(attention) for attention in attentions]
|
109 |
+
# ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
|
110 |
+
attention_stacked = torch.stack([attention for attention in attention_squeezed])
|
111 |
+
attentions = attention_stacked
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
elif model_type == ModelType.PROT_T5:
|
114 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|