jannisborn commited on
Commit
2bd2b96
·
unverified ·
1 Parent(s): b958e0e
Files changed (1) hide show
  1. src/streamlit_app.py +493 -90
src/streamlit_app.py CHANGED
@@ -2,9 +2,35 @@ import altair as alt
2
  import pandas as pd
3
  import streamlit_vertical_slider as svs
4
  import torch
5
- from streamlit_vertical_slider import vertical_slider
6
-
7
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  st.title("Number Token Loss - Demo")
10
 
@@ -15,109 +41,486 @@ to form a valid probability distribution, visualizes it, and computes the corres
15
  Cross Entropy, NTL-MSE, and NTL-WAS losses.
16
  """)
17
 
18
- # Vertical sliders for predicted probabilities of tokens 0-9 and "Text"
19
- st.markdown("#### Predicted Token Probabilities")
20
- cols = st.columns(11)
21
- prob_values = []
22
- for i, col in enumerate(cols):
23
- label = f"Token {i}" if i < 10 else "Text"
24
- with col:
25
- val = svs.vertical_slider(
26
- label=label,
27
- min_value=0.0,
28
- max_value=1.0,
29
- step=0.1,
30
- height=50,
31
- key=f"slider_{i}",
32
- slider_color="green",
33
- track_color="lightgray",
34
- thumb_color="black",
35
- )
36
- prob_values.append(val)
37
-
38
- # Normalize the probabilities to sum to 1
39
- total = sum(prob_values)
40
- probs = (
41
- torch.ones(11) / 11.0
42
- if total == 0
43
- else torch.tensor([v / total for v in prob_values])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
45
 
46
- # Token labels
47
- options = [str(i) for i in range(10)] + ["Text"]
 
 
 
 
 
 
 
 
48
 
49
- # Ground truth token selection
50
- gt_choice = st.selectbox("Ground Truth Token", options=options, index=0)
51
- if gt_choice == "Text":
52
- gt_index = 10
53
- gt_numeric = None
54
  else:
55
- gt_index = int(gt_choice)
56
- gt_numeric = gt_index
57
 
58
- # Visualize the input distribution with highlighted ground truth bar
59
  st.markdown("#### Input Probability Distribution")
60
- df_dist = pd.DataFrame({"token": options, "probability": probs.numpy()})
 
61
  chart = (
62
- alt.Chart(df_dist)
63
- .mark_bar()
64
- .encode(
65
- x=alt.X("token:N", title="Token"),
66
  y=alt.Y("probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1])),
67
- color=alt.condition(
68
- alt.datum.token == gt_choice,
69
- alt.value("green"), # Highlight ground truth token
70
- alt.value("steelblue"), # Other tokens
71
- ),
72
- )
73
- .properties(height=300)
74
  )
75
  st.altair_chart(chart, use_container_width=True)
76
 
77
- # Compute Cross Entropy loss: -log(predicted probability of the ground truth)
78
- ce_loss = -torch.log(torch.clamp(probs[gt_index], min=1e-9))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- # Compute NTL-MSE loss
81
- if gt_numeric is None:
82
- ntl_mse_loss = torch.tensor(0.0)
83
- else:
84
- numeric_probs = probs[:10]
85
- values = torch.arange(0, 10, dtype=torch.float32)
86
- pred_value = torch.sum(numeric_probs * values)
87
- ntl_mse_loss = (pred_value - float(gt_numeric)) ** 2
88
-
89
- # Compute NTL-WAS loss
90
- if gt_numeric is None:
91
- ntl_was_loss = torch.tensor(0.0)
92
- else:
93
- numeric_probs = probs[:10]
94
- values = torch.arange(0, 10, dtype=torch.float32)
95
- abs_diff = torch.abs(values - float(gt_numeric))
96
- ntl_was_loss = torch.sum(numeric_probs * abs_diff)
97
 
98
- # Convert losses to Python floats and round to 3 decimals
99
  ce_val = round(ce_loss.item(), 3)
100
- mse_val = round(ntl_mse_loss.item(), 3)
101
- was_val = round(ntl_was_loss.item(), 3)
102
 
103
- # Display numeric values of the losses
104
- st.subheader("Loss Values")
105
- st.write(f"**Cross Entropy:** {ce_val:.3f}")
106
- st.write(f"**NTL-MSE:** {mse_val:.3f}")
107
- st.write(f"**NTL-WAS:** {was_val:.3f}")
108
 
109
- # Bar chart comparison of the three losses
110
- st.subheader("Loss Comparison Chart")
111
- loss_df = pd.DataFrame(
112
- {
113
- "Loss": ["Cross Entropy", "NTL-MSE", "NTL-WAS"],
114
- "Value": [ce_val, mse_val, was_val],
115
- }
116
- ).set_index("Loss")
117
- st.bar_chart(loss_df)
118
 
119
- # References / resources section with links
120
- st.markdown("### Resources")
121
- st.markdown(
122
- "- **Paper:** [Regress, Don't Guess – A Regression-like Loss on Number Tokens for Language Models](https://arxiv.org/abs/2411.02083) \n- **Code:** [tum-ai/number-token-loss (GitHub)](https://github.com/tum-ai/number-token-loss)"
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import pandas as pd
3
  import streamlit_vertical_slider as svs
4
  import torch
5
+ # from streamlit_vertical_slider import vertical_slider # Not directly used, svs.vertical_slider is
 
6
  import streamlit as st
7
+ import time
8
+ import plotly.graph_objects as go # Add Plotly import
9
+
10
+ # Define options globally as it's used in initialization and UI
11
+ options = [str(i) for i in range(10)] + ["Text"]
12
+
13
+ # --- Session State Initialization ---
14
+ # Ensure all session state variables are initialized before first use, especially by widgets.
15
+ if 'running_demo' not in st.session_state:
16
+ st.session_state.running_demo = False
17
+ if 'demo_step' not in st.session_state:
18
+ st.session_state.demo_step = 0
19
+ if 'last_update_time' not in st.session_state:
20
+ st.session_state.last_update_time = 0
21
+ if 'loss_container' not in st.session_state:
22
+ st.session_state.loss_container = None
23
+ if 'previous_chart_html' not in st.session_state:
24
+ st.session_state.previous_chart_html = ""
25
+
26
+ # Initialize states for sliders and ground_truth selector
27
+ # Using len(options) to correctly size for 0-9 + "Text"
28
+ for i in range(len(options)):
29
+ if f"slider_{i}" not in st.session_state:
30
+ st.session_state[f"slider_{i}"] = 1.0 / len(options)
31
+ if 'ground_truth' not in st.session_state:
32
+ st.session_state['ground_truth'] = options[0] # Default to "0"
33
+
34
 
35
  st.title("Number Token Loss - Demo")
36
 
 
41
  Cross Entropy, NTL-MSE, and NTL-WAS losses.
42
  """)
