Spaces:
Running
Running
File size: 3,918 Bytes
e448593 d830963 e448593 d830963 e448593 d830963 e448593 d830963 e448593 d830963 e448593 d830963 e448593 d830963 e448593 d830963 e448593 d830963 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import altair as alt
import pandas as pd
import streamlit as st
import streamlit_vertical_slider as svs
import torch
from streamlit_vertical_slider import vertical_slider
st.title("Number Token Loss - Demo")
st.markdown("""
Adjust the sliders to set a predicted probability for each token (0-9 and "Text").
The sliders are vertical and compact. The app normalizes the slider values
to form a valid probability distribution, visualizes it, and computes the corresponding
Cross Entropy, NTL-MSE, and NTL-WAS losses.
""")
# Vertical sliders for predicted probabilities of tokens 0-9 and "Text"
st.markdown("#### Predicted Token Probabilities")
cols = st.columns(11)
prob_values = []
for i, col in enumerate(cols):
label = f"Token {i}" if i < 10 else "Text"
with col:
val = svs.vertical_slider(
label=label,
min_value=0.0,
max_value=1.0,
step=0.1,
height=50,
key=f"slider_{i}",
slider_color="green",
track_color="lightgray",
thumb_color="black",
)
prob_values.append(val)
# Normalize the probabilities to sum to 1
total = sum(prob_values)
probs = (
torch.ones(11) / 11.0
if total == 0
else torch.tensor([v / total for v in prob_values])
)
# Token labels
options = [str(i) for i in range(10)] + ["Text"]
# Ground truth token selection
gt_choice = st.selectbox("Ground Truth Token", options=options, index=0)
if gt_choice == "Text":
gt_index = 10
gt_numeric = None
else:
gt_index = int(gt_choice)
gt_numeric = gt_index
# Visualize the input distribution with highlighted ground truth bar
st.markdown("#### Input Probability Distribution")
df_dist = pd.DataFrame({"token": options, "probability": probs.numpy()})
chart = (
alt.Chart(df_dist)
.mark_bar()
.encode(
x=alt.X("token:N", title="Token"),
y=alt.Y("probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1])),
color=alt.condition(
alt.datum.token == gt_choice,
alt.value("green"), # Highlight ground truth token
alt.value("steelblue"), # Other tokens
),
)
.properties(height=300)
)
st.altair_chart(chart, use_container_width=True)
# Compute Cross Entropy loss: -log(predicted probability of the ground truth)
ce_loss = -torch.log(torch.clamp(probs[gt_index], min=1e-9))
# Compute NTL-MSE loss
if gt_numeric is None:
ntl_mse_loss = torch.tensor(0.0)
else:
numeric_probs = probs[:10]
values = torch.arange(0, 10, dtype=torch.float32)
pred_value = torch.sum(numeric_probs * values)
ntl_mse_loss = (pred_value - float(gt_numeric)) ** 2
# Compute NTL-WAS loss
if gt_numeric is None:
ntl_was_loss = torch.tensor(0.0)
else:
numeric_probs = probs[:10]
values = torch.arange(0, 10, dtype=torch.float32)
abs_diff = torch.abs(values - float(gt_numeric))
ntl_was_loss = torch.sum(numeric_probs * abs_diff)
# Convert losses to Python floats and round to 3 decimals
ce_val = round(ce_loss.item(), 3)
mse_val = round(ntl_mse_loss.item(), 3)
was_val = round(ntl_was_loss.item(), 3)
# Display numeric values of the losses
st.subheader("Loss Values")
st.write(f"**Cross Entropy:** {ce_val:.3f}")
st.write(f"**NTL-MSE:** {mse_val:.3f}")
st.write(f"**NTL-WAS:** {was_val:.3f}")
# Bar chart comparison of the three losses
st.subheader("Loss Comparison Chart")
loss_df = pd.DataFrame(
{
"Loss": ["Cross Entropy", "NTL-MSE", "NTL-WAS"],
"Value": [ce_val, mse_val, was_val],
}
).set_index("Loss")
st.bar_chart(loss_df)
# References / resources section with links
st.markdown("### Resources")
st.markdown(
"- **Paper:** [Regress, Don't Guess – A Regression-like Loss on Number Tokens for Language Models](https://arxiv.org/abs/2411.02083) \n- **Code:** [tum-ai/number-token-loss (GitHub)](https://github.com/tum-ai/number-token-loss)"
)
|