aksell commited on
Commit
1431daf
·
1 Parent(s): db81f50

Remove unused debuging code path

Browse files
Files changed (1) hide show
  1. hexviz/attention.py +6 -21
hexviz/attention.py CHANGED
@@ -103,27 +103,12 @@ def get_attention(
103
  with torch.no_grad():
104
  outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
105
  attentions = outputs.attentions
106
- if attentions[0].shape[-1] == attentions[0].shape[-2] == 1:
107
- reshaped = [attention.view(attention.shape[1], attention.shape[0]) for attention in attentions]
108
- n_residues = reshaped[0].shape[-1]
109
- n_heads = 16
110
- i,j = torch.triu_indices(n_residues, n_residues)
111
-
112
- attentions_symmetric = []
113
- # Make symmetric attention matrix
114
- for attention in reshaped:
115
- x = torch.zeros(n_heads, n_residues, n_residues)
116
- x[:,i,j] = attention
117
- x[:,j,i] = attention
118
- attentions_symmetric.append(x)
119
- attentions = torch.stack([attention for attention in attentions_symmetric])
120
- else:
121
- # torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
122
- attention_squeezed = [torch.squeeze(attention) for attention in attentions]
123
-
124
- # ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
125
- attention_stacked = torch.stack([attention for attention in attention_squeezed])
126
- attentions = attention_stacked
127
 
128
  elif model_type == ModelType.PROT_T5:
129
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
103
  with torch.no_grad():
104
  outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
105
  attentions = outputs.attentions
106
+
107
+ # torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
108
+ attention_squeezed = [torch.squeeze(attention) for attention in attentions]
109
+ # ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
110
+ attention_stacked = torch.stack([attention for attention in attention_squeezed])
111
+ attentions = attention_stacked
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  elif model_type == ModelType.PROT_T5:
114
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")