jannisborn commited on
Commit
9914a10
·
unverified ·
1 Parent(s): ac2c591
Files changed (2) hide show
  1. src/scenarios.py +60 -0
  2. src/streamlit_app.py +337 -384
src/scenarios.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # (1) A one-hot moving from token 0 to token 10 (“Text”)
4
+ dirac = [
5
+ {
6
+ "name": f"Dirac: all mass on token {i}",
7
+ "values": [1.0 if j == i else 0.0 for j in range(11)],
8
+ "ground_truth": "4",
9
+ "explanation": "A Dirac distribution: all probability on a single token.",
10
+ }
11
+ for i in range(11)
12
+ ]
13
+
14
+
15
+ # (2) A Gaussian with peak_mass=0.6 at center, remaining mass=0.4 spread by a Gaussian ---
16
+ def make_gauss_values(center, n=11, sigma=1.5, peak_mass=0.6):
17
+ xs = np.arange(n)
18
+ # unnormalized Gaussian
19
+ kernel = np.exp(-0.5 * ((xs - center) / sigma) ** 2)
20
+ # zero out the center, re-normalize the *other* weights to sum to 1
21
+ others = kernel.copy()
22
+ others[center] = 0.0
23
+ others /= others.sum()
24
+ # allocate 0.6 to the center, 0.4 to the rest
25
+ vals = others * (1.0 - peak_mass)
26
+ vals[center] = peak_mass
27
+ return vals.tolist()
28
+
29
+
30
+ gauss = [
31
+ {
32
+ "name": f"Gaussian: center at token {c}",
33
+ "values": make_gauss_values(c),
34
+ "ground_truth": "4",
35
+ "explanation": "Gaussian-style: 0.6 mass at the highlighted token, 0.4 spread smoothly to its neighbors.",
36
+ }
37
+ for c in range(11)
38
+ ]
39
+
40
+
41
+ # (3) Bimodal: two spikes of 0.5 mass each, symmetrically offset from the GT=4 ---
42
+ def make_bimodal_values(offset, n=11, gt=4):
43
+ # clamp to [0,n-1]
44
+ left = max(0, gt - offset)
45
+ right = min(n - 1, gt + offset)
46
+ vals = [0.0] * n
47
+ vals[left] = 0.5
48
+ vals[right] = 0.5
49
+ return vals
50
+
51
+
52
+ bimodal = [
53
+ {
54
+ "name": f"Bimodal: peaks at tokens {max(0, 4 - d)} & {min(10, 4 + d)}",
55
+ "values": make_bimodal_values(d),
56
+ "ground_truth": "4",
57
+ "explanation": "Two-point (bimodal) distribution: equal 0.5 mass on each peak, which move ±offset from the ground truth.",
58
+ }
59
+ for d in range(11)
60
+ ]
src/streamlit_app.py CHANGED
@@ -1,392 +1,163 @@
 
 
1
  import altair as alt
 
2
  import pandas as pd
 
3
  import streamlit_vertical_slider as svs
4
  import torch
5
- # from streamlit_vertical_slider import vertical_slider # Not directly used, svs.vertical_slider is
6
- import streamlit as st
7
- import time
8
- import plotly.graph_objects as go # Add Plotly import
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Define options globally as it's used in initialization and UI
11
  options = [str(i) for i in range(10)] + ["Text"]
12
 
13
  # --- Session State Initialization ---
14
  # Ensure all session state variables are initialized before first use, especially by widgets.
15
- if 'running_demo' not in st.session_state:
16
  st.session_state.running_demo = False
17
- if 'demo_step' not in st.session_state:
18
  st.session_state.demo_step = 0
19
- if 'last_update_time' not in st.session_state:
20
  st.session_state.last_update_time = 0
21
- if 'loss_container' not in st.session_state:
22
  st.session_state.loss_container = None
23
- if 'previous_chart_html' not in st.session_state:
24
  st.session_state.previous_chart_html = ""
 
 
 
 
 
25
 
26
  # Initialize states for sliders and ground_truth selector
27
  # Using len(options) to correctly size for 0-9 + "Text"
28
  for i in range(len(options)):
29
  if f"slider_{i}" not in st.session_state:
30
  st.session_state[f"slider_{i}"] = 1.0 / len(options)
31
- if 'ground_truth' not in st.session_state:
32
- st.session_state['ground_truth'] = options[0] # Default to "0"
33
 
34
 
35
  st.title("Number Token Loss - Demo")
36
 
