Spaces:
Sleeping
Sleeping
Format with Black
Browse files- hexviz/models.py +14 -8
- hexviz/view.py +50 -21
- hexviz/🧬Attention_Visualization.py +92 -30
hexviz/models.py
CHANGED
@@ -4,12 +4,17 @@ import streamlit as st
|
|
4 |
import torch
|
5 |
from tape import ProteinBertModel, TAPETokenizer
|
6 |
from tokenizers import Tokenizer
|
7 |
-
from transformers import (
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
class ModelType(str, Enum):
|
12 |
-
TAPE_BERT = "
|
13 |
ZymCTRL = "ZymCTRL"
|
14 |
PROT_BERT = "ProtBert"
|
15 |
|
@@ -24,22 +29,23 @@ class Model:
|
|
24 |
@st.cache
|
25 |
def get_tape_bert() -> tuple[TAPETokenizer, ProteinBertModel]:
|
26 |
tokenizer = TAPETokenizer()
|
27 |
-
model = ProteinBertModel.from_pretrained(
|
28 |
return tokenizer, model
|
29 |
|
|
|
30 |
# Streamlit is not able to hash the tokenizer for ZymCTRL
|
31 |
# With streamlit 1.19 cache_object should work without this
|
32 |
@st.cache(hash_funcs={Tokenizer: lambda _: None})
|
33 |
def get_zymctrl() -> tuple[GPT2TokenizerFast, GPT2LMHeadModel]:
|
34 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
35 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
36 |
-
model = GPT2LMHeadModel.from_pretrained(
|
37 |
return tokenizer, model
|
38 |
|
39 |
|
40 |
@st.cache(hash_funcs={BertTokenizer: lambda _: None})
|
41 |
def get_prot_bert() -> tuple[BertTokenizer, BertModel]:
|
42 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
43 |
-
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False
|
44 |
model = BertModel.from_pretrained("Rostlab/prot_bert").to(device)
|
45 |
-
return tokenizer, model
|
|
|
4 |
import torch
|
5 |
from tape import ProteinBertModel, TAPETokenizer
|
6 |
from tokenizers import Tokenizer
|
7 |
+
from transformers import (
|
8 |
+
AutoTokenizer,
|
9 |
+
BertModel,
|
10 |
+
BertTokenizer,
|
11 |
+
GPT2LMHeadModel,
|
12 |
+
GPT2TokenizerFast,
|
13 |
+
)
|
14 |
|
15 |
|
16 |
class ModelType(str, Enum):
|
17 |
+
TAPE_BERT = "TapeBert"
|
18 |
ZymCTRL = "ZymCTRL"
|
19 |
PROT_BERT = "ProtBert"
|
20 |
|
|
|
29 |
@st.cache
|
30 |
def get_tape_bert() -> tuple[TAPETokenizer, ProteinBertModel]:
|
31 |
tokenizer = TAPETokenizer()
|
32 |
+
model = ProteinBertModel.from_pretrained("bert-base", output_attentions=True)
|
33 |
return tokenizer, model
|
34 |
|
35 |
+
|
36 |
# Streamlit is not able to hash the tokenizer for ZymCTRL
|
37 |
# With streamlit 1.19 cache_object should work without this
|
38 |
@st.cache(hash_funcs={Tokenizer: lambda _: None})
|
39 |
def get_zymctrl() -> tuple[GPT2TokenizerFast, GPT2LMHeadModel]:
|
40 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
41 |
+
tokenizer = AutoTokenizer.from_pretrained("nferruz/ZymCTRL")
|
42 |
+
model = GPT2LMHeadModel.from_pretrained("nferruz/ZymCTRL").to(device)
|
43 |
return tokenizer, model
|
44 |
|
45 |
|
46 |
@st.cache(hash_funcs={BertTokenizer: lambda _: None})
|
47 |
def get_prot_bert() -> tuple[BertTokenizer, BertModel]:
|
48 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
49 |
+
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
|
50 |
model = BertModel.from_pretrained("Rostlab/prot_bert").to(device)
|
51 |
+
return tokenizer, model
|
hexviz/view.py
CHANGED
@@ -6,17 +6,26 @@ from Bio.PDB import PDBParser
|
|
6 |
from hexviz.attention import get_pdb_file, get_pdb_from_seq
|
7 |
|
8 |
menu_items = {
|
9 |
-
"Get Help": "https://huggingface.co/spaces/aksell/hexviz/discussions/new",
|
10 |
-
"Report a bug": "https://huggingface.co/spaces/aksell/hexviz/discussions/new",
|
11 |
-
"About": "Created by [Aksel Lenes](https://github.com/aksell/) from Noelia Ferruz's group at the Institute of Molecular Biology of Barcelona. Read more at https://www.aiproteindesign.com/"
|
12 |
-
|
|
|
13 |
|
14 |
def get_selecte_model_index(models):
|
15 |
selected_model_name = st.session_state.get("selected_model_name", None)
|
16 |
if selected_model_name is None:
|
17 |
return 0
|
18 |
else:
|
19 |
-
return next(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
def clear_model_state():
|
22 |
if "plot_heads" in st.session_state:
|
@@ -32,13 +41,22 @@ def clear_model_state():
|
|
32 |
if "plot_heads" in st.session_state:
|
33 |
del st.session_state.plot_heads
|
34 |
|
|
|
35 |
def select_model(models):
|
36 |
if "selected_model_name" not in st.session_state:
|
37 |
st.session_state.selected_model_name = models[0].name.value
|
38 |
-
selected_model_name = st.selectbox(
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
return select_model
|
41 |
|
|
|
42 |
def clear_pdb_state():
|
43 |
if "selected_chains" in st.session_state:
|
44 |
del st.session_state.selected_chains
|
@@ -49,16 +67,14 @@ def clear_pdb_state():
|
|
49 |
if "uploaded_pdb_str" in st.session_state:
|
50 |
del st.session_state.uploaded_pdb_str
|
51 |
|
|
|
52 |
def select_pdb():
|
53 |
if "pdb_id" not in st.session_state:
|
54 |
st.session_state.pdb_id = "2FZ5"
|
55 |
-
pdb_id = st.text_input(
|
56 |
-
label = "1.PDB ID",
|
57 |
-
key = "pdb_id",
|
58 |
-
on_change=clear_pdb_state
|
59 |
-
)
|
60 |
return pdb_id
|
61 |
|
|
|
62 |
def select_protein(pdb_code, uploaded_file, input_sequence):
|
63 |
# We get the pdb from 1 of 3 places:
|
64 |
# 1. Cached pdb from session storage
|
@@ -85,6 +101,7 @@ def select_protein(pdb_code, uploaded_file, input_sequence):
|
|
85 |
structure = parser.get_structure(pdb_code, StringIO(pdb_str))
|
86 |
return pdb_str, structure, source
|
87 |
|
|
|
88 |
def select_heads_and_layers(sidebar, model):
|
89 |
sidebar.markdown(
|
90 |
"""
|
@@ -93,23 +110,35 @@ def select_heads_and_layers(sidebar, model):
|
|
93 |
"""
|
94 |
)
|
95 |
if "plot_heads" not in st.session_state:
|
96 |
-
st.session_state.plot_heads = (1, model.heads//2)
|
97 |
-
head_range = sidebar.slider(
|
|
|
|
|
98 |
if "plot_layers" not in st.session_state:
|
99 |
-
st.session_state.plot_layers = (1, model.layers//2)
|
100 |
-
layer_range = sidebar.slider(
|
|
|
|
|
101 |
|
102 |
if "plot_step_size" not in st.session_state:
|
103 |
st.session_state.plot_step_size = 1
|
104 |
-
step_size = sidebar.number_input(
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
return layer_sequence, head_sequence
|
109 |
|
|
|
110 |
def select_sequence_slice(sequence_length):
|
111 |
st.sidebar.markdown("Sequence segment to plot")
|
112 |
if "sequence_slice" not in st.session_state:
|
113 |
st.session_state.sequence_slice = (1, min(50, sequence_length))
|
114 |
-
slice = st.sidebar.slider(
|
115 |
-
|
|
|
|
|
|
6 |
from hexviz.attention import get_pdb_file, get_pdb_from_seq
|
7 |
|
8 |
menu_items = {
|
9 |
+
"Get Help": "https://huggingface.co/spaces/aksell/hexviz/discussions/new",
|
10 |
+
"Report a bug": "https://huggingface.co/spaces/aksell/hexviz/discussions/new",
|
11 |
+
"About": "Created by [Aksel Lenes](https://github.com/aksell/) from Noelia Ferruz's group at the Institute of Molecular Biology of Barcelona. Read more at https://www.aiproteindesign.com/",
|
12 |
+
}
|
13 |
+
|
14 |
|
15 |
def get_selecte_model_index(models):
|
16 |
selected_model_name = st.session_state.get("selected_model_name", None)
|
17 |
if selected_model_name is None:
|
18 |
return 0
|
19 |
else:
|
20 |
+
return next(
|
21 |
+
(
|
22 |
+
i
|
23 |
+
for i, model in enumerate(models)
|
24 |
+
if model.name.value == selected_model_name
|
25 |
+
),
|
26 |
+
None,
|
27 |
+
)
|
28 |
+
|
29 |
|
30 |
def clear_model_state():
|
31 |
if "plot_heads" in st.session_state:
|
|
|
41 |
if "plot_heads" in st.session_state:
|
42 |
del st.session_state.plot_heads
|
43 |
|
44 |
+
|
45 |
def select_model(models):
|
46 |
if "selected_model_name" not in st.session_state:
|
47 |
st.session_state.selected_model_name = models[0].name.value
|
48 |
+
selected_model_name = st.selectbox(
|
49 |
+
"Select model",
|
50 |
+
[model.name.value for model in models],
|
51 |
+
key="selected_model_name",
|
52 |
+
on_change=clear_model_state,
|
53 |
+
)
|
54 |
+
select_model = next(
|
55 |
+
(model for model in models if model.name.value == selected_model_name), None
|
56 |
+
)
|
57 |
return select_model
|
58 |
|
59 |
+
|
60 |
def clear_pdb_state():
|
61 |
if "selected_chains" in st.session_state:
|
62 |
del st.session_state.selected_chains
|
|
|
67 |
if "uploaded_pdb_str" in st.session_state:
|
68 |
del st.session_state.uploaded_pdb_str
|
69 |
|
70 |
+
|
71 |
def select_pdb():
|
72 |
if "pdb_id" not in st.session_state:
|
73 |
st.session_state.pdb_id = "2FZ5"
|
74 |
+
pdb_id = st.text_input(label="1.PDB ID", key="pdb_id", on_change=clear_pdb_state)
|
|
|
|
|
|
|
|
|
75 |
return pdb_id
|
76 |
|
77 |
+
|
78 |
def select_protein(pdb_code, uploaded_file, input_sequence):
|
79 |
# We get the pdb from 1 of 3 places:
|
80 |
# 1. Cached pdb from session storage
|
|
|
101 |
structure = parser.get_structure(pdb_code, StringIO(pdb_str))
|
102 |
return pdb_str, structure, source
|
103 |
|
104 |
+
|
105 |
def select_heads_and_layers(sidebar, model):
|
106 |
sidebar.markdown(
|
107 |
"""
|
|
|
110 |
"""
|
111 |
)
|
112 |
if "plot_heads" not in st.session_state:
|
113 |
+
st.session_state.plot_heads = (1, model.heads // 2)
|
114 |
+
head_range = sidebar.slider(
|
115 |
+
"Heads to plot", min_value=1, max_value=model.heads, key="plot_heads", step=1
|
116 |
+
)
|
117 |
if "plot_layers" not in st.session_state:
|
118 |
+
st.session_state.plot_layers = (1, model.layers // 2)
|
119 |
+
layer_range = sidebar.slider(
|
120 |
+
"Layers to plot", min_value=1, max_value=model.layers, key="plot_layers", step=1
|
121 |
+
)
|
122 |
|
123 |
if "plot_step_size" not in st.session_state:
|
124 |
st.session_state.plot_step_size = 1
|
125 |
+
step_size = sidebar.number_input(
|
126 |
+
"Optional step size to skip heads and layers",
|
127 |
+
key="plot_step_size",
|
128 |
+
min_value=1,
|
129 |
+
max_value=model.layers,
|
130 |
+
)
|
131 |
+
layer_sequence = list(range(layer_range[0] - 1, layer_range[1], step_size))
|
132 |
+
head_sequence = list(range(head_range[0] - 1, head_range[1], step_size))
|
133 |
|
134 |
return layer_sequence, head_sequence
|
135 |
|
136 |
+
|
137 |
def select_sequence_slice(sequence_length):
|
138 |
st.sidebar.markdown("Sequence segment to plot")
|
139 |
if "sequence_slice" not in st.session_state:
|
140 |
st.session_state.sequence_slice = (1, min(50, sequence_length))
|
141 |
+
slice = st.sidebar.slider(
|
142 |
+
"Sequence", key="sequence_slice", min_value=1, max_value=sequence_length, step=1
|
143 |
+
)
|
144 |
+
return slice
|
hexviz/🧬Attention_Visualization.py
CHANGED
@@ -4,8 +4,11 @@ import stmol
|
|
4 |
import streamlit as st
|
5 |
from stmol import showmol
|
6 |
|
7 |
-
from hexviz.attention import (
|
8 |
-
|
|
|
|
|
|
|
9 |
from hexviz.models import Model, ModelType
|
10 |
from hexviz.view import menu_items, select_model, select_pdb, select_protein
|
11 |
|
@@ -21,10 +24,14 @@ models = [
|
|
21 |
Model(name=ModelType.PROT_BERT, layers=30, heads=16),
|
22 |
]
|
23 |
|
24 |
-
with st.expander(
|
|
|
|
|
25 |
pdb_id = select_pdb()
|
26 |
uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
|
27 |
-
input_sequence = st.text_area(
|
|
|
|
|
28 |
sequence, error = clean_and_validate_sequence(input_sequence)
|
29 |
if error:
|
30 |
st.error(error)
|
@@ -35,14 +42,19 @@ st.sidebar.markdown(
|
|
35 |
"""
|
36 |
Configure visualization
|
37 |
---
|
38 |
-
"""
|
|
|
39 |
chains = get_chains(structure)
|
40 |
|
41 |
if "selected_chains" not in st.session_state:
|
42 |
st.session_state.selected_chains = chains
|
43 |
-
selected_chains = st.sidebar.multiselect(
|
|
|
|
|
44 |
|
45 |
-
show_ligands = st.sidebar.checkbox(
|
|
|
|
|
46 |
st.session_state.show_ligands = show_ligands
|
47 |
|
48 |
|
@@ -50,9 +62,14 @@ st.sidebar.markdown(
|
|
50 |
"""
|
51 |
Attention parameters
|
52 |
---
|
53 |
-
"""
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
label_highest = st.sidebar.checkbox("Label highest attention residues", value=True)
|
57 |
sidechain_highest = st.sidebar.checkbox("Show sidechains", value=True)
|
58 |
# TODO add avg or max attention as params
|
@@ -60,7 +77,9 @@ sidechain_highest = st.sidebar.checkbox("Show sidechains", value=True)
|
|
60 |
|
61 |
with st.sidebar.expander("Label residues manually"):
|
62 |
hl_chain = st.selectbox(label="Chain to label", options=selected_chains, index=0)
|
63 |
-
hl_resi_list = st.multiselect(
|
|
|
|
|
64 |
|
65 |
label_resi = st.checkbox(label="Label Residues", value=True)
|
66 |
|
@@ -71,26 +90,46 @@ with left:
|
|
71 |
with mid:
|
72 |
if "selected_layer" not in st.session_state:
|
73 |
st.session_state["selected_layer"] = 5
|
74 |
-
layer_one = st.selectbox(
|
|
|
|
|
|
|
|
|
75 |
layer = layer_one - 1
|
76 |
with right:
|
77 |
if "selected_head" not in st.session_state:
|
78 |
st.session_state["selected_head"] = 1
|
79 |
-
head_one = st.selectbox(
|
|
|
|
|
|
|
|
|
80 |
head = head_one - 1
|
81 |
|
82 |
-
|
83 |
if selected_model.name == ModelType.ZymCTRL:
|
84 |
try:
|
85 |
ec_class = structure.header["compound"]["1"]["ec"]
|
86 |
except KeyError:
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
90 |
|
91 |
-
attention_pairs, top_residues = get_attention_pairs(pdb_str=pdb_str, chain_ids=selected_chains, layer=layer, head=head, threshold=min_attn, model_type=selected_model.name, top_n=n_highest_resis)
|
92 |
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
def get_3dview(pdb):
|
96 |
xyzview = py3Dmol.view()
|
@@ -100,38 +139,61 @@ def get_3dview(pdb):
|
|
100 |
|
101 |
# Show all ligands as stick (heteroatoms)
|
102 |
if show_ligands:
|
103 |
-
xyzview.addStyle({"hetflag": True},
|
104 |
-
{"stick": {"radius": 0.2}})
|
105 |
|
106 |
# If no chains are selected, show all chains
|
107 |
if selected_chains:
|
108 |
hidden_chains = [x for x in chains if x not in selected_chains]
|
109 |
for chain in hidden_chains:
|
110 |
-
xyzview.setStyle({"chain": chain},{"cross":{"hidden":"true"}})
|
111 |
# Hide ligands for chain too
|
112 |
-
xyzview.addStyle(
|
|
|
|
|
113 |
|
114 |
if len(selected_chains) == 1:
|
115 |
-
xyzview.zoomTo({
|
116 |
else:
|
117 |
xyzview.zoomTo()
|
118 |
|
119 |
for att_weight, first, second, _, _, _ in attention_pairs:
|
120 |
-
stmol.add_cylinder(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
if label_resi:
|
123 |
for hl_resi in hl_resi_list:
|
124 |
-
xyzview.addResLabels(
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
if label_highest:
|
128 |
for _, _, chain, res in top_residues:
|
129 |
-
xyzview.addResLabels(
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
if sidechain_highest:
|
132 |
-
xyzview.addStyle(
|
|
|
|
|
133 |
return xyzview
|
134 |
|
|
|
135 |
xyzview = get_3dview(pdb_id)
|
136 |
showmol(xyzview, height=500, width=800)
|
137 |
|
|
|
4 |
import streamlit as st
|
5 |
from stmol import showmol
|
6 |
|
7 |
+
from hexviz.attention import (
|
8 |
+
clean_and_validate_sequence,
|
9 |
+
get_attention_pairs,
|
10 |
+
get_chains,
|
11 |
+
)
|
12 |
from hexviz.models import Model, ModelType
|
13 |
from hexviz.view import menu_items, select_model, select_pdb, select_protein
|
14 |
|
|
|
24 |
Model(name=ModelType.PROT_BERT, layers=30, heads=16),
|
25 |
]
|
26 |
|
27 |
+
with st.expander(
|
28 |
+
"Input a PDB id, upload a PDB file or input a sequence", expanded=True
|
29 |
+
):
|
30 |
pdb_id = select_pdb()
|
31 |
uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
|
32 |
+
input_sequence = st.text_area(
|
33 |
+
"3.Input sequence", "", key="input_sequence", max_chars=400
|
34 |
+
)
|
35 |
sequence, error = clean_and_validate_sequence(input_sequence)
|
36 |
if error:
|
37 |
st.error(error)
|
|
|
42 |
"""
|
43 |
Configure visualization
|
44 |
---
|
45 |
+
"""
|
46 |
+
)
|
47 |
chains = get_chains(structure)
|
48 |
|
49 |
if "selected_chains" not in st.session_state:
|
50 |
st.session_state.selected_chains = chains
|
51 |
+
selected_chains = st.sidebar.multiselect(
|
52 |
+
label="Select Chain(s)", options=chains, key="selected_chains"
|
53 |
+
)
|
54 |
|
55 |
+
show_ligands = st.sidebar.checkbox(
|
56 |
+
"Show ligands", value=st.session_state.get("show_ligands", True)
|
57 |
+
)
|
58 |
st.session_state.show_ligands = show_ligands
|
59 |
|
60 |
|
|
|
62 |
"""
|
63 |
Attention parameters
|
64 |
---
|
65 |
+
"""
|
66 |
+
)
|
67 |
+
min_attn = st.sidebar.slider(
|
68 |
+
"Minimum attention", min_value=0.0, max_value=0.4, value=0.1
|
69 |
+
)
|
70 |
+
n_highest_resis = st.sidebar.number_input(
|
71 |
+
"Num highest attention resis to label", value=2, min_value=1, max_value=100
|
72 |
+
)
|
73 |
label_highest = st.sidebar.checkbox("Label highest attention residues", value=True)
|
74 |
sidechain_highest = st.sidebar.checkbox("Show sidechains", value=True)
|
75 |
# TODO add avg or max attention as params
|
|
|
77 |
|
78 |
with st.sidebar.expander("Label residues manually"):
|
79 |
hl_chain = st.selectbox(label="Chain to label", options=selected_chains, index=0)
|
80 |
+
hl_resi_list = st.multiselect(
|
81 |
+
label="Selected Residues", options=list(range(1, 5000))
|
82 |
+
)
|
83 |
|
84 |
label_resi = st.checkbox(label="Label Residues", value=True)
|
85 |
|
|
|
90 |
with mid:
|
91 |
if "selected_layer" not in st.session_state:
|
92 |
st.session_state["selected_layer"] = 5
|
93 |
+
layer_one = st.selectbox(
|
94 |
+
"Layer",
|
95 |
+
options=[i for i in range(1, selected_model.layers + 1)],
|
96 |
+
key="selected_layer",
|
97 |
+
)
|
98 |
layer = layer_one - 1
|
99 |
with right:
|
100 |
if "selected_head" not in st.session_state:
|
101 |
st.session_state["selected_head"] = 1
|
102 |
+
head_one = st.selectbox(
|
103 |
+
"Head",
|
104 |
+
options=[i for i in range(1, selected_model.heads + 1)],
|
105 |
+
key="selected_head",
|
106 |
+
)
|
107 |
head = head_one - 1
|
108 |
|
109 |
+
ec_class = ""
|
110 |
if selected_model.name == ModelType.ZymCTRL:
|
111 |
try:
|
112 |
ec_class = structure.header["compound"]["1"]["ec"]
|
113 |
except KeyError:
|
114 |
+
pass
|
115 |
+
ec_class = st.sidebar.text_input(
|
116 |
+
"Enzyme classification number fetched from PDB", ec_class
|
117 |
+
)
|
118 |
|
|
|
119 |
|
120 |
+
attention_pairs, top_residues = get_attention_pairs(
|
121 |
+
pdb_str=pdb_str,
|
122 |
+
chain_ids=selected_chains,
|
123 |
+
layer=layer,
|
124 |
+
head=head,
|
125 |
+
threshold=min_attn,
|
126 |
+
model_type=selected_model.name,
|
127 |
+
ec_class=ec_class,
|
128 |
+
top_n=n_highest_resis,
|
129 |
+
)
|
130 |
+
|
131 |
+
sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
|
132 |
+
|
133 |
|
134 |
def get_3dview(pdb):
|
135 |
xyzview = py3Dmol.view()
|
|
|
139 |
|
140 |
# Show all ligands as stick (heteroatoms)
|
141 |
if show_ligands:
|
142 |
+
xyzview.addStyle({"hetflag": True}, {"stick": {"radius": 0.2}})
|
|
|
143 |
|
144 |
# If no chains are selected, show all chains
|
145 |
if selected_chains:
|
146 |
hidden_chains = [x for x in chains if x not in selected_chains]
|
147 |
for chain in hidden_chains:
|
148 |
+
xyzview.setStyle({"chain": chain}, {"cross": {"hidden": "true"}})
|
149 |
# Hide ligands for chain too
|
150 |
+
xyzview.addStyle(
|
151 |
+
{"chain": chain, "hetflag": True}, {"cross": {"hidden": "true"}}
|
152 |
+
)
|
153 |
|
154 |
if len(selected_chains) == 1:
|
155 |
+
xyzview.zoomTo({"chain": f"{selected_chains[0]}"})
|
156 |
else:
|
157 |
xyzview.zoomTo()
|
158 |
|
159 |
for att_weight, first, second, _, _, _ in attention_pairs:
|
160 |
+
stmol.add_cylinder(
|
161 |
+
xyzview,
|
162 |
+
start=first,
|
163 |
+
end=second,
|
164 |
+
cylradius=att_weight,
|
165 |
+
cylColor="red",
|
166 |
+
dashed=False,
|
167 |
+
)
|
168 |
|
169 |
if label_resi:
|
170 |
for hl_resi in hl_resi_list:
|
171 |
+
xyzview.addResLabels(
|
172 |
+
{"chain": hl_chain, "resi": hl_resi},
|
173 |
+
{
|
174 |
+
"backgroundColor": "lightgray",
|
175 |
+
"fontColor": "black",
|
176 |
+
"backgroundOpacity": 0.5,
|
177 |
+
},
|
178 |
+
)
|
179 |
|
180 |
if label_highest:
|
181 |
for _, _, chain, res in top_residues:
|
182 |
+
xyzview.addResLabels(
|
183 |
+
{"chain": chain, "resi": res},
|
184 |
+
{
|
185 |
+
"backgroundColor": "lightgray",
|
186 |
+
"fontColor": "black",
|
187 |
+
"backgroundOpacity": 0.5,
|
188 |
+
},
|
189 |
+
)
|
190 |
if sidechain_highest:
|
191 |
+
xyzview.addStyle(
|
192 |
+
{"chain": chain, "resi": res}, {"stick": {"radius": 0.2}}
|
193 |
+
)
|
194 |
return xyzview
|
195 |
|
196 |
+
|
197 |
xyzview = get_3dview(pdb_id)
|
198 |
showmol(xyzview, height=500, width=800)
|
199 |
|