aksell commited on
Commit
fd3cdd5
·
1 Parent(s): 2d68959

Store selected sequence slice

Browse files
hexviz/pages/🗺️Identify_Interesting_Heads.py CHANGED
@@ -4,7 +4,7 @@ from hexviz.attention import get_attention, get_sequence, get_structure
4
  from hexviz.models import Model, ModelType
5
  from hexviz.plot import plot_tiled_heatmap
6
  from hexviz.view import (menu_items, select_heads_and_layers, select_model,
7
- select_pdb)
8
 
9
  st.set_page_config(layout="wide", menu_items=menu_items)
10
  st.subheader("Find interesting heads and layers")
@@ -34,12 +34,7 @@ selected_chain = next(chain for chain in chains if chain.id == chain_selection)
34
  sequence = get_sequence(selected_chain)
35
 
36
  l = len(sequence)
37
- st.sidebar.markdown("Sequence segment to plot")
38
- if "sequence-slice" not in st.session_state:
39
- st.session_state["sequence-slice"] = (1, min(50, l))
40
- slice_start, slice_end = st.sidebar.slider("Sequence", key="sequence-slice", min_value=1, max_value=l, step=1)
41
- # slice_start= st.sidebar.number_input(f"Section start(1-{l})",value=1, min_value=1, max_value=l)
42
- # slice_end = st.sidebar.number_input(f"Section end(1-{l})",value=50, min_value=1, max_value=l)
43
  truncated_sequence = sequence[slice_start-1:slice_end]
44
 
45
 
 
4
  from hexviz.models import Model, ModelType
5
  from hexviz.plot import plot_tiled_heatmap
6
  from hexviz.view import (menu_items, select_heads_and_layers, select_model,
7
+ select_pdb, select_sequence_slice)
8
 
9
  st.set_page_config(layout="wide", menu_items=menu_items)
10
  st.subheader("Find interesting heads and layers")
 
34
  sequence = get_sequence(selected_chain)
35
 
36
  l = len(sequence)
37
+ slice_start, slice_end = select_sequence_slice(l)
 
 
 
 
 
38
  truncated_sequence = sequence[slice_start-1:slice_end]
39
 
40
 
hexviz/view.py CHANGED
@@ -41,9 +41,12 @@ def select_pdb():
41
  pdb_id = st.sidebar.text_input(
42
  label="PDB ID",
43
  value=stored_pdb or "2FZ5")
44
- if pdb_id != stored_pdb:
 
45
  st.session_state.selected_chains = None
46
  st.session_state.selected_chain_index = 0
 
 
47
  st.session_state.pdb_id = pdb_id
48
  return pdb_id
49
 
@@ -69,4 +72,12 @@ def select_heads_and_layers(sidebar, model):
69
  layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
70
  head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
71
 
72
- return layer_sequence, head_sequence
 
 
 
 
 
 
 
 
 
41
  pdb_id = st.sidebar.text_input(
42
  label="PDB ID",
43
  value=stored_pdb or "2FZ5")
44
+ pdb_changed = stored_pdb != pdb_id
45
+ if pdb_changed:
46
  st.session_state.selected_chains = None
47
  st.session_state.selected_chain_index = 0
48
+ if "sequence_slice" in st.session_state:
49
+ del st.session_state.sequence_slice
50
  st.session_state.pdb_id = pdb_id
51
  return pdb_id
52
 
 
72
  layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
73
  head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
74
 
75
+ return layer_sequence, head_sequence
76
+
77
+ def select_sequence_slice(sequence_length):
78
+ st.sidebar.markdown("Sequence segment to plot")
79
+ if "sequence_slice" not in st.session_state:
80
+ st.session_state["sequence_slice"] = (1, min(50, sequence_length))
81
+ slice = st.sidebar.slider("Sequence", value=st.session_state.sequence_slice, min_value=1, max_value=sequence_length, step=1)
82
+ st.session_state.sequence_slice = slice
83
+ return slice