aksell commited on
Commit
852aa1f
·
1 Parent(s): c663b1c

Add ProtT5

Browse files
hexviz/attention.py CHANGED
@@ -6,7 +6,13 @@ import streamlit as st
6
  import torch
7
  from Bio.PDB import PDBParser, Polypeptide, Structure
8
 
9
- from hexviz.models import ModelType, get_prot_bert, get_tape_bert, get_zymctrl
 
 
 
 
 
 
10
 
11
 
12
  def get_structure(pdb_code: str) -> Structure:
@@ -20,6 +26,7 @@ def get_structure(pdb_code: str) -> Structure:
20
  structure = parser.get_structure(pdb_code, file)
21
  return structure
22
 
 
23
  def get_pdb_file(pdb_code: str) -> Structure:
24
  """
25
  Get structure from PDB
@@ -29,6 +36,7 @@ def get_pdb_file(pdb_code: str) -> Structure:
29
  file = StringIO(pdb_data)
30
  return file
31
 
 
32
  @st.cache
33
  def get_pdb_from_seq(sequence: str) -> str:
34
  """
@@ -39,6 +47,7 @@ def get_pdb_from_seq(sequence: str) -> str:
39
  pdb_str = res.text
40
  return pdb_str
41
 
 
42
  def get_chains(structure: Structure) -> list[str]:
43
  """
44
  Get list of chains in a structure
@@ -49,6 +58,7 @@ def get_chains(structure: Structure) -> list[str]:
49
  chains.append(chain.id)
50
  return chains
51
 
 
52
  def get_sequence(chain) -> str:
53
  """
54
  Get sequence from a chain
@@ -57,13 +67,18 @@ def get_sequence(chain) -> str:
57
  """
58
  residues = [residue.get_resname() for residue in chain.get_residues()]
59
  # TODO ask if using protein_letters_3to1_extended makes sense
60
- residues_single_letter = map(lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), residues)
 
 
61
 
62
  return "".join(list(residues_single_letter))
63
 
 
64
  def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
65
  lines = sequence.split("\n")
66
- cleaned_sequence = "".join(line.upper() for line in lines if not line.startswith(">"))
 
 
67
  cleaned_sequence = cleaned_sequence.replace(" ", "")
68
  valid_residues = set(Polypeptide.protein_letters_3to1.values())
69
  residues_in_sequence = set(cleaned_sequence)
@@ -84,9 +99,7 @@ def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
84
 
85
 
86
  @st.cache
87
- def get_attention(
88
- sequence: str, model_type: ModelType = ModelType.TAPE_BERT
89
- ):
90
  """
91
  Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
