aksell commited on
Commit
a2eb24d
·
1 Parent(s): 9af4b80

Add options to label, highlight and show ligand

Browse files
hexviz/pages/3_🏗️Test:_Birds_Eye_View.py CHANGED
@@ -80,13 +80,22 @@ if selected_model.name == ModelType.ZymCTRL:
80
  )
81
 
82
 
83
- residues = [res for res in selected_chain.get_residues()]
84
- sequence = res_to_1letter(residues)
85
-
86
-
 
 
 
 
 
 
87
  layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
 
88
 
89
- min_attn = st.sidebar.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
 
 
90
 
91
  attention, tokens = get_attention(
92
  sequence=sequence,
@@ -107,7 +116,6 @@ xyzview = py3Dmol.view(
107
  linked=False,
108
  viewergrid=(grid_rows, grid_cols),
109
  )
110
- xyzview.setStyle({"cartoon": {"color": "white"}})
111
 
112
 
113
  for row, layer in enumerate(layer_sequence):
@@ -140,4 +148,22 @@ for row, layer in enumerate(layer_sequence):
140
  viewer=(row, col),
141
  )
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  stmol.showmol(xyzview, height=viewer_height, width=viewer_width)
 
80
  )
81
 
82
 
83
+ min_attn = st.sidebar.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
84
+ if "show_ligands" not in st.session_state:
85
+ st.session_state.show_ligands = True
86
+ show_ligands = st.sidebar.checkbox("Show ligands", key="show_ligands")
87
+
88
+ with st.sidebar.expander("Highlight residues"):
89
+ st.write("Residue will be highlighted in yellow")
90
+ hl_resi_list = st.multiselect(label="Selected Residues", options=list(range(1, 5000)))
91
+ highlight_resi = st.checkbox(label="Highlight residues", value=True)
92
+ label_resi = st.checkbox(label="Label residue names", value=False)
93
  layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
94
+ # TODO add slider for widht of grid
95
 
96
+
97
+ residues = [res for res in selected_chain.get_residues()]
98
+ sequence = res_to_1letter(residues)
99
 
100
  attention, tokens = get_attention(
101
  sequence=sequence,
 
116
  linked=False,
117
  viewergrid=(grid_rows, grid_cols),
118
  )
 
119
 
120
 
121
  for row, layer in enumerate(layer_sequence):
 
148
  viewer=(row, col),
149
  )
150
 
151
+
152
+ xyzview.setStyle({"cartoon": {"color": "white"}})
153
+ if highlight_resi:
154
+ for res in hl_resi_list:
155
+ xyzview.setStyle({"resi": res}, {"cartoon": {"color": "yellow"}})
156
+ if label_resi:
157
+ for hl_resi in hl_resi_list:
158
+ xyzview.addResLabels(
159
+ {"resi": hl_resi},
160
+ {
161
+ "backgroundColor": "lightgray",
162
+ "fontColor": "black",
163
+ "backgroundOpacity": 0.5,
164
+ },
165
+ )
166
+ if show_ligands:
167
+ xyzview.addStyle({"hetflag": True}, {"stick": {"radius": 0.2}})
168
+
169
  stmol.showmol(xyzview, height=viewer_height, width=viewer_width)