Spaces:
Sleeping
Sleeping
Add ProtT5
Browse files- hexviz/attention.py +64 -19
- hexviz/models.py +15 -0
- hexviz/pages/1_🗺️Identify_Interesting_Heads.py +1 -0
- hexviz/pages/2_📄Documentation.py +1 -0
- hexviz/🧬Attention_Visualization.py +3 -2
hexviz/attention.py
CHANGED
@@ -6,7 +6,13 @@ import streamlit as st
|
|
6 |
import torch
|
7 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
8 |
|
9 |
-
from hexviz.models import
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
def get_structure(pdb_code: str) -> Structure:
|
@@ -20,6 +26,7 @@ def get_structure(pdb_code: str) -> Structure:
|
|
20 |
structure = parser.get_structure(pdb_code, file)
|
21 |
return structure
|
22 |
|
|
|
23 |
def get_pdb_file(pdb_code: str) -> Structure:
|
24 |
"""
|
25 |
Get structure from PDB
|
@@ -29,6 +36,7 @@ def get_pdb_file(pdb_code: str) -> Structure:
|
|
29 |
file = StringIO(pdb_data)
|
30 |
return file
|
31 |
|
|
|
32 |
@st.cache
|
33 |
def get_pdb_from_seq(sequence: str) -> str:
|
34 |
"""
|
@@ -39,6 +47,7 @@ def get_pdb_from_seq(sequence: str) -> str:
|
|
39 |
pdb_str = res.text
|
40 |
return pdb_str
|
41 |
|
|
|
42 |
def get_chains(structure: Structure) -> list[str]:
|
43 |
"""
|
44 |
Get list of chains in a structure
|
@@ -49,6 +58,7 @@ def get_chains(structure: Structure) -> list[str]:
|
|
49 |
chains.append(chain.id)
|
50 |
return chains
|
51 |
|
|
|
52 |
def get_sequence(chain) -> str:
|
53 |
"""
|
54 |
Get sequence from a chain
|
@@ -57,13 +67,18 @@ def get_sequence(chain) -> str:
|
|
57 |
"""
|
58 |
residues = [residue.get_resname() for residue in chain.get_residues()]
|
59 |
# TODO ask if using protein_letters_3to1_extended makes sense
|
60 |
-
residues_single_letter = map(
|
|
|
|
|
61 |
|
62 |
return "".join(list(residues_single_letter))
|
63 |
|
|
|
64 |
def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
|
65 |
lines = sequence.split("\n")
|
66 |
-
cleaned_sequence = "".join(
|
|
|
|
|
67 |
cleaned_sequence = cleaned_sequence.replace(" ", "")
|
68 |
valid_residues = set(Polypeptide.protein_letters_3to1.values())
|
69 |
residues_in_sequence = set(cleaned_sequence)
|
@@ -84,9 +99,7 @@ def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
|
|
84 |
|
85 |
|
86 |
@st.cache
|
87 |
-
def get_attention(
|
88 |
-
sequence: str, model_type: ModelType = ModelType.TAPE_BERT
|
89 |
-
):
|
90 |
"""
|
91 |
Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
|
92 |
"""
|
@@ -104,11 +117,15 @@ def get_attention(
|
|
104 |
|
105 |
elif model_type == ModelType.ZymCTRL:
|
106 |
tokenizer, model = get_zymctrl()
|
107 |
-
inputs = tokenizer(sequence, return_tensors=
|
108 |
-
attention_mask = tokenizer(sequence, return_tensors=
|
|
|
|
|
109 |
|
110 |
with torch.no_grad():
|
111 |
-
outputs = model(
|
|
|
|
|
112 |
attentions = outputs.attentions
|
113 |
|
114 |
# torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
|
@@ -128,12 +145,27 @@ def get_attention(
|
|
128 |
attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
|
129 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
else:
|
132 |
raise ValueError(f"Model {model_type} not supported")
|
133 |
|
134 |
# Transfer to CPU to avoid issues with streamlit caching
|
135 |
return attentions.cpu()
|
136 |
|
|
|
137 |
def unidirectional_avg_filtered(attention, layer, head, threshold):
|
138 |
num_layers, num_heads, seq_len, _ = attention.shape
|
139 |
attention_head = attention[layer, head]
|
@@ -147,7 +179,7 @@ def unidirectional_avg_filtered(attention, layer, head, threshold):
|
|
147 |
if avg >= threshold:
|
148 |
unidirectional_avg_for_head.append((avg, i, j))
|
149 |
return unidirectional_avg_for_head
|
150 |
-
|
151 |
|
152 |
# Passing the pdb_str here is a workaround for streamlit caching
|
153 |
# where I need the input to be hashable and not changing
|
@@ -155,7 +187,15 @@ def unidirectional_avg_filtered(attention, layer, head, threshold):
|
|
155 |
# Thist twice. If streamlit is upgaded to past 0.17 this can be
|
156 |
# fixed.
|
157 |
@st.cache
|
158 |
-
def get_attention_pairs(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
structure = PDBParser().get_structure("pdb", StringIO(pdb_str))
|
160 |
if chain_ids:
|
161 |
chains = [ch for ch in structure.get_chains() if ch.id in chain_ids]
|
@@ -167,9 +207,11 @@ def get_attention_pairs(pdb_str: str, layer: int, head: int, chain_ids: list[str
|
|
167 |
for chain in chains:
|
168 |
sequence = get_sequence(chain)
|
169 |
attention = get_attention(sequence=sequence, model_type=model_type)
|
170 |
-
attention_unidirectional = unidirectional_avg_filtered(
|
|
|
|
|
171 |
|
172 |
-
# Store sum of attention in to a resiue (from the unidirectional attention)
|
173 |
residue_attention = {}
|
174 |
for attn_value, res_1, res_2 in attention_unidirectional:
|
175 |
try:
|
@@ -178,15 +220,18 @@ def get_attention_pairs(pdb_str: str, layer: int, head: int, chain_ids: list[str
|
|
178 |
except KeyError:
|
179 |
continue
|
180 |
|
181 |
-
attention_pairs.append(
|
|
|
|
|
182 |
residue_attention[res_1] = residue_attention.get(res_1, 0) + attn_value
|
183 |
residue_attention[res_2] = residue_attention.get(res_2, 0) + attn_value
|
184 |
-
|
185 |
-
top_n_residues = sorted(
|
186 |
-
|
|
|
|
|
187 |
for res, attn_sum in top_n_residues:
|
188 |
coord = chain[res]["CA"].coord.tolist()
|
189 |
top_residues.append((attn_sum, coord, chain.id, res))
|
190 |
-
|
191 |
-
return attention_pairs, top_residues
|
192 |
|
|
|
|
6 |
import torch
|
7 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
8 |
|
9 |
+
from hexviz.models import (
|
10 |
+
ModelType,
|
11 |
+
get_prot_bert,
|
12 |
+
get_prot_t5,
|
13 |
+
get_tape_bert,
|
14 |
+
get_zymctrl,
|
15 |
+
)
|
16 |
|
17 |
|
18 |
def get_structure(pdb_code: str) -> Structure:
|
|
|
26 |
structure = parser.get_structure(pdb_code, file)
|
27 |
return structure
|
28 |
|
29 |
+
|
30 |
def get_pdb_file(pdb_code: str) -> Structure:
|
31 |
"""
|
32 |
Get structure from PDB
|
|
|
36 |
file = StringIO(pdb_data)
|
37 |
return file
|
38 |
|
39 |
+
|
40 |
@st.cache
|
41 |
def get_pdb_from_seq(sequence: str) -> str:
|
42 |
"""
|
|
|
47 |
pdb_str = res.text
|
48 |
return pdb_str
|
49 |
|
50 |
+
|
51 |
def get_chains(structure: Structure) -> list[str]:
|
52 |
"""
|
53 |
Get list of chains in a structure
|
|
|
58 |
chains.append(chain.id)
|
59 |
return chains
|
60 |
|
61 |
+
|
62 |
def get_sequence(chain) -> str:
|
63 |
"""
|
64 |
Get sequence from a chain
|
|
|
67 |
"""
|
68 |
residues = [residue.get_resname() for residue in chain.get_residues()]
|
69 |
# TODO ask if using protein_letters_3to1_extended makes sense
|
70 |
+
residues_single_letter = map(
|
71 |
+
lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), residues
|
72 |
+
)
|
73 |
|
74 |
return "".join(list(residues_single_letter))
|
75 |
|
76 |
+
|
77 |
def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
|
78 |
lines = sequence.split("\n")
|
79 |
+
cleaned_sequence = "".join(
|
80 |
+
line.upper() for line in lines if not line.startswith(">")
|
81 |
+
)
|
82 |
cleaned_sequence = cleaned_sequence.replace(" ", "")
|
83 |
valid_residues = set(Polypeptide.protein_letters_3to1.values())
|
84 |
residues_in_sequence = set(cleaned_sequence)
|
|
|
99 |
|
100 |
|
101 |
@st.cache
|
102 |
+
def get_attention(sequence: str, model_type: ModelType = ModelType.TAPE_BERT):
|
|
|
|
|
103 |
"""
|
104 |
Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
|
105 |
"""
|
|
|
117 |
|
118 |
elif model_type == ModelType.ZymCTRL:
|
119 |
tokenizer, model = get_zymctrl()
|
120 |
+
inputs = tokenizer(sequence, return_tensors="pt").input_ids.to(device)
|
121 |
+
attention_mask = tokenizer(sequence, return_tensors="pt").attention_mask.to(
|
122 |
+
device
|
123 |
+
)
|
124 |
|
125 |
with torch.no_grad():
|
126 |
+
outputs = model(
|
127 |
+
inputs, attention_mask=attention_mask, output_attentions=True
|
128 |
+
)
|
129 |
attentions = outputs.attentions
|
130 |
|
131 |
# torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
|
|
|
145 |
attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
|
146 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
147 |
|
148 |
+
elif model_type == ModelType.PROT_T5:
|
149 |
+
tokenizer, model = get_prot_t5()
|
150 |
+
sequence_separated = " ".join(sequence)
|
151 |
+
token_idxs = tokenizer.encode(sequence_separated)
|
152 |
+
inputs = torch.tensor(token_idxs).unsqueeze(0).to(device)
|
153 |
+
with torch.no_grad():
|
154 |
+
attentions = model(inputs, output_attentions=True)[
|
155 |
+
-1
|
156 |
+
] # Do you need an attention mask?
|
157 |
+
|
158 |
+
# Remove attention to <pad> (first) and <extra_id_1>, <extra_id_2> (last) tokens
|
159 |
+
attentions = [attention[:, :, 3:-3, 3:-3] for attention in attentions]
|
160 |
+
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
161 |
+
|
162 |
else:
|
163 |
raise ValueError(f"Model {model_type} not supported")
|
164 |
|
165 |
# Transfer to CPU to avoid issues with streamlit caching
|
166 |
return attentions.cpu()
|
167 |
|
168 |
+
|
169 |
def unidirectional_avg_filtered(attention, layer, head, threshold):
|
170 |
num_layers, num_heads, seq_len, _ = attention.shape
|
171 |
attention_head = attention[layer, head]
|
|
|
179 |
if avg >= threshold:
|
180 |
unidirectional_avg_for_head.append((avg, i, j))
|
181 |
return unidirectional_avg_for_head
|
182 |
+
|
183 |
|
184 |
# Passing the pdb_str here is a workaround for streamlit caching
|
185 |
# where I need the input to be hashable and not changing
|
|
|
187 |
# Thist twice. If streamlit is upgaded to past 0.17 this can be
|
188 |
# fixed.
|
189 |
@st.cache
|
190 |
+
def get_attention_pairs(
|
191 |
+
pdb_str: str,
|
192 |
+
layer: int,
|
193 |
+
head: int,
|
194 |
+
chain_ids: list[str] | None,
|
195 |
+
threshold: int = 0.2,
|
196 |
+
model_type: ModelType = ModelType.TAPE_BERT,
|
197 |
+
top_n: int = 2,
|
198 |
+
):
|
199 |
structure = PDBParser().get_structure("pdb", StringIO(pdb_str))
|
200 |
if chain_ids:
|
201 |
chains = [ch for ch in structure.get_chains() if ch.id in chain_ids]
|
|
|
207 |
for chain in chains:
|
208 |
sequence = get_sequence(chain)
|
209 |
attention = get_attention(sequence=sequence, model_type=model_type)
|
210 |
+
attention_unidirectional = unidirectional_avg_filtered(
|
211 |
+
attention, layer, head, threshold
|
212 |
+
)
|
213 |
|
214 |
+
# Store sum of attention in to a resiue (from the unidirectional attention)
|
215 |
residue_attention = {}
|
216 |
for attn_value, res_1, res_2 in attention_unidirectional:
|
217 |
try:
|
|
|
220 |
except KeyError:
|
221 |
continue
|
222 |
|
223 |
+
attention_pairs.append(
|
224 |
+
(attn_value, coord_1, coord_2, chain.id, res_1, res_2)
|
225 |
+
)
|
226 |
residue_attention[res_1] = residue_attention.get(res_1, 0) + attn_value
|
227 |
residue_attention[res_2] = residue_attention.get(res_2, 0) + attn_value
|
228 |
+
|
229 |
+
top_n_residues = sorted(
|
230 |
+
residue_attention.items(), key=lambda x: x[1], reverse=True
|
231 |
+
)[:top_n]
|
232 |
+
|
233 |
for res, attn_sum in top_n_residues:
|
234 |
coord = chain[res]["CA"].coord.tolist()
|
235 |
top_residues.append((attn_sum, coord, chain.id, res))
|
|
|
|
|
236 |
|
237 |
+
return attention_pairs, top_residues
|
hexviz/models.py
CHANGED
@@ -10,6 +10,8 @@ from transformers import (
|
|
10 |
BertTokenizer,
|
11 |
GPT2LMHeadModel,
|
12 |
GPT2TokenizerFast,
|
|
|
|
|
13 |
)
|
14 |
|
15 |
|
@@ -17,6 +19,7 @@ class ModelType(str, Enum):
|
|
17 |
TAPE_BERT = "TapeBert"
|
18 |
ZymCTRL = "ZymCTRL"
|
19 |
PROT_BERT = "ProtBert"
|
|
|
20 |
|
21 |
|
22 |
class Model:
|
@@ -49,3 +52,15 @@ def get_prot_bert() -> tuple[BertTokenizer, BertModel]:
|
|
49 |
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
|
50 |
model = BertModel.from_pretrained("Rostlab/prot_bert").to(device)
|
51 |
return tokenizer, model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
BertTokenizer,
|
11 |
GPT2LMHeadModel,
|
12 |
GPT2TokenizerFast,
|
13 |
+
T5EncoderModel,
|
14 |
+
T5Tokenizer,
|
15 |
)
|
16 |
|
17 |
|
|
|
19 |
TAPE_BERT = "TapeBert"
|
20 |
ZymCTRL = "ZymCTRL"
|
21 |
PROT_BERT = "ProtBert"
|
22 |
+
PROT_T5 = "ProtT5"
|
23 |
|
24 |
|
25 |
class Model:
|
|
|
52 |
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
|
53 |
model = BertModel.from_pretrained("Rostlab/prot_bert").to(device)
|
54 |
return tokenizer, model
|
55 |
+
|
56 |
+
|
57 |
+
@st.cache
|
58 |
+
def get_prot_t5():
|
59 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
60 |
+
tokenizer = T5Tokenizer.from_pretrained(
|
61 |
+
"Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
|
62 |
+
)
|
63 |
+
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(
|
64 |
+
device
|
65 |
+
)
|
66 |
+
return tokenizer, model
|
hexviz/pages/1_🗺️Identify_Interesting_Heads.py
CHANGED
@@ -23,6 +23,7 @@ models = [
|
|
23 |
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
24 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
25 |
Model(name=ModelType.PROT_BERT, layers=30, heads=16),
|
|
|
26 |
]
|
27 |
|
28 |
with st.expander(
|
|
|
23 |
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
24 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
25 |
Model(name=ModelType.PROT_BERT, layers=30, heads=16),
|
26 |
+
Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
27 |
]
|
28 |
|
29 |
with st.expander(
|
hexviz/pages/2_📄Documentation.py
CHANGED
@@ -45,6 +45,7 @@ Hexviz currently supports the following models:
|
|
45 |
1. [ProtBERT](https://huggingface.co/Rostlab/prot_bert_bfd)
|
46 |
2. [ZymCTRL](https://huggingface.co/nferruz/ZymCTRL)
|
47 |
3. [TapeBert](https://github.com/songlab-cal/tape/blob/master/tape/models/modeling_bert.py) - a nickname coined in BERTOLOGY meets biology for the Bert Base model pre-trained on Pfam in [TAPE](https://www.biorxiv.org/content/10.1101/676825v1). TapeBert is used extensively in BERTOlogy meets biology.
|
|
|
48 |
|
49 |
## FAQ
|
50 |
1. I can't see any attention- "bars" in the visualization, what is wrong? -> Lower the `minimum attention`.
|
|
|
45 |
1. [ProtBERT](https://huggingface.co/Rostlab/prot_bert_bfd)
|
46 |
2. [ZymCTRL](https://huggingface.co/nferruz/ZymCTRL)
|
47 |
3. [TapeBert](https://github.com/songlab-cal/tape/blob/master/tape/models/modeling_bert.py) - a nickname coined in BERTOLOGY meets biology for the Bert Base model pre-trained on Pfam in [TAPE](https://www.biorxiv.org/content/10.1101/676825v1). TapeBert is used extensively in BERTOlogy meets biology.
|
48 |
+
4. [ProtT5 half](https://huggingface.co/Rostlab/prot_t5_xl_half_uniref50-enc)
|
49 |
|
50 |
## FAQ
|
51 |
1. I can't see any attention- "bars" in the visualization, what is wrong? -> Lower the `minimum attention`.
|
hexviz/🧬Attention_Visualization.py
CHANGED
@@ -22,6 +22,7 @@ models = [
|
|
22 |
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
23 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
24 |
Model(name=ModelType.PROT_BERT, layers=30, heads=16),
|
|
|
25 |
]
|
26 |
|
27 |
with st.expander(
|
@@ -219,8 +220,8 @@ st.table(df)
|
|
219 |
st.markdown(
|
220 |
"""
|
221 |
### Check out the other pages
|
222 |
-
[🗺️Identify Interesting heads](Identify_Interesting_Heads)
|
223 |
-
|
224 |
|
225 |
[📄Documentation](Documentation) has information on protein language models, attention analysis and hexviz."""
|
226 |
)
|
|
|
22 |
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
23 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
24 |
Model(name=ModelType.PROT_BERT, layers=30, heads=16),
|
25 |
+
Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
26 |
]
|
27 |
|
28 |
with st.expander(
|
|
|
220 |
st.markdown(
|
221 |
"""
|
222 |
### Check out the other pages
|
223 |
+
[🗺️Identify Interesting heads](Identify_Interesting_Heads) gives a bird's eye view of attention patterns for a model.
|
224 |
+
This can help you pick what specific attention heads to look at for your protein.
|
225 |
|
226 |
[📄Documentation](Documentation) has information on protein language models, attention analysis and hexviz."""
|
227 |
)
|