aksell commited on
Commit
c2b2898
·
1 Parent(s): 708426b

Share pdb selector between views

Browse files
hexviz/pages/🗺️Identify_Interesting_Heads.py CHANGED
@@ -3,7 +3,7 @@ import streamlit as st
3
  from hexviz.attention import get_attention, get_sequence, get_structure
4
  from hexviz.models import Model, ModelType
5
  from hexviz.plot import plot_tiled_heatmap
6
- from hexviz.view import get_selecte_model_index
7
 
8
  st.set_page_config(layout="wide")
9
  st.subheader("Find interesting heads and layers")
@@ -14,16 +14,9 @@ models = [
14
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
15
  ]
16
 
17
- selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=get_selecte_model_index(models))
18
- st.session_state.selected_model_name = selected_model_name
19
- selected_model = next((model for model in models if model.name.value == selected_model_name), None)
20
-
21
- pdb_id = st.sidebar.text_input(
22
- label="PDB ID",
23
- value=st.session_state.get("pdb_id", "2FZ5"),
24
- )
25
- st.session_state.pdb_id = pdb_id
26
 
 
27
 
28
  structure = get_structure(pdb_id)
29
 
 
3
  from hexviz.attention import get_attention, get_sequence, get_structure
4
  from hexviz.models import Model, ModelType
5
  from hexviz.plot import plot_tiled_heatmap
6
+ from hexviz.view import select_model, select_pdb
7
 
8
  st.set_page_config(layout="wide")
9
  st.subheader("Find interesting heads and layers")
 
14
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
15
  ]
16
 
17
+ selected_model = select_model(models)
 
 
 
 
 
 
 
 
18
 
19
+ pdb_id = select_pdb()
20
 
21
  structure = get_structure(pdb_id)
22
 
hexviz/view.py CHANGED
@@ -6,4 +6,27 @@ def get_selecte_model_index(models):
6
  if selected_model_name is None:
7
  return 0
8
  else:
9
- return next((i for i, model in enumerate(models) if model.name.value == selected_model_name), None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  if selected_model_name is None:
7
  return 0
8
  else:
9
+ return next((i for i, model in enumerate(models) if model.name.value == selected_model_name), None)
10
+
11
+ def select_model(models):
12
+ """
13
+ Select model, prefil selector with selected model from session storage
14
+
15
+ Saves the selected model in session storage.
16
+ """
17
+ selected_model_name = st.selectbox("Select model", [model.name.value for model in models], index=get_selecte_model_index(models))
18
+ st.session_state.selected_model_name = selected_model_name
19
+ select_model = next((model for model in models if model.name.value == selected_model_name), None)
20
+ return select_model
21
+
22
+ def select_pdb():
23
+ st.sidebar.markdown(
24
+ """
25
+ Select Protein
26
+ ---
27
+ """)
28
+ pdb_id = st.sidebar.text_input(
29
+ label="PDB ID",
30
+ value=st.session_state.get("pdb_id", "2FZ5"))
31
+ st.session_state.pdb_id = pdb_id
32
+ return pdb_id
hexviz/🧬Attention_Visualization.py CHANGED
@@ -6,7 +6,7 @@ from stmol import showmol
6
 
7
  from hexviz.attention import get_attention_pairs, get_chains, get_structure
8
  from hexviz.models import Model, ModelType
9
- from hexviz.view import get_selecte_model_index
10
 
11
  st.title("Attention Visualization on proteins")
12
 
@@ -16,19 +16,11 @@ models = [
16
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
17
  ]
18
 
19
- st.sidebar.markdown(
20
- """
21
- Select Protein
22
- ---
23
- """)
24
- pdb_id = st.sidebar.text_input(
25
- label="PDB ID",
26
- value=st.session_state.get("pdb_id", "2FZ5"))
27
- st.session_state.pdb_id = pdb_id
28
  structure = get_structure(pdb_id)
29
  chains = get_chains(structure)
30
 
31
- selected_chains = st.sidebar.multiselect(label="Chain(s)", options=chains, default=st.session_state.get("selected_chains", None) or chains)
32
  st.session_state.selected_chains = selected_chains
33
 
34
 
@@ -59,9 +51,7 @@ label_highest = st.sidebar.checkbox("Label highest attention pairs", value=True)
59
 
60
  left, mid, right = st.columns(3)
61
  with left:
62
- selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=get_selecte_model_index(models))
63
- st.session_state.selected_model_name = selected_model_name
64
- selected_model = next((model for model in models if model.name.value == selected_model_name), None)
65
  with mid:
66
  layer_one = st.number_input("Layer", value=5, min_value=1, max_value=selected_model.layers)
67
  layer = layer_one - 1
 
6
 
7
  from hexviz.attention import get_attention_pairs, get_chains, get_structure
8
  from hexviz.models import Model, ModelType
9
+ from hexviz.view import select_model, select_pdb
10
 
11
  st.title("Attention Visualization on proteins")
12
 
 
16
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
17
  ]
18
 
19
+ pdb_id = select_pdb()
 
 
 
 
 
 
 
 
20
  structure = get_structure(pdb_id)
21
  chains = get_chains(structure)
22
 
23
+ selected_chains = st.sidebar.multiselect(label="Select Chain(s)", options=chains, default=st.session_state.get("selected_chains", None) or chains)
24
  st.session_state.selected_chains = selected_chains
25
 
26
 
 
51
 
52
  left, mid, right = st.columns(3)
53
  with left:
54
+ selected_model = select_model(models)
 
 
55
  with mid:
56
  layer_one = st.number_input("Layer", value=5, min_value=1, max_value=selected_model.layers)
57
  layer = layer_one - 1