Spaces:
Sleeping
Sleeping
Store selected sequence slice
Browse files
hexviz/pages/🗺️Identify_Interesting_Heads.py
CHANGED
@@ -4,7 +4,7 @@ from hexviz.attention import get_attention, get_sequence, get_structure
|
|
4 |
from hexviz.models import Model, ModelType
|
5 |
from hexviz.plot import plot_tiled_heatmap
|
6 |
from hexviz.view import (menu_items, select_heads_and_layers, select_model,
|
7 |
-
select_pdb)
|
8 |
|
9 |
st.set_page_config(layout="wide", menu_items=menu_items)
|
10 |
st.subheader("Find interesting heads and layers")
|
@@ -34,12 +34,7 @@ selected_chain = next(chain for chain in chains if chain.id == chain_selection)
|
|
34 |
sequence = get_sequence(selected_chain)
|
35 |
|
36 |
l = len(sequence)
|
37 |
-
|
38 |
-
if "sequence-slice" not in st.session_state:
|
39 |
-
st.session_state["sequence-slice"] = (1, min(50, l))
|
40 |
-
slice_start, slice_end = st.sidebar.slider("Sequence", key="sequence-slice", min_value=1, max_value=l, step=1)
|
41 |
-
# slice_start= st.sidebar.number_input(f"Section start(1-{l})",value=1, min_value=1, max_value=l)
|
42 |
-
# slice_end = st.sidebar.number_input(f"Section end(1-{l})",value=50, min_value=1, max_value=l)
|
43 |
truncated_sequence = sequence[slice_start-1:slice_end]
|
44 |
|
45 |
|
|
|
4 |
from hexviz.models import Model, ModelType
|
5 |
from hexviz.plot import plot_tiled_heatmap
|
6 |
from hexviz.view import (menu_items, select_heads_and_layers, select_model,
|
7 |
+
select_pdb, select_sequence_slice)
|
8 |
|
9 |
st.set_page_config(layout="wide", menu_items=menu_items)
|
10 |
st.subheader("Find interesting heads and layers")
|
|
|
34 |
sequence = get_sequence(selected_chain)
|
35 |
|
36 |
l = len(sequence)
|
37 |
+
slice_start, slice_end = select_sequence_slice(l)
|
|
|
|
|
|
|
|
|
|
|
38 |
truncated_sequence = sequence[slice_start-1:slice_end]
|
39 |
|
40 |
|
hexviz/view.py
CHANGED
@@ -41,9 +41,12 @@ def select_pdb():
|
|
41 |
pdb_id = st.sidebar.text_input(
|
42 |
label="PDB ID",
|
43 |
value=stored_pdb or "2FZ5")
|
44 |
-
|
|
|
45 |
st.session_state.selected_chains = None
|
46 |
st.session_state.selected_chain_index = 0
|
|
|
|
|
47 |
st.session_state.pdb_id = pdb_id
|
48 |
return pdb_id
|
49 |
|
@@ -69,4 +72,12 @@ def select_heads_and_layers(sidebar, model):
|
|
69 |
layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
|
70 |
head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
|
71 |
|
72 |
-
return layer_sequence, head_sequence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
pdb_id = st.sidebar.text_input(
|
42 |
label="PDB ID",
|
43 |
value=stored_pdb or "2FZ5")
|
44 |
+
pdb_changed = stored_pdb != pdb_id
|
45 |
+
if pdb_changed:
|
46 |
st.session_state.selected_chains = None
|
47 |
st.session_state.selected_chain_index = 0
|
48 |
+
if "sequence_slice" in st.session_state:
|
49 |
+
del st.session_state.sequence_slice
|
50 |
st.session_state.pdb_id = pdb_id
|
51 |
return pdb_id
|
52 |
|
|
|
72 |
layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
|
73 |
head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
|
74 |
|
75 |
+
return layer_sequence, head_sequence
|
76 |
+
|
77 |
+
def select_sequence_slice(sequence_length):
|
78 |
+
st.sidebar.markdown("Sequence segment to plot")
|
79 |
+
if "sequence_slice" not in st.session_state:
|
80 |
+
st.session_state["sequence_slice"] = (1, min(50, sequence_length))
|
81 |
+
slice = st.sidebar.slider("Sequence", value=st.session_state.sequence_slice, min_value=1, max_value=sequence_length, step=1)
|
82 |
+
st.session_state.sequence_slice = slice
|
83 |
+
return slice
|