Spaces:
Running
Running
import altair as alt | |
import pandas as pd | |
import streamlit as st | |
import streamlit_vertical_slider as svs | |
import torch | |
from streamlit_vertical_slider import vertical_slider | |
st.title("Number Token Loss - Demo") | |
st.markdown(""" | |
Adjust the sliders to set a predicted probability for each token (0-9 and "Text"). | |
The sliders are vertical and compact. The app normalizes the slider values | |
to form a valid probability distribution, visualizes it, and computes the corresponding | |
Cross Entropy, NTL-MSE, and NTL-WAS losses. | |
""") | |
# Vertical sliders for predicted probabilities of tokens 0-9 and "Text" | |
st.markdown("#### Predicted Token Probabilities") | |
cols = st.columns(11) | |
prob_values = [] | |
for i, col in enumerate(cols): | |
label = f"Token {i}" if i < 10 else "Text" | |
with col: | |
val = svs.vertical_slider( | |
label=label, | |
min_value=0.0, | |
max_value=1.0, | |
step=0.1, | |
height=50, | |
key=f"slider_{i}", | |
slider_color="green", | |
track_color="lightgray", | |
thumb_color="black", | |
) | |
prob_values.append(val) | |
# Normalize the probabilities to sum to 1 | |
total = sum(prob_values) | |
probs = ( | |
torch.ones(11) / 11.0 | |
if total == 0 | |
else torch.tensor([v / total for v in prob_values]) | |
) | |
# Token labels | |
options = [str(i) for i in range(10)] + ["Text"] | |
# Ground truth token selection | |
gt_choice = st.selectbox("Ground Truth Token", options=options, index=0) | |
if gt_choice == "Text": | |
gt_index = 10 | |
gt_numeric = None | |
else: | |
gt_index = int(gt_choice) | |
gt_numeric = gt_index | |
# Visualize the input distribution with highlighted ground truth bar | |
st.markdown("#### Input Probability Distribution") | |
df_dist = pd.DataFrame({"token": options, "probability": probs.numpy()}) | |
chart = ( | |
alt.Chart(df_dist) | |
.mark_bar() | |
.encode( | |
x=alt.X("token:N", title="Token"), | |
y=alt.Y("probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1])), | |
color=alt.condition( | |
alt.datum.token == gt_choice, | |
alt.value("green"), # Highlight ground truth token | |
alt.value("steelblue"), # Other tokens | |
), | |
) | |
.properties(height=300) | |
) | |
st.altair_chart(chart, use_container_width=True) | |
# Compute Cross Entropy loss: -log(predicted probability of the ground truth) | |
ce_loss = -torch.log(torch.clamp(probs[gt_index], min=1e-9)) | |
# Compute NTL-MSE loss | |
if gt_numeric is None: | |
ntl_mse_loss = torch.tensor(0.0) | |
else: | |
numeric_probs = probs[:10] | |
values = torch.arange(0, 10, dtype=torch.float32) | |
pred_value = torch.sum(numeric_probs * values) | |
ntl_mse_loss = (pred_value - float(gt_numeric)) ** 2 | |
# Compute NTL-WAS loss | |
if gt_numeric is None: | |
ntl_was_loss = torch.tensor(0.0) | |
else: | |
numeric_probs = probs[:10] | |
values = torch.arange(0, 10, dtype=torch.float32) | |
abs_diff = torch.abs(values - float(gt_numeric)) | |
ntl_was_loss = torch.sum(numeric_probs * abs_diff) | |
# Convert losses to Python floats and round to 3 decimals | |
ce_val = round(ce_loss.item(), 3) | |
mse_val = round(ntl_mse_loss.item(), 3) | |
was_val = round(ntl_was_loss.item(), 3) | |
# Display numeric values of the losses | |
st.subheader("Loss Values") | |
st.write(f"**Cross Entropy:** {ce_val:.3f}") | |
st.write(f"**NTL-MSE:** {mse_val:.3f}") | |
st.write(f"**NTL-WAS:** {was_val:.3f}") | |
# Bar chart comparison of the three losses | |
st.subheader("Loss Comparison Chart") | |
loss_df = pd.DataFrame( | |
{ | |
"Loss": ["Cross Entropy", "NTL-MSE", "NTL-WAS"], | |
"Value": [ce_val, mse_val, was_val], | |
} | |
).set_index("Loss") | |
st.bar_chart(loss_df) | |
# References / resources section with links | |
st.markdown("### Resources") | |
st.markdown( | |
"- **Paper:** [Regress, Don't Guess β A Regression-like Loss on Number Tokens for Language Models](https://arxiv.org/abs/2411.02083) \n- **Code:** [tum-ai/number-token-loss (GitHub)](https://github.com/tum-ai/number-token-loss)" | |
) | |