aksell commited on
Commit
599e725
·
1 Parent(s): 92ba471

Store selected head and layer in attention visualization

Browse files
hexviz/view.py CHANGED
@@ -28,6 +28,10 @@ def select_model(models):
28
  del st.session_state.plot_heads
29
  if "plot_layers" in st.session_state:
30
  del st.session_state.plot_layers
 
 
 
 
31
  select_model = next((model for model in models if model.name.value == selected_model_name), None)
32
  return select_model
33
 
 
28
  del st.session_state.plot_heads
29
  if "plot_layers" in st.session_state:
30
  del st.session_state.plot_layers
31
+ if "selected_head" in st.session_state:
32
+ del st.session_state.selected_head
33
+ if "selected_layer" in st.session_state:
34
+ del st.session_state.selected_layer
35
  select_model = next((model for model in models if model.name.value == selected_model_name), None)
36
  return select_model
37
 
hexviz/🧬Attention_Visualization.py CHANGED
@@ -47,10 +47,12 @@ left, mid, right = st.columns(3)
47
  with left:
48
  selected_model = select_model(models)
49
  with mid:
50
- layer_one = st.number_input("Layer", value=5, min_value=1, max_value=selected_model.layers)
 
51
  layer = layer_one - 1
52
  with right:
53
- head_one = st.number_input("Head", value=1, min_value=1, max_value=selected_model.heads)
 
54
  head = head_one - 1
55
 
56
 
 
47
  with left:
48
  selected_model = select_model(models)
49
  with mid:
50
+ layer_one = st.number_input("Layer",value=st.session_state.get("selected_layer", 5), min_value=1, max_value=selected_model.layers)
51
+ st.session_state["selected_layer"] = layer_one
52
  layer = layer_one - 1
53
  with right:
54
+ head_one = st.number_input("Head", value=st.session_state.get("selected_head", 1), min_value=1, max_value=selected_model.heads)
55
+ st.session_state["selected_head"] = head_one
56
  head = head_one - 1
57
 
58