aksell commited on
Commit
92ba471
·
1 Parent(s): fd3cdd5

Check session state with .get

Browse files
Files changed (1) hide show
  1. hexviz/view.py +3 -11
hexviz/view.py CHANGED
@@ -57,15 +57,9 @@ def select_heads_and_layers(sidebar, model):
57
  ---
58
  """
59
  )
60
- if "plot_heads" not in st.session_state:
61
- "setting heads"
62
- st.session_state.plot_heads = (1, model.heads//2)
63
- if "plot_layers" not in st.session_state:
64
- "setting layers"
65
- st.session_state.plot_layers = (1, model.layers//2)
66
- head_range = sidebar.slider("Heads to plot", min_value=1, max_value=model.heads, value=st.session_state.plot_heads, step=1)
67
  st.session_state.plot_heads = head_range
68
- layer_range = sidebar.slider("Layers to plot", min_value=1, max_value=model.layers, value=st.session_state.plot_layers, step=1)
69
  st.session_state.plot_layers = layer_range
70
 
71
  step_size = sidebar.number_input("Optional step size to skip heads and layers", value=1, min_value=1, max_value=model.layers)
@@ -76,8 +70,6 @@ def select_heads_and_layers(sidebar, model):
76
 
77
  def select_sequence_slice(sequence_length):
78
  st.sidebar.markdown("Sequence segment to plot")
79
- if "sequence_slice" not in st.session_state:
80
- st.session_state["sequence_slice"] = (1, min(50, sequence_length))
81
- slice = st.sidebar.slider("Sequence", value=st.session_state.sequence_slice, min_value=1, max_value=sequence_length, step=1)
82
  st.session_state.sequence_slice = slice
83
  return slice
 
57
  ---
58
  """
59
  )
60
+ head_range = sidebar.slider("Heads to plot", min_value=1, max_value=model.heads, value=st.session_state.get("plot_heads", (1, model.heads//2)), step=1)
 
 
 
 
 
 
61
  st.session_state.plot_heads = head_range
62
+ layer_range = sidebar.slider("Layers to plot", min_value=1, max_value=model.layers, value=st.session_state.get("plot_layers", (1, model.layers//2)), step=1)
63
  st.session_state.plot_layers = layer_range
64
 
65
  step_size = sidebar.number_input("Optional step size to skip heads and layers", value=1, min_value=1, max_value=model.layers)
 
70
 
71
  def select_sequence_slice(sequence_length):
72
  st.sidebar.markdown("Sequence segment to plot")
73
+ slice = st.sidebar.slider("Sequence", value=st.session_state.get("sequence_slice", (1, min(50, sequence_length))), min_value=1, max_value=sequence_length, step=1)
 
 
74
  st.session_state.sequence_slice = slice
75
  return slice