Spaces:
Sleeping
Sleeping
Store selected head and layer in attention visualization
Browse files- hexviz/view.py +4 -0
- hexviz/🧬Attention_Visualization.py +4 -2
hexviz/view.py
CHANGED
@@ -28,6 +28,10 @@ def select_model(models):
|
|
28 |
del st.session_state.plot_heads
|
29 |
if "plot_layers" in st.session_state:
|
30 |
del st.session_state.plot_layers
|
|
|
|
|
|
|
|
|
31 |
select_model = next((model for model in models if model.name.value == selected_model_name), None)
|
32 |
return select_model
|
33 |
|
|
|
28 |
del st.session_state.plot_heads
|
29 |
if "plot_layers" in st.session_state:
|
30 |
del st.session_state.plot_layers
|
31 |
+
if "selected_head" in st.session_state:
|
32 |
+
del st.session_state.selected_head
|
33 |
+
if "selected_layer" in st.session_state:
|
34 |
+
del st.session_state.selected_layer
|
35 |
select_model = next((model for model in models if model.name.value == selected_model_name), None)
|
36 |
return select_model
|
37 |
|
hexviz/🧬Attention_Visualization.py
CHANGED
@@ -47,10 +47,12 @@ left, mid, right = st.columns(3)
|
|
47 |
with left:
|
48 |
selected_model = select_model(models)
|
49 |
with mid:
|
50 |
-
layer_one = st.number_input("Layer",
|
|
|
51 |
layer = layer_one - 1
|
52 |
with right:
|
53 |
-
head_one = st.number_input("Head", value=1, min_value=1, max_value=selected_model.heads)
|
|
|
54 |
head = head_one - 1
|
55 |
|
56 |
|
|
|
47 |
with left:
|
48 |
selected_model = select_model(models)
|
49 |
with mid:
|
50 |
+
layer_one = st.number_input("Layer",value=st.session_state.get("selected_layer", 5), min_value=1, max_value=selected_model.layers)
|
51 |
+
st.session_state["selected_layer"] = layer_one
|
52 |
layer = layer_one - 1
|
53 |
with right:
|
54 |
+
head_one = st.number_input("Head", value=st.session_state.get("selected_head", 1), min_value=1, max_value=selected_model.heads)
|
55 |
+
st.session_state["selected_head"] = head_one
|
56 |
head = head_one - 1
|
57 |
|
58 |
|