43
 
44
+ # --- Scenario Definitions ---
45
+ scenarios = [
46
+ {
47
+ "name": "Probability mass at 0",
48
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
49
+ "ground_truth": "0",
50
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
51
+ },
52
+ {
53
+ "name": "Probability mass at 0",
54
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
55
+ "ground_truth": "1",
56
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
57
+ },
58
+ {
59
+ "name": "Probability mass at 0",
60
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
61
+ "ground_truth": "2",
62
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
63
+ },
64
+ {
65
+ "name": "Probability mass at 0",
66
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
67
+ "ground_truth": "3",
68
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
69
+ },
70
+ {
71
+ "name": "Probability mass at 0",
72
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
73
+ "ground_truth": "4",
74
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
75
+ },
76
+ {
77
+ "name": "Probability mass at 0",
78
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
79
+ "ground_truth": "5",
80
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
81
+ },
82
+ {
83
+ "name": "Probability mass at 0",
84
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
85
+ "ground_truth": "6",
86
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
87
+ },
88
+ {
89
+ "name": "Probability mass at 0",
90
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
91
+ "ground_truth": "7",
92
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
93
+ },
94
+ {
95
+ "name": "Probability mass at 0",
96
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
97
+ "ground_truth": "8",
98
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
99
+ },
100
+ {
101
+ "name": "Probability mass at 0",
102
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
103
+ "ground_truth": "9",
104
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
105
+ },
106
+
107
+
108
+ {
109
+ "name": "Probability mass around 5",
110
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
111
+ "ground_truth": "0",
112
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
113
+ },
114
+ {
115
+ "name": "Probability mass around 5",
116
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
117
+ "ground_truth": "1",
118
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
119
+ },
120
+ {
121
+ "name": "Probability mass around 5",
122
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
123
+ "ground_truth": "2",
124
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
125
+ },
126
+ {
127
+ "name": "Probability mass around 5",
128
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
129
+ "ground_truth": "3",
130
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
131
+ },
132
+ {
133
+ "name": "Probability mass around 5",
134
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
135
+ "ground_truth": "4",
136
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
137
+ },
138
+ {
139
+ "name": "Probability mass around ground truth (5)",
140
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
141
+ "ground_truth": "5",
142
+ "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
143
+ },
144
+ {
145
+ "name": "Probability mass around 5",
146
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
147
+ "ground_truth": "6",
148
+ "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
149
+ },
150
+ {
151
+ "name": "Probability mass around 5",
152
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
153
+ "ground_truth": "7",
154
+ "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
155
+ },
156
+ {
157
+ "name": "Probability mass around 5",
158
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
159
+ "ground_truth": "8",
160
+ "explanation": "Cross Entropy is high, NTL is higher but still penalizes less than CE because distribution knows it's a number."
161
+ },
162
+ {
163
+ "name": "Probability mass around 5",
164
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
165
+ "ground_truth": "9",
166
+ "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
167
+ },
168
+
169
+ {
170
+ "name": "Probability mass concentrated on 5",
171
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
172
+ "ground_truth": "0",
173
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
174
+ },
175
+ {
176
+ "name": "Probability mass concentrated on 5",
177
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
178
+ "ground_truth": "1",
179
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
180
+ },
181
+ {
182
+ "name": "Probability mass concentrated on 5",
183
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
184
+ "ground_truth": "2",
185
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
186
+ },
187
+ {
188
+ "name": "Probability mass concentrated on 5",
189
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
190
+ "ground_truth": "3",
191
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
192
+ },
193
+ {
194
+ "name": "Probability mass concentrated on 5",
195
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
196
+ "ground_truth": "4",
197
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
198
+ },
199
+ {
200
+ "name": "Probability mass concentrated on 5",
201
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
202
+ "ground_truth": "5",
203
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
204
+ },
205
+ {
206
+ "name": "Probability mass concentrated on 5",
207
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
208
+ "ground_truth": "6",
209
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
210
+ },
211
+ {
212
+ "name": "Probability mass concentrated on 5",
213
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
214
+ "ground_truth": "7",
215
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
216
+ },
217
+ {
218
+ "name": "Probability mass concentrated on 5",
219
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
220
+ "ground_truth": "8",
221
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
222
+ },
223
+ {
224
+ "name": "Probability mass concentrated on 5",
225
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
226
+ "ground_truth": "9",
227
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
228
+ },
229
+
230
+
231
+ {
232
+ "name": "Probability mass concentrated on 1",
233
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
234
+ "ground_truth": "0",
235
+ "explanation": "Both losses are low because the prediction is correct."
236
+ },
237
+ {
238
+ "name": "Probability mass concentrated on 1",
239
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
240
+ "ground_truth": "1",
241
+ "explanation": "Both losses are low because the prediction is correct."
242
+ },
243
+ {
244
+ "name": "Probability mass concentrated on 1",
245
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
246
+ "ground_truth": "2",
247
+ "explanation": "Both losses are low because the prediction is correct."
248
+ },
249
+ {
250
+ "name": "Probability mass concentrated on 1",
251
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
252
+ "ground_truth": "3",
253
+ "explanation": "Both losses are low because the prediction is correct."
254
+ },
255
+ {
256
+ "name": "Probability mass concentrated on 1",
257
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
258
+ "ground_truth": "4",
259
+ "explanation": "Both losses are low because the prediction is correct."
260
+ },
261
+ {
262
+ "name": "Probability mass concentrated on 1",
263
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
264
+ "ground_truth": "5",
265
+ "explanation": "Both losses are low because the prediction is correct."
266
+ },
267
+ {
268
+ "name": "Probability mass concentrated on 1",
269
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
270
+ "ground_truth": "6",
271
+ "explanation": "Both losses are low because the prediction is correct."
272
+ },
273
+ {
274
+ "name": "Probability mass concentrated on 1",
275
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
276
+ "ground_truth": "7",
277
+ "explanation": "Both losses are low because the prediction is correct."
278
+ },
279
+ {
280
+ "name": "Probability mass concentrated on 1",
281
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
282
+ "ground_truth": "8",
283
+ "explanation": "Both losses are low because the prediction is correct."
284
+ },
285
+ {
286
+ "name": "Probability mass concentrated on 1",
287
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
288
+ "ground_truth": "9",
289
+ "explanation": "Both losses are low because the prediction is correct."
290
+ },
291
+
292
+
293
+ {
294
+ "name": "Almost correct (1 vs 2)",
295
+ "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
296
+ "ground_truth": "0",
297
+ "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
298
+ },
299
+ {
300
+ "name": "Almost correct (1 vs 2)",
301
+ "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
302
+ "ground_truth": "1",
303
+ "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
304
+ },
305
+ {
306
+ "name": "Almost correct (1 vs 2)",
307
+ "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
308
+ "ground_truth": "2",
309
+ "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
310
+ },
311
+ {
312
+ "name": "Almost correct (1 vs 2)",
313
+ "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
314
+ "ground_truth": "3",
315
+ "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
316
+ }
317
+ ]
318
+
319
+ # --- Helper Functions ---
320
+ def apply_scenario(step_idx):
321
+ scenario = scenarios[step_idx]
322
+ # These assignments modify session state. They must be done *before* the widgets
323
+ # are rendered in the script run that should display these new values.
324
+ for i, val in enumerate(scenario["values"]):
325
+ st.session_state[f"slider_{i}"] = val
326
+ st.session_state['ground_truth'] = scenario["ground_truth"]
327
+
328
+ def start_demo():
329
+ st.session_state.running_demo = True
330
+ st.session_state.demo_step = 0
331
+ st.session_state.last_update_time = time.time()
332
+ apply_scenario(0) # Apply the first scenario's state
333
+ # The button click that calls start_demo() will itself cause a rerun.
334
+
335
+ def stop_demo():
336
+ st.session_state.running_demo = False
337
+
338
+ # --- Demo State Advancement Logic ---
339
+ # This block handles advancing the demo. If it advances, it updates session state
340
+ # and then reruns. This ensures widgets are drawn with the new state in the next run.
341
+ if st.session_state.running_demo:
342
+ current_time = time.time()
343
+ if current_time - st.session_state.last_update_time > 3.0: # 3 seconds per scenario
344
+ next_step = (st.session_state.demo_step + 1) % len(scenarios)
345
+ st.session_state.demo_step = next_step
346
+ apply_scenario(next_step) # Update session state for the new scenario
347
+ st.session_state.last_update_time = time.time() # Reset timer
348
+ st.rerun() # Crucial: Rerun to reflect changes in widgets and charts
349
+
350
+ # --- UI Rendering ---
351
+ # This section renders the main UI. It executes after any potential rerun from the block above.
352
+
353
+ if st.session_state.running_demo:
354
+ st.info(f"Showing scenario {st.session_state.demo_step + 1}/{len(scenarios)}: {scenarios[st.session_state.demo_step]['name']}")
355
+ st.markdown(f"**Explanation:** {scenarios[st.session_state.demo_step]['explanation']}")
356
+ if st.button("Stop Demo"):
357
+ stop_demo()
358
+ st.rerun()
359
+ else: # Not st.session_state.running_demo
360
+ if st.button("Start Automated Demo"):
361
+ start_demo() # This calls apply_scenario(0)
362
+ st.rerun() # Rerun to enter demo mode and draw scenario 0 correctly
363
+
364
+ # Sliders and Ground Truth Selector
365
+ # These widgets will read their initial values from st.session_state.
366
+ # User interactions will update st.session_state directly due to their keys.
367
+ if not st.session_state.running_demo:
368
+ st.markdown("#### Predicted Token Probabilities")
369
+ cols = st.columns(len(options))
370
+ for i, col in enumerate(cols):
371
+ label = options[i] # Use token name directly for label
372
+ with col:
373
+ svs.vertical_slider(
374
+ label=label, min_value=0.0, max_value=1.0, step=0.01, height=50,
375
+ key=f"slider_{i}", # This key links the widget to st.session_state[f"slider_{i}"]
376
+ slider_color="green", track_color="lightgray", thumb_color="black"
377
+ )
378
+
379
+ # Ground truth selectbox
380
+ st.selectbox(
381
+ "Ground Truth Token", options=options,
382
+ index=options.index(st.session_state['ground_truth']), # Display value from session state
383
+ key='ground_truth' # Links widget to st.session_state['ground_truth']
384
  )
