File size: 3,918 Bytes
e448593
 
 
d830963
 
 
e448593
d830963
e448593
d830963
 
 
 
 
 
e448593
d830963
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e448593
d830963
 
 
 
 
 
 
e448593
d830963
 
e448593
d830963
 
 
 
 
 
 
 
e448593
d830963
 
 
 
 
 
e448593
d830963
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import altair as alt
import pandas as pd
import streamlit as st
import streamlit_vertical_slider as svs
import torch
from streamlit_vertical_slider import vertical_slider

st.title("Number Token Loss - Demo")

st.markdown("""
Adjust the sliders to set a predicted probability for each token (0-9 and "Text"). 
The sliders are vertical and compact. The app normalizes the slider values 
to form a valid probability distribution, visualizes it, and computes the corresponding 
Cross Entropy, NTL-MSE, and NTL-WAS losses.
""")

# Vertical sliders for predicted probabilities of tokens 0-9 and "Text"
st.markdown("#### Predicted Token Probabilities")
cols = st.columns(11)
prob_values = []
for i, col in enumerate(cols):
    label = f"Token {i}" if i < 10 else "Text"
    with col:
        val = svs.vertical_slider(
            label=label,
            min_value=0.0,
            max_value=1.0,
            step=0.1,
            height=50,
            key=f"slider_{i}",
            slider_color="green",
            track_color="lightgray",
            thumb_color="black",
        )
        prob_values.append(val)

# Normalize the probabilities to sum to 1
total = sum(prob_values)
probs = (
    torch.ones(11) / 11.0
    if total == 0
    else torch.tensor([v / total for v in prob_values])
)

# Token labels
options = [str(i) for i in range(10)] + ["Text"]

# Ground truth token selection
gt_choice = st.selectbox("Ground Truth Token", options=options, index=0)
if gt_choice == "Text":
    gt_index = 10
    gt_numeric = None
else:
    gt_index = int(gt_choice)
    gt_numeric = gt_index

# Visualize the input distribution with highlighted ground truth bar
st.markdown("#### Input Probability Distribution")
df_dist = pd.DataFrame({"token": options, "probability": probs.numpy()})
chart = (
    alt.Chart(df_dist)
    .mark_bar()
    .encode(
        x=alt.X("token:N", title="Token"),
        y=alt.Y("probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1])),
        color=alt.condition(
            alt.datum.token == gt_choice,
            alt.value("green"),  # Highlight ground truth token
            alt.value("steelblue"),  # Other tokens
        ),
    )
    .properties(height=300)
)
st.altair_chart(chart, use_container_width=True)

# Compute Cross Entropy loss: -log(predicted probability of the ground truth)
ce_loss = -torch.log(torch.clamp(probs[gt_index], min=1e-9))

# Compute NTL-MSE loss
if gt_numeric is None:
    ntl_mse_loss = torch.tensor(0.0)
else:
    numeric_probs = probs[:10]
    values = torch.arange(0, 10, dtype=torch.float32)
    pred_value = torch.sum(numeric_probs * values)
    ntl_mse_loss = (pred_value - float(gt_numeric)) ** 2

# Compute NTL-WAS loss
if gt_numeric is None:
    ntl_was_loss = torch.tensor(0.0)
else:
    numeric_probs = probs[:10]
    values = torch.arange(0, 10, dtype=torch.float32)
    abs_diff = torch.abs(values - float(gt_numeric))
    ntl_was_loss = torch.sum(numeric_probs * abs_diff)

# Convert losses to Python floats and round to 3 decimals
ce_val = round(ce_loss.item(), 3)
mse_val = round(ntl_mse_loss.item(), 3)
was_val = round(ntl_was_loss.item(), 3)

# Display numeric values of the losses
st.subheader("Loss Values")
st.write(f"**Cross Entropy:** {ce_val:.3f}")
st.write(f"**NTL-MSE:** {mse_val:.3f}")
st.write(f"**NTL-WAS:** {was_val:.3f}")

# Bar chart comparison of the three losses
st.subheader("Loss Comparison Chart")
loss_df = pd.DataFrame(
    {
        "Loss": ["Cross Entropy", "NTL-MSE", "NTL-WAS"],
        "Value": [ce_val, mse_val, was_val],
    }
).set_index("Loss")
st.bar_chart(loss_df)

# References / resources section with links
st.markdown("### Resources")
st.markdown(
    "- **Paper:** [Regress, Don't Guess – A Regression-like Loss on Number Tokens for Language Models](https://arxiv.org/abs/2411.02083)  \n- **Code:** [tum-ai/number-token-loss (GitHub)](https://github.com/tum-ai/number-token-loss)"
)