92
  """
@@ -104,11 +117,15 @@ def get_attention(
104
 
105
  elif model_type == ModelType.ZymCTRL:
106
  tokenizer, model = get_zymctrl()
107
- inputs = tokenizer(sequence, return_tensors='pt').input_ids.to(device)
108
- attention_mask = tokenizer(sequence, return_tensors='pt').attention_mask.to(device)
 
 
109
 
110
  with torch.no_grad():
111
- outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
 
 
112
  attentions = outputs.attentions
113
 
114
  # torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
@@ -128,12 +145,27 @@ def get_attention(
128
  attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
129
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  else:
132
  raise ValueError(f"Model {model_type} not supported")
133
 
134
  # Transfer to CPU to avoid issues with streamlit caching
135
  return attentions.cpu()
136
 
 
137
  def unidirectional_avg_filtered(attention, layer, head, threshold):
138
  num_layers, num_heads, seq_len, _ = attention.shape
139
  attention_head = attention[layer, head]
@@ -147,7 +179,7 @@ def unidirectional_avg_filtered(attention, layer, head, threshold):
147
  if avg >= threshold:
148
  unidirectional_avg_for_head.append((avg, i, j))
149
  return unidirectional_avg_for_head
150
-
151
 
152
  # Passing the pdb_str here is a workaround for streamlit caching
153
  # where I need the input to be hashable and not changing
@@ -155,7 +187,15 @@ def unidirectional_avg_filtered(attention, layer, head, threshold):
155
  # Thist twice. If streamlit is upgaded to past 0.17 this can be
156
  # fixed.
157
  @st.cache
158
- def get_attention_pairs(pdb_str: str, layer: int, head: int, chain_ids: list[str] | None ,threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT, top_n: int = 2):
 
 
 
 
 
 
 
 
159
  structure = PDBParser().get_structure("pdb", StringIO(pdb_str))
160
  if chain_ids:
161
  chains = [ch for ch in structure.get_chains() if ch.id in chain_ids]
@@ -167,9 +207,11 @@ def get_attention_pairs(pdb_str: str, layer: int, head: int, chain_ids: list[str
167
  for chain in chains:
168
  sequence = get_sequence(chain)
169
  attention = get_attention(sequence=sequence, model_type=model_type)
170
- attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold)
 
 
171
 
172
- # Store sum of attention in to a resiue (from the unidirectional attention)
173
  residue_attention = {}
174
  for attn_value, res_1, res_2 in attention_unidirectional:
175
  try:
@@ -178,15 +220,18 @@ def get_attention_pairs(pdb_str: str, layer: int, head: int, chain_ids: list[str
178
  except KeyError:
179
  continue
180
 
181
- attention_pairs.append((attn_value, coord_1, coord_2, chain.id, res_1, res_2))
 
 
182
  residue_attention[res_1] = residue_attention.get(res_1, 0) + attn_value
183
  residue_attention[res_2] = residue_attention.get(res_2, 0) + attn_value
184
-
185
- top_n_residues = sorted(residue_attention.items(), key=lambda x: x[1], reverse=True)[:top_n]
186
-
 
 
187
  for res, attn_sum in top_n_residues:
188
  coord = chain[res]["CA"].coord.tolist()
189
  top_residues.append((attn_sum, coord, chain.id, res))
190
-
191
- return attention_pairs, top_residues
192
 
 
 
6
  import torch
7
  from Bio.PDB import PDBParser, Polypeptide, Structure
8
 
9
+ from hexviz.models import (
10
+ ModelType,
11
+ get_prot_bert,
12
+ get_prot_t5,
13
+ get_tape_bert,
14
+ get_zymctrl,
15
+ )
16
 
17
 
18
  def get_structure(pdb_code: str) -> Structure:
 
26
  structure = parser.get_structure(pdb_code, file)
27
  return structure
28
 
29
+
30
  def get_pdb_file(pdb_code: str) -> Structure:
31
  """
32
  Get structure from PDB
 
36
  file = StringIO(pdb_data)
37
  return file
38
 
39
+
40
  @st.cache
41
  def get_pdb_from_seq(sequence: str) -> str:
42
  """
 
47
  pdb_str = res.text
48
  return pdb_str
49
 
50
+
51
  def get_chains(structure: Structure) -> list[str]:
52
  """
53
  Get list of chains in a structure
 
58
  chains.append(chain.id)
59
  return chains
60
 
61
+
62
  def get_sequence(chain) -> str:
63
  """
64
  Get sequence from a chain
 
67
  """
68
  residues = [residue.get_resname() for residue in chain.get_residues()]
69
  # TODO ask if using protein_letters_3to1_extended makes sense
70
+ residues_single_letter = map(
71
+ lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), residues
72
+ )
73
 
74
  return "".join(list(residues_single_letter))
75
 
76
+
77
  def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
78
  lines = sequence.split("\n")
79
+ cleaned_sequence = "".join(
80
+ line.upper() for line in lines if not line.startswith(">")
81
+ )
82
  cleaned_sequence = cleaned_sequence.replace(" ", "")
83
  valid_residues = set(Polypeptide.protein_letters_3to1.values())
84
  residues_in_sequence = set(cleaned_sequence)
 
99
 
100
 
101
  @st.cache
102
+ def get_attention(sequence: str, model_type: ModelType = ModelType.TAPE_BERT):
 
 
103
  """
104
  Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
105
  """
 
117
 
118
  elif model_type == ModelType.ZymCTRL:
119
  tokenizer, model = get_zymctrl()
120
+ inputs = tokenizer(sequence, return_tensors="pt").input_ids.to(device)
121
+ attention_mask = tokenizer(sequence, return_tensors="pt").attention_mask.to(
122
+ device
123
+ )
124
 
125
  with torch.no_grad():
126
+ outputs = model(
127
+ inputs, attention_mask=attention_mask, output_attentions=True
128
+ )
129
  attentions = outputs.attentions