385
 
386
+ # Placeholder for charts and loss calculations that will be updated
387
+ # This section always reads the current st.session_state to generate its content.
388
+
389
+ current_prob_values_from_state = [st.session_state.get(f"slider_{j}", 1.0/len(options)) for j in range(len(options))]
390
+ total_from_state = sum(current_prob_values_from_state)
391
+ probs_for_charts = (
392
+ torch.ones(len(options)) / len(options)
393
+ if total_from_state == 0
394
+ else torch.tensor([v / total_from_state for v in current_prob_values_from_state])
395
+ )
396
 
397
+ gt_choice_for_charts = st.session_state.get('ground_truth', options[0])
398
+ if gt_choice_for_charts == "Text":
399
+ gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10)
400
+ gt_numeric_for_charts = None
 
401
  else:
402
+ gt_index_for_charts = int(gt_choice_for_charts)
403
+ gt_numeric_for_charts = gt_index_for_charts
404
 
 
405
  st.markdown("#### Input Probability Distribution")
406
+ df_dist = pd.DataFrame({"token": options, "probability": probs_for_charts.numpy()})
407
+ df_dist["type"] = ["Ground Truth" if token == gt_choice_for_charts else "Prediction" for token in options]
408
  chart = (
409
+ alt.Chart(df_dist).mark_bar().encode(
410
+ x=alt.X("token:N", title="Token", sort=options), # Ensure consistent sort order
 
 
411
  y=alt.Y("probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1])),
412
+ color=alt.Color("type:N", scale=alt.Scale(domain=["Ground Truth", "Prediction"], range=["green", "steelblue"]), legend=alt.Legend(title="Token Type"))
413
+ ).properties(height=300)
 
 
 
 
 
414
  )
415
  st.altair_chart(chart, use_container_width=True)
416
 
417
+ ce_loss = -torch.log(torch.clamp(probs_for_charts[gt_index_for_charts], min=1e-9))
418
+ if gt_numeric_for_charts is None: # Text token
419
+ ntl_mse_loss = torch.tensor(float('nan')) # MSE not applicable for text
420
+ ntl_was_loss = torch.tensor(float('nan')) # WAS not applicable for text
421
+ else: # Numeric token
422
+ numeric_probs_for_loss = probs_for_charts[:10] # Probabilities for 0-9
423
+ # Ensure numeric_probs_for_loss sums to 1 for NTL calculations if it's a subset
424
+ numeric_probs_sum = torch.sum(numeric_probs_for_loss)
425
+ if numeric_probs_sum > 1e-6 : # Avoid division by zero
426
+ normalized_numeric_probs = numeric_probs_for_loss / numeric_probs_sum
427
+ else:
428
+ normalized_numeric_probs = torch.zeros_like(numeric_probs_for_loss)
429
+
430
+
431
+ loss_values_tensor = torch.arange(0, 10, dtype=torch.float32)
432
+
433
+ # Use normalized probabilities for NTL if only considering numeric tokens
434
+ if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6 :
435
+ pred_value = torch.sum( (probs_for_charts[:10]/torch.sum(probs_for_charts[:10])) * loss_values_tensor)
436
+ elif gt_choice_for_charts != "Text": # if sum is zero, pred_value is ill-defined or 0
437
+ pred_value = torch.tensor(0.0)
438
+ else: # Should not happen if gt_numeric_for_charts is not None
439
+ pred_value = torch.tensor(float('nan'))
440
+
441
+
442
+ if not torch.isnan(pred_value):
443
+ ntl_mse_loss = (pred_value - float(gt_numeric_for_charts)) ** 2
444
+ abs_diff = torch.abs(loss_values_tensor - float(gt_numeric_for_charts))
445
+ if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6:
446
+ ntl_was_loss = torch.sum((probs_for_charts[:10]/torch.sum(probs_for_charts[:10])) * abs_diff)
447
+ elif gt_choice_for_charts != "Text":
448
+ ntl_was_loss = torch.tensor(0.0) # Or some other default if all numeric probs are zero
449
+ else:
450
+ ntl_was_loss = torch.tensor(float('nan'))
451
+ else:
452
+ ntl_mse_loss = torch.tensor(float('nan'))
453
+ ntl_was_loss = torch.tensor(float('nan'))
454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
 
 
456
  ce_val = round(ce_loss.item(), 3)
457
+ mse_val = round(ntl_mse_loss.item(), 3) if not torch.isnan(ntl_mse_loss) else "N/A"
458
+ was_val = round(ntl_was_loss.item(), 3) if not torch.isnan(ntl_was_loss) else "N/A"
459
 
 
 
 
 
 
460
 
461
+ loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]}
462
+ if was_val != "N/A":
463
+ loss_data["Loss"].append("NTL-WAS")
464
+ loss_data["Value"].append(was_val)
465
+ if mse_val != "N/A":
466
+ loss_data["Loss"].append("NTL-MSE")
467
+ loss_data["Value"].append(mse_val)
 
 
468
 
