aksell commited on
Commit
ee057cd
·
1 Parent(s): 6fb3f0b

Add test view of grid of attention on structure

Browse files
hexviz/pages/3_🏗️Test:_Birds_Eye_View.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import py3Dmol
4
+ import stmol
5
+ import streamlit as st
6
+
7
+ from hexviz.attention import (
8
+ clean_and_validate_sequence,
9
+ get_attention,
10
+ get_attention_pairs,
11
+ res_to_1letter,
12
+ )
13
+ from hexviz.models import Model, ModelType
14
+ from hexviz.view import (
15
+ menu_items,
16
+ select_heads_and_layers,
17
+ select_model,
18
+ select_pdb,
19
+ select_protein,
20
+ )
21
+
22
+ st.set_page_config(layout="wide", menu_items=menu_items)
23
+ st.title("Test: Attention Bird's Eye View")
24
+
25
+
26
+ for k, v in st.session_state.items():
27
+ st.session_state[k] = v
28
+
29
+ models = [
30
+ Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
31
+ Model(name=ModelType.ZymCTRL, layers=36, heads=16),
32
+ Model(name=ModelType.PROT_BERT, layers=30, heads=16),
33
+ Model(name=ModelType.PROT_T5, layers=24, heads=32),
34
+ ]
35
+
36
+ with st.expander("Input a PDB id, upload a PDB file or input a sequence", expanded=True):
37
+ pdb_id = select_pdb() or "2FZ5"
38
+ uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
39
+ input_sequence = st.text_area("3.Input sequence", "", key="input_sequence", max_chars=400)
40
+ sequence, error = clean_and_validate_sequence(input_sequence)
41
+ if error:
42
+ st.error(error)
43
+ pdb_str, structure, source = select_protein(pdb_id, uploaded_file, sequence)
44
+ st.write(f"Visualizing: {source}")
45
+
46
+ selected_model = select_model(models)
47
+
48
+
49
+ chains = list(structure.get_chains())
50
+ chain_ids = [chain.id for chain in chains]
51
+ if "selected_chain" not in st.session_state:
52
+ st.session_state.selected_chain = chain_ids[0] if chain_ids else None
53
+ chain_selection = st.sidebar.selectbox(
54
+ label="Select Chain",
55
+ options=chain_ids,
56
+ key="selected_chain",
57
+ )
58
+
59
+ selected_chain = next(chain for chain in chains if chain.id == chain_selection)
60
+
61
+ ec_number = ""
62
+ if selected_model.name == ModelType.ZymCTRL:
63
+ st.sidebar.markdown(
64
+ """
65
+ ZymCTRL EC number
66
+ ---
67
+ """
68
+ )
69
+ try:
70
+ ec_number = structure.header["compound"]["1"]["ec"]
71
+ except KeyError:
72
+ pass
73
+ ec_number = st.sidebar.text_input("Enzyme Comission number (EC)", ec_number)
74
+
75
+ # Validate EC number
76
+ if not re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", ec_number):
77
+ st.sidebar.error(
78
+ """Please enter a valid Enzyme Commission number in the format of 4
79
+ integers separated by periods (e.g., 1.2.3.21)"""
80
+ )
81
+
82
+
83
+ residues = [res for res in selected_chain.get_residues()]
84
+ sequence = res_to_1letter(residues)
85
+
86
+
87
+ layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
88
+
89
+ min_attn = st.sidebar.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
90
+
91
+ attention, tokens = get_attention(
92
+ sequence=sequence,
93
+ model_type=selected_model.name,
94
+ ec_number=ec_number,
95
+ )
96
+
97
+ grid_rows = len(layer_sequence)
98
+ grid_cols = len(head_sequence)
99
+ viewer_width = 1300
100
+ cell_width = viewer_width / grid_cols
101
+ viewer_height = int(cell_width * grid_rows)
102
+
103
+ xyzview = py3Dmol.view(
104
+ width=viewer_width,
105
+ height=viewer_height,
106
+ query=f"pdb:{pdb_id}",
107
+ linked=False,
108
+ viewergrid=(grid_rows, grid_cols),
109
+ )
110
+ xyzview.setStyle({"cartoon": {"color": "white"}})
111
+
112
+
113
+ for row, layer in enumerate(layer_sequence):
114
+ for col, head in enumerate(head_sequence):
115
+ attention_pairs, top_residues = get_attention_pairs(
116
+ pdb_str=pdb_str,
117
+ chain_ids=None,
118
+ layer=layer,
119
+ head=head,
120
+ threshold=min_attn,
121
+ model_type=selected_model.name,
122
+ top_n=1,
123
+ ec_numbers=None,
124
+ )
125
+
126
+ for att_weight, first, second in attention_pairs:
127
+ cylradius = att_weight
128
+ cylColor = "red"
129
+ dashed = False
130
+ xyzview.addCylinder(
131
+ {
132
+ "start": {"x": first[0], "y": first[1], "z": first[2]},
133
+ "end": {"x": second[0], "y": second[1], "z": second[2]},
134
+ "radius": cylradius,
135
+ "fromCap": True,
136
+ "toCap": True,
137
+ "color": cylColor,
138
+ "dashed": dashed,
139
+ },
140
+ viewer=(row, col),
141
+ )
142
+
143
+ stmol.showmol(xyzview, height=viewer_height, width=viewer_width)