aksell commited on
Commit
708426b
·
1 Parent(s): ae26fb8

Persist selected PDB, chain and model in session state

Browse files

Selected chain(s) are not shared between views.
Because one of the views allows multiselect of chains and the
other does not.

hexviz/pages/🗺️Identify_Interesting_Heads.py CHANGED
@@ -3,6 +3,7 @@ import streamlit as st
3
  from hexviz.attention import get_attention, get_sequence, get_structure
4
  from hexviz.models import Model, ModelType
5
  from hexviz.plot import plot_tiled_heatmap
 
6
 
7
  st.set_page_config(layout="wide")
8
  st.subheader("Find interesting heads and layers")
@@ -13,13 +14,15 @@ models = [
13
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
14
  ]
15
 
16
- selected_model_name = st.sidebar.selectbox("Select a model", [model.name.value for model in models], index=0)
 
17
  selected_model = next((model for model in models if model.name.value == selected_model_name), None)
18
 
19
  pdb_id = st.sidebar.text_input(
20
  label="PDB ID",
21
- value="1AKE",
22
  )
 
23
 
24
 
25
  structure = get_structure(pdb_id)
@@ -29,7 +32,9 @@ chain_ids = [chain.id for chain in chains]
29
  chain_selection = st.sidebar.selectbox(
30
  label="Select Chain",
31
  options=chain_ids,
 
32
  )
 
33
 
34
  selected_chain = next(chain for chain in chains if chain.id == chain_selection)
35
  sequence = get_sequence(selected_chain)
 
3
  from hexviz.attention import get_attention, get_sequence, get_structure
4
  from hexviz.models import Model, ModelType
5
  from hexviz.plot import plot_tiled_heatmap
6
+ from hexviz.view import get_selecte_model_index
7
 
8
  st.set_page_config(layout="wide")
9
  st.subheader("Find interesting heads and layers")
 
14
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
15
  ]
16
 
17
+ selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=get_selecte_model_index(models))
18
+ st.session_state.selected_model_name = selected_model_name
19
  selected_model = next((model for model in models if model.name.value == selected_model_name), None)
20
 
21
  pdb_id = st.sidebar.text_input(
22
  label="PDB ID",
23
+ value=st.session_state.get("pdb_id", "2FZ5"),
24
  )
25
+ st.session_state.pdb_id = pdb_id
26
 
27
 
28
  structure = get_structure(pdb_id)
 
32
  chain_selection = st.sidebar.selectbox(
33
  label="Select Chain",
34
  options=chain_ids,
35
+ index=st.session_state.get("selected_chain_index", 0)
36
  )
37
+ st.session_state.selected_chain_index = chain_ids.index(chain_selection)
38
 
39
  selected_chain = next(chain for chain in chains if chain.id == chain_selection)
40
  sequence = get_sequence(selected_chain)
hexviz/view.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def get_selecte_model_index(models):
5
+ selected_model_name = st.session_state.get("selected_model_name", None)
6
+ if selected_model_name is None:
7
+ return 0
8
+ else:
9
+ return next((i for i, model in enumerate(models) if model.name.value == selected_model_name), None)
hexviz/🧬Attention_Visualization.py CHANGED
@@ -6,6 +6,7 @@ from stmol import showmol
6
 
7
  from hexviz.attention import get_attention_pairs, get_chains, get_structure
8
  from hexviz.models import Model, ModelType
 
9
 
10
  st.title("Attention Visualization on proteins")
11
 
@@ -22,11 +23,13 @@ st.sidebar.markdown(
22
  """)
23
  pdb_id = st.sidebar.text_input(
24
  label="PDB ID",
25
- value="2FZ5",
26
- )
27
  structure = get_structure(pdb_id)
28
  chains = get_chains(structure)
29
- selected_chains = st.sidebar.multiselect(label="Chain(s)", options=chains, default=chains)
 
 
30
 
31
 
32
  st.sidebar.markdown(
@@ -56,7 +59,8 @@ label_highest = st.sidebar.checkbox("Label highest attention pairs", value=True)
56
 
57
  left, mid, right = st.columns(3)
58
  with left:
59
- selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
 
60
  selected_model = next((model for model in models if model.name.value == selected_model_name), None)
61
  with mid:
62
  layer_one = st.number_input("Layer", value=5, min_value=1, max_value=selected_model.layers)
 
6
 
7
  from hexviz.attention import get_attention_pairs, get_chains, get_structure
8
  from hexviz.models import Model, ModelType
9
+ from hexviz.view import get_selecte_model_index
10
 
11
  st.title("Attention Visualization on proteins")
12
 
 
23
  """)
24
  pdb_id = st.sidebar.text_input(
25
  label="PDB ID",
26
+ value=st.session_state.get("pdb_id", "2FZ5"))
27
+ st.session_state.pdb_id = pdb_id
28
  structure = get_structure(pdb_id)
29
  chains = get_chains(structure)
30
+
31
+ selected_chains = st.sidebar.multiselect(label="Chain(s)", options=chains, default=st.session_state.get("selected_chains", None) or chains)
32
+ st.session_state.selected_chains = selected_chains
33
 
34
 
35
  st.sidebar.markdown(
 
59
 
60
  left, mid, right = st.columns(3)
61
  with left:
62
+ selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=get_selecte_model_index(models))
63
+ st.session_state.selected_model_name = selected_model_name
64
  selected_model = next((model for model in models if model.name.value == selected_model_name), None)
65
  with mid:
66
  layer_one = st.number_input("Layer", value=5, min_value=1, max_value=selected_model.layers)