jannisborn commited on
Commit
0dc70d1
·
unverified ·
1 Parent(s): 9914a10
Files changed (3) hide show
  1. .gitignore +2 -0
  2. src/scenarios.py +33 -20
  3. src/streamlit_app.py +273 -214
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.DS_Store
2
+ *__pycache__
src/scenarios.py CHANGED
@@ -1,9 +1,12 @@
1
  import numpy as np
2
 
 
 
 
3
  # (1) A one-hot moving from token 0 to token 10 (“Text”)
4
  dirac = [
5
  {
6
- "name": f"Dirac: all mass on token {i}",
7
  "values": [1.0 if j == i else 0.0 for j in range(11)],
8
  "ground_truth": "4",
9
  "explanation": "A Dirac distribution: all probability on a single token.",
@@ -29,7 +32,7 @@ def make_gauss_values(center, n=11, sigma=1.5, peak_mass=0.6):
29
 
30
  gauss = [
31
  {
32
- "name": f"Gaussian: center at token {c}",
33
  "values": make_gauss_values(c),
34
  "ground_truth": "4",
35
  "explanation": "Gaussian-style: 0.6 mass at the highlighted token, 0.4 spread smoothly to its neighbors.",
@@ -38,23 +41,33 @@ gauss = [
38
  ]
39
 
40
 
41
- # (3) Bimodal: two spikes of 0.5 mass each, symmetrically offset from the GT=4 ---
42
- def make_bimodal_values(offset, n=11, gt=4):
43
- # clamp to [0,n-1]
44
- left = max(0, gt - offset)
45
- right = min(n - 1, gt + offset)
46
- vals = [0.0] * n
47
- vals[left] = 0.5
48
- vals[right] = 0.5
49
- return vals
 
 
 
50
 
 
 
 
 
 
 
 
51
 
52
- bimodal = [
53
- {
54
- "name": f"Bimodal: peaks at tokens {max(0, 4 - d)} & {min(10, 4 + d)}",
55
- "values": make_bimodal_values(d),
56
- "ground_truth": "4",
57
- "explanation": "Two-point (bimodal) distribution: equal 0.5 mass on each peak, which move ±offset from the ground truth.",
58
- }
59
- for d in range(11)
60
- ]
 
1
  import numpy as np
2
 
3
+ options = [str(i) for i in range(10)] + ["Text"]
4
+
5
+
6
  # (1) A one-hot moving from token 0 to token 10 (“Text”)
7
  dirac = [
8
  {
9
+ "name": f"Dirac: all mass on token {options[i]}",
10
  "values": [1.0 if j == i else 0.0 for j in range(11)],
11
  "ground_truth": "4",
12
  "explanation": "A Dirac distribution: all probability on a single token.",
 
32
 
33
  gauss = [
34
  {
35
+ "name": f"Gaussian: center at token {options[c]}",
36
  "values": make_gauss_values(c),
37
  "ground_truth": "4",
38
  "explanation": "Gaussian-style: 0.6 mass at the highlighted token, 0.4 spread smoothly to its neighbors.",
 
41
  ]
42
 
43
 
44
+ def make_bimodal_scenarios(gt_token: str, options: list[str]) -> list[dict]:
45
+ """
46
+ Build a list of { name, values, explanation } dicts, where
47
+ each scenario splits 50/50 between tokens (gt±offset),
48
+ wrapping around via Python’s % operator.
49
+ """
50
+ n = len(options)
51
+ gt_idx = options.index(gt_token)
52
+ scenarios = []
53
+ for offset in range(n):
54
+ left = (gt_idx - offset) % n
55
+ right = (gt_idx + offset) % n
56
 
57
+ # build the 50/50 (or 1.0 at gt when offset=0) vector
58
+ vals = [0.0] * n
59
+ if left == right:
60
+ vals[left] = 1.0
61
+ else:
62
+ vals[left] = 0.5
63
+ vals[right] = 0.5
64
 
65
+ label = f"({options[left]}, {options[right]})"
66
+ scenarios.append(
67
+ {
68
+ "name": label,
69
+ "values": vals,
70
+ "explanation": "50/50 mass at these two tokens (wrapping).",
71
+ }
72
+ )
73
+ return scenarios
src/streamlit_app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import time
2
 
3
  import altair as alt
@@ -7,25 +8,38 @@ import streamlit as st
7
  import streamlit_vertical_slider as svs
8
  import torch
9
 
10
- from scenarios import bimodal, dirac, gauss
 
 
11
 
12
  DEMO_INTERVAL = 1.5
13
- NTL_MSE_SCALING = 0.5
14
- MAX_LOSS_PLOT = 15
15
  LAST_STEP = -1
16
 
17
- # """TODO:
18
- # - Remove flickering of loss evolution scenario plot (lower ylim?)
19
- # - Move manual part down (predicted token probabilities)
20
- # - Allow to set GT token for each demo
21
- # - Add text token to loss evolution barplot
22
- # - pick good default (4?)
23
- # """
24
-
25
 
26
  # Define options globally as it's used in initialization and UI
27
  options = [str(i) for i in range(10)] + ["Text"]
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # --- Session State Initialization ---
30
  # Ensure all session state variables are initialized before first use, especially by widgets.
31
  if "running_demo" not in st.session_state:
@@ -44,37 +58,44 @@ if "active_scenarios" not in st.session_state:
44
  if "loss_history" not in st.session_state:
45
  st.session_state.loss_history = []
46
 
 
47
  # Initialize states for sliders and ground_truth selector
48
  # Using len(options) to correctly size for 0-9 + "Text"
49
  for i in range(len(options)):
50
  if f"slider_{i}" not in st.session_state:
51
- st.session_state[f"slider_{i}"] = 1.0 / len(options)
52
  if "ground_truth" not in st.session_state:
53
- st.session_state["ground_truth"] = options[0] # Default to "0"
 
 
 
 
54
 
55
 
56
- st.title("Number Token Loss - Demo")
57
 
58
  st.markdown(
59
- """
60
- **Instructions**
61
-
62
- 1. **Pick a ground truth token (0–9).**
63
- 2. **Select one of the three automated demos:**
64
- - **Dirac**: a one-hot (Dirac) distribution whose single 1.0 mass moves from token 0 all the way to “Text.”
65
- - **Gaussian**: a peaked Gaussian (0.6 mass at center, 0.4 spread) that slides its center from token 0 to “Text.”
66
- - **Bimodal**: two equal peaks (0.5 each) that start at (0,8) and then move symmetrically away from the GT token.
67
  """
68
  )
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  if "ground_truth" not in st.session_state:
71
  st.session_state["ground_truth"] = "4"
72
- gt = st.selectbox(
73
- "Ground Truth Token",
74
- options=options,
75
- index=options.index(st.session_state["ground_truth"]),
76
- key="ground_truth",
77
- )
78
 
79
 
80
  def apply_scenario(step_idx):
@@ -84,7 +105,9 @@ def apply_scenario(step_idx):
84
 
85
 
86
  def start_dirac_demo():
 
87
  st.session_state.active_scenarios = dirac
 
88
  st.session_state.running_demo = True
89
  st.session_state.demo_step = 0
90
  st.session_state.last_update_time = time.time()
@@ -92,7 +115,9 @@ def start_dirac_demo():
92
 
93
 
94
  def start_gauss_demo():
 
95
  st.session_state.active_scenarios = gauss
 
96
  st.session_state.running_demo = True
97
  st.session_state.demo_step = 0
98
  st.session_state.last_update_time = time.time()
@@ -100,7 +125,11 @@ def start_gauss_demo():
100
 
101
 
102
  def start_bimodal_demo():
103
- st.session_state.active_scenarios = bimodal
 
 
 
 
104
  st.session_state.running_demo = True
105
  st.session_state.demo_step = 0
106
  st.session_state.last_update_time = time.time()
@@ -118,11 +147,15 @@ if st.session_state.running_demo:
118
  scenario = st.session_state.active_scenarios
119
  current_time = time.time()
120
  if current_time - st.session_state.last_update_time > DEMO_INTERVAL:
121
- next_step = (st.session_state.demo_step + 1) % len(scenario)
122
- st.session_state.demo_step = next_step
123
- apply_scenario(next_step) # Update session state for the new scenario
124
- st.session_state.last_update_time = time.time() # Reset timer
125
- st.rerun() # Crucial: Rerun to reflect changes in widgets and charts
 
 
 
 
126
 
127
  # --- UI Rendering ---
128
  # This section renders the main UI. It executes after any potential rerun from the block above.
@@ -151,12 +184,9 @@ else:
151
  start_bimodal_demo()
152
  st.rerun()
153
 
154
-
155
- # Placeholder for charts and loss calculations that will be updated
156
- # This section always reads the current st.session_state to generate its content.
157
-
158
  current_prob_values_from_state = [
159
- st.session_state.get(f"slider_{j}", 1.0 / len(options)) for j in range(len(options))
 
160
  ]
161
  total_from_state = sum(current_prob_values_from_state)
162
  probs_for_charts = (
@@ -165,7 +195,12 @@ probs_for_charts = (
165
  else torch.tensor([v / total_from_state for v in current_prob_values_from_state])
166
  )
167
 
168
- gt_choice_for_charts = st.session_state.get("ground_truth", options[0])
 
 
 
 
 
169
  if gt_choice_for_charts == "Text":
170
  gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10)
171
  gt_numeric_for_charts = None
@@ -174,8 +209,9 @@ else:
174
  gt_numeric_for_charts = gt_index_for_charts
175
 
176
  gt = st.session_state["ground_truth"]
 
177
 
178
- st.markdown(f"#### Predicted Probability Distribution Ground truth token {gt}")
179
  df_dist = pd.DataFrame(
180
  {"token": options, "probability": probs_for_charts.numpy().round(2)}
181
  )
@@ -183,26 +219,22 @@ df_dist["type"] = [
183
  "Ground Truth" if token == gt_choice_for_charts else "Prediction"
184
  for token in options
185
  ]
186
- bg = (
187
- alt.Chart(pd.DataFrame({"token": [gt]}))
188
- .mark_bar(size=40, color="lightgray", opacity=0.4)
189
- .encode(
190
- x=alt.X("token:N", sort=options),
191
- x2=alt.X2("token:N"), # pin the right edge to the same category
192
- y=alt.value(0), # bottom at y=0
193
- y2=alt.value(1), # top at y=1 (full height)
194
- )
195
- )
196
 
197
  bars = (
198
  alt.Chart(df_dist)
199
- .mark_bar()
200
  .encode(
201
  x=alt.X(
202
  "token:N",
203
  title="Token",
204
  sort=options,
205
- axis=alt.Axis(labelAngle=0, labelFontSize=14, titleFontSize=16),
 
 
 
 
 
 
206
  ),
207
  y=alt.Y(
208
  "probability:Q",
@@ -210,21 +242,34 @@ bars = (
210
  scale=alt.Scale(domain=[0, 1]),
211
  axis=alt.Axis(format=".2f", labelFontSize=14, titleFontSize=16),
212
  ),
213
- color=alt.Color(
214
- "type:N",
215
- scale=alt.Scale(
216
- domain=["Ground Truth", "Prediction"], range=["green", "steelblue"]
217
- ),
218
- legend=alt.Legend(title="Token Type", titleFontSize=16, labelFontSize=14),
219
- ),
220
  tooltip=[
221
  alt.Tooltip("token:N", title="Token"),
222
- alt.Tooltip("probability:Q", title="Probability", format=".2f"),
223
- alt.Tooltip("type:N", title="Type"),
224
  ],
225
  )
226
- .properties(height=300)
227
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  annot1 = (
229
  alt.Chart(pd.DataFrame({"token": [gt]}))
230
  .mark_text(
@@ -233,12 +278,11 @@ annot1 = (
233
  dx=25,
234
  fontSize=14,
235
  fontWeight="bold",
236
- color="green",
237
  )
238
  .encode(x=alt.X("token:N", sort=options), y=alt.value(1))
239
  )
240
 
241
- # second line: “truth=4”
242
  annot2 = (
243
  alt.Chart(pd.DataFrame({"token": [gt]}))
244
  .mark_text(
@@ -247,185 +291,164 @@ annot2 = (
247
  dx=35,
248
  fontSize=14,
249
  fontWeight="bold",
250
- color="green",
251
  )
252
  .encode(x=alt.X("token:N", sort=options), y=alt.value(1))
253
  )
254
 
255
  # 4) Layer them in order: background, bars, annotation
256
- final_chart = (bg + bars + annot1 + annot2).properties(height=300)
257
 
258
  st.altair_chart(final_chart, use_container_width=True)
 
259
 
260
- ce_loss = -torch.log(torch.clamp(probs_for_charts[gt_index_for_charts], min=1e-9))
261
-
262
- if gt_numeric_for_charts is None: # Text token
263
- ntl_mse_loss = torch.tensor(float("nan")) # MSE not applicable for text
264
- ntl_was_loss = torch.tensor(float("nan")) # WAS not applicable for text
265
- else: # Numeric token
266
- numeric_probs_for_loss = probs_for_charts[:10] # Probabilities for 0-9
267
- # Ensure numeric_probs_for_loss sums to 1 for NTL calculations if it's a subset
268
- numeric_probs_sum = torch.sum(numeric_probs_for_loss)
269
- if numeric_probs_sum > 1e-6: # Avoid division by zero
270
- normalized_numeric_probs = numeric_probs_for_loss / numeric_probs_sum
271
- else:
272
- normalized_numeric_probs = torch.zeros_like(numeric_probs_for_loss)
273
-
274
- loss_values_tensor = torch.arange(0, 10, dtype=torch.float32)
275
 
276
- # Use normalized probabilities for NTL if only considering numeric tokens
277
- if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6:
278
- pred_value = torch.sum(
279
- (probs_for_charts[:10] / torch.sum(probs_for_charts[:10]))
280
- * loss_values_tensor
281
- )
282
- elif (
283
- gt_choice_for_charts != "Text"
284
- ): # if sum is zero, pred_value is ill-defined or 0
285
- pred_value = torch.tensor(0.0)
286
- else: # Should not happen if gt_numeric_for_charts is not None
287
- pred_value = torch.tensor(float("nan"))
288
-
289
- if not torch.isnan(pred_value):
290
- ntl_mse_loss = ntl_mse_loss = (
291
- NTL_MSE_SCALING * (pred_value - float(gt_numeric_for_charts)) ** 2
292
- )
293
- abs_diff = torch.abs(loss_values_tensor - float(gt_numeric_for_charts))
294
- if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6:
295
- ntl_was_loss = torch.sum(
296
- (probs_for_charts[:10] / torch.sum(probs_for_charts[:10])) * abs_diff
297
- )
298
- elif gt_choice_for_charts != "Text":
299
- ntl_was_loss = torch.tensor(0.0)
300
- else:
301
- ntl_was_loss = torch.tensor(float("nan"))
302
 
 
 
 
303
  else:
304
- ntl_mse_loss = torch.tensor(float("nan"))
305
- ntl_was_loss = torch.tensor(float("nan"))
306
-
307
-
308
- ce_val = round(ce_loss.item(), 3)
309
- mse_val = round(ntl_mse_loss.item(), 3) if not torch.isnan(ntl_mse_loss) else "N/A"
310
- was_val = round(ntl_was_loss.item(), 3) if not torch.isnan(ntl_was_loss) else "N/A"
311
 
312
-
313
- if len(st.session_state.loss_history) < st.session_state.demo_step + 1:
314
  st.session_state.loss_history.append(
315
  {
316
- "token_index": np.argmax(
317
- st.session_state.active_scenarios[st.session_state["demo_step"]][
318
- "values"
319
- ]
320
- ),
321
- # int(np.argmax(st.session_state['values']))
322
- # int(),
323
- "CE": ce_val,
324
- "NTL-MSE": mse_val if mse_val != "N/A" else None,
325
- "NTL-WAS": was_val if was_val != "N/A" else None,
326
  }
327
  )
328
- last_step = st.session_state.demo_step
329
-
330
- if st.session_state.loss_history:
331
- loss_plot_data = []
332
- for entry in st.session_state.loss_history:
333
- for loss_type in ["CE", "NTL-MSE", "NTL-WAS"]:
334
- if entry[loss_type] is not None:
335
- loss_plot_data.append(
336
- {
337
- "Token Index": entry["token_index"],
338
- "Loss Type": loss_type,
339
- "Loss Value": entry[loss_type], # TODO: clip to MAX_LOSS_PLOT?
340
- }
341
- )
342
-
343
- df_loss_plot = pd.DataFrame(loss_plot_data)
 
344
 
345
  loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]}
346
  if was_val != "N/A":
347
  loss_data["Loss"].append("NTL-WAS")
348
  loss_data["Value"].append(was_val)
349
- if mse_val != "N/A":
350
- loss_data["Loss"].append("NTL-MSE")
351
- loss_data["Value"].append(mse_val)
352
 
353
  loss_df = pd.DataFrame(loss_data)
354
 
355
- # ============== Chart Display ==============
 
 
 
 
 
356
 
357
 
358
- st.subheader("Loss Evolution Over Scenarios")
 
359
 
360
- x_domain = list(range(10))
361
 
362
  grouped_chart = (
363
  alt.Chart(df_loss_plot)
364
  .mark_bar()
365
  .encode(
366
  x=alt.X(
367
- "Token Index:O",
368
- title="Predicted Token Index",
369
- axis=alt.Axis(labelAngle=0),
370
- scale=alt.Scale(domain=x_domain),
 
371
  ),
372
  y=alt.Y(
373
- "Loss Value:Q", title="Loss", scale=alt.Scale(domain=[0, MAX_LOSS_PLOT])
 
 
 
374
  ),
375
- color=alt.Color("Loss Type:N", legend=alt.Legend(title="Loss")),
376
- xOffset="Loss Type:N", # <== this causes the grouping instead of stacking
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  )
378
- .properties(height=300)
379
  )
380
-
381
  st.altair_chart(grouped_chart, use_container_width=True)
382
 
383
 
384
  # Create a single chart for loss visualization
385
- st.subheader("Loss Comparison")
386
- st.markdown("""
387
- Adjust the sliders to set a predicted probability for each token (0-9 and "Text").
388
- The sliders are vertical and compact. The app normalizes the slider values
389
- to form a valid probability distribution, visualizes it, and computes the corresponding
390
- Cross Entropy, NTL-MSE, and NTL-WAS losses.
391
- """)
392
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
- # Create an Altair chart that will look good and redraw cleanly
395
- chart = (
396
- alt.Chart(loss_df)
397
- .mark_bar()
398
- .encode(
399
- x=alt.X("Loss:N", sort=loss_df["Loss"].tolist()),
400
- y=alt.Y(
401
- "Value:Q",
402
- scale=alt.Scale(
403
- domain=[
404
- 0,
405
- max(
406
- loss_df["Value"].max() * 1.2,
407
- 20 if st.session_state.running_demo else 0.5,
408
- ),
409
- ]
410
- ),
411
- ),
412
- color=alt.Color(
413
- "Loss:N",
414
- scale=alt.Scale(
415
- domain=["Cross Entropy", "NTL-WAS", "NTL-MSE"],
416
- range=["steelblue", "red", "forestgreen"],
417
- ),
418
- ),
419
- tooltip=["Loss", "Value"],
420
  )
421
- .properties(height=300)
422
- )
423
 
424
- # Sliders and Ground Truth Selector
425
- # These widgets will read their initial values from st.session_state.
426
- # User interactions will update st.session_state directly due to their keys.
427
- if not st.session_state.running_demo:
428
- st.markdown("#### Predicted Token Probabilities")
429
  cols = st.columns(len(options))
430
  for i, col in enumerate(cols):
431
  label = options[i] # Use token name directly for label
@@ -436,23 +459,58 @@ if not st.session_state.running_demo:
436
  max_value=1.0,
437
  step=0.01,
438
  height=50,
439
- key=f"slider_{i}", # This key links the widget to st.session_state[f"slider_{i}"]
440
  slider_color="green",
441
  track_color="lightgray",
442
  thumb_color="black",
443
  )
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
- # Add value labels on top of bars
447
- text = chart.mark_text(align="center", baseline="bottom", dy=-5, fontSize=14).encode(
448
- text=alt.Text("Value:Q", format=".3f")
449
- )
450
 
451
- # Combine chart and text
452
- final_chart = chart + text
 
 
 
 
 
453
 
454
  # Display chart with the full container width
455
- st.altair_chart(final_chart, use_container_width=True)
456
 
457
  # --- Polling Rerun for Demo Mode ---
458
  # If the demo is running and we haven't just advanced (which would have caused a rerun),
@@ -463,17 +521,18 @@ if st.session_state.running_demo:
463
  time.sleep(0.1)
464
  st.rerun()
465
 
466
- # Add explanation of the demonstration
467
  st.markdown("""
468
- ### What Does This Demo Show?
 
 
469
 
470
- - **Cross Entropy Loss**: Only cares if the prediction is exactly right or wrong - it doesn't consider how "close" a numerical prediction is.
471
- - **Number Token Loss (NTL)**: Considers numerical proximity - predicting "7" when the true value is "8" is better than predicting "2".
472
  """)
473
 
474
- # References / resources section with links (common to both modes)
475
- st.markdown("### Resources")
476
  st.markdown("""
477
- - [Paper: Number Token Loss (ArXiv)](https://arxiv.org/abs/2411.02083)
478
- - [GitHub: Number Token Loss](https://github.com/tum-ai/number-token-loss)
 
479
  """)
 
1
+ import logging
2
  import time
3
 
4
  import altair as alt
 
8
  import streamlit_vertical_slider as svs
9
  import torch
10
 
11
+ from scenarios import dirac, gauss, make_bimodal_scenarios
12
+
13
+ logging.getLogger("streamlit.watcher.local_sources_watcher").setLevel(logging.ERROR)
14
 
15
  DEMO_INTERVAL = 1.5
16
+ CE_SCALING = 0.25
17
+ MAX_LOSS_PLOT = 6
18
  LAST_STEP = -1
19
 
 
 
 
 
 
 
 
 
20
 
21
  # Define options globally as it's used in initialization and UI
22
  options = [str(i) for i in range(10)] + ["Text"]
23
 
24
+
25
+ def compute_losses(probs: torch.Tensor, gt_token: str) -> tuple[float, float, float]:
26
+ """Compute CE, NTL-MAE, NTL-WAS losses for the given probability vector and ground truth token."""
27
+ ce_loss = CE_SCALING * -torch.log(
28
+ torch.clamp(probs[options.index(gt_token)], min=1e-9)
29
+ )
30
+
31
+ numeric_mass = probs[:10].sum()
32
+
33
+ if gt_token == "Text" or numeric_mass < 1e-6:
34
+ return ce_loss.item(), 0.0, 0.0
35
+
36
+ gt_numeric = int(gt_token)
37
+ token_vals = torch.arange(10, dtype=torch.float32)
38
+ mae = numeric_mass * abs(torch.dot(token_vals, probs[:10]) - gt_numeric)
39
+ was = numeric_mass * torch.dot(probs[:10], torch.abs(token_vals - gt_numeric))
40
+ return round(ce_loss.item(), 3), round(mae.item(), 3), round(was.item(), 3)
41
+
42
+
43
  # --- Session State Initialization ---
44
  # Ensure all session state variables are initialized before first use, especially by widgets.
45
  if "running_demo" not in st.session_state:
 
58
  if "loss_history" not in st.session_state:
59
  st.session_state.loss_history = []
60
 
61
+
62
  # Initialize states for sliders and ground_truth selector
63
  # Using len(options) to correctly size for 0-9 + "Text"
64
  for i in range(len(options)):
65
  if f"slider_{i}" not in st.session_state:
66
+ st.session_state[f"slider_{i}"] = 0
67
  if "ground_truth" not in st.session_state:
68
+ st.session_state["ground_truth"] = options[5]
69
+ if "manual_ground_truth" not in st.session_state:
70
+ st.session_state["manual_ground_truth"] = options[5]
71
+ if "demo_name" not in st.session_state:
72
+ st.session_state["demo_name"] = "Dirac"
73
 
74
 
75
+ st.title("NTL -- The Number Token Loss 🚀")
76
 
77
  st.markdown(
78
+ """This is the interactive demo for our [ICML 2025](https://arxiv.org/abs/2411.02083) paper!🎉
79
+ ➡️ NTL augments cross-entropy to help LMs reason better with numbers 🧠
 
 
 
 
 
 
80
  """
81
  )
82
 
83
+ st.subheader("Demo 1 — NTL vs. Cross Entropy in 3 Scenarios")
84
+
85
+ st.markdown("""
86
+ 1️⃣ Pick a ground truth token: a digit (0–9) or "Text" 📝 (simulates generic text tokens).
87
+ 2️⃣ Choose a demo:
88
+ - **Dirac** ⚡: All probability mass on one token.
89
+ - **Gaussian** 🌊: Soft bell-curve around the true number.
90
+ - **Bimodal** 🎯: Two peaks moving away from the target.
91
+
92
+ Watch how losses evolve as predictions get worse — and see how NTL shines compared to CE! 🌟
93
+ """)
94
+
95
+
96
  if "ground_truth" not in st.session_state:
97
  st.session_state["ground_truth"] = "4"
98
+ gt = st.selectbox("Ground Truth Token", options=options, key="ground_truth")
 
 
 
 
 
99
 
100
 
101
  def apply_scenario(step_idx):
 
105
 
106
 
107
  def start_dirac_demo():
108
+ st.session_state.loss_history = []
109
  st.session_state.active_scenarios = dirac
110
+ st.session_state.demo_name = "Dirac"
111
  st.session_state.running_demo = True
112
  st.session_state.demo_step = 0
113
  st.session_state.last_update_time = time.time()
 
115
 
116
 
117
  def start_gauss_demo():
118
+ st.session_state.loss_history = []
119
  st.session_state.active_scenarios = gauss
120
+ st.session_state.demo_name = "Gauss"
121
  st.session_state.running_demo = True
122
  st.session_state.demo_step = 0
123
  st.session_state.last_update_time = time.time()
 
125
 
126
 
127
  def start_bimodal_demo():
128
+ st.session_state.loss_history = []
129
+ gt = st.session_state["ground_truth"]
130
+ st.session_state.active_scenarios = make_bimodal_scenarios(gt, options)
131
+
132
+ st.session_state.demo_name = f"Bimodal (GT={gt})"
133
  st.session_state.running_demo = True
134
  st.session_state.demo_step = 0
135
  st.session_state.last_update_time = time.time()
 
147
  scenario = st.session_state.active_scenarios
148
  current_time = time.time()
149
  if current_time - st.session_state.last_update_time > DEMO_INTERVAL:
150
+ # if we haven’t yet shown the last scenario, advance
151
+ if st.session_state.demo_step < len(scenario) - 1:
152
+ st.session_state.demo_step += 1
153
+ apply_scenario(st.session_state.demo_step)
154
+ st.session_state.last_update_time = current_time
155
+ st.rerun()
156
+ else:
157
+ # we just displayed the final case → stop
158
+ st.session_state.running_demo = False
159
 
160
  # --- UI Rendering ---
161
  # This section renders the main UI. It executes after any potential rerun from the block above.
 
184
  start_bimodal_demo()
185
  st.rerun()
186
 
 
 
 
 
187
  current_prob_values_from_state = [
188
+ st.session_state.get(f"slider_{j}", 0)
189
+ for j in range(len(options)) # 1.0 / len(options)) for j in range(len(options))
190
  ]
191
  total_from_state = sum(current_prob_values_from_state)
192
  probs_for_charts = (
 
195
  else torch.tensor([v / total_from_state for v in current_prob_values_from_state])
196
  )
197
 
198
+ # Use manual GT token when not in running demo
199
+ gt_choice_for_charts = (
200
+ st.session_state["manual_ground_truth"]
201
+ if not st.session_state.running_demo
202
+ else st.session_state["ground_truth"]
203
+ )
204
  if gt_choice_for_charts == "Text":
205
  gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10)
206
  gt_numeric_for_charts = None
 
209
  gt_numeric_for_charts = gt_index_for_charts
210
 
211
  gt = st.session_state["ground_truth"]
212
+ demo_name = st.session_state["demo_name"]
213
 
214
+ st.markdown(f"#### Predicted distributionground truth: {gt}")
215
  df_dist = pd.DataFrame(
216
  {"token": options, "probability": probs_for_charts.numpy().round(2)}
217
  )
 
219
  "Ground Truth" if token == gt_choice_for_charts else "Prediction"
220
  for token in options
221
  ]
 
 
 
 
 
 
 
 
 
 
222
 
223
  bars = (
224
  alt.Chart(df_dist)
225
+ .mark_bar(color="dodgerblue", size=40)
226
  .encode(
227
  x=alt.X(
228
  "token:N",
229
  title="Token",
230
  sort=options,
231
+ axis=alt.Axis(
232
+ labelAngle=0,
233
+ labelFontSize=14,
234
+ titleFontSize=16,
235
+ labelAlign="center",
236
+ labelFlush=False,
237
+ ),
238
  ),
239
  y=alt.Y(
240
  "probability:Q",
 
242
  scale=alt.Scale(domain=[0, 1]),
243
  axis=alt.Axis(format=".2f", labelFontSize=14, titleFontSize=16),
244
  ),
 
 
 
 
 
 
 
245
  tooltip=[
246
  alt.Tooltip("token:N", title="Token"),
247
+ alt.Tooltip("probability:Q", title="Predicted Prob.", format=".2f"),
 
248
  ],
249
  )
 
250
  )
251
+
252
+ bg_bar = pd.DataFrame({"token": [gt], "height": [1.0]})
253
+ gt_bar = (
254
+ alt.Chart(bg_bar)
255
+ .mark_bar(
256
+ color="darkgreen",
257
+ size=20,
258
+ opacity=0.3,
259
+ stroke="gray",
260
+ strokeWidth=2,
261
+ strokeDash=[4, 4],
262
+ )
263
+ .encode(
264
+ x=alt.X("token:N", sort=options),
265
+ y=alt.Y("height:Q", scale=alt.Scale(domain=[0, 1])),
266
+ tooltip=[
267
+ alt.Tooltip("token:N", title="Ground Truth"),
268
+ alt.Tooltip("height:Q", title="Desired mass", format=".2f"),
269
+ ],
270
+ )
271
+ )
272
+
273
  annot1 = (
274
  alt.Chart(pd.DataFrame({"token": [gt]}))
275
  .mark_text(
 
278
  dx=25,
279
  fontSize=14,
280
  fontWeight="bold",
281
+ color="darkgreen",
282
  )
283
  .encode(x=alt.X("token:N", sort=options), y=alt.value(1))
284
  )
285
 
 
286
  annot2 = (
287
  alt.Chart(pd.DataFrame({"token": [gt]}))
288
  .mark_text(
 
291
  dx=35,
292
  fontSize=14,
293
  fontWeight="bold",
294
+ color="darkgreen",
295
  )
296
  .encode(x=alt.X("token:N", sort=options), y=alt.value(1))
297
  )
298
 
299
  # 4) Layer them in order: background, bars, annotation
300
+ final_chart = (gt_bar + bars + annot1 + annot2).properties(height=200)
301
 
302
  st.altair_chart(final_chart, use_container_width=True)
303
+ ce_val, mae_val, was_val = compute_losses(probs_for_charts, gt_choice_for_charts)
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
+ if (
307
+ st.session_state.running_demo
308
+ and len(st.session_state.loss_history) < st.session_state.demo_step + 1
309
+ ):
310
+ step = st.session_state.demo_step
311
+ scenario = st.session_state.active_scenarios[step]
312
+ ce, mae, was = compute_losses(probs_for_charts, gt_choice_for_charts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
+ # pick x_val differently for bimodal vs others
315
+ if st.session_state.demo_name.startswith("Bimodal"):
316
+ x_val = scenario["name"] # e.g. "(4,4)", "(3,5)", …
317
  else:
318
+ # exactly like before:
319
+ best_idx = np.argmax(scenario["values"])
320
+ x_val = options[best_idx] # "0", "1", …, or "Text"
 
 
 
 
321
 
 
 
322
  st.session_state.loss_history.append(
323
  {
324
+ "step": step,
325
+ "x_val": x_val,
326
+ "Cross Entropy": ce,
327
+ "NTL-MAE": mae,
328
+ "NTL-WAS": was,
 
 
 
 
 
329
  }
330
  )
331
+
332
+
333
+ # 1) build a raw DF from histories
334
+ df = pd.DataFrame(st.session_state.loss_history)
335
+
336
+ if df.empty:
337
+ # define an empty "melted" DataFrame with the right columns
338
+ df_loss_plot = pd.DataFrame(columns=["step", "x_val", "Loss Type", "Loss Value"])
339
+ else:
340
+ # now it's safe to melt
341
+ df_loss_plot = df.melt(
342
+ id_vars=["step", "x_val"],
343
+ value_vars=["Cross Entropy", "NTL-MAE", "NTL-WAS"],
344
+ var_name="Loss Type",
345
+ value_name="Loss Value",
346
+ )
347
+
348
 
349
  loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]}
350
  if was_val != "N/A":
351
  loss_data["Loss"].append("NTL-WAS")
352
  loss_data["Value"].append(was_val)
353
+ if mae_val != "N/A":
354
+ loss_data["Loss"].append("NTL-MAE")
355
+ loss_data["Value"].append(mae_val)
356
 
357
  loss_df = pd.DataFrame(loss_data)
358
 
359
+ if st.session_state.demo_name.startswith("Bimodal"):
360
+ domain = [sc["name"] for sc in st.session_state.active_scenarios]
361
+ x_title = f"Offset from GT {st.session_state['ground_truth']}"
362
+ else:
363
+ domain = options
364
+ x_title = f"Maximum of predicted {st.session_state['demo_name']} distribution"
365
 
366
 
367
+ # ============== Chart Display ==============
368
+
369
 
370
+ st.markdown("#### Loss as a function of predicted distribution")
371
 
372
  grouped_chart = (
373
  alt.Chart(df_loss_plot)
374
  .mark_bar()
375
  .encode(
376
  x=alt.X(
377
+ "x_val:N",
378
+ title=x_title,
379
+ sort=domain,
380
+ scale=alt.Scale(domain=domain),
381
+ axis=alt.Axis(labelAngle=0, labelFontSize=14, titleFontSize=16),
382
  ),
383
  y=alt.Y(
384
+ "Loss Value:Q",
385
+ title="Loss Value",
386
+ scale=alt.Scale(domain=[0, MAX_LOSS_PLOT], nice=False, clamp=True),
387
+ axis=alt.Axis(labelFontSize=14, titleFontSize=16),
388
  ),
389
+ color=alt.Color(
390
+ "Loss Type:N",
391
+ scale=alt.Scale(
392
+ domain=["Cross Entropy", "NTL-WAS", "NTL-MAE"],
393
+ range=["red", "limegreen", "blueviolet"],
394
+ ),
395
+ legend=alt.Legend(
396
+ title="",
397
+ orient="top",
398
+ direction="horizontal",
399
+ columns=3,
400
+ ),
401
+ ),
402
+ xOffset="Loss Type:N", # grouped bars
403
+ tooltip=[
404
+ alt.Tooltip("x_val:N", title="Scenario"),
405
+ alt.Tooltip("Loss Type:N", title="Loss Type"),
406
+ alt.Tooltip("Loss Value:Q", title="Value", format=".3f"),
407
+ ],
408
  )
409
+ .properties(height=250)
410
  )
 
411
  st.altair_chart(grouped_chart, use_container_width=True)
412
 
413
 
414
  # Create a single chart for loss visualization
415
+ if not st.session_state.running_demo:
416
+ for i in range(len(options)):
417
+ st.session_state[f"slider_{i}"] = 0.0
418
+ st.session_state.demo_step = 0
 
 
 
419
 
420
+ st.subheader("Demo 2 -- Manual loss comparison")
421
+ st.subheader("🧪 Demo 2 — Craft your own distribution")
422
+ st.markdown("""
423
+ This demo gives you more control but is harder to interpret. See it as a playground! 🎨
424
+ Manually adjust the sliders to change the predicted probabilities for each token.
425
+ The demo normalizes the values to form a valid probability distribution and calculates the losses.
426
+
427
+ 👣 **Steps:**
428
+ - Use the **vertical sliders** to allocate probability to each token.
429
+ - Choose the correct **Ground Truth Token** (0–9 or "Text" 📜).
430
+ - Observe how each loss function reacts.
431
+
432
+ 💡 **Tip:** Want to trick the loss? Try putting all mass on the wrong token or spread it wildly. See how NTL handles it! 😈
433
+ """)
434
+
435
+ manual_gt = st.selectbox(
436
+ "Ground Truth Token",
437
+ options=options,
438
+ key="manual_ground_truth",
439
+ )
440
 
441
+ loss_df = pd.DataFrame(
442
+ {
443
+ "Loss": ["Cross Entropy", "NTL-MAE", "NTL-WAS"],
444
+ "Value": [ce_val, mae_val, was_val],
445
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  )
 
 
447
 
448
+ # Sliders and Ground Truth Selector
449
+ # These widgets will read their initial values from st.session_state.
450
+ # User interactions will update st.session_state directly due to their keys.
451
+ st.markdown("#### Adjust the predicted token probability")
 
452
  cols = st.columns(len(options))
453
  for i, col in enumerate(cols):
454
  label = options[i] # Use token name directly for label
 
459
  max_value=1.0,
460
  step=0.01,
461
  height=50,
462
+ key=f"slider_{i}",
463
  slider_color="green",
464
  track_color="lightgray",
465
  thumb_color="black",
466
  )
467
 
468
+ chart = (
469
+ alt.Chart(loss_df)
470
+ .mark_bar()
471
+ .encode(
472
+ x=alt.X("Loss:N", sort=loss_df["Loss"].tolist()),
473
+ y=alt.Y(
474
+ "Value:Q",
475
+ scale=alt.Scale(
476
+ domain=[
477
+ 0,
478
+ max(
479
+ loss_df["Value"].max() * 1.2,
480
+ 20 if st.session_state.running_demo else 0.5,
481
+ ),
482
+ ]
483
+ ),
484
+ ),
485
+ color=alt.Color(
486
+ "Loss:N",
487
+ scale=alt.Scale(
488
+ domain=["Cross Entropy", "NTL-WAS", "NTL-MAE"],
489
+ range=["orangered", "limegreen", "blueviolet"],
490
+ ),
491
+ ),
492
+ tooltip=["Loss", "Value"],
493
+ )
494
+ .properties(height=300)
495
+ )
496
+
497
+ text = chart.mark_text(
498
+ align="center", baseline="bottom", dy=-5, fontSize=14
499
+ ).encode(text=alt.Text("Value:Q", format=".3f"))
500
+ final_chart = chart + text
501
+ st.altair_chart(final_chart, use_container_width=True)
502
 
 
 
 
 
503
 
504
+ # # Add value labels on top of bars
505
+ # text = chart.mark_text(align="center", baseline="bottom", dy=-5, fontSize=14).encode(
506
+ # text=alt.Text("Value:Q", format=".3f")
507
+ # )
508
+
509
+ # # Combine chart and text
510
+ # final_chart = chart + text
511
 
512
  # Display chart with the full container width
513
+ # st.altair_chart(final_chart, use_container_width=True)
514
 
515
  # --- Polling Rerun for Demo Mode ---
516
  # If the demo is running and we haven't just advanced (which would have caused a rerun),
 
521
  time.sleep(0.1)
522
  st.rerun()
523
 
524
+
525
  st.markdown("""
526
+ ### 🤔 TL;DR Why NTL?
527
+ Cross Entropy only cares if the prediction is exactly right or wrong ❌✅ — it doesn’t care *how close* a guess is!
528
+ That’s bad for LLMs doing math and numeric reasoning 🧮.
529
 
530
+ 💥 NTL fixes that: it behaves like a regression loss on the token head, rewarding predictions that are numerically close.
 
531
  """)
532
 
533
+ st.markdown("#### 📚 Further Resources")
 
534
  st.markdown("""
535
+ - 📄 [ICML 2025 Paper](https://arxiv.org/abs/2411.02083)
536
+ - 🌐 [NTL Landing Page](https://tum-ai.github.io/number-token-loss/)
537
+ - 💻 [GitHub Code](https://github.com/tum-ai/number-token-loss)
538
  """)