NumberTokenLoss / src /streamlit_app.py
jannisborn's picture
wip
9914a10 unverified
raw
history blame
16.3 kB
import time
import altair as alt
import numpy as np
import pandas as pd
import streamlit as st
import streamlit_vertical_slider as svs
import torch
from scenarios import bimodal, dirac, gauss
DEMO_INTERVAL = 1.5
NTL_MSE_SCALING = 0.5
MAX_LOSS_PLOT = 15
LAST_STEP = -1
# """TODO:
# - Remove flickering of loss evolution scenario plot (lower ylim?)
# - Move manual part down (predicted token probabilities)
# - Allow to set GT token for each demo
# - Add text token to loss evolution barplot
# - pick good default (4?)
# """
# Define options globally as it's used in initialization and UI
options = [str(i) for i in range(10)] + ["Text"]
# --- Session State Initialization ---
# Ensure all session state variables are initialized before first use, especially by widgets.
if "running_demo" not in st.session_state:
st.session_state.running_demo = False
if "demo_step" not in st.session_state:
st.session_state.demo_step = 0
if "last_update_time" not in st.session_state:
st.session_state.last_update_time = 0
if "loss_container" not in st.session_state:
st.session_state.loss_container = None
if "previous_chart_html" not in st.session_state:
st.session_state.previous_chart_html = ""
if "active_scenarios" not in st.session_state:
# default if you want one to load on first show
st.session_state.active_scenarios = dirac
if "loss_history" not in st.session_state:
st.session_state.loss_history = []
# Initialize states for sliders and ground_truth selector
# Using len(options) to correctly size for 0-9 + "Text"
for i in range(len(options)):
if f"slider_{i}" not in st.session_state:
st.session_state[f"slider_{i}"] = 1.0 / len(options)
if "ground_truth" not in st.session_state:
st.session_state["ground_truth"] = options[0] # Default to "0"
st.title("Number Token Loss - Demo")
st.markdown(
"""
**Instructions**
1. **Pick a ground truth token (0–9).**
2. **Select one of the three automated demos:**
- **Dirac**: a one-hot (Dirac) distribution whose single 1.0 mass moves from token 0 all the way to “Text.”
- **Gaussian**: a peaked Gaussian (0.6 mass at center, 0.4 spread) that slides its center from token 0 to “Text.”
- **Bimodal**: two equal peaks (0.5 each) that start at (0,8) and then move symmetrically away from the GT token.
"""
)
if "ground_truth" not in st.session_state:
st.session_state["ground_truth"] = "4"
gt = st.selectbox(
"Ground Truth Token",
options=options,
index=options.index(st.session_state["ground_truth"]),
key="ground_truth",
)
def apply_scenario(step_idx):
scenario = st.session_state.active_scenarios[step_idx]
for i, val in enumerate(scenario["values"]):
st.session_state[f"slider_{i}"] = val
def start_dirac_demo():
st.session_state.active_scenarios = dirac
st.session_state.running_demo = True
st.session_state.demo_step = 0
st.session_state.last_update_time = time.time()
apply_scenario(0)
def start_gauss_demo():
st.session_state.active_scenarios = gauss
st.session_state.running_demo = True
st.session_state.demo_step = 0
st.session_state.last_update_time = time.time()
apply_scenario(0)
def start_bimodal_demo():
st.session_state.active_scenarios = bimodal
st.session_state.running_demo = True
st.session_state.demo_step = 0
st.session_state.last_update_time = time.time()
apply_scenario(0)
def stop_demo():
st.session_state.running_demo = False
# --- Demo State Advancement Logic ---
# This block handles advancing the demo. If it advances, it updates session state
# and then reruns. This ensures widgets are drawn with the new state in the next run.
if st.session_state.running_demo:
scenario = st.session_state.active_scenarios
current_time = time.time()
if current_time - st.session_state.last_update_time > DEMO_INTERVAL:
next_step = (st.session_state.demo_step + 1) % len(scenario)
st.session_state.demo_step = next_step
apply_scenario(next_step) # Update session state for the new scenario
st.session_state.last_update_time = time.time() # Reset timer
st.rerun() # Crucial: Rerun to reflect changes in widgets and charts
# --- UI Rendering ---
# This section renders the main UI. It executes after any potential rerun from the block above.
if st.session_state.running_demo:
st.info(
f"Showing scenario {st.session_state.demo_step + 1}"
f"/{len(st.session_state.active_scenarios)}: "
f"{st.session_state.active_scenarios[st.session_state.demo_step]['name']}"
)
if st.button("Stop Demo"):
st.session_state.running_demo = False
st.rerun()
else:
col1, col2, col3 = st.columns(3)
with col1:
if st.button("Run: Dirac"):
start_dirac_demo()
st.rerun()
with col2:
if st.button("Run: Gauss"):
start_gauss_demo()
st.rerun()
with col3:
if st.button("Run: Bimodal"):
start_bimodal_demo()
st.rerun()
# Placeholder for charts and loss calculations that will be updated
# This section always reads the current st.session_state to generate its content.
current_prob_values_from_state = [
st.session_state.get(f"slider_{j}", 1.0 / len(options)) for j in range(len(options))
]
total_from_state = sum(current_prob_values_from_state)
probs_for_charts = (
torch.ones(len(options)) / len(options)
if total_from_state == 0
else torch.tensor([v / total_from_state for v in current_prob_values_from_state])
)
gt_choice_for_charts = st.session_state.get("ground_truth", options[0])
if gt_choice_for_charts == "Text":
gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10)
gt_numeric_for_charts = None
else:
gt_index_for_charts = int(gt_choice_for_charts)
gt_numeric_for_charts = gt_index_for_charts
gt = st.session_state["ground_truth"]
st.markdown(f"#### Predicted Probability Distribution — Ground truth token {gt}")
df_dist = pd.DataFrame(
{"token": options, "probability": probs_for_charts.numpy().round(2)}
)
df_dist["type"] = [
"Ground Truth" if token == gt_choice_for_charts else "Prediction"
for token in options
]
bg = (
alt.Chart(pd.DataFrame({"token": [gt]}))
.mark_bar(size=40, color="lightgray", opacity=0.4)
.encode(
x=alt.X("token:N", sort=options),
x2=alt.X2("token:N"), # pin the right edge to the same category
y=alt.value(0), # bottom at y=0
y2=alt.value(1), # top at y=1 (full height)
)
)
bars = (
alt.Chart(df_dist)
.mark_bar()
.encode(
x=alt.X(
"token:N",
title="Token",
sort=options,
axis=alt.Axis(labelAngle=0, labelFontSize=14, titleFontSize=16),
),
y=alt.Y(
"probability:Q",
title="Probability",
scale=alt.Scale(domain=[0, 1]),
axis=alt.Axis(format=".2f", labelFontSize=14, titleFontSize=16),
),
color=alt.Color(
"type:N",
scale=alt.Scale(
domain=["Ground Truth", "Prediction"], range=["green", "steelblue"]
),
legend=alt.Legend(title="Token Type", titleFontSize=16, labelFontSize=14),
),
tooltip=[
alt.Tooltip("token:N", title="Token"),
alt.Tooltip("probability:Q", title="Probability", format=".2f"),
alt.Tooltip("type:N", title="Type"),
],
)
.properties(height=300)
)
annot1 = (
alt.Chart(pd.DataFrame({"token": [gt]}))
.mark_text(
text="⬇ Ground",
dy=-25, # 10px above the top of the bar
dx=25,
fontSize=14,
fontWeight="bold",
color="green",
)
.encode(x=alt.X("token:N", sort=options), y=alt.value(1))
)
# second line: “truth=4”
annot2 = (
alt.Chart(pd.DataFrame({"token": [gt]}))
.mark_text(
text=f"truth={gt}",
dy=-10, # 25px above the top, so it sits above line 1
dx=35,
fontSize=14,
fontWeight="bold",
color="green",
)
.encode(x=alt.X("token:N", sort=options), y=alt.value(1))
)
# 4) Layer them in order: background, bars, annotation
final_chart = (bg + bars + annot1 + annot2).properties(height=300)
st.altair_chart(final_chart, use_container_width=True)
ce_loss = -torch.log(torch.clamp(probs_for_charts[gt_index_for_charts], min=1e-9))
if gt_numeric_for_charts is None: # Text token
ntl_mse_loss = torch.tensor(float("nan")) # MSE not applicable for text
ntl_was_loss = torch.tensor(float("nan")) # WAS not applicable for text
else: # Numeric token
numeric_probs_for_loss = probs_for_charts[:10] # Probabilities for 0-9
# Ensure numeric_probs_for_loss sums to 1 for NTL calculations if it's a subset
numeric_probs_sum = torch.sum(numeric_probs_for_loss)
if numeric_probs_sum > 1e-6: # Avoid division by zero
normalized_numeric_probs = numeric_probs_for_loss / numeric_probs_sum
else:
normalized_numeric_probs = torch.zeros_like(numeric_probs_for_loss)
loss_values_tensor = torch.arange(0, 10, dtype=torch.float32)
# Use normalized probabilities for NTL if only considering numeric tokens
if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6:
pred_value = torch.sum(
(probs_for_charts[:10] / torch.sum(probs_for_charts[:10]))
* loss_values_tensor
)
elif (
gt_choice_for_charts != "Text"
): # if sum is zero, pred_value is ill-defined or 0
pred_value = torch.tensor(0.0)
else: # Should not happen if gt_numeric_for_charts is not None
pred_value = torch.tensor(float("nan"))
if not torch.isnan(pred_value):
ntl_mse_loss = ntl_mse_loss = (
NTL_MSE_SCALING * (pred_value - float(gt_numeric_for_charts)) ** 2
)
abs_diff = torch.abs(loss_values_tensor - float(gt_numeric_for_charts))
if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6:
ntl_was_loss = torch.sum(
(probs_for_charts[:10] / torch.sum(probs_for_charts[:10])) * abs_diff
)
elif gt_choice_for_charts != "Text":
ntl_was_loss = torch.tensor(0.0)
else:
ntl_was_loss = torch.tensor(float("nan"))
else:
ntl_mse_loss = torch.tensor(float("nan"))
ntl_was_loss = torch.tensor(float("nan"))
ce_val = round(ce_loss.item(), 3)
mse_val = round(ntl_mse_loss.item(), 3) if not torch.isnan(ntl_mse_loss) else "N/A"
was_val = round(ntl_was_loss.item(), 3) if not torch.isnan(ntl_was_loss) else "N/A"
if len(st.session_state.loss_history) < st.session_state.demo_step + 1:
st.session_state.loss_history.append(
{
"token_index": np.argmax(
st.session_state.active_scenarios[st.session_state["demo_step"]][
"values"
]
),
# int(np.argmax(st.session_state['values']))
# int(),
"CE": ce_val,
"NTL-MSE": mse_val if mse_val != "N/A" else None,
"NTL-WAS": was_val if was_val != "N/A" else None,
}
)
last_step = st.session_state.demo_step
if st.session_state.loss_history:
loss_plot_data = []
for entry in st.session_state.loss_history:
for loss_type in ["CE", "NTL-MSE", "NTL-WAS"]:
if entry[loss_type] is not None:
loss_plot_data.append(
{
"Token Index": entry["token_index"],
"Loss Type": loss_type,
"Loss Value": entry[loss_type], # TODO: clip to MAX_LOSS_PLOT?
}
)
df_loss_plot = pd.DataFrame(loss_plot_data)
loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]}
if was_val != "N/A":
loss_data["Loss"].append("NTL-WAS")
loss_data["Value"].append(was_val)
if mse_val != "N/A":
loss_data["Loss"].append("NTL-MSE")
loss_data["Value"].append(mse_val)
loss_df = pd.DataFrame(loss_data)
# ============== Chart Display ==============
st.subheader("Loss Evolution Over Scenarios")
x_domain = list(range(10))
grouped_chart = (
alt.Chart(df_loss_plot)
.mark_bar()
.encode(
x=alt.X(
"Token Index:O",
title="Predicted Token Index",
axis=alt.Axis(labelAngle=0),
scale=alt.Scale(domain=x_domain),
),
y=alt.Y(
"Loss Value:Q", title="Loss", scale=alt.Scale(domain=[0, MAX_LOSS_PLOT])
),
color=alt.Color("Loss Type:N", legend=alt.Legend(title="Loss")),
xOffset="Loss Type:N", # <== this causes the grouping instead of stacking
)
.properties(height=300)
)
st.altair_chart(grouped_chart, use_container_width=True)
# Create a single chart for loss visualization
st.subheader("Loss Comparison")
st.markdown("""
Adjust the sliders to set a predicted probability for each token (0-9 and "Text").
The sliders are vertical and compact. The app normalizes the slider values
to form a valid probability distribution, visualizes it, and computes the corresponding
Cross Entropy, NTL-MSE, and NTL-WAS losses.
""")
# Create an Altair chart that will look good and redraw cleanly
chart = (
alt.Chart(loss_df)
.mark_bar()
.encode(
x=alt.X("Loss:N", sort=loss_df["Loss"].tolist()),
y=alt.Y(
"Value:Q",
scale=alt.Scale(
domain=[
0,
max(
loss_df["Value"].max() * 1.2,
20 if st.session_state.running_demo else 0.5,
),
]
),
),
color=alt.Color(
"Loss:N",
scale=alt.Scale(
domain=["Cross Entropy", "NTL-WAS", "NTL-MSE"],
range=["steelblue", "red", "forestgreen"],
),
),
tooltip=["Loss", "Value"],
)
.properties(height=300)
)
# Sliders and Ground Truth Selector
# These widgets will read their initial values from st.session_state.
# User interactions will update st.session_state directly due to their keys.
if not st.session_state.running_demo:
st.markdown("#### Predicted Token Probabilities")
cols = st.columns(len(options))
for i, col in enumerate(cols):
label = options[i] # Use token name directly for label
with col:
svs.vertical_slider(
label=label,
min_value=0.0,
max_value=1.0,
step=0.01,
height=50,
key=f"slider_{i}", # This key links the widget to st.session_state[f"slider_{i}"]
slider_color="green",
track_color="lightgray",
thumb_color="black",
)
# Add value labels on top of bars
text = chart.mark_text(align="center", baseline="bottom", dy=-5, fontSize=14).encode(
text=alt.Text("Value:Q", format=".3f")
)
# Combine chart and text
final_chart = chart + text
# Display chart with the full container width
st.altair_chart(final_chart, use_container_width=True)
# --- Polling Rerun for Demo Mode ---
# If the demo is running and we haven't just advanced (which would have caused a rerun),
# then we do a short sleep and rerun to keep the polling loop alive.
if st.session_state.running_demo:
# This check is implicitly: if we are here and demo is running, it means
# the time-based advance condition was NOT met in the block at the top.
time.sleep(0.1)
st.rerun()
# Add explanation of the demonstration
st.markdown("""
### What Does This Demo Show?
- **Cross Entropy Loss**: Only cares if the prediction is exactly right or wrong - it doesn't consider how "close" a numerical prediction is.
- **Number Token Loss (NTL)**: Considers numerical proximity - predicting "7" when the true value is "8" is better than predicting "2".
""")
# References / resources section with links (common to both modes)
st.markdown("### Resources")
st.markdown("""
- [Paper: Number Token Loss (ArXiv)](https://arxiv.org/abs/2411.02083)
- [GitHub: Number Token Loss](https://github.com/tum-ai/number-token-loss)
""")