Spaces:
Sleeping
Sleeping
Share pdb selector between views
Browse files- hexviz/pages/🗺️Identify_Interesting_Heads.py +3 -10
- hexviz/view.py +24 -1
- hexviz/🧬Attention_Visualization.py +4 -14
hexviz/pages/🗺️Identify_Interesting_Heads.py
CHANGED
@@ -3,7 +3,7 @@ import streamlit as st
|
|
3 |
from hexviz.attention import get_attention, get_sequence, get_structure
|
4 |
from hexviz.models import Model, ModelType
|
5 |
from hexviz.plot import plot_tiled_heatmap
|
6 |
-
from hexviz.view import
|
7 |
|
8 |
st.set_page_config(layout="wide")
|
9 |
st.subheader("Find interesting heads and layers")
|
@@ -14,16 +14,9 @@ models = [
|
|
14 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
15 |
]
|
16 |
|
17 |
-
|
18 |
-
st.session_state.selected_model_name = selected_model_name
|
19 |
-
selected_model = next((model for model in models if model.name.value == selected_model_name), None)
|
20 |
-
|
21 |
-
pdb_id = st.sidebar.text_input(
|
22 |
-
label="PDB ID",
|
23 |
-
value=st.session_state.get("pdb_id", "2FZ5"),
|
24 |
-
)
|
25 |
-
st.session_state.pdb_id = pdb_id
|
26 |
|
|
|
27 |
|
28 |
structure = get_structure(pdb_id)
|
29 |
|
|
|
3 |
from hexviz.attention import get_attention, get_sequence, get_structure
|
4 |
from hexviz.models import Model, ModelType
|
5 |
from hexviz.plot import plot_tiled_heatmap
|
6 |
+
from hexviz.view import select_model, select_pdb
|
7 |
|
8 |
st.set_page_config(layout="wide")
|
9 |
st.subheader("Find interesting heads and layers")
|
|
|
14 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
15 |
]
|
16 |
|
17 |
+
selected_model = select_model(models)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
pdb_id = select_pdb()
|
20 |
|
21 |
structure = get_structure(pdb_id)
|
22 |
|
hexviz/view.py
CHANGED
@@ -6,4 +6,27 @@ def get_selecte_model_index(models):
|
|
6 |
if selected_model_name is None:
|
7 |
return 0
|
8 |
else:
|
9 |
-
return next((i for i, model in enumerate(models) if model.name.value == selected_model_name), None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
if selected_model_name is None:
|
7 |
return 0
|
8 |
else:
|
9 |
+
return next((i for i, model in enumerate(models) if model.name.value == selected_model_name), None)
|
10 |
+
|
11 |
+
def select_model(models):
|
12 |
+
"""
|
13 |
+
Select model, prefil selector with selected model from session storage
|
14 |
+
|
15 |
+
Saves the selected model in session storage.
|
16 |
+
"""
|
17 |
+
selected_model_name = st.selectbox("Select model", [model.name.value for model in models], index=get_selecte_model_index(models))
|
18 |
+
st.session_state.selected_model_name = selected_model_name
|
19 |
+
select_model = next((model for model in models if model.name.value == selected_model_name), None)
|
20 |
+
return select_model
|
21 |
+
|
22 |
+
def select_pdb():
|
23 |
+
st.sidebar.markdown(
|
24 |
+
"""
|
25 |
+
Select Protein
|
26 |
+
---
|
27 |
+
""")
|
28 |
+
pdb_id = st.sidebar.text_input(
|
29 |
+
label="PDB ID",
|
30 |
+
value=st.session_state.get("pdb_id", "2FZ5"))
|
31 |
+
st.session_state.pdb_id = pdb_id
|
32 |
+
return pdb_id
|
hexviz/🧬Attention_Visualization.py
CHANGED
@@ -6,7 +6,7 @@ from stmol import showmol
|
|
6 |
|
7 |
from hexviz.attention import get_attention_pairs, get_chains, get_structure
|
8 |
from hexviz.models import Model, ModelType
|
9 |
-
from hexviz.view import
|
10 |
|
11 |
st.title("Attention Visualization on proteins")
|
12 |
|
@@ -16,19 +16,11 @@ models = [
|
|
16 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
17 |
]
|
18 |
|
19 |
-
|
20 |
-
"""
|
21 |
-
Select Protein
|
22 |
-
---
|
23 |
-
""")
|
24 |
-
pdb_id = st.sidebar.text_input(
|
25 |
-
label="PDB ID",
|
26 |
-
value=st.session_state.get("pdb_id", "2FZ5"))
|
27 |
-
st.session_state.pdb_id = pdb_id
|
28 |
structure = get_structure(pdb_id)
|
29 |
chains = get_chains(structure)
|
30 |
|
31 |
-
selected_chains = st.sidebar.multiselect(label="Chain(s)", options=chains, default=st.session_state.get("selected_chains", None) or chains)
|
32 |
st.session_state.selected_chains = selected_chains
|
33 |
|
34 |
|
@@ -59,9 +51,7 @@ label_highest = st.sidebar.checkbox("Label highest attention pairs", value=True)
|
|
59 |
|
60 |
left, mid, right = st.columns(3)
|
61 |
with left:
|
62 |
-
|
63 |
-
st.session_state.selected_model_name = selected_model_name
|
64 |
-
selected_model = next((model for model in models if model.name.value == selected_model_name), None)
|
65 |
with mid:
|
66 |
layer_one = st.number_input("Layer", value=5, min_value=1, max_value=selected_model.layers)
|
67 |
layer = layer_one - 1
|
|
|
6 |
|
7 |
from hexviz.attention import get_attention_pairs, get_chains, get_structure
|
8 |
from hexviz.models import Model, ModelType
|
9 |
+
from hexviz.view import select_model, select_pdb
|
10 |
|
11 |
st.title("Attention Visualization on proteins")
|
12 |
|
|
|
16 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
17 |
]
|
18 |
|
19 |
+
pdb_id = select_pdb()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
structure = get_structure(pdb_id)
|
21 |
chains = get_chains(structure)
|
22 |
|
23 |
+
selected_chains = st.sidebar.multiselect(label="Select Chain(s)", options=chains, default=st.session_state.get("selected_chains", None) or chains)
|
24 |
st.session_state.selected_chains = selected_chains
|
25 |
|
26 |
|
|
|
51 |
|
52 |
left, mid, right = st.columns(3)
|
53 |
with left:
|
54 |
+
selected_model = select_model(models)
|
|
|
|
|
55 |
with mid:
|
56 |
layer_one = st.number_input("Layer", value=5, min_value=1, max_value=selected_model.layers)
|
57 |
layer = layer_one - 1
|