import altair as alt import pandas as pd import streamlit as st import streamlit_vertical_slider as svs import torch from streamlit_vertical_slider import vertical_slider st.title("Number Token Loss - Demo") st.markdown(""" Adjust the sliders to set a predicted probability for each token (0-9 and "Text"). The sliders are vertical and compact. The app normalizes the slider values to form a valid probability distribution, visualizes it, and computes the corresponding Cross Entropy, NTL-MSE, and NTL-WAS losses. """) # Vertical sliders for predicted probabilities of tokens 0-9 and "Text" st.markdown("#### Predicted Token Probabilities") cols = st.columns(11) prob_values = [] for i, col in enumerate(cols): label = f"Token {i}" if i < 10 else "Text" with col: val = svs.vertical_slider( label=label, min_value=0.0, max_value=1.0, step=0.1, height=50, key=f"slider_{i}", slider_color="green", track_color="lightgray", thumb_color="black", ) prob_values.append(val) # Normalize the probabilities to sum to 1 total = sum(prob_values) probs = ( torch.ones(11) / 11.0 if total == 0 else torch.tensor([v / total for v in prob_values]) ) # Token labels options = [str(i) for i in range(10)] + ["Text"] # Ground truth token selection gt_choice = st.selectbox("Ground Truth Token", options=options, index=0) if gt_choice == "Text": gt_index = 10 gt_numeric = None else: gt_index = int(gt_choice) gt_numeric = gt_index # Visualize the input distribution with highlighted ground truth bar st.markdown("#### Input Probability Distribution") df_dist = pd.DataFrame({"token": options, "probability": probs.numpy()}) chart = ( alt.Chart(df_dist) .mark_bar() .encode( x=alt.X("token:N", title="Token"), y=alt.Y("probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1])), color=alt.condition( alt.datum.token == gt_choice, alt.value("green"), # Highlight ground truth token alt.value("steelblue"), # Other tokens ), ) .properties(height=300) ) st.altair_chart(chart, use_container_width=True) # Compute Cross Entropy loss: -log(predicted probability of the ground truth) ce_loss = -torch.log(torch.clamp(probs[gt_index], min=1e-9)) # Compute NTL-MSE loss if gt_numeric is None: ntl_mse_loss = torch.tensor(0.0) else: numeric_probs = probs[:10] values = torch.arange(0, 10, dtype=torch.float32) pred_value = torch.sum(numeric_probs * values) ntl_mse_loss = (pred_value - float(gt_numeric)) ** 2 # Compute NTL-WAS loss if gt_numeric is None: ntl_was_loss = torch.tensor(0.0) else: numeric_probs = probs[:10] values = torch.arange(0, 10, dtype=torch.float32) abs_diff = torch.abs(values - float(gt_numeric)) ntl_was_loss = torch.sum(numeric_probs * abs_diff) # Convert losses to Python floats and round to 3 decimals ce_val = round(ce_loss.item(), 3) mse_val = round(ntl_mse_loss.item(), 3) was_val = round(ntl_was_loss.item(), 3) # Display numeric values of the losses st.subheader("Loss Values") st.write(f"**Cross Entropy:** {ce_val:.3f}") st.write(f"**NTL-MSE:** {mse_val:.3f}") st.write(f"**NTL-WAS:** {was_val:.3f}") # Bar chart comparison of the three losses st.subheader("Loss Comparison Chart") loss_df = pd.DataFrame( { "Loss": ["Cross Entropy", "NTL-MSE", "NTL-WAS"], "Value": [ce_val, mse_val, was_val], } ).set_index("Loss") st.bar_chart(loss_df) # References / resources section with links st.markdown("### Resources") st.markdown( "- **Paper:** [Regress, Don't Guess – A Regression-like Loss on Number Tokens for Language Models](https://arxiv.org/abs/2411.02083) \n- **Code:** [tum-ai/number-token-loss (GitHub)](https://github.com/tum-ai/number-token-loss)" )