130
 
131
  # torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
 
145
  attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
146
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
147
 
148
+ elif model_type == ModelType.PROT_T5:
149
+ tokenizer, model = get_prot_t5()
150
+ sequence_separated = " ".join(sequence)
151
+ token_idxs = tokenizer.encode(sequence_separated)
152
+ inputs = torch.tensor(token_idxs).unsqueeze(0).to(device)
153
+ with torch.no_grad():
154
+ attentions = model(inputs, output_attentions=True)[
155
+ -1
156
+ ] # Do you need an attention mask?
157
+
158
+ # Remove attention to <pad> (first) and <extra_id_1>, <extra_id_2> (last) tokens
159
+ attentions = [attention[:, :, 3:-3, 3:-3] for attention in attentions]
160
+ attentions = torch.stack([attention.squeeze(0) for attention in attentions])
161
+
162
  else:
163
  raise ValueError(f"Model {model_type} not supported")
164
 
165
  # Transfer to CPU to avoid issues with streamlit caching
166
  return attentions.cpu()
167
 
168
+
169
  def unidirectional_avg_filtered(attention, layer, head, threshold):
170
  num_layers, num_heads, seq_len, _ = attention.shape
171
  attention_head = attention[layer, head]
 
179
  if avg >= threshold:
180
  unidirectional_avg_for_head.append((avg, i, j))
181
  return unidirectional_avg_for_head
182
+
183
 
184
  # Passing the pdb_str here is a workaround for streamlit caching
185
  # where I need the input to be hashable and not changing
 
187
  # Thist twice. If streamlit is upgaded to past 0.17 this can be
188
  # fixed.
189
  @st.cache
190
+ def get_attention_pairs(
191
+ pdb_str: str,
192
+ layer: int,
193
+ head: int,
194
+ chain_ids: list[str] | None,
195
+ threshold: int = 0.2,
196
+ model_type: ModelType = ModelType.TAPE_BERT,
197
+ top_n: int = 2,
198
+ ):
199
  structure = PDBParser().get_structure("pdb", StringIO(pdb_str))
200
  if chain_ids:
201
  chains = [ch for ch in structure.get_chains() if ch.id in chain_ids]
 
207
  for chain in chains:
208
  sequence = get_sequence(chain)
209
  attention = get_attention(sequence=sequence, model_type=model_type)
210
+ attention_unidirectional = unidirectional_avg_filtered(
211
+ attention, layer, head, threshold
212
+ )
213
 
214
+ # Store sum of attention in to a resiue (from the unidirectional attention)
215
  residue_attention = {}
216
  for attn_value, res_1, res_2 in attention_unidirectional:
217
  try:
 
220
  except KeyError:
221
  continue
222
 
223
+ attention_pairs.append(
224
+ (attn_value, coord_1, coord_2, chain.id, res_1, res_2)
225
+ )
226
  residue_attention[res_1] = residue_attention.get(res_1, 0) + attn_value
227
  residue_attention[res_2] = residue_attention.get(res_2, 0) + attn_value
228
+
229
+ top_n_residues = sorted(
230
+ residue_attention.items(), key=lambda x: x[1], reverse=True
231
+ )[:top_n]
232
+
233
  for res, attn_sum in top_n_residues:
234
  coord = chain[res]["CA"].coord.tolist()
235
  top_residues.append((attn_sum, coord, chain.id, res))
 
 
236
 
237
+ return attention_pairs, top_residues
hexviz/models.py CHANGED
@@ -10,6 +10,8 @@ from transformers import (
10
  BertTokenizer,
11
  GPT2LMHeadModel,
12
  GPT2TokenizerFast,
 
 
13
  )
14
 
15
 
@@ -17,6 +19,7 @@ class ModelType(str, Enum):
17
  TAPE_BERT = "TapeBert"
18
  ZymCTRL = "ZymCTRL"
19
  PROT_BERT = "ProtBert"
 
20
 
21
 
22
  class Model:
@@ -49,3 +52,15 @@ def get_prot_bert() -> tuple[BertTokenizer, BertModel]:
49
  tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
50
  model = BertModel.from_pretrained("Rostlab/prot_bert").to(device)
51
  return tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  BertTokenizer,
11
  GPT2LMHeadModel,
12
  GPT2TokenizerFast,
