Spaces:
Sleeping
Sleeping
Persist selected PDB, chain and model in session state
Browse filesSelected chain(s) are not shared between views.
Because one of the views allows multiselect of chains and the
other does not.
hexviz/pages/🗺️Identify_Interesting_Heads.py
CHANGED
@@ -3,6 +3,7 @@ import streamlit as st
|
|
3 |
from hexviz.attention import get_attention, get_sequence, get_structure
|
4 |
from hexviz.models import Model, ModelType
|
5 |
from hexviz.plot import plot_tiled_heatmap
|
|
|
6 |
|
7 |
st.set_page_config(layout="wide")
|
8 |
st.subheader("Find interesting heads and layers")
|
@@ -13,13 +14,15 @@ models = [
|
|
13 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
14 |
]
|
15 |
|
16 |
-
selected_model_name = st.
|
|
|
17 |
selected_model = next((model for model in models if model.name.value == selected_model_name), None)
|
18 |
|
19 |
pdb_id = st.sidebar.text_input(
|
20 |
label="PDB ID",
|
21 |
-
value="
|
22 |
)
|
|
|
23 |
|
24 |
|
25 |
structure = get_structure(pdb_id)
|
@@ -29,7 +32,9 @@ chain_ids = [chain.id for chain in chains]
|
|
29 |
chain_selection = st.sidebar.selectbox(
|
30 |
label="Select Chain",
|
31 |
options=chain_ids,
|
|
|
32 |
)
|
|
|
33 |
|
34 |
selected_chain = next(chain for chain in chains if chain.id == chain_selection)
|
35 |
sequence = get_sequence(selected_chain)
|
|
|
3 |
from hexviz.attention import get_attention, get_sequence, get_structure
|
4 |
from hexviz.models import Model, ModelType
|
5 |
from hexviz.plot import plot_tiled_heatmap
|
6 |
+
from hexviz.view import get_selecte_model_index
|
7 |
|
8 |
st.set_page_config(layout="wide")
|
9 |
st.subheader("Find interesting heads and layers")
|
|
|
14 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
15 |
]
|
16 |
|
17 |
+
selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=get_selecte_model_index(models))
|
18 |
+
st.session_state.selected_model_name = selected_model_name
|
19 |
selected_model = next((model for model in models if model.name.value == selected_model_name), None)
|
20 |
|
21 |
pdb_id = st.sidebar.text_input(
|
22 |
label="PDB ID",
|
23 |
+
value=st.session_state.get("pdb_id", "2FZ5"),
|
24 |
)
|
25 |
+
st.session_state.pdb_id = pdb_id
|
26 |
|
27 |
|
28 |
structure = get_structure(pdb_id)
|
|
|
32 |
chain_selection = st.sidebar.selectbox(
|
33 |
label="Select Chain",
|
34 |
options=chain_ids,
|
35 |
+
index=st.session_state.get("selected_chain_index", 0)
|
36 |
)
|
37 |
+
st.session_state.selected_chain_index = chain_ids.index(chain_selection)
|
38 |
|
39 |
selected_chain = next(chain for chain in chains if chain.id == chain_selection)
|
40 |
sequence = get_sequence(selected_chain)
|
hexviz/view.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
def get_selecte_model_index(models):
|
5 |
+
selected_model_name = st.session_state.get("selected_model_name", None)
|
6 |
+
if selected_model_name is None:
|
7 |
+
return 0
|
8 |
+
else:
|
9 |
+
return next((i for i, model in enumerate(models) if model.name.value == selected_model_name), None)
|
hexviz/🧬Attention_Visualization.py
CHANGED
@@ -6,6 +6,7 @@ from stmol import showmol
|
|
6 |
|
7 |
from hexviz.attention import get_attention_pairs, get_chains, get_structure
|
8 |
from hexviz.models import Model, ModelType
|
|
|
9 |
|
10 |
st.title("Attention Visualization on proteins")
|
11 |
|
@@ -22,11 +23,13 @@ st.sidebar.markdown(
|
|
22 |
""")
|
23 |
pdb_id = st.sidebar.text_input(
|
24 |
label="PDB ID",
|
25 |
-
value="2FZ5"
|
26 |
-
|
27 |
structure = get_structure(pdb_id)
|
28 |
chains = get_chains(structure)
|
29 |
-
|
|
|
|
|
30 |
|
31 |
|
32 |
st.sidebar.markdown(
|
@@ -56,7 +59,8 @@ label_highest = st.sidebar.checkbox("Label highest attention pairs", value=True)
|
|
56 |
|
57 |
left, mid, right = st.columns(3)
|
58 |
with left:
|
59 |
-
selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=
|
|
|
60 |
selected_model = next((model for model in models if model.name.value == selected_model_name), None)
|
61 |
with mid:
|
62 |
layer_one = st.number_input("Layer", value=5, min_value=1, max_value=selected_model.layers)
|
|
|
6 |
|
7 |
from hexviz.attention import get_attention_pairs, get_chains, get_structure
|
8 |
from hexviz.models import Model, ModelType
|
9 |
+
from hexviz.view import get_selecte_model_index
|
10 |
|
11 |
st.title("Attention Visualization on proteins")
|
12 |
|
|
|
23 |
""")
|
24 |
pdb_id = st.sidebar.text_input(
|
25 |
label="PDB ID",
|
26 |
+
value=st.session_state.get("pdb_id", "2FZ5"))
|
27 |
+
st.session_state.pdb_id = pdb_id
|
28 |
structure = get_structure(pdb_id)
|
29 |
chains = get_chains(structure)
|
30 |
+
|
31 |
+
selected_chains = st.sidebar.multiselect(label="Chain(s)", options=chains, default=st.session_state.get("selected_chains", None) or chains)
|
32 |
+
st.session_state.selected_chains = selected_chains
|
33 |
|
34 |
|
35 |
st.sidebar.markdown(
|
|
|
59 |
|
60 |
left, mid, right = st.columns(3)
|
61 |
with left:
|
62 |
+
selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=get_selecte_model_index(models))
|
63 |
+
st.session_state.selected_model_name = selected_model_name
|
64 |
selected_model = next((model for model in models if model.name.value == selected_model_name), None)
|
65 |
with mid:
|
66 |
layer_one = st.number_input("Layer", value=5, min_value=1, max_value=selected_model.layers)
|