469
+ loss_df = pd.DataFrame(loss_data)
470
+
471
+ # ============== Chart Display ==============
472
+ # Create a single chart for loss visualization
473
+ st.subheader("Loss Comparison")
474
+
475
+ # Create an Altair chart that will look good and redraw cleanly
476
+ chart = alt.Chart(loss_df).mark_bar().encode(
477
+ x=alt.X('Loss:N', sort=loss_df["Loss"].tolist()),
478
+ y=alt.Y('Value:Q', scale=alt.Scale(domain=[0, max(loss_df["Value"].max() * 1.2, 20 if st.session_state.running_demo else 0.5)])),
479
+ color=alt.Color('Loss:N', scale=alt.Scale(
480
+ domain=['Cross Entropy', 'NTL-WAS', 'NTL-MSE'],
481
+ range=['steelblue', 'red', 'forestgreen']
482
+ )),
483
+ tooltip=['Loss', 'Value']
484
+ ).properties(
485
+ height=300
486
  )
487
+
488
+ # Add value labels on top of bars
489
+ text = chart.mark_text(
490
+ align='center',
491
+ baseline='bottom',
492
+ dy=-5,
493
+ fontSize=14
494
+ ).encode(
495
+ text=alt.Text('Value:Q', format='.3f')
496
+ )
497
+
498
+ # Combine chart and text
499
+ final_chart = (chart + text)
500
+
501
+ # Display chart with the full container width
502
+ st.altair_chart(final_chart, use_container_width=True)
503
+
504
+ # --- Polling Rerun for Demo Mode ---
505
+ # If the demo is running and we haven't just advanced (which would have caused a rerun),
506
+ # then we do a short sleep and rerun to keep the polling loop alive.
507
+ if st.session_state.running_demo:
508
+ # This check is implicitly: if we are here and demo is running, it means
509
+ # the time-based advance condition was NOT met in the block at the top.
510
+ time.sleep(0.1) # Adjusted from 0.2 to 0.5 (or try 1.0)
511
+ st.rerun()
512
+
513
+ # Add explanation of the demonstration
514
+ st.markdown("""
515
+ ### What Does This Demo Show?
516
+
517
+ - **Cross Entropy Loss**: Only cares if the prediction is exactly right or wrong - it doesn't consider how "close" a numerical prediction is.
518
+ - **Number Token Loss (NTL)**: Considers numerical proximity - predicting "7" when the true value is "8" is better than predicting "2".
519
+ """)
520
+
521
+ # References / resources section with links (common to both modes)
522
+ st.markdown("### Resources")
523
+ st.markdown("""
524
+ - [Paper: Number Token Loss (ArXiv)](https://arxiv.org/abs/2411.02083)
525
+ - [GitHub: Number Token Loss](https://github.com/tum-ai/number-token-loss)
526
+ """)