13
+ T5EncoderModel,
14
+ T5Tokenizer,
15
  )
16
 
17
 
 
19
  TAPE_BERT = "TapeBert"
20
  ZymCTRL = "ZymCTRL"
21
  PROT_BERT = "ProtBert"
22
+ PROT_T5 = "ProtT5"
23
 
24
 
25
  class Model:
 
52
  tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
53
  model = BertModel.from_pretrained("Rostlab/prot_bert").to(device)
54
  return tokenizer, model
55
+
56
+
57
+ @st.cache
58
+ def get_prot_t5():
59
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
60
+ tokenizer = T5Tokenizer.from_pretrained(
61
+ "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
62
+ )
63
+ model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(
64
+ device
65
+ )
66
+ return tokenizer, model
hexviz/pages/1_🗺️Identify_Interesting_Heads.py CHANGED
@@ -23,6 +23,7 @@ models = [
23
  Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
24
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
25
  Model(name=ModelType.PROT_BERT, layers=30, heads=16),
 
26
  ]
27
 
28
  with st.expander(
 
23
  Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
24
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
25
  Model(name=ModelType.PROT_BERT, layers=30, heads=16),
26
+ Model(name=ModelType.PROT_T5, layers=24, heads=32),
27
  ]
28
 
29
  with st.expander(
hexviz/pages/2_📄Documentation.py CHANGED
@@ -45,6 +45,7 @@ Hexviz currently supports the following models:
45
  1. [ProtBERT](https://huggingface.co/Rostlab/prot_bert_bfd)
46
  2. [ZymCTRL](https://huggingface.co/nferruz/ZymCTRL)
47
  3. [TapeBert](https://github.com/songlab-cal/tape/blob/master/tape/models/modeling_bert.py) - a nickname coined in BERTOLOGY meets biology for the Bert Base model pre-trained on Pfam in [TAPE](https://www.biorxiv.org/content/10.1101/676825v1). TapeBert is used extensively in BERTOlogy meets biology.
 
48
 
49
  ## FAQ
50
  1. I can't see any attention- "bars" in the visualization, what is wrong? -> Lower the `minimum attention`.
 
45
  1. [ProtBERT](https://huggingface.co/Rostlab/prot_bert_bfd)
46
  2. [ZymCTRL](https://huggingface.co/nferruz/ZymCTRL)
47
  3. [TapeBert](https://github.com/songlab-cal/tape/blob/master/tape/models/modeling_bert.py) - a nickname coined in BERTOLOGY meets biology for the Bert Base model pre-trained on Pfam in [TAPE](https://www.biorxiv.org/content/10.1101/676825v1). TapeBert is used extensively in BERTOlogy meets biology.
48
+ 4. [ProtT5 half](https://huggingface.co/Rostlab/prot_t5_xl_half_uniref50-enc)
49
 
50
  ## FAQ
51
  1. I can't see any attention- "bars" in the visualization, what is wrong? -> Lower the `minimum attention`.
hexviz/🧬Attention_Visualization.py CHANGED
@@ -22,6 +22,7 @@ models = [
22
  Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
23
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
24
  Model(name=ModelType.PROT_BERT, layers=30, heads=16),
 
25
  ]
26
 
27
  with st.expander(
@@ -219,8 +220,8 @@ st.table(df)
219
  st.markdown(
220
  """
221
  ### Check out the other pages
222
- [🗺️Identify Interesting heads](Identify_Interesting_Heads) give a birds-eye view of attention patterns for a model,
223
- this can help you pick what specific attention heads to look at for your protein.
224
 
225
  [📄Documentation](Documentation) has information on protein language models, attention analysis and hexviz."""
226
  )
 
22
  Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
23
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
24
  Model(name=ModelType.PROT_BERT, layers=30, heads=16),
25
+ Model(name=ModelType.PROT_T5, layers=24, heads=32),
26
  ]
27
 
28
  with st.expander(
 
220
  st.markdown(
221
  """
222
  ### Check out the other pages
223
+ [🗺️Identify Interesting heads](Identify_Interesting_Heads) gives a bird's eye view of attention patterns for a model.
224
+ This can help you pick what specific attention heads to look at for your protein.
225
 
226
  [📄Documentation](Documentation) has information on protein language models, attention analysis and hexviz."""
227
  )