Spaces:
Sleeping
Sleeping
Plot avg attention not sum
Browse files- hexviz/attention.py +9 -10
- tests/test_attention.py +4 -4
hexviz/attention.py
CHANGED
|
@@ -85,21 +85,20 @@ def get_attention(
|
|
| 85 |
|
| 86 |
return attentions
|
| 87 |
|
| 88 |
-
def
|
| 89 |
num_layers, num_heads, seq_len, _ = attention.shape
|
| 90 |
attention_head = attention[layer, head]
|
| 91 |
-
|
| 92 |
for i in range(seq_len):
|
| 93 |
for j in range(i, seq_len):
|
| 94 |
# Attention matrices for BERT models are asymetric.
|
| 95 |
-
# Bidirectional attention is
|
| 96 |
-
# attention values
|
| 97 |
-
# TODO think... does this operation make sense?
|
| 98 |
sum = attention_head[i, j].item() + attention_head[j, i].item()
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
| 103 |
@st.cache
|
| 104 |
def get_attention_pairs(pdb_code: str, layer: int, head: int, threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT):
|
| 105 |
# fetch structure
|
|
@@ -110,7 +109,7 @@ def get_attention_pairs(pdb_code: str, layer: int, head: int, threshold: int = 0
|
|
| 110 |
attention_pairs = []
|
| 111 |
for i, sequence in enumerate(sequences):
|
| 112 |
attention = get_attention(sequence=sequence, model_type=model_type)
|
| 113 |
-
attention_unidirectional =
|
| 114 |
chain = list(structure.get_chains())[i]
|
| 115 |
for attn_value, res_1, res_2 in attention_unidirectional:
|
| 116 |
try:
|
|
|
|
| 85 |
|
| 86 |
return attentions
|
| 87 |
|
| 88 |
+
def unidirectional_avg_filtered(attention, layer, head, threshold):
|
| 89 |
num_layers, num_heads, seq_len, _ = attention.shape
|
| 90 |
attention_head = attention[layer, head]
|
| 91 |
+
unidirectional_avg_for_head = []
|
| 92 |
for i in range(seq_len):
|
| 93 |
for j in range(i, seq_len):
|
| 94 |
# Attention matrices for BERT models are asymetric.
|
| 95 |
+
# Bidirectional attention is represented by the average of the two values
|
|
|
|
|
|
|
| 96 |
sum = attention_head[i, j].item() + attention_head[j, i].item()
|
| 97 |
+
avg = sum / 2
|
| 98 |
+
if avg >= threshold:
|
| 99 |
+
unidirectional_avg_for_head.append((avg, i, j))
|
| 100 |
+
return unidirectional_avg_for_head
|
| 101 |
+
|
| 102 |
@st.cache
|
| 103 |
def get_attention_pairs(pdb_code: str, layer: int, head: int, threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT):
|
| 104 |
# fetch structure
|
|
|
|
| 109 |
attention_pairs = []
|
| 110 |
for i, sequence in enumerate(sequences):
|
| 111 |
attention = get_attention(sequence=sequence, model_type=model_type)
|
| 112 |
+
attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold)
|
| 113 |
chain = list(structure.get_chains())[i]
|
| 114 |
for attn_value, res_1, res_2 in attention_unidirectional:
|
| 115 |
try:
|
tests/test_attention.py
CHANGED
|
@@ -2,7 +2,7 @@ import torch
|
|
| 2 |
from Bio.PDB.Structure import Structure
|
| 3 |
|
| 4 |
from hexviz.attention import (ModelType, get_attention, get_sequences,
|
| 5 |
-
get_structure,
|
| 6 |
|
| 7 |
|
| 8 |
def test_get_structure():
|
|
@@ -58,14 +58,14 @@ def test_get_attention_prot_bert():
|
|
| 58 |
assert result is not None
|
| 59 |
assert result.shape == torch.Size([30, 16, 3, 3])
|
| 60 |
|
| 61 |
-
def
|
| 62 |
# 1 head, 1 layer, 4 residues long attention tensor
|
| 63 |
attention= torch.tensor([[[[1, 2, 3, 4],
|
| 64 |
[2, 5, 6, 7],
|
| 65 |
[3, 6, 8, 9],
|
| 66 |
[4, 7, 9, 11]]]], dtype=torch.float32)
|
| 67 |
|
| 68 |
-
result =
|
| 69 |
|
| 70 |
assert result is not None
|
| 71 |
assert len(result) == 10
|
|
@@ -74,6 +74,6 @@ def test_get_unidirection_sum_filtered():
|
|
| 74 |
[2, 5, 6],
|
| 75 |
[4, 7, 91]]]], dtype=torch.float32)
|
| 76 |
|
| 77 |
-
result =
|
| 78 |
|
| 79 |
assert len(result) == 6
|
|
|
|
| 2 |
from Bio.PDB.Structure import Structure
|
| 3 |
|
| 4 |
from hexviz.attention import (ModelType, get_attention, get_sequences,
|
| 5 |
+
get_structure, unidirectional_avg_filtered)
|
| 6 |
|
| 7 |
|
| 8 |
def test_get_structure():
|
|
|
|
| 58 |
assert result is not None
|
| 59 |
assert result.shape == torch.Size([30, 16, 3, 3])
|
| 60 |
|
| 61 |
+
def test_get_unidirection_avg_filtered():
|
| 62 |
# 1 head, 1 layer, 4 residues long attention tensor
|
| 63 |
attention= torch.tensor([[[[1, 2, 3, 4],
|
| 64 |
[2, 5, 6, 7],
|
| 65 |
[3, 6, 8, 9],
|
| 66 |
[4, 7, 9, 11]]]], dtype=torch.float32)
|
| 67 |
|
| 68 |
+
result = unidirectional_avg_filtered(attention, 0, 0, 0)
|
| 69 |
|
| 70 |
assert result is not None
|
| 71 |
assert len(result) == 10
|
|
|
|
| 74 |
[2, 5, 6],
|
| 75 |
[4, 7, 91]]]], dtype=torch.float32)
|
| 76 |
|
| 77 |
+
result = unidirectional_avg_filtered(attention, 0, 0, 0)
|
| 78 |
|
| 79 |
assert len(result) == 6
|