Spaces:
Sleeping
Sleeping
Add text and slice selection to head selection view
Browse files
hexviz/pages/Identify_interesting_Heads.py
CHANGED
@@ -5,6 +5,7 @@ from hexviz.models import Model, ModelType
|
|
5 |
from hexviz.plot import plot_tiled_heatmap
|
6 |
|
7 |
st.set_page_config(layout="wide")
|
|
|
8 |
|
9 |
|
10 |
models = [
|
@@ -20,16 +21,22 @@ pdb_id = st.sidebar.text_input(
|
|
20 |
value="1AKE",
|
21 |
)
|
22 |
|
|
|
23 |
structure = get_structure(pdb_id)
|
24 |
chains = list(structure.get_chains())
|
25 |
|
26 |
sequence = get_sequence(chains[0])
|
27 |
l = len(sequence)
|
28 |
-
|
29 |
-
|
|
|
|
|
30 |
|
|
|
|
|
|
|
|
|
31 |
attention = get_attention(sequence=truncated_sequence, model_type=selected_model.name)
|
32 |
|
33 |
-
st.subheader("Find interesting heads and layers")
|
34 |
fig = plot_tiled_heatmap(attention, layer_count=selected_model.layers, head_count=selected_model.heads)
|
35 |
st.pyplot(fig)
|
|
|
5 |
from hexviz.plot import plot_tiled_heatmap
|
6 |
|
7 |
st.set_page_config(layout="wide")
|
8 |
+
st.subheader("Find interesting heads and layers")
|
9 |
|
10 |
|
11 |
models = [
|
|
|
21 |
value="1AKE",
|
22 |
)
|
23 |
|
24 |
+
|
25 |
structure = get_structure(pdb_id)
|
26 |
chains = list(structure.get_chains())
|
27 |
|
28 |
sequence = get_sequence(chains[0])
|
29 |
l = len(sequence)
|
30 |
+
slice_start= st.sidebar.number_input(f"Slice start(1-{l})",value=1, min_value=1, max_value=l)
|
31 |
+
slice_end = st.sidebar.number_input(f"Slice end(1-{l})",value=50, min_value=1, max_value=l)
|
32 |
+
truncated_sequence = sequence[slice_start-1:slice_end]
|
33 |
+
|
34 |
|
35 |
+
st.text(f"Each tile is a heatmap of attention for the slice {slice_start}:{slice_end} of {pdb_id} (chain A) for one attention head.")
|
36 |
+
|
37 |
+
# TODO: Decide if you should get attention for the full sequence or just the truncated sequence
|
38 |
+
# Attention values will change depending on what we do.
|
39 |
attention = get_attention(sequence=truncated_sequence, model_type=selected_model.name)
|
40 |
|
|
|
41 |
fig = plot_tiled_heatmap(attention, layer_count=selected_model.layers, head_count=selected_model.heads)
|
42 |
st.pyplot(fig)
|