37
- st.markdown("""
38
- Adjust the sliders to set a predicted probability for each token (0-9 and "Text").
39
- The sliders are vertical and compact. The app normalizes the slider values
40
- to form a valid probability distribution, visualizes it, and computes the corresponding
41
- Cross Entropy, NTL-MSE, and NTL-WAS losses.
42
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # --- Scenario Definitions ---
45
- scenarios = [
46
- {
47
- "name": "Probability mass at 0",
48
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
49
- "ground_truth": "0",
50
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
51
- },
52
- {
53
- "name": "Probability mass at 0",
54
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
55
- "ground_truth": "1",
56
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
57
- },
58
- {
59
- "name": "Probability mass at 0",
60
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
61
- "ground_truth": "2",
62
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
63
- },
64
- {
65
- "name": "Probability mass at 0",
66
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
67
- "ground_truth": "3",
68
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
69
- },
70
- {
71
- "name": "Probability mass at 0",
72
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
73
- "ground_truth": "4",
74
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
75
- },
76
- {
77
- "name": "Probability mass at 0",
78
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
79
- "ground_truth": "5",
80
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
81
- },
82
- {
83
- "name": "Probability mass at 0",
84
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
85
- "ground_truth": "6",
86
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
87
- },
88
- {
89
- "name": "Probability mass at 0",
90
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
91
- "ground_truth": "7",
92
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
93
- },
94
- {
95
- "name": "Probability mass at 0",
96
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
97
- "ground_truth": "8",
98
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
99
- },
100
- {
101
- "name": "Probability mass at 0",
102
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
103
- "ground_truth": "9",
104
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
105
- },
106
-
107
-
108
- {
109
- "name": "Probability mass around 5",
110
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
111
- "ground_truth": "0",
112
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
113
- },
114
- {
115
- "name": "Probability mass around 5",
116
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
117
- "ground_truth": "1",
118
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
119
- },
120
- {
121
- "name": "Probability mass around 5",
122
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
123
- "ground_truth": "2",
124
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
125
- },
126
- {
127
- "name": "Probability mass around 5",
128
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
129
- "ground_truth": "3",
130
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
131
- },
132
- {
133
- "name": "Probability mass around 5",
134
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
135
- "ground_truth": "4",
136
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
137
- },
138
- {
139
- "name": "Probability mass around ground truth (5)",
140
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
141
- "ground_truth": "5",
142
- "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
143
- },
144
- {
145
- "name": "Probability mass around 5",
146
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
147
- "ground_truth": "6",
148
- "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
149
- },
150
- {
151
- "name": "Probability mass around 5",
152
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
153
- "ground_truth": "7",
154
- "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
155
- },
156
- {
157
- "name": "Probability mass around 5",
158
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
159
- "ground_truth": "8",
160
- "explanation": "Cross Entropy is high, NTL is higher but still penalizes less than CE because distribution knows it's a number."
161
- },
162
- {
163
- "name": "Probability mass around 5",
164
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
165
- "ground_truth": "9",
166
- "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
167
- },
168
-
169
- {
170
- "name": "Probability mass concentrated on 5",
171
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
172
- "ground_truth": "0",
173
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
174
- },
175
- {
176
- "name": "Probability mass concentrated on 5",
177
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
178
- "ground_truth": "1",
179
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
180
- },
181
- {
182
- "name": "Probability mass concentrated on 5",
183
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
184
- "ground_truth": "2",
185
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
186
- },
187
- {
188
- "name": "Probability mass concentrated on 5",
189
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
190
- "ground_truth": "3",
191
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
192
- },
193
- {
194
- "name": "Probability mass concentrated on 5",
195
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
196
- "ground_truth": "4",
197
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
198
- },
199
- {
200
- "name": "Probability mass concentrated on 5",
201
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
202
- "ground_truth": "5",
203
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
204
- },
205
- {
206
- "name": "Probability mass concentrated on 5",
207
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
208
- "ground_truth": "6",
209
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
210
- },
211
- {
212
- "name": "Probability mass concentrated on 5",
213
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
214
- "ground_truth": "7",
215
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
216
- },
217
- {
218
- "name": "Probability mass concentrated on 5",
219
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
220
- "ground_truth": "8",
221
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
222
- },
223
- {
224
- "name": "Probability mass concentrated on 5",
225
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
226
- "ground_truth": "9",
227
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
228
- },
229
-
230
-
231
- {
232
- "name": "Probability mass concentrated on 1",
233
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
234
- "ground_truth": "0",
235
- "explanation": "Both losses are low because the prediction is correct."
236
- },
237
- {
238
- "name": "Probability mass concentrated on 1",
239
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
240
- "ground_truth": "1",
241
- "explanation": "Both losses are low because the prediction is correct."
242
- },
243
- {
244
- "name": "Probability mass concentrated on 1",
245
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
246
- "ground_truth": "2",
247
- "explanation": "Both losses are low because the prediction is correct."
248
- },
249
- {
250
- "name": "Probability mass concentrated on 1",
251
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
252
- "ground_truth": "3",
253
- "explanation": "Both losses are low because the prediction is correct."
254
- },
255
- {
256
- "name": "Probability mass concentrated on 1",
257
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
258
- "ground_truth": "4",
259
- "explanation": "Both losses are low because the prediction is correct."
260
- },
261
- {
262
- "name": "Probability mass concentrated on 1",
263
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
264
- "ground_truth": "5",
265
- "explanation": "Both losses are low because the prediction is correct."
266
- },
267
- {
268
- "name": "Probability mass concentrated on 1",
269
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
270
- "ground_truth": "6",
271
- "explanation": "Both losses are low because the prediction is correct."
272
- },
273
- {
274
- "name": "Probability mass concentrated on 1",
275
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
276
- "ground_truth": "7",
277
- "explanation": "Both losses are low because the prediction is correct."
278
- },
279
- {
280
- "name": "Probability mass concentrated on 1",
281
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
282
- "ground_truth": "8",
283
- "explanation": "Both losses are low because the prediction is correct."
284
- },
285
- {
286
- "name": "Probability mass concentrated on 1",
287
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
288
- "ground_truth": "9",
289
- "explanation": "Both losses are low because the prediction is correct."
290
- },
291
-
292
-
293
- {
294
- "name": "Almost correct (1 vs 2)",
295
- "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
296
- "ground_truth": "0",
297
- "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
298
- },
299
- {
300
- "name": "Almost correct (1 vs 2)",
301
- "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
302
- "ground_truth": "1",
303
- "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
304
- },
305
- {
306
- "name": "Almost correct (1 vs 2)",
307
- "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
308
- "ground_truth": "2",
309
- "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
310
- },
311
- {
312
- "name": "Almost correct (1 vs 2)",
313
- "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
314
- "ground_truth": "3",
315
- "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
316
- }
317
- ]
318
 
319
- # --- Helper Functions ---
320
  def apply_scenario(step_idx):
321
- scenario = scenarios[step_idx]
322
- # These assignments modify session state. They must be done *before* the widgets
323
- # are rendered in the script run that should display these new values.
324
  for i, val in enumerate(scenario["values"]):
325
  st.session_state[f"slider_{i}"] = val
326
- st.session_state['ground_truth'] = scenario["ground_truth"]
327
 
328
- def start_demo():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  st.session_state.running_demo = True
330
  st.session_state.demo_step = 0
331
  st.session_state.last_update_time = time.time()
332
- apply_scenario(0) # Apply the first scenario's state
333
- # The button click that calls start_demo() will itself cause a rerun.
334
 
335
  def stop_demo():
336
  st.session_state.running_demo = False
337
 
 
338
  # --- Demo State Advancement Logic ---
339
  # This block handles advancing the demo. If it advances, it updates session state
340
  # and then reruns. This ensures widgets are drawn with the new state in the next run.
341
  if st.session_state.running_demo:
 
342
  current_time = time.time()
343
- if current_time - st.session_state.last_update_time > 3.0: # 3 seconds per scenario
344
- next_step = (st.session_state.demo_step + 1) % len(scenarios)
345
  st.session_state.demo_step = next_step
346
  apply_scenario(next_step) # Update session state for the new scenario
347
- st.session_state.last_update_time = time.time() # Reset timer
348
  st.rerun() # Crucial: Rerun to reflect changes in widgets and charts
349
 
350
  # --- UI Rendering ---
351
  # This section renders the main UI. It executes after any potential rerun from the block above.
352
 
353
  if st.session_state.running_demo:
354
- st.info(f"Showing scenario {st.session_state.demo_step + 1}/{len(scenarios)}: {scenarios[st.session_state.demo_step]['name']}")
355
- st.markdown(f"**Explanation:** {scenarios[st.session_state.demo_step]['explanation']}")
 
 
 
356
  if st.button("Stop Demo"):
357
- stop_demo()
358
  st.rerun()
359
- else: # Not st.session_state.running_demo
360
- if st.button("Start Automated Demo"):
361
- start_demo() # This calls apply_scenario(0)
362
- st.rerun() # Rerun to enter demo mode and draw scenario 0 correctly
363
-
364
- # Sliders and Ground Truth Selector
365
- # These widgets will read their initial values from st.session_state.
366
- # User interactions will update st.session_state directly due to their keys.
367
- if not st.session_state.running_demo:
368
- st.markdown("#### Predicted Token Probabilities")
369
- cols = st.columns(len(options))
370
- for i, col in enumerate(cols):
371
- label = options[i] # Use token name directly for label
372
- with col:
373
- svs.vertical_slider(
374
- label=label, min_value=0.0, max_value=1.0, step=0.01, height=50,
375
- key=f"slider_{i}", # This key links the widget to st.session_state[f"slider_{i}"]
376
- slider_color="green", track_color="lightgray", thumb_color="black"
377
- )
378
 
379
- # Ground truth selectbox
380
- st.selectbox(
381
- "Ground Truth Token", options=options,
382
- index=options.index(st.session_state['ground_truth']), # Display value from session state
383
- key='ground_truth' # Links widget to st.session_state['ground_truth']
384
- )
385
 
386
  # Placeholder for charts and loss calculations that will be updated
387
  # This section always reads the current st.session_state to generate its content.
388
 
389
- current_prob_values_from_state = [st.session_state.get(f"slider_{j}", 1.0/len(options)) for j in range(len(options))]
 
 
390
  total_from_state = sum(current_prob_values_from_state)
391
  probs_for_charts = (
392
  torch.ones(len(options)) / len(options)
@@ -394,63 +165,144 @@ probs_for_charts = (
394
  else torch.tensor([v / total_from_state for v in current_prob_values_from_state])
395
  )
396
 
397
- gt_choice_for_charts = st.session_state.get('ground_truth', options[0])
398
  if gt_choice_for_charts == "Text":
399
- gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10)
400
  gt_numeric_for_charts = None
401
  else:
402
  gt_index_for_charts = int(gt_choice_for_charts)
403
  gt_numeric_for_charts = gt_index_for_charts
404
 
405
- st.markdown("#### Input Probability Distribution")
406
- df_dist = pd.DataFrame({"token": options, "probability": probs_for_charts.numpy()})
407
- df_dist["type"] = ["Ground Truth" if token == gt_choice_for_charts else "Prediction" for token in options]
408
- chart = (
409
- alt.Chart(df_dist).mark_bar().encode(
410
- x=alt.X("token:N", title="Token", sort=options), # Ensure consistent sort order
411
- y=alt.Y("probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1])),
412
- color=alt.Color("type:N", scale=alt.Scale(domain=["Ground Truth", "Prediction"], range=["green", "steelblue"]), legend=alt.Legend(title="Token Type"))
413
- ).properties(height=300)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  )
415
- st.altair_chart(chart, use_container_width=True)
 
 
 
 
416
 
417
  ce_loss = -torch.log(torch.clamp(probs_for_charts[gt_index_for_charts], min=1e-9))
418
- if gt_numeric_for_charts is None: # Text token
419
- ntl_mse_loss = torch.tensor(float('nan')) # MSE not applicable for text
420
- ntl_was_loss = torch.tensor(float('nan')) # WAS not applicable for text
421
- else: # Numeric token
422
- numeric_probs_for_loss = probs_for_charts[:10] # Probabilities for 0-9
 
423
  # Ensure numeric_probs_for_loss sums to 1 for NTL calculations if it's a subset
424
  numeric_probs_sum = torch.sum(numeric_probs_for_loss)
425
- if numeric_probs_sum > 1e-6 : # Avoid division by zero
426
- normalized_numeric_probs = numeric_probs_for_loss / numeric_probs_sum
427
  else:
428
- normalized_numeric_probs = torch.zeros_like(numeric_probs_for_loss)
429
-
430
 
431
  loss_values_tensor = torch.arange(0, 10, dtype=torch.float32)
432
 
433
  # Use normalized probabilities for NTL if only considering numeric tokens
434
- if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6 :
435
- pred_value = torch.sum( (probs_for_charts[:10]/torch.sum(probs_for_charts[:10])) * loss_values_tensor)
436
- elif gt_choice_for_charts != "Text": # if sum is zero, pred_value is ill-defined or 0
437
- pred_value = torch.tensor(0.0)
438
- else: # Should not happen if gt_numeric_for_charts is not None
439
- pred_value = torch.tensor(float('nan'))
440
-
 
 
 
 
441
 
442
  if not torch.isnan(pred_value):
443
- ntl_mse_loss = (pred_value - float(gt_numeric_for_charts)) ** 2
 
 
444
  abs_diff = torch.abs(loss_values_tensor - float(gt_numeric_for_charts))
445
  if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6:
446
- ntl_was_loss = torch.sum((probs_for_charts[:10]/torch.sum(probs_for_charts[:10])) * abs_diff)
 
 
447
  elif gt_choice_for_charts != "Text":
448
- ntl_was_loss = torch.tensor(0.0) # Or some other default if all numeric probs are zero
449
  else:
450
- ntl_was_loss = torch.tensor(float('nan'))
 
451
  else:
452
- ntl_mse_loss = torch.tensor(float('nan'))
453
- ntl_was_loss = torch.tensor(float('nan'))
454
 
455
 
456
  ce_val = round(ce_loss.item(), 3)
@@ -458,6 +310,38 @@ mse_val = round(ntl_mse_loss.item(), 3) if not torch.isnan(ntl_mse_loss) else "N
458
  was_val = round(ntl_was_loss.item(), 3) if not torch.isnan(ntl_was_loss) else "N/A"
459
 
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]}
462
  if was_val != "N/A":
463
  loss_data["Loss"].append("NTL-WAS")
@@ -469,34 +353,103 @@ if mse_val != "N/A":
469
  loss_df = pd.DataFrame(loss_data)
470
 
471
  # ============== Chart Display ==============
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  # Create a single chart for loss visualization
473
  st.subheader("Loss Comparison")
 
 
 
 
 
 
 
474
 
475
  # Create an Altair chart that will look good and redraw cleanly
476
- chart = alt.Chart(loss_df).mark_bar().encode(
477
- x=alt.X('Loss:N', sort=loss_df["Loss"].tolist()),
478
- y=alt.Y('Value:Q', scale=alt.Scale(domain=[0, max(loss_df["Value"].max() * 1.2, 20 if st.session_state.running_demo else 0.5)])),
479
- color=alt.Color('Loss:N', scale=alt.Scale(
480
- domain=['Cross Entropy', 'NTL-WAS', 'NTL-MSE'],
481
- range=['steelblue', 'red', 'forestgreen']
482
- )),
483
- tooltip=['Loss', 'Value']
484
- ).properties(
485
- height=300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  )
487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  # Add value labels on top of bars
489
- text = chart.mark_text(
490
- align='center',
491
- baseline='bottom',
492
- dy=-5,
493
- fontSize=14
494
- ).encode(
495
- text=alt.Text('Value:Q', format='.3f')
496
  )
497
 
498
  # Combine chart and text
499
- final_chart = (chart + text)
500
 
501
  # Display chart with the full container width
502
  st.altair_chart(final_chart, use_container_width=True)
@@ -507,7 +460,7 @@ st.altair_chart(final_chart, use_container_width=True)
507
  if st.session_state.running_demo:
508
  # This check is implicitly: if we are here and demo is running, it means
509
  # the time-based advance condition was NOT met in the block at the top.
510
- time.sleep(0.1) # Adjusted from 0.2 to 0.5 (or try 1.0)
511
  st.rerun()
512
 
513
  # Add explanation of the demonstration
 
1
+ import time
2
+
3
  import altair as alt
4
+ import numpy as np
5
  import pandas as pd
6
+ import streamlit as st
7
  import streamlit_vertical_slider as svs
8
  import torch
9
+
10
+ from scenarios import bimodal, dirac, gauss
11
+
12
+ DEMO_INTERVAL = 1.5
13
+ NTL_MSE_SCALING = 0.5
14
+ MAX_LOSS_PLOT = 15
15
+ LAST_STEP = -1
16
+
17
+ # """TODO:
18
+ # - Remove flickering of loss evolution scenario plot (lower ylim?)
19
+ # - Move manual part down (predicted token probabilities)
20
+ # - Allow to set GT token for each demo
21
+ # - Add text token to loss evolution barplot
22
+ # - pick good default (4?)
23
+ # """
24
+
25
 
26
  # Define options globally as it's used in initialization and UI
27
  options = [str(i) for i in range(10)] + ["Text"]
28
 
29
  # --- Session State Initialization ---
30
  # Ensure all session state variables are initialized before first use, especially by widgets.
31
+ if "running_demo" not in st.session_state:
32
  st.session_state.running_demo = False
33
+ if "demo_step" not in st.session_state:
34
  st.session_state.demo_step = 0
35
+ if "last_update_time" not in st.session_state:
36
  st.session_state.last_update_time = 0
37
+ if "loss_container" not in st.session_state:
38
  st.session_state.loss_container = None
39
+ if "previous_chart_html" not in st.session_state:
40
  st.session_state.previous_chart_html = ""
41
+ if "active_scenarios" not in st.session_state:
42
+ # default if you want one to load on first show
43
+ st.session_state.active_scenarios = dirac
44
+ if "loss_history" not in st.session_state:
45
+ st.session_state.loss_history = []
46
 
47
  # Initialize states for sliders and ground_truth selector
48
  # Using len(options) to correctly size for 0-9 + "Text"
49
  for i in range(len(options)):
50
  if f"slider_{i}" not in st.session_state:
51
  st.session_state[f"slider_{i}"] = 1.0 / len(options)
52
+ if "ground_truth" not in st.session_state:
53
+ st.session_state["ground_truth"] = options[0] # Default to "0"
54
 
55
 
56
  st.title("Number Token Loss - Demo")
57
 
58
+ st.markdown(
59
+ """
60
+ **Instructions**
61
+
62
+ 1. **Pick a ground truth token (0–9).**
63
+ 2. **Select one of the three automated demos:**
64
+ - **Dirac**: a one-hot (Dirac) distribution whose single 1.0 mass moves from token 0 all the way to “Text.”
65
+ - **Gaussian**: a peaked Gaussian (0.6 mass at center, 0.4 spread) that slides its center from token 0 to “Text.”
66
+ - **Bimodal**: two equal peaks (0.5 each) that start at (0,8) and then move symmetrically away from the GT token.
67
+ """
68
+ )
69
+
70
+ if "ground_truth" not in st.session_state:
71
+ st.session_state["ground_truth"] = "4"
72
+ gt = st.selectbox(
73
+ "Ground Truth Token",
74
+ options=options,
75
+ index=options.index(st.session_state["ground_truth"]),
76
+ key="ground_truth",
77
+ )
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
 
80
  def apply_scenario(step_idx):
81
+ scenario = st.session_state.active_scenarios[step_idx]
 
 
82
  for i, val in enumerate(scenario["values"]):
83
  st.session_state[f"slider_{i}"] = val
 
84
 
85
+
86
+ def start_dirac_demo():
87
+ st.session_state.active_scenarios = dirac
88
+ st.session_state.running_demo = True
89
+ st.session_state.demo_step = 0
90
+ st.session_state.last_update_time = time.time()
91
+ apply_scenario(0)
92
+
93
+
94
+ def start_gauss_demo():
95
+ st.session_state.active_scenarios = gauss
96
+ st.session_state.running_demo = True
97
+ st.session_state.demo_step = 0
98
+ st.session_state.last_update_time = time.time()
99
+ apply_scenario(0)
100
+
101
+
102
+ def start_bimodal_demo():
103
+ st.session_state.active_scenarios = bimodal
104
  st.session_state.running_demo = True
105
  st.session_state.demo_step = 0
106
  st.session_state.last_update_time = time.time()
107
+ apply_scenario(0)
108
+
109
 
110
  def stop_demo():
111
  st.session_state.running_demo = False
112
 
113
+
114
  # --- Demo State Advancement Logic ---
115
  # This block handles advancing the demo. If it advances, it updates session state
116
  # and then reruns. This ensures widgets are drawn with the new state in the next run.
117
  if st.session_state.running_demo:
118
+ scenario = st.session_state.active_scenarios
119
  current_time = time.time()
120
+ if current_time - st.session_state.last_update_time > DEMO_INTERVAL:
121
+ next_step = (st.session_state.demo_step + 1) % len(scenario)
122
  st.session_state.demo_step = next_step
123
  apply_scenario(next_step) # Update session state for the new scenario
124
+ st.session_state.last_update_time = time.time() # Reset timer
125
  st.rerun() # Crucial: Rerun to reflect changes in widgets and charts
126
 
127
  # --- UI Rendering ---
128
  # This section renders the main UI. It executes after any potential rerun from the block above.
129
 
130
  if st.session_state.running_demo:
131
+ st.info(
132
+ f"Showing scenario {st.session_state.demo_step + 1}"
133
+ f"/{len(st.session_state.active_scenarios)}: "
134
+ f"{st.session_state.active_scenarios[st.session_state.demo_step]['name']}"
135
+ )
136
  if st.button("Stop Demo"):
137
+ st.session_state.running_demo = False
138
  st.rerun()
139
+ else:
140
+ col1, col2, col3 = st.columns(3)
141
+ with col1:
142
+ if st.button("Run: Dirac"):
143
+ start_dirac_demo()
144
+ st.rerun()
145
+ with col2:
146
+ if st.button("Run: Gauss"):
147
+ start_gauss_demo()
148
+ st.rerun()
149
+ with col3:
150
+ if st.button("Run: Bimodal"):
151
+ start_bimodal_demo()
152
+ st.rerun()
 
 
 
 
 
153
 
 
 
 
 
 
 
154
 
155
  # Placeholder for charts and loss calculations that will be updated
156
  # This section always reads the current st.session_state to generate its content.
157
 
158
+ current_prob_values_from_state = [
159
+ st.session_state.get(f"slider_{j}", 1.0 / len(options)) for j in range(len(options))
160
+ ]
161
  total_from_state = sum(current_prob_values_from_state)
162
  probs_for_charts = (
163
  torch.ones(len(options)) / len(options)
 
165
  else torch.tensor([v / total_from_state for v in current_prob_values_from_state])
166
  )
167
 
168
+ gt_choice_for_charts = st.session_state.get("ground_truth", options[0])
169
  if gt_choice_for_charts == "Text":
170
+ gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10)
171
  gt_numeric_for_charts = None
