Spaces:
Sleeping
Sleeping
Check session state with .get
Browse files- hexviz/view.py +3 -11
hexviz/view.py
CHANGED
@@ -57,15 +57,9 @@ def select_heads_and_layers(sidebar, model):
|
|
57 |
---
|
58 |
"""
|
59 |
)
|
60 |
-
|
61 |
-
"setting heads"
|
62 |
-
st.session_state.plot_heads = (1, model.heads//2)
|
63 |
-
if "plot_layers" not in st.session_state:
|
64 |
-
"setting layers"
|
65 |
-
st.session_state.plot_layers = (1, model.layers//2)
|
66 |
-
head_range = sidebar.slider("Heads to plot", min_value=1, max_value=model.heads, value=st.session_state.plot_heads, step=1)
|
67 |
st.session_state.plot_heads = head_range
|
68 |
-
layer_range = sidebar.slider("Layers to plot", min_value=1, max_value=model.layers, value=st.session_state.plot_layers, step=1)
|
69 |
st.session_state.plot_layers = layer_range
|
70 |
|
71 |
step_size = sidebar.number_input("Optional step size to skip heads and layers", value=1, min_value=1, max_value=model.layers)
|
@@ -76,8 +70,6 @@ def select_heads_and_layers(sidebar, model):
|
|
76 |
|
77 |
def select_sequence_slice(sequence_length):
|
78 |
st.sidebar.markdown("Sequence segment to plot")
|
79 |
-
|
80 |
-
st.session_state["sequence_slice"] = (1, min(50, sequence_length))
|
81 |
-
slice = st.sidebar.slider("Sequence", value=st.session_state.sequence_slice, min_value=1, max_value=sequence_length, step=1)
|
82 |
st.session_state.sequence_slice = slice
|
83 |
return slice
|
|
|
57 |
---
|
58 |
"""
|
59 |
)
|
60 |
+
head_range = sidebar.slider("Heads to plot", min_value=1, max_value=model.heads, value=st.session_state.get("plot_heads", (1, model.heads//2)), step=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
st.session_state.plot_heads = head_range
|
62 |
+
layer_range = sidebar.slider("Layers to plot", min_value=1, max_value=model.layers, value=st.session_state.get("plot_layers", (1, model.layers//2)), step=1)
|
63 |
st.session_state.plot_layers = layer_range
|
64 |
|
65 |
step_size = sidebar.number_input("Optional step size to skip heads and layers", value=1, min_value=1, max_value=model.layers)
|
|
|
70 |
|
71 |
def select_sequence_slice(sequence_length):
|
72 |
st.sidebar.markdown("Sequence segment to plot")
|
73 |
+
slice = st.sidebar.slider("Sequence", value=st.session_state.get("sequence_slice", (1, min(50, sequence_length))), min_value=1, max_value=sequence_length, step=1)
|
|
|
|
|
74 |
st.session_state.sequence_slice = slice
|
75 |
return slice
|