NumberTokenLoss / src /streamlit_app.py
jannisborn's picture
update
2bd2b96 unverified
raw
history blame
23.5 kB
import altair as alt
import pandas as pd
import streamlit_vertical_slider as svs
import torch
# from streamlit_vertical_slider import vertical_slider # Not directly used, svs.vertical_slider is
import streamlit as st
import time
import plotly.graph_objects as go # Add Plotly import
# Define options globally as it's used in initialization and UI
options = [str(i) for i in range(10)] + ["Text"]
# --- Session State Initialization ---
# Ensure all session state variables are initialized before first use, especially by widgets.
if 'running_demo' not in st.session_state:
st.session_state.running_demo = False
if 'demo_step' not in st.session_state:
st.session_state.demo_step = 0
if 'last_update_time' not in st.session_state:
st.session_state.last_update_time = 0
if 'loss_container' not in st.session_state:
st.session_state.loss_container = None
if 'previous_chart_html' not in st.session_state:
st.session_state.previous_chart_html = ""
# Initialize states for sliders and ground_truth selector
# Using len(options) to correctly size for 0-9 + "Text"
for i in range(len(options)):
if f"slider_{i}" not in st.session_state:
st.session_state[f"slider_{i}"] = 1.0 / len(options)
if 'ground_truth' not in st.session_state:
st.session_state['ground_truth'] = options[0] # Default to "0"
st.title("Number Token Loss - Demo")
st.markdown("""
Adjust the sliders to set a predicted probability for each token (0-9 and "Text").
The sliders are vertical and compact. The app normalizes the slider values
to form a valid probability distribution, visualizes it, and computes the corresponding
Cross Entropy, NTL-MSE, and NTL-WAS losses.
""")
# --- Scenario Definitions ---
scenarios = [
{
"name": "Probability mass at 0",
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "0",
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
},
{
"name": "Probability mass at 0",
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "1",
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
},
{
"name": "Probability mass at 0",
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "2",
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
},
{
"name": "Probability mass at 0",
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "3",
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
},
{
"name": "Probability mass at 0",
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "4",
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
},
{
"name": "Probability mass at 0",
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "5",
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
},
{
"name": "Probability mass at 0",
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "6",
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
},
{
"name": "Probability mass at 0",
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "7",
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
},
{
"name": "Probability mass at 0",
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "8",
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
},
{
"name": "Probability mass at 0",
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "9",
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
},
{
"name": "Probability mass around 5",
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
"ground_truth": "0",
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
},
{
"name": "Probability mass around 5",
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
"ground_truth": "1",
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
},
{
"name": "Probability mass around 5",
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
"ground_truth": "2",
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
},
{
"name": "Probability mass around 5",
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
"ground_truth": "3",
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
},
{
"name": "Probability mass around 5",
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
"ground_truth": "4",
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
},
{
"name": "Probability mass around ground truth (5)",
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
"ground_truth": "5",
"explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
},
{
"name": "Probability mass around 5",
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
"ground_truth": "6",
"explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
},
{
"name": "Probability mass around 5",
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
"ground_truth": "7",
"explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
},
{
"name": "Probability mass around 5",
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
"ground_truth": "8",
"explanation": "Cross Entropy is high, NTL is higher but still penalizes less than CE because distribution knows it's a number."
},
{
"name": "Probability mass around 5",
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
"ground_truth": "9",
"explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
},
{
"name": "Probability mass concentrated on 5",
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "0",
"explanation": "Both CE and NTL are high because the prediction is far from correct."
},
{
"name": "Probability mass concentrated on 5",
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "1",
"explanation": "Both CE and NTL are high because the prediction is far from correct."
},
{
"name": "Probability mass concentrated on 5",
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "2",
"explanation": "Both CE and NTL are high because the prediction is far from correct."
},
{
"name": "Probability mass concentrated on 5",
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "3",
"explanation": "Both CE and NTL are high because the prediction is far from correct."
},
{
"name": "Probability mass concentrated on 5",
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "4",
"explanation": "Both CE and NTL are high because the prediction is far from correct."
},
{
"name": "Probability mass concentrated on 5",
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "5",
"explanation": "Both CE and NTL are high because the prediction is far from correct."
},
{
"name": "Probability mass concentrated on 5",
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "6",
"explanation": "Both CE and NTL are high because the prediction is far from correct."
},
{
"name": "Probability mass concentrated on 5",
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "7",
"explanation": "Both CE and NTL are high because the prediction is far from correct."
},
{
"name": "Probability mass concentrated on 5",
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "8",
"explanation": "Both CE and NTL are high because the prediction is far from correct."
},
{
"name": "Probability mass concentrated on 5",
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
"ground_truth": "9",
"explanation": "Both CE and NTL are high because the prediction is far from correct."
},
{
"name": "Probability mass concentrated on 1",
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
"ground_truth": "0",
"explanation": "Both losses are low because the prediction is correct."
},
{
"name": "Probability mass concentrated on 1",
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
"ground_truth": "1",
"explanation": "Both losses are low because the prediction is correct."
},
{
"name": "Probability mass concentrated on 1",
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
"ground_truth": "2",
"explanation": "Both losses are low because the prediction is correct."
},
{
"name": "Probability mass concentrated on 1",
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
"ground_truth": "3",
"explanation": "Both losses are low because the prediction is correct."
},
{
"name": "Probability mass concentrated on 1",
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
"ground_truth": "4",
"explanation": "Both losses are low because the prediction is correct."
},
{
"name": "Probability mass concentrated on 1",
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
"ground_truth": "5",
"explanation": "Both losses are low because the prediction is correct."
},
{
"name": "Probability mass concentrated on 1",
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
"ground_truth": "6",
"explanation": "Both losses are low because the prediction is correct."
},
{
"name": "Probability mass concentrated on 1",
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
"ground_truth": "7",
"explanation": "Both losses are low because the prediction is correct."
},
{
"name": "Probability mass concentrated on 1",
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
"ground_truth": "8",
"explanation": "Both losses are low because the prediction is correct."
},
{
"name": "Probability mass concentrated on 1",
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
"ground_truth": "9",
"explanation": "Both losses are low because the prediction is correct."
},
{
"name": "Almost correct (1 vs 2)",
"values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
"ground_truth": "0",
"explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
},
{
"name": "Almost correct (1 vs 2)",
"values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
"ground_truth": "1",
"explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
},
{
"name": "Almost correct (1 vs 2)",
"values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
"ground_truth": "2",
"explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
},
{
"name": "Almost correct (1 vs 2)",
"values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
"ground_truth": "3",
"explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
}
]
# --- Helper Functions ---
def apply_scenario(step_idx):
scenario = scenarios[step_idx]
# These assignments modify session state. They must be done *before* the widgets
# are rendered in the script run that should display these new values.
for i, val in enumerate(scenario["values"]):
st.session_state[f"slider_{i}"] = val
st.session_state['ground_truth'] = scenario["ground_truth"]
def start_demo():
st.session_state.running_demo = True
st.session_state.demo_step = 0
st.session_state.last_update_time = time.time()
apply_scenario(0) # Apply the first scenario's state
# The button click that calls start_demo() will itself cause a rerun.
def stop_demo():
st.session_state.running_demo = False
# --- Demo State Advancement Logic ---
# This block handles advancing the demo. If it advances, it updates session state
# and then reruns. This ensures widgets are drawn with the new state in the next run.
if st.session_state.running_demo:
current_time = time.time()
if current_time - st.session_state.last_update_time > 3.0: # 3 seconds per scenario
next_step = (st.session_state.demo_step + 1) % len(scenarios)
st.session_state.demo_step = next_step
apply_scenario(next_step) # Update session state for the new scenario
st.session_state.last_update_time = time.time() # Reset timer
st.rerun() # Crucial: Rerun to reflect changes in widgets and charts
# --- UI Rendering ---
# This section renders the main UI. It executes after any potential rerun from the block above.
if st.session_state.running_demo:
st.info(f"Showing scenario {st.session_state.demo_step + 1}/{len(scenarios)}: {scenarios[st.session_state.demo_step]['name']}")
st.markdown(f"**Explanation:** {scenarios[st.session_state.demo_step]['explanation']}")
if st.button("Stop Demo"):
stop_demo()
st.rerun()
else: # Not st.session_state.running_demo
if st.button("Start Automated Demo"):
start_demo() # This calls apply_scenario(0)
st.rerun() # Rerun to enter demo mode and draw scenario 0 correctly
# Sliders and Ground Truth Selector
# These widgets will read their initial values from st.session_state.
# User interactions will update st.session_state directly due to their keys.
if not st.session_state.running_demo:
st.markdown("#### Predicted Token Probabilities")
cols = st.columns(len(options))
for i, col in enumerate(cols):
label = options[i] # Use token name directly for label
with col:
svs.vertical_slider(
label=label, min_value=0.0, max_value=1.0, step=0.01, height=50,
key=f"slider_{i}", # This key links the widget to st.session_state[f"slider_{i}"]
slider_color="green", track_color="lightgray", thumb_color="black"
)
# Ground truth selectbox
st.selectbox(
"Ground Truth Token", options=options,
index=options.index(st.session_state['ground_truth']), # Display value from session state
key='ground_truth' # Links widget to st.session_state['ground_truth']
)
# Placeholder for charts and loss calculations that will be updated
# This section always reads the current st.session_state to generate its content.
current_prob_values_from_state = [st.session_state.get(f"slider_{j}", 1.0/len(options)) for j in range(len(options))]
total_from_state = sum(current_prob_values_from_state)
probs_for_charts = (
torch.ones(len(options)) / len(options)
if total_from_state == 0
else torch.tensor([v / total_from_state for v in current_prob_values_from_state])
)
gt_choice_for_charts = st.session_state.get('ground_truth', options[0])
if gt_choice_for_charts == "Text":
gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10)
gt_numeric_for_charts = None
else:
gt_index_for_charts = int(gt_choice_for_charts)
gt_numeric_for_charts = gt_index_for_charts
st.markdown("#### Input Probability Distribution")
df_dist = pd.DataFrame({"token": options, "probability": probs_for_charts.numpy()})
df_dist["type"] = ["Ground Truth" if token == gt_choice_for_charts else "Prediction" for token in options]
chart = (
alt.Chart(df_dist).mark_bar().encode(
x=alt.X("token:N", title="Token", sort=options), # Ensure consistent sort order
y=alt.Y("probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1])),
color=alt.Color("type:N", scale=alt.Scale(domain=["Ground Truth", "Prediction"], range=["green", "steelblue"]), legend=alt.Legend(title="Token Type"))
).properties(height=300)
)
st.altair_chart(chart, use_container_width=True)
ce_loss = -torch.log(torch.clamp(probs_for_charts[gt_index_for_charts], min=1e-9))
if gt_numeric_for_charts is None: # Text token
ntl_mse_loss = torch.tensor(float('nan')) # MSE not applicable for text
ntl_was_loss = torch.tensor(float('nan')) # WAS not applicable for text
else: # Numeric token
numeric_probs_for_loss = probs_for_charts[:10] # Probabilities for 0-9
# Ensure numeric_probs_for_loss sums to 1 for NTL calculations if it's a subset
numeric_probs_sum = torch.sum(numeric_probs_for_loss)
if numeric_probs_sum > 1e-6 : # Avoid division by zero
normalized_numeric_probs = numeric_probs_for_loss / numeric_probs_sum
else:
normalized_numeric_probs = torch.zeros_like(numeric_probs_for_loss)
loss_values_tensor = torch.arange(0, 10, dtype=torch.float32)
# Use normalized probabilities for NTL if only considering numeric tokens
if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6 :
pred_value = torch.sum( (probs_for_charts[:10]/torch.sum(probs_for_charts[:10])) * loss_values_tensor)
elif gt_choice_for_charts != "Text": # if sum is zero, pred_value is ill-defined or 0
pred_value = torch.tensor(0.0)
else: # Should not happen if gt_numeric_for_charts is not None
pred_value = torch.tensor(float('nan'))
if not torch.isnan(pred_value):
ntl_mse_loss = (pred_value - float(gt_numeric_for_charts)) ** 2
abs_diff = torch.abs(loss_values_tensor - float(gt_numeric_for_charts))
if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6:
ntl_was_loss = torch.sum((probs_for_charts[:10]/torch.sum(probs_for_charts[:10])) * abs_diff)
elif gt_choice_for_charts != "Text":
ntl_was_loss = torch.tensor(0.0) # Or some other default if all numeric probs are zero
else:
ntl_was_loss = torch.tensor(float('nan'))
else:
ntl_mse_loss = torch.tensor(float('nan'))
ntl_was_loss = torch.tensor(float('nan'))
ce_val = round(ce_loss.item(), 3)
mse_val = round(ntl_mse_loss.item(), 3) if not torch.isnan(ntl_mse_loss) else "N/A"
was_val = round(ntl_was_loss.item(), 3) if not torch.isnan(ntl_was_loss) else "N/A"
loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]}
if was_val != "N/A":
loss_data["Loss"].append("NTL-WAS")
loss_data["Value"].append(was_val)
if mse_val != "N/A":
loss_data["Loss"].append("NTL-MSE")
loss_data["Value"].append(mse_val)
loss_df = pd.DataFrame(loss_data)
# ============== Chart Display ==============
# Create a single chart for loss visualization
st.subheader("Loss Comparison")
# Create an Altair chart that will look good and redraw cleanly
chart = alt.Chart(loss_df).mark_bar().encode(
x=alt.X('Loss:N', sort=loss_df["Loss"].tolist()),
y=alt.Y('Value:Q', scale=alt.Scale(domain=[0, max(loss_df["Value"].max() * 1.2, 20 if st.session_state.running_demo else 0.5)])),
color=alt.Color('Loss:N', scale=alt.Scale(
domain=['Cross Entropy', 'NTL-WAS', 'NTL-MSE'],
range=['steelblue', 'red', 'forestgreen']
)),
tooltip=['Loss', 'Value']
).properties(
height=300
)
# Add value labels on top of bars
text = chart.mark_text(
align='center',
baseline='bottom',
dy=-5,
fontSize=14
).encode(
text=alt.Text('Value:Q', format='.3f')
)
# Combine chart and text
final_chart = (chart + text)
# Display chart with the full container width
st.altair_chart(final_chart, use_container_width=True)
# --- Polling Rerun for Demo Mode ---
# If the demo is running and we haven't just advanced (which would have caused a rerun),
# then we do a short sleep and rerun to keep the polling loop alive.
if st.session_state.running_demo:
# This check is implicitly: if we are here and demo is running, it means
# the time-based advance condition was NOT met in the block at the top.
time.sleep(0.1) # Adjusted from 0.2 to 0.5 (or try 1.0)
st.rerun()
# Add explanation of the demonstration
st.markdown("""
### What Does This Demo Show?
- **Cross Entropy Loss**: Only cares if the prediction is exactly right or wrong - it doesn't consider how "close" a numerical prediction is.
- **Number Token Loss (NTL)**: Considers numerical proximity - predicting "7" when the true value is "8" is better than predicting "2".
""")
# References / resources section with links (common to both modes)
st.markdown("### Resources")
st.markdown("""
- [Paper: Number Token Loss (ArXiv)](https://arxiv.org/abs/2411.02083)
- [GitHub: Number Token Loss](https://github.com/tum-ai/number-token-loss)
""")