Spaces:
Running
Running
import logging | |
import time | |
import altair as alt | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
import streamlit_vertical_slider as svs | |
import torch | |
from scenarios import dirac, gauss, make_bimodal_scenarios | |
logging.getLogger("streamlit.watcher.local_sources_watcher").setLevel(logging.ERROR) | |
DEMO_INTERVAL = 1.5 | |
CE_SCALING = 0.25 | |
MAX_LOSS_PLOT = 6 | |
LAST_STEP = -1 | |
# Define options globally as it's used in initialization and UI | |
options = [str(i) for i in range(10)] + ["Text"] | |
def compute_losses(probs: torch.Tensor, gt_token: str) -> tuple[float, float, float]: | |
"""Compute CE, NTL-MAE, NTL-WAS losses for the given probability vector and ground truth token.""" | |
ce_loss = CE_SCALING * -torch.log( | |
torch.clamp(probs[options.index(gt_token)], min=1e-9) | |
) | |
numeric_mass = probs[:10].sum() | |
if gt_token == "Text" or numeric_mass < 1e-6: | |
return ce_loss.item(), 0.0, 0.0 | |
gt_numeric = int(gt_token) | |
token_vals = torch.arange(10, dtype=torch.float32) | |
mae = numeric_mass * abs(torch.dot(token_vals, probs[:10]) - gt_numeric) | |
was = numeric_mass * torch.dot(probs[:10], torch.abs(token_vals - gt_numeric)) | |
return round(ce_loss.item(), 3), round(mae.item(), 3), round(was.item(), 3) | |
# --- Session State Initialization --- | |
# Ensure all session state variables are initialized before first use, especially by widgets. | |
if "running_demo" not in st.session_state: | |
st.session_state.running_demo = False | |
if "demo_step" not in st.session_state: | |
st.session_state.demo_step = 0 | |
if "last_update_time" not in st.session_state: | |
st.session_state.last_update_time = 0 | |
if "loss_container" not in st.session_state: | |
st.session_state.loss_container = None | |
if "previous_chart_html" not in st.session_state: | |
st.session_state.previous_chart_html = "" | |
if "active_scenarios" not in st.session_state: | |
# default if you want one to load on first show | |
st.session_state.active_scenarios = dirac | |
if "loss_history" not in st.session_state: | |
st.session_state.loss_history = [] | |
# Initialize states for sliders and ground_truth selector | |
# Using len(options) to correctly size for 0-9 + "Text" | |
for i in range(len(options)): | |
if f"slider_{i}" not in st.session_state: | |
st.session_state[f"slider_{i}"] = 0 | |
if "ground_truth" not in st.session_state: | |
st.session_state["ground_truth"] = options[5] | |
if "manual_ground_truth" not in st.session_state: | |
st.session_state["manual_ground_truth"] = options[5] | |
if "demo_name" not in st.session_state: | |
st.session_state["demo_name"] = "Dirac" | |
st.title("NTL -- The Number Token Loss ๐") | |
st.markdown( | |
"""This is the interactive demo for our [ICML 2025](https://arxiv.org/abs/2411.02083) paper!๐ | |
โก๏ธ NTL augments cross-entropy to help LMs reason better with numbers ๐ง | |
""" | |
) | |
st.subheader("Demo 1 โ NTL vs. Cross Entropy in 3 Scenarios") | |
st.markdown(""" | |
1๏ธโฃ Pick a ground truth token: a digit (0โ9) or "Text" ๐ (simulates generic text tokens). | |
2๏ธโฃ Choose a demo: | |
- **Dirac** โก: All probability mass on one token. | |
- **Gaussian** ๐: Soft bell-curve around the true number. | |
- **Bimodal** ๐ฏ: Two peaks moving away from the target. | |
Watch how losses evolve as predictions get worse โ and see how NTL shines compared to CE! ๐ | |
""") | |
if "ground_truth" not in st.session_state: | |
st.session_state["ground_truth"] = "4" | |
gt = st.selectbox("Ground Truth Token", options=options, key="ground_truth") | |
def apply_scenario(step_idx): | |
scenario = st.session_state.active_scenarios[step_idx] | |
for i, val in enumerate(scenario["values"]): | |
st.session_state[f"slider_{i}"] = val | |
def start_dirac_demo(): | |
st.session_state.loss_history = [] | |
st.session_state.active_scenarios = dirac | |
st.session_state.demo_name = "Dirac" | |
st.session_state.running_demo = True | |
st.session_state.demo_step = 0 | |
st.session_state.last_update_time = time.time() | |
apply_scenario(0) | |
def start_gauss_demo(): | |
st.session_state.loss_history = [] | |
st.session_state.active_scenarios = gauss | |
st.session_state.demo_name = "Gauss" | |
st.session_state.running_demo = True | |
st.session_state.demo_step = 0 | |
st.session_state.last_update_time = time.time() | |
apply_scenario(0) | |
def start_bimodal_demo(): | |
st.session_state.loss_history = [] | |
gt = st.session_state["ground_truth"] | |
st.session_state.active_scenarios = make_bimodal_scenarios(gt, options) | |
st.session_state.demo_name = f"Bimodal (GT={gt})" | |
st.session_state.running_demo = True | |
st.session_state.demo_step = 0 | |
st.session_state.last_update_time = time.time() | |
apply_scenario(0) | |
def stop_demo(): | |
st.session_state.running_demo = False | |
# --- Demo State Advancement Logic --- | |
# This block handles advancing the demo. If it advances, it updates session state | |
# and then reruns. This ensures widgets are drawn with the new state in the next run. | |
if st.session_state.running_demo: | |
scenario = st.session_state.active_scenarios | |
current_time = time.time() | |
if current_time - st.session_state.last_update_time > DEMO_INTERVAL: | |
# if we havenโt yet shown the last scenario, advance | |
if st.session_state.demo_step < len(scenario) - 1: | |
st.session_state.demo_step += 1 | |
apply_scenario(st.session_state.demo_step) | |
st.session_state.last_update_time = current_time | |
st.rerun() | |
else: | |
# we just displayed the final case โ stop | |
st.session_state.running_demo = False | |
# --- UI Rendering --- | |
# This section renders the main UI. It executes after any potential rerun from the block above. | |
if st.session_state.running_demo: | |
st.info( | |
f"Showing scenario {st.session_state.demo_step + 1}" | |
f"/{len(st.session_state.active_scenarios)}: " | |
f"{st.session_state.active_scenarios[st.session_state.demo_step]['name']}" | |
) | |
if st.button("Stop Demo"): | |
st.session_state.running_demo = False | |
st.rerun() | |
else: | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
if st.button("Run: Dirac"): | |
start_dirac_demo() | |
st.rerun() | |
with col2: | |
if st.button("Run: Gauss"): | |
start_gauss_demo() | |
st.rerun() | |
with col3: | |
if st.button("Run: Bimodal"): | |
start_bimodal_demo() | |
st.rerun() | |
current_prob_values_from_state = [ | |
st.session_state.get(f"slider_{j}", 0) | |
for j in range(len(options)) # 1.0 / len(options)) for j in range(len(options)) | |
] | |
total_from_state = sum(current_prob_values_from_state) | |
probs_for_charts = ( | |
torch.ones(len(options)) / len(options) | |
if total_from_state == 0 | |
else torch.tensor([v / total_from_state for v in current_prob_values_from_state]) | |
) | |
# Use manual GT token when not in running demo | |
gt_choice_for_charts = ( | |
st.session_state["manual_ground_truth"] | |
if not st.session_state.running_demo | |
else st.session_state["ground_truth"] | |
) | |
if gt_choice_for_charts == "Text": | |
gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10) | |
gt_numeric_for_charts = None | |
else: | |
gt_index_for_charts = int(gt_choice_for_charts) | |
gt_numeric_for_charts = gt_index_for_charts | |
gt = st.session_state["ground_truth"] | |
demo_name = st.session_state["demo_name"] | |
st.markdown(f"#### Predicted distribution โ ground truth: {gt}") | |
df_dist = pd.DataFrame( | |
{"token": options, "probability": probs_for_charts.numpy().round(2)} | |
) | |
df_dist["type"] = [ | |
"Ground Truth" if token == gt_choice_for_charts else "Prediction" | |
for token in options | |
] | |
bars = ( | |
alt.Chart(df_dist) | |
.mark_bar(color="dodgerblue", size=40) | |
.encode( | |
x=alt.X( | |
"token:N", | |
title="Token", | |
sort=options, | |
axis=alt.Axis( | |
labelAngle=0, | |
labelFontSize=14, | |
titleFontSize=16, | |
labelAlign="center", | |
labelFlush=False, | |
), | |
), | |
y=alt.Y( | |
"probability:Q", | |
title="Probability", | |
scale=alt.Scale(domain=[0, 1]), | |
axis=alt.Axis(format=".2f", labelFontSize=14, titleFontSize=16), | |
), | |
tooltip=[ | |
alt.Tooltip("token:N", title="Token"), | |
alt.Tooltip("probability:Q", title="Predicted Prob.", format=".2f"), | |
], | |
) | |
) | |
bg_bar = pd.DataFrame({"token": [gt], "height": [1.0]}) | |
gt_bar = ( | |
alt.Chart(bg_bar) | |
.mark_bar( | |
color="darkgreen", | |
size=20, | |
opacity=0.3, | |
stroke="gray", | |
strokeWidth=2, | |
strokeDash=[4, 4], | |
) | |
.encode( | |
x=alt.X("token:N", sort=options), | |
y=alt.Y("height:Q", scale=alt.Scale(domain=[0, 1])), | |
tooltip=[ | |
alt.Tooltip("token:N", title="Ground Truth"), | |
alt.Tooltip("height:Q", title="Desired mass", format=".2f"), | |
], | |
) | |
) | |
annot1 = ( | |
alt.Chart(pd.DataFrame({"token": [gt]})) | |
.mark_text( | |
text="โฌ Ground", | |
dy=-25, # 10px above the top of the bar | |
dx=25, | |
fontSize=14, | |
fontWeight="bold", | |
color="darkgreen", | |
) | |
.encode(x=alt.X("token:N", sort=options), y=alt.value(1)) | |
) | |
annot2 = ( | |
alt.Chart(pd.DataFrame({"token": [gt]})) | |
.mark_text( | |
text=f"truth={gt}", | |
dy=-10, # 25px above the top, so it sits above line 1 | |
dx=35, | |
fontSize=14, | |
fontWeight="bold", | |
color="darkgreen", | |
) | |
.encode(x=alt.X("token:N", sort=options), y=alt.value(1)) | |
) | |
# 4) Layer them in order: background, bars, annotation | |
final_chart = (gt_bar + bars + annot1 + annot2).properties(height=200) | |
st.altair_chart(final_chart, use_container_width=True) | |
ce_val, mae_val, was_val = compute_losses(probs_for_charts, gt_choice_for_charts) | |
if ( | |
st.session_state.running_demo | |
and len(st.session_state.loss_history) < st.session_state.demo_step + 1 | |
): | |
step = st.session_state.demo_step | |
scenario = st.session_state.active_scenarios[step] | |
ce, mae, was = compute_losses(probs_for_charts, gt_choice_for_charts) | |
# pick x_val differently for bimodal vs others | |
if st.session_state.demo_name.startswith("Bimodal"): | |
x_val = scenario["name"] # e.g. "(4,4)", "(3,5)", โฆ | |
else: | |
# exactly like before: | |
best_idx = np.argmax(scenario["values"]) | |
x_val = options[best_idx] # "0", "1", โฆ, or "Text" | |
st.session_state.loss_history.append( | |
{ | |
"step": step, | |
"x_val": x_val, | |
"Cross Entropy": ce, | |
"NTL-MAE": mae, | |
"NTL-WAS": was, | |
} | |
) | |
# 1) build a raw DF from histories | |
df = pd.DataFrame(st.session_state.loss_history) | |
if df.empty: | |
# define an empty "melted" DataFrame with the right columns | |
df_loss_plot = pd.DataFrame(columns=["step", "x_val", "Loss Type", "Loss Value"]) | |
else: | |
# now it's safe to melt | |
df_loss_plot = df.melt( | |
id_vars=["step", "x_val"], | |
value_vars=["Cross Entropy", "NTL-MAE", "NTL-WAS"], | |
var_name="Loss Type", | |
value_name="Loss Value", | |
) | |
loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]} | |
if was_val != "N/A": | |
loss_data["Loss"].append("NTL-WAS") | |
loss_data["Value"].append(was_val) | |
if mae_val != "N/A": | |
loss_data["Loss"].append("NTL-MAE") | |
loss_data["Value"].append(mae_val) | |
loss_df = pd.DataFrame(loss_data) | |
if st.session_state.demo_name.startswith("Bimodal"): | |
domain = [sc["name"] for sc in st.session_state.active_scenarios] | |
x_title = f"Offset from GT {st.session_state['ground_truth']}" | |
else: | |
domain = options | |
x_title = f"Maximum of predicted {st.session_state['demo_name']} distribution" | |
# ============== Chart Display ============== | |
st.markdown("#### Loss as a function of predicted distribution") | |
grouped_chart = ( | |
alt.Chart(df_loss_plot) | |
.mark_bar() | |
.encode( | |
x=alt.X( | |
"x_val:N", | |
title=x_title, | |
sort=domain, | |
scale=alt.Scale(domain=domain), | |
axis=alt.Axis(labelAngle=0, labelFontSize=14, titleFontSize=16), | |
), | |
y=alt.Y( | |
"Loss Value:Q", | |
title="Loss Value", | |
scale=alt.Scale(domain=[0, MAX_LOSS_PLOT], nice=False, clamp=True), | |
axis=alt.Axis(labelFontSize=14, titleFontSize=16), | |
), | |
color=alt.Color( | |
"Loss Type:N", | |
scale=alt.Scale( | |
domain=["Cross Entropy", "NTL-WAS", "NTL-MAE"], | |
range=["red", "limegreen", "blueviolet"], | |
), | |
legend=alt.Legend( | |
title="", | |
orient="top", | |
direction="horizontal", | |
columns=3, | |
), | |
), | |
xOffset="Loss Type:N", # grouped bars | |
tooltip=[ | |
alt.Tooltip("x_val:N", title="Scenario"), | |
alt.Tooltip("Loss Type:N", title="Loss Type"), | |
alt.Tooltip("Loss Value:Q", title="Value", format=".3f"), | |
], | |
) | |
.properties(height=250) | |
) | |
st.altair_chart(grouped_chart, use_container_width=True) | |
# Create a single chart for loss visualization | |
if not st.session_state.running_demo: | |
for i in range(len(options)): | |
st.session_state[f"slider_{i}"] = 0.0 | |
st.session_state.demo_step = 0 | |
st.subheader("Demo 2 -- Manual loss comparison") | |
st.subheader("๐งช Demo 2 โ Craft your own distribution") | |
st.markdown(""" | |
This demo gives you more control but is harder to interpret. See it as a playground! ๐จ | |
Manually adjust the sliders to change the predicted probabilities for each token. | |
The demo normalizes the values to form a valid probability distribution and calculates the losses. | |
๐ฃ **Steps:** | |
- Use the **vertical sliders** to allocate probability to each token. | |
- Choose the correct **Ground Truth Token** (0โ9 or "Text" ๐). | |
- Observe how each loss function reacts. | |
๐ก **Tip:** Want to trick the loss? Try putting all mass on the wrong token or spread it wildly. See how NTL handles it! ๐ | |
""") | |
manual_gt = st.selectbox( | |
"Ground Truth Token", | |
options=options, | |
key="manual_ground_truth", | |
) | |
loss_df = pd.DataFrame( | |
{ | |
"Loss": ["Cross Entropy", "NTL-MAE", "NTL-WAS"], | |
"Value": [ce_val, mae_val, was_val], | |
} | |
) | |
# Sliders and Ground Truth Selector | |
# These widgets will read their initial values from st.session_state. | |
# User interactions will update st.session_state directly due to their keys. | |
st.markdown("#### Adjust the predicted token probability") | |
cols = st.columns(len(options)) | |
for i, col in enumerate(cols): | |
label = options[i] # Use token name directly for label | |
with col: | |
svs.vertical_slider( | |
label=label, | |
min_value=0.0, | |
max_value=1.0, | |
step=0.01, | |
height=50, | |
key=f"slider_{i}", | |
slider_color="green", | |
track_color="lightgray", | |
thumb_color="black", | |
) | |
chart = ( | |
alt.Chart(loss_df) | |
.mark_bar() | |
.encode( | |
x=alt.X("Loss:N", sort=loss_df["Loss"].tolist()), | |
y=alt.Y( | |
"Value:Q", | |
scale=alt.Scale( | |
domain=[ | |
0, | |
max( | |
loss_df["Value"].max() * 1.2, | |
20 if st.session_state.running_demo else 0.5, | |
), | |
] | |
), | |
), | |
color=alt.Color( | |
"Loss:N", | |
scale=alt.Scale( | |
domain=["Cross Entropy", "NTL-WAS", "NTL-MAE"], | |
range=["orangered", "limegreen", "blueviolet"], | |
), | |
), | |
tooltip=["Loss", "Value"], | |
) | |
.properties(height=300) | |
) | |
text = chart.mark_text( | |
align="center", baseline="bottom", dy=-5, fontSize=14 | |
).encode(text=alt.Text("Value:Q", format=".3f")) | |
final_chart = chart + text | |
st.altair_chart(final_chart, use_container_width=True) | |
# # Add value labels on top of bars | |
# text = chart.mark_text(align="center", baseline="bottom", dy=-5, fontSize=14).encode( | |
# text=alt.Text("Value:Q", format=".3f") | |
# ) | |
# # Combine chart and text | |
# final_chart = chart + text | |
# Display chart with the full container width | |
# st.altair_chart(final_chart, use_container_width=True) | |
# --- Polling Rerun for Demo Mode --- | |
# If the demo is running and we haven't just advanced (which would have caused a rerun), | |
# then we do a short sleep and rerun to keep the polling loop alive. | |
if st.session_state.running_demo: | |
# This check is implicitly: if we are here and demo is running, it means | |
# the time-based advance condition was NOT met in the block at the top. | |
time.sleep(0.1) | |
st.rerun() | |
st.markdown(""" | |
### ๐ค TL;DR โ Why NTL? | |
Cross Entropy only cares if the prediction is exactly right or wrong โโ โ it doesnโt care *how close* a guess is! | |
Thatโs bad for LLMs doing math and numeric reasoning ๐งฎ. | |
๐ฅ NTL fixes that: it behaves like a regression loss on the token head, rewarding predictions that are numerically close. | |
""") | |
st.markdown("#### ๐ Further Resources") | |
st.markdown(""" | |
- ๐ [ICML 2025 Paper](https://arxiv.org/abs/2411.02083) | |
- ๐ [NTL Landing Page](https://tum-ai.github.io/number-token-loss/) | |
- ๐ป [GitHub Code](https://github.com/tum-ai/number-token-loss) | |
""") | |