aksell commited on
Commit
a2cfd88
·
1 Parent(s): 8cef26d

Format with Black

Browse files
hexviz/models.py CHANGED
@@ -4,12 +4,17 @@ import streamlit as st
4
  import torch
5
  from tape import ProteinBertModel, TAPETokenizer
6
  from tokenizers import Tokenizer
7
- from transformers import (AutoTokenizer, BertModel, BertTokenizer,
8
- GPT2LMHeadModel, GPT2TokenizerFast)
 
 
 
 
 
9
 
10
 
11
  class ModelType(str, Enum):
12
- TAPE_BERT = "TAPE-BERT"
13
  ZymCTRL = "ZymCTRL"
14
  PROT_BERT = "ProtBert"
15
 
@@ -24,22 +29,23 @@ class Model:
24
  @st.cache
25
  def get_tape_bert() -> tuple[TAPETokenizer, ProteinBertModel]:
26
  tokenizer = TAPETokenizer()
27
- model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
28
  return tokenizer, model
29
 
 
30
  # Streamlit is not able to hash the tokenizer for ZymCTRL
31
  # With streamlit 1.19 cache_object should work without this
32
  @st.cache(hash_funcs={Tokenizer: lambda _: None})
33
  def get_zymctrl() -> tuple[GPT2TokenizerFast, GPT2LMHeadModel]:
34
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
- tokenizer = AutoTokenizer.from_pretrained('nferruz/ZymCTRL')
36
- model = GPT2LMHeadModel.from_pretrained('nferruz/ZymCTRL').to(device)
37
  return tokenizer, model
38
 
39
 
40
  @st.cache(hash_funcs={BertTokenizer: lambda _: None})
41
  def get_prot_bert() -> tuple[BertTokenizer, BertModel]:
42
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
43
- tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
44
  model = BertModel.from_pretrained("Rostlab/prot_bert").to(device)
45
- return tokenizer, model
 
4
  import torch
5
  from tape import ProteinBertModel, TAPETokenizer
6
  from tokenizers import Tokenizer
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ BertModel,
10
+ BertTokenizer,
11
+ GPT2LMHeadModel,
12
+ GPT2TokenizerFast,
13
+ )
14
 
15
 
16
  class ModelType(str, Enum):
17
+ TAPE_BERT = "TapeBert"
18
  ZymCTRL = "ZymCTRL"
19
  PROT_BERT = "ProtBert"
20
 
 
29
  @st.cache
30
  def get_tape_bert() -> tuple[TAPETokenizer, ProteinBertModel]:
31
  tokenizer = TAPETokenizer()
32
+ model = ProteinBertModel.from_pretrained("bert-base", output_attentions=True)
33
  return tokenizer, model
34
 
35
+
36
  # Streamlit is not able to hash the tokenizer for ZymCTRL
37
  # With streamlit 1.19 cache_object should work without this
38
  @st.cache(hash_funcs={Tokenizer: lambda _: None})
39
  def get_zymctrl() -> tuple[GPT2TokenizerFast, GPT2LMHeadModel]:
40
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
41
+ tokenizer = AutoTokenizer.from_pretrained("nferruz/ZymCTRL")
42
+ model = GPT2LMHeadModel.from_pretrained("nferruz/ZymCTRL").to(device)
43
  return tokenizer, model
44
 
45
 
46
  @st.cache(hash_funcs={BertTokenizer: lambda _: None})
47
  def get_prot_bert() -> tuple[BertTokenizer, BertModel]:
48
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
49
+ tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
50
  model = BertModel.from_pretrained("Rostlab/prot_bert").to(device)
51
+ return tokenizer, model
hexviz/view.py CHANGED
@@ -6,17 +6,26 @@ from Bio.PDB import PDBParser
6
  from hexviz.attention import get_pdb_file, get_pdb_from_seq
7
 
8
  menu_items = {
9
- "Get Help": "https://huggingface.co/spaces/aksell/hexviz/discussions/new",
10
- "Report a bug": "https://huggingface.co/spaces/aksell/hexviz/discussions/new",
11
- "About": "Created by [Aksel Lenes](https://github.com/aksell/) from Noelia Ferruz's group at the Institute of Molecular Biology of Barcelona. Read more at https://www.aiproteindesign.com/"
12
- }
 