172
  else:
173
  gt_index_for_charts = int(gt_choice_for_charts)
174
  gt_numeric_for_charts = gt_index_for_charts
175
 
176
+ gt = st.session_state["ground_truth"]
177
+
178
+ st.markdown(f"#### Predicted Probability Distribution Ground truth token {gt}")
179
+ df_dist = pd.DataFrame(
180
+ {"token": options, "probability": probs_for_charts.numpy().round(2)}
181
+ )
182
+ df_dist["type"] = [
183
+ "Ground Truth" if token == gt_choice_for_charts else "Prediction"
184
+ for token in options
185
+ ]
186
+ bg = (
187
+ alt.Chart(pd.DataFrame({"token": [gt]}))
188
+ .mark_bar(size=40, color="lightgray", opacity=0.4)
189
+ .encode(
190
+ x=alt.X("token:N", sort=options),
191
+ x2=alt.X2("token:N"), # pin the right edge to the same category
192
+ y=alt.value(0), # bottom at y=0
193
+ y2=alt.value(1), # top at y=1 (full height)
194
+ )
195
+ )
196
+
197
+ bars = (
198
+ alt.Chart(df_dist)
199
+ .mark_bar()
200
+ .encode(
201
+ x=alt.X(
202
+ "token:N",
203
+ title="Token",
204
+ sort=options,
205
+ axis=alt.Axis(labelAngle=0, labelFontSize=14, titleFontSize=16),
206
+ ),
207
+ y=alt.Y(
208
+ "probability:Q",
209
+ title="Probability",
210
+ scale=alt.Scale(domain=[0, 1]),
211
+ axis=alt.Axis(format=".2f", labelFontSize=14, titleFontSize=16),
212
+ ),
213
+ color=alt.Color(
214
+ "type:N",
215
+ scale=alt.Scale(
216
+ domain=["Ground Truth", "Prediction"], range=["green", "steelblue"]
217
+ ),
218
+ legend=alt.Legend(title="Token Type", titleFontSize=16, labelFontSize=14),
219
+ ),
220
+ tooltip=[
221
+ alt.Tooltip("token:N", title="Token"),
222
+ alt.Tooltip("probability:Q", title="Probability", format=".2f"),
223
+ alt.Tooltip("type:N", title="Type"),
224
+ ],
225
+ )
226
+ .properties(height=300)
227
+ )
228
+ annot1 = (
229
+ alt.Chart(pd.DataFrame({"token": [gt]}))
230
+ .mark_text(
231
+ text="⬇ Ground",
232
+ dy=-25, # 10px above the top of the bar
233
+ dx=25,
234
+ fontSize=14,
235
+ fontWeight="bold",
236
+ color="green",
237
+ )
238
+ .encode(x=alt.X("token:N", sort=options), y=alt.value(1))
239
+ )
240
+
241
+ # second line: “truth=4”
242
+ annot2 = (
243
+ alt.Chart(pd.DataFrame({"token": [gt]}))
244
+ .mark_text(
245
+ text=f"truth={gt}",
246
+ dy=-10, # 25px above the top, so it sits above line 1
247
+ dx=35,
248
+ fontSize=14,
249
+ fontWeight="bold",
250
+ color="green",
251
+ )
252
+ .encode(x=alt.X("token:N", sort=options), y=alt.value(1))
253
  )
254
+
255
+ # 4) Layer them in order: background, bars, annotation
256
+ final_chart = (bg + bars + annot1 + annot2).properties(height=300)
257
+
258
+ st.altair_chart(final_chart, use_container_width=True)
259
 
260
  ce_loss = -torch.log(torch.clamp(probs_for_charts[gt_index_for_charts], min=1e-9))
261
+
262
+ if gt_numeric_for_charts is None: # Text token
263
+ ntl_mse_loss = torch.tensor(float("nan")) # MSE not applicable for text
264
+ ntl_was_loss = torch.tensor(float("nan")) # WAS not applicable for text
265
+ else: # Numeric token
266
+ numeric_probs_for_loss = probs_for_charts[:10] # Probabilities for 0-9
267
  # Ensure numeric_probs_for_loss sums to 1 for NTL calculations if it's a subset
268
  numeric_probs_sum = torch.sum(numeric_probs_for_loss)
269
+ if numeric_probs_sum > 1e-6: # Avoid division by zero
270
+ normalized_numeric_probs = numeric_probs_for_loss / numeric_probs_sum
271
  else:
272
+ normalized_numeric_probs = torch.zeros_like(numeric_probs_for_loss)
 
273
 
274
  loss_values_tensor = torch.arange(0, 10, dtype=torch.float32)