13
 
14
  def get_selecte_model_index(models):
15
  selected_model_name = st.session_state.get("selected_model_name", None)
16
  if selected_model_name is None:
17
  return 0
18
  else:
19
- return next((i for i, model in enumerate(models) if model.name.value == selected_model_name), None)
 
 
 
 
 
 
 
 
20
 
21
  def clear_model_state():
22
  if "plot_heads" in st.session_state:
@@ -32,13 +41,22 @@ def clear_model_state():
32
  if "plot_heads" in st.session_state:
33
  del st.session_state.plot_heads
34
 
 
35
  def select_model(models):
36
  if "selected_model_name" not in st.session_state:
37
  st.session_state.selected_model_name = models[0].name.value
38
- selected_model_name = st.selectbox("Select model", [model.name.value for model in models], key="selected_model_name", on_change=clear_model_state)
39
- select_model = next((model for model in models if model.name.value == selected_model_name), None)
 
 
 
 
 
 
 
40
  return select_model
41
 
 
42
  def clear_pdb_state():
43
  if "selected_chains" in st.session_state:
44
  del st.session_state.selected_chains
@@ -49,16 +67,14 @@ def clear_pdb_state():
49
  if "uploaded_pdb_str" in st.session_state:
50
  del st.session_state.uploaded_pdb_str
51
 
 
52
  def select_pdb():
53
  if "pdb_id" not in st.session_state:
54
  st.session_state.pdb_id = "2FZ5"
55
- pdb_id = st.text_input(
56
- label = "1.PDB ID",
57
- key = "pdb_id",
58
- on_change=clear_pdb_state
59
- )
60
  return pdb_id
61
 
 
62
  def select_protein(pdb_code, uploaded_file, input_sequence):
63
  # We get the pdb from 1 of 3 places:
64
  # 1. Cached pdb from session storage
@@ -85,6 +101,7 @@ def select_protein(pdb_code, uploaded_file, input_sequence):
85
  structure = parser.get_structure(pdb_code, StringIO(pdb_str))
86
  return pdb_str, structure, source
87
 
 
88
  def select_heads_and_layers(sidebar, model):
89
  sidebar.markdown(
90
  """
@@ -93,23 +110,35 @@ def select_heads_and_layers(sidebar, model):
93
  """
94
  )
95
  if "plot_heads" not in st.session_state:
96
- st.session_state.plot_heads = (1, model.heads//2)
97
- head_range = sidebar.slider("Heads to plot", min_value=1, max_value=model.heads, key="plot_heads", step=1)
 
 
98
  if "plot_layers" not in st.session_state:
99
- st.session_state.plot_layers = (1, model.layers//2)
100
- layer_range = sidebar.slider("Layers to plot", min_value=1, max_value=model.layers, key="plot_layers", step=1)
 
 
101
 
102
  if "plot_step_size" not in st.session_state:
103
  st.session_state.plot_step_size = 1
104
- step_size = sidebar.number_input("Optional step size to skip heads and layers", key="plot_step_size", min_value=1, max_value=model.layers)
105
- layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
106
- head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
 
 
 
 
 
107
 
108
  return layer_sequence, head_sequence
109
 
 
110
  def select_sequence_slice(sequence_length):
111
  st.sidebar.markdown("Sequence segment to plot")
112
  if "sequence_slice" not in st.session_state:
113
  st.session_state.sequence_slice = (1, min(50, sequence_length))
114
- slice = st.sidebar.slider("Sequence", key="sequence_slice", min_value=1, max_value=sequence_length, step=1)
115
- return slice
 
 
 
6
  from hexviz.attention import get_pdb_file, get_pdb_from_seq
7
 
8
  menu_items = {
9
+ "Get Help": "https://huggingface.co/spaces/aksell/hexviz/discussions/new",
10
+ "Report a bug": "https://huggingface.co/spaces/aksell/hexviz/discussions/new",
11
+ "About": "Created by [Aksel Lenes](https://github.com/aksell/) from Noelia Ferruz's group at the Institute of Molecular Biology of Barcelona. Read more at https://www.aiproteindesign.com/",
12
+ }
13
+
14
 
15
  def get_selecte_model_index(models):
16
  selected_model_name = st.session_state.get("selected_model_name", None)
17
  if selected_model_name is None:
18
  return 0
19
  else:
20
+ return next(
21
+ (
22
+ i
23
+ for i, model in enumerate(models)
24
+ if model.name.value == selected_model_name
25
+ ),
26
+ None,
27
+ )
28
+
29
 
30
  def clear_model_state():
31
  if "plot_heads" in st.session_state:
 
41
  if "plot_heads" in st.session_state:
42
  del st.session_state.plot_heads
43
 
44
+
45
  def select_model(models):
46
  if "selected_model_name" not in st.session_state:
47
  st.session_state.selected_model_name = models[0].name.value
48
+ selected_model_name = st.selectbox(
49
+ "Select model",
50
+ [model.name.value for model in models],
51
+ key="selected_model_name",
52
+ on_change=clear_model_state,
53
+ )
54
+ select_model = next(
55
+ (model for model in models if model.name.value == selected_model_name), None
56
+ )
57
  return select_model
58
 
59
+
60
  def clear_pdb_state():
61
  if "selected_chains" in st.session_state:
62
  del st.session_state.selected_chains
 
67
  if "uploaded_pdb_str" in st.session_state:
68
  del st.session_state.uploaded_pdb_str
69
 
70
+
71
  def select_pdb():
72
  if "pdb_id" not in st.session_state:
73
  st.session_state.pdb_id = "2FZ5"
74
+ pdb_id = st.text_input(label="1.PDB ID", key="pdb_id", on_change=clear_pdb_state)
 
 
 
 
75
  return pdb_id
76
 
77
+
78
  def select_protein(pdb_code, uploaded_file, input_sequence):
79
  # We get the pdb from 1 of 3 places:
80
  # 1. Cached pdb from session storage
 
101
  structure = parser.get_structure(pdb_code, StringIO(pdb_str))
102
  return pdb_str, structure, source
103
 
104
+
105
  def select_heads_and_layers(sidebar, model):
106
  sidebar.markdown(
107
  """
 
110
  """
111
  )
112
  if "plot_heads" not in st.session_state:
113
+ st.session_state.plot_heads = (1, model.heads // 2)
114
+ head_range = sidebar.slider(
115
+ "Heads to plot", min_value=1, max_value=model.heads, key="plot_heads", step=1
116
+ )
117
  if "plot_layers" not in st.session_state:
118
+ st.session_state.plot_layers = (1, model.layers // 2)
119
+ layer_range = sidebar.slider(
120
+ "Layers to plot", min_value=1, max_value=model.layers, key="plot_layers", step=1
121
+ )
122
 
123
  if "plot_step_size" not in st.session_state:
124
  st.session_state.plot_step_size = 1
125
+ step_size = sidebar.number_input(
126
+ "Optional step size to skip heads and layers",
127
+ key="plot_step_size",
128
+ min_value=1,
129
+ max_value=model.layers,
130
+ )
131
+ layer_sequence = list(range(layer_range[0] - 1, layer_range[1], step_size))
132
+ head_sequence = list(range(head_range[0] - 1, head_range[1], step_size))
133
 
134
  return layer_sequence, head_sequence
135
 
136
+
137
  def select_sequence_slice(sequence_length):
138
  st.sidebar.markdown("Sequence segment to plot")
139
  if "sequence_slice" not in st.session_state:
140
  st.session_state.sequence_slice = (1, min(50, sequence_length))
141
+ slice = st.sidebar.slider(
142
+ "Sequence", key="sequence_slice", min_value=1, max_value=sequence_length, step=1
143
+ )
144
+ return slice
hexviz/🧬Attention_Visualization.py CHANGED
@@ -4,8 +4,11 @@ import stmol
4
  import streamlit as st
5
  from stmol import showmol
6
 
7
- from hexviz.attention import (clean_and_validate_sequence, get_attention_pairs,
8
- get_chains)
 
 
 
9
  from hexviz.models import Model, ModelType
10
  from hexviz.view import menu_items, select_model, select_pdb, select_protein
11
 
@@ -21,10 +24,14 @@ models = [
21
  Model(name=ModelType.PROT_BERT, layers=30, heads=16),
22
  ]
23
 
24
- with st.expander("Input a PDB id, upload a PDB file or input a sequence", expanded=True):
 
 
25
  pdb_id = select_pdb()
26
  uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
27
- input_sequence = st.text_area("3.Input sequence", "", key="input_sequence", max_chars=400)
 
 
28
  sequence, error = clean_and_validate_sequence(input_sequence)
29
  if error:
30
  st.error(error)
@@ -35,14 +42,19 @@ st.sidebar.markdown(
35
  """
36
  Configure visualization
37
  ---
38
- """)
 
39
  chains = get_chains(structure)
40
 
41
  if "selected_chains" not in st.session_state:
42
  st.session_state.selected_chains = chains
43
- selected_chains = st.sidebar.multiselect(label="Select Chain(s)", options=chains, key="selected_chains")
 
 
44
 
45
- show_ligands = st.sidebar.checkbox("Show ligands", value=st.session_state.get("show_ligands", True))
 
 
46
  st.session_state.show_ligands = show_ligands
47
 
48
 
@@ -50,9 +62,14 @@ st.sidebar.markdown(
50
  """
51
  Attention parameters
52
  ---
53
- """)
54
- min_attn = st.sidebar.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
55
- n_highest_resis = st.sidebar.number_input("Num highest attention resis to label", value=2, min_value=1, max_value=100)
 
 
 
 
 
56
  label_highest = st.sidebar.checkbox("Label highest attention residues", value=True)
57
  sidechain_highest = st.sidebar.checkbox("Show sidechains", value=True)
58
  # TODO add avg or max attention as params
@@ -60,7 +77,9 @@ sidechain_highest = st.sidebar.checkbox("Show sidechains", value=True)
60
 
61
  with st.sidebar.expander("Label residues manually"):
62
  hl_chain = st.selectbox(label="Chain to label", options=selected_chains, index=0)
63
- hl_resi_list = st.multiselect(label="Selected Residues",options=list(range(1,5000)))
 
 
64
 
65
  label_resi = st.checkbox(label="Label Residues", value=True)
66
 
@@ -71,26 +90,46 @@ with left:
71
  with mid:
72
  if "selected_layer" not in st.session_state:
73
  st.session_state["selected_layer"] = 5
74
- layer_one = st.selectbox("Layer", options=[i for i in range(1, selected_model.layers+1)], key="selected_layer")
 
 
 
 
75
  layer = layer_one - 1
76
  with right:
77
  if "selected_head" not in st.session_state:
78
  st.session_state["selected_head"] = 1
79
- head_one = st.selectbox("Head", options=[i for i in range(1, selected_model.heads+1)], key="selected_head")
 
 
 
 
80
  head = head_one - 1
81
 
82
-
83
  if selected_model.name == ModelType.ZymCTRL:
84
  try:
85
  ec_class = structure.header["compound"]["1"]["ec"]
86
  except KeyError:
87
- ec_class = None
88
- if ec_class and selected_model.name == ModelType.ZymCTRL:
89
- ec_class = st.sidebar.text_input("Enzyme classification number fetched from PDB", ec_class)
 
90
 
91
- attention_pairs, top_residues = get_attention_pairs(pdb_str=pdb_str, chain_ids=selected_chains, layer=layer, head=head, threshold=min_attn, model_type=selected_model.name, top_n=n_highest_resis)
92
 
93
- sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  def get_3dview(pdb):
96
  xyzview = py3Dmol.view()
@@ -100,38 +139,61 @@ def get_3dview(pdb):
100
 
101
  # Show all ligands as stick (heteroatoms)
102
  if show_ligands:
103
- xyzview.addStyle({"hetflag": True},
104
- {"stick": {"radius": 0.2}})
105
 
106
  # If no chains are selected, show all chains
107
  if selected_chains:
108
  hidden_chains = [x for x in chains if x not in selected_chains]
109
  for chain in hidden_chains:
110
- xyzview.setStyle({"chain": chain},{"cross":{"hidden":"true"}})
111
  # Hide ligands for chain too
112
- xyzview.addStyle({"chain": chain, "hetflag": True},{"cross": {"hidden": "true"}})
 
 
113
 
114
  if len(selected_chains) == 1:
115
- xyzview.zoomTo({'chain': f'{selected_chains[0]}'})
116
  else:
117
  xyzview.zoomTo()
118
 
119
  for att_weight, first, second, _, _, _ in attention_pairs:
120
- stmol.add_cylinder(xyzview, start=first, end=second, cylradius=att_weight, cylColor='red', dashed=False)
 
 
 
 
 
 
 
121
 
122
  if label_resi:
123
  for hl_resi in hl_resi_list:
124
- xyzview.addResLabels({"chain": hl_chain,"resi": hl_resi},
125
- {"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
 
 
 
 
 
 
126
 
127
  if label_highest:
128
  for _, _, chain, res in top_residues:
129
- xyzview.addResLabels({"chain": chain, "resi": res},
130
- {"backgroundColor": "lightgray", "fontColor": "black", "backgroundOpacity": 0.5})
 
 
 
 
 
 
131
  if sidechain_highest:
132
- xyzview.addStyle({"chain": chain, "resi": res},{"stick": {"radius": 0.2}})
 
 
133
  return xyzview
134
 
 
135
  xyzview = get_3dview(pdb_id)
136
  showmol(xyzview, height=500, width=800)
137
 
 
4
  import streamlit as st
5
  from stmol import showmol
6
 
7
+ from hexviz.attention import (
8
+ clean_and_validate_sequence,
9
+ get_attention_pairs,
10
+ get_chains,
11
+ )
12
  from hexviz.models import Model, ModelType
13
  from hexviz.view import menu_items, select_model, select_pdb, select_protein
14
 
 
24
  Model(name=ModelType.PROT_BERT, layers=30, heads=16),
25
  ]
26
 
27
+ with st.expander(
28
+ "Input a PDB id, upload a PDB file or input a sequence", expanded=True
29
+ ):
30
  pdb_id = select_pdb()
31
  uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
32
+ input_sequence = st.text_area(
33
+ "3.Input sequence", "", key="input_sequence", max_chars=400
34
+ )
35
  sequence, error = clean_and_validate_sequence(input_sequence)
36
  if error:
37
  st.error(error)
 
42
  """
43
  Configure visualization
44
  ---
45
+ """
46
+ )
47
  chains = get_chains(structure)
48
 
49
  if "selected_chains" not in st.session_state:
50
  st.session_state.selected_chains = chains
51
+ selected_chains = st.sidebar.multiselect(
52
+ label="Select Chain(s)", options=chains, key="selected_chains"
53
+ )
54
 
55
+ show_ligands = st.sidebar.checkbox(
56
+ "Show ligands", value=st.session_state.get("show_ligands", True)
57
+ )
58
  st.session_state.show_ligands = show_ligands
59
 
60
 
 
62
  """
63
  Attention parameters
64
  ---
65
+ """
66
+ )
67
+ min_attn = st.sidebar.slider(
68
+ "Minimum attention", min_value=0.0, max_value=0.4, value=0.1
69
+ )
70
+ n_highest_resis = st.sidebar.number_input(
71
+ "Num highest attention resis to label", value=2, min_value=1, max_value=100
72
+ )
73
  label_highest = st.sidebar.checkbox("Label highest attention residues", value=True)
74
  sidechain_highest = st.sidebar.checkbox("Show sidechains", value=True)
75
  # TODO add avg or max attention as params
 
77
 
78
  with st.sidebar.expander("Label residues manually"):
79
  hl_chain = st.selectbox(label="Chain to label", options=selected_chains, index=0)
80
+ hl_resi_list = st.multiselect(
81
+ label="Selected Residues", options=list(range(1, 5000))
82
+ )
83
 
84
  label_resi = st.checkbox(label="Label Residues", value=True)
85
 
 
90
  with mid:
91
  if "selected_layer" not in st.session_state:
92
  st.session_state["selected_layer"] = 5
93
+ layer_one = st.selectbox(
94
+ "Layer",
95
+ options=[i for i in range(1, selected_model.layers + 1)],
96
+ key="selected_layer",
97
+ )
98
  layer = layer_one - 1
99
  with right:
100
  if "selected_head" not in st.session_state:
101
  st.session_state["selected_head"] = 1
102
+ head_one = st.selectbox(
103
+ "Head",
104
+ options=[i for i in range(1, selected_model.heads + 1)],
105
+ key="selected_head",
106
+ )
107
  head = head_one - 1
108
 
109
+ ec_class = ""
110
  if selected_model.name == ModelType.ZymCTRL:
111
  try:
112
  ec_class = structure.header["compound"]["1"]["ec"]
113
  except KeyError:
114
+ pass
115
+ ec_class = st.sidebar.text_input(
116
+ "Enzyme classification number fetched from PDB", ec_class
117
+ )
118
 
 
119
 
120
+ attention_pairs, top_residues = get_attention_pairs(
121
+ pdb_str=pdb_str,
122
+ chain_ids=selected_chains,
123
+ layer=layer,
124
+ head=head,
125
+ threshold=min_attn,
126
+ model_type=selected_model.name,
127
+ ec_class=ec_class,
128
+ top_n=n_highest_resis,
129
+ )
130
+
131
+ sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
132
+
133
 
134
  def get_3dview(pdb):
135
  xyzview = py3Dmol.view()
 
139
 
140
  # Show all ligands as stick (heteroatoms)
141
  if show_ligands:
142
+ xyzview.addStyle({"hetflag": True}, {"stick": {"radius": 0.2}})
 
143
 
144
  # If no chains are selected, show all chains
145
  if selected_chains:
146
  hidden_chains = [x for x in chains if x not in selected_chains]
147
  for chain in hidden_chains:
148
+ xyzview.setStyle({"chain": chain}, {"cross": {"hidden": "true"}})
149
  # Hide ligands for chain too
150
+ xyzview.addStyle(
151
+ {"chain": chain, "hetflag": True}, {"cross": {"hidden": "true"}}
152
+ )
153
 
154
  if len(selected_chains) == 1:
155
+ xyzview.zoomTo({"chain": f"{selected_chains[0]}"})
156
  else:
157
  xyzview.zoomTo()
158
 
159
  for att_weight, first, second, _, _, _ in attention_pairs:
160
+ stmol.add_cylinder(
161
+ xyzview,
162
+ start=first,
163
+ end=second,
164
+ cylradius=att_weight,
165
+ cylColor="red",
166
+ dashed=False,
167
+ )
168
 
169
  if label_resi:
170
  for hl_resi in hl_resi_list:
171
+ xyzview.addResLabels(
172
+ {"chain": hl_chain, "resi": hl_resi},
173
+ {
174
+ "backgroundColor": "lightgray",
175
+ "fontColor": "black",
176
+ "backgroundOpacity": 0.5,
177
+ },
178
+ )
179
 
180
  if label_highest:
181
  for _, _, chain, res in top_residues:
182
+ xyzview.addResLabels(
183
+ {"chain": chain, "resi": res},
184
+ {
185
+ "backgroundColor": "lightgray",
186
+ "fontColor": "black",
187
+ "backgroundOpacity": 0.5,
188
+ },
189
+ )
190
  if sidechain_highest:
191
+ xyzview.addStyle(
192
+ {"chain": chain, "resi": res}, {"stick": {"radius": 0.2}}
193
+ )
194
  return xyzview
195
 
196
+
197
  xyzview = get_3dview(pdb_id)
198
  showmol(xyzview, height=500, width=800)
199