275
 
276
  # Use normalized probabilities for NTL if only considering numeric tokens
277
+ if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6:
278
+ pred_value = torch.sum(
279
+ (probs_for_charts[:10] / torch.sum(probs_for_charts[:10]))
280
+ * loss_values_tensor
281
+ )
282
+ elif (
283
+ gt_choice_for_charts != "Text"
284
+ ): # if sum is zero, pred_value is ill-defined or 0
285
+ pred_value = torch.tensor(0.0)
286
+ else: # Should not happen if gt_numeric_for_charts is not None
287
+ pred_value = torch.tensor(float("nan"))
288
 
289
  if not torch.isnan(pred_value):
290
+ ntl_mse_loss = ntl_mse_loss = (
291
+ NTL_MSE_SCALING * (pred_value - float(gt_numeric_for_charts)) ** 2
292
+ )
293
  abs_diff = torch.abs(loss_values_tensor - float(gt_numeric_for_charts))
294
  if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6:
295
+ ntl_was_loss = torch.sum(
296
+ (probs_for_charts[:10] / torch.sum(probs_for_charts[:10])) * abs_diff
297
+ )
298
  elif gt_choice_for_charts != "Text":
299
+ ntl_was_loss = torch.tensor(0.0)
300
  else:
301
+ ntl_was_loss = torch.tensor(float("nan"))
302
+
303
  else:
304
+ ntl_mse_loss = torch.tensor(float("nan"))
305
+ ntl_was_loss = torch.tensor(float("nan"))
306
 
307
 
308
  ce_val = round(ce_loss.item(), 3)
 
310
  was_val = round(ntl_was_loss.item(), 3) if not torch.isnan(ntl_was_loss) else "N/A"
311
 
312
 
313
+ if len(st.session_state.loss_history) < st.session_state.demo_step + 1:
314
+ st.session_state.loss_history.append(
315
+ {
316
+ "token_index": np.argmax(
317
+ st.session_state.active_scenarios[st.session_state["demo_step"]][
318
+ "values"
319
+ ]
320
+ ),
321
+ # int(np.argmax(st.session_state['values']))
322
+ # int(),
323
+ "CE": ce_val,
324
+ "NTL-MSE": mse_val if mse_val != "N/A" else None,
325
+ "NTL-WAS": was_val if was_val != "N/A" else None,
326
+ }
327
+ )
328
+ last_step = st.session_state.demo_step
329
+
330
+ if st.session_state.loss_history:
331
+ loss_plot_data = []
332
+ for entry in st.session_state.loss_history:
333
+ for loss_type in ["CE", "NTL-MSE", "NTL-WAS"]:
334
+ if entry[loss_type] is not None:
335
+ loss_plot_data.append(
336
+ {
337
+ "Token Index": entry["token_index"],
338
+ "Loss Type": loss_type,
339
+ "Loss Value": entry[loss_type], # TODO: clip to MAX_LOSS_PLOT?
340
+ }
341
+ )
342
+
343
+ df_loss_plot = pd.DataFrame(loss_plot_data)
344
+
345
  loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]}
346
  if was_val != "N/A":
347
  loss_data["Loss"].append("NTL-WAS")
 
353
  loss_df = pd.DataFrame(loss_data)
354
 
355
  # ============== Chart Display ==============
356
+
357
+
358
+ st.subheader("Loss Evolution Over Scenarios")
359
+
360
+ x_domain = list(range(10))
361
+
362
+ grouped_chart = (
363
+ alt.Chart(df_loss_plot)
364
+ .mark_bar()
365
+ .encode(
366
+ x=alt.X(
367
+ "Token Index:O",
368
+ title="Predicted Token Index",
369
+ axis=alt.Axis(labelAngle=0),
370
+ scale=alt.Scale(domain=x_domain),
371
+ ),
372
+ y=alt.Y(
373
+ "Loss Value:Q", title="Loss", scale=alt.Scale(domain=[0, MAX_LOSS_PLOT])
374
+ ),
375
+ color=alt.Color("Loss Type:N", legend=alt.Legend(title="Loss")),
376
+ xOffset="Loss Type:N", # <== this causes the grouping instead of stacking
377
+ )
378
+ .properties(height=300)
379
+ )
380
+
381
+ st.altair_chart(grouped_chart, use_container_width=True)
382
+
383
+
384
  # Create a single chart for loss visualization
385
  st.subheader("Loss Comparison")
386
+ st.markdown("""
387
+ Adjust the sliders to set a predicted probability for each token (0-9 and "Text").
388
+ The sliders are vertical and compact. The app normalizes the slider values
389
+ to form a valid probability distribution, visualizes it, and computes the corresponding
390
+ Cross Entropy, NTL-MSE, and NTL-WAS losses.
391
+ """)
392
+
393
 
394
  # Create an Altair chart that will look good and redraw cleanly
395
+ chart = (
396
+ alt.Chart(loss_df)
397
+ .mark_bar()
398
+ .encode(
399
+ x=alt.X("Loss:N", sort=loss_df["Loss"].tolist()),
400
+ y=alt.Y(
401
+ "Value:Q",
402
+ scale=alt.Scale(
403
+ domain=[
404
+ 0,
405
+ max(
406
+ loss_df["Value"].max() * 1.2,
407
+ 20 if st.session_state.running_demo else 0.5,
408
+ ),
409
+ ]
410
+ ),
411
+ ),
412
+ color=alt.Color(
413
+ "Loss:N",
414
+ scale=alt.Scale(
415
+ domain=["Cross Entropy", "NTL-WAS", "NTL-MSE"],
416
+ range=["steelblue", "red", "forestgreen"],
417
+ ),
418
+ ),
419
+ tooltip=["Loss", "Value"],
420
+ )
421
+ .properties(height=300)
422
  )
423
 
424
+ # Sliders and Ground Truth Selector
425
+ # These widgets will read their initial values from st.session_state.
426
+ # User interactions will update st.session_state directly due to their keys.
427
+ if not st.session_state.running_demo:
428
+ st.markdown("#### Predicted Token Probabilities")
429
+ cols = st.columns(len(options))
430
+ for i, col in enumerate(cols):
431
+ label = options[i] # Use token name directly for label
432
+ with col:
433
+ svs.vertical_slider(
434
+ label=label,
435
+ min_value=0.0,
436
+ max_value=1.0,
437
+ step=0.01,
438
+ height=50,
439
+ key=f"slider_{i}", # This key links the widget to st.session_state[f"slider_{i}"]
440
+ slider_color="green",
441
+ track_color="lightgray",
442
+ thumb_color="black",
443
+ )
444
+
445
+
446
  # Add value labels on top of bars
447
+ text = chart.mark_text(align="center", baseline="bottom", dy=-5, fontSize=14).encode(
448
+ text=alt.Text("Value:Q", format=".3f")
 
 
 
 
 
449
  )
450
 
451
  # Combine chart and text
452
+ final_chart = chart + text
453
 
454
  # Display chart with the full container width
455
  st.altair_chart(final_chart, use_container_width=True)
 
460
  if st.session_state.running_demo:
461
  # This check is implicitly: if we are here and demo is running, it means
462
  # the time-based advance condition was NOT met in the block at the top.
463
+ time.sleep(0.1)
464
  st.rerun()
465
 
466
  # Add explanation of the demonstration