Spaces:
Running
Running
import time | |
import altair as alt | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
import streamlit_vertical_slider as svs | |
import torch | |
from scenarios import bimodal, dirac, gauss | |
DEMO_INTERVAL = 1.5 | |
NTL_MSE_SCALING = 0.5 | |
MAX_LOSS_PLOT = 15 | |
LAST_STEP = -1 | |
# """TODO: | |
# - Remove flickering of loss evolution scenario plot (lower ylim?) | |
# - Move manual part down (predicted token probabilities) | |
# - Allow to set GT token for each demo | |
# - Add text token to loss evolution barplot | |
# - pick good default (4?) | |
# """ | |
# Define options globally as it's used in initialization and UI | |
options = [str(i) for i in range(10)] + ["Text"] | |
# --- Session State Initialization --- | |
# Ensure all session state variables are initialized before first use, especially by widgets. | |
if "running_demo" not in st.session_state: | |
st.session_state.running_demo = False | |
if "demo_step" not in st.session_state: | |
st.session_state.demo_step = 0 | |
if "last_update_time" not in st.session_state: | |
st.session_state.last_update_time = 0 | |
if "loss_container" not in st.session_state: | |
st.session_state.loss_container = None | |
if "previous_chart_html" not in st.session_state: | |
st.session_state.previous_chart_html = "" | |
if "active_scenarios" not in st.session_state: | |
# default if you want one to load on first show | |
st.session_state.active_scenarios = dirac | |
if "loss_history" not in st.session_state: | |
st.session_state.loss_history = [] | |
# Initialize states for sliders and ground_truth selector | |
# Using len(options) to correctly size for 0-9 + "Text" | |
for i in range(len(options)): | |
if f"slider_{i}" not in st.session_state: | |
st.session_state[f"slider_{i}"] = 1.0 / len(options) | |
if "ground_truth" not in st.session_state: | |
st.session_state["ground_truth"] = options[0] # Default to "0" | |
st.title("Number Token Loss - Demo") | |
st.markdown( | |
""" | |
**Instructions** | |
1. **Pick a ground truth token (0–9).** | |
2. **Select one of the three automated demos:** | |
- **Dirac**: a one-hot (Dirac) distribution whose single 1.0 mass moves from token 0 all the way to “Text.” | |
- **Gaussian**: a peaked Gaussian (0.6 mass at center, 0.4 spread) that slides its center from token 0 to “Text.” | |
- **Bimodal**: two equal peaks (0.5 each) that start at (0,8) and then move symmetrically away from the GT token. | |
""" | |
) | |
if "ground_truth" not in st.session_state: | |
st.session_state["ground_truth"] = "4" | |
gt = st.selectbox( | |
"Ground Truth Token", | |
options=options, | |
index=options.index(st.session_state["ground_truth"]), | |
key="ground_truth", | |
) | |
def apply_scenario(step_idx): | |
scenario = st.session_state.active_scenarios[step_idx] | |
for i, val in enumerate(scenario["values"]): | |
st.session_state[f"slider_{i}"] = val | |
def start_dirac_demo(): | |
st.session_state.active_scenarios = dirac | |
st.session_state.running_demo = True | |
st.session_state.demo_step = 0 | |
st.session_state.last_update_time = time.time() | |
apply_scenario(0) | |
def start_gauss_demo(): | |
st.session_state.active_scenarios = gauss | |
st.session_state.running_demo = True | |
st.session_state.demo_step = 0 | |
st.session_state.last_update_time = time.time() | |
apply_scenario(0) | |
def start_bimodal_demo(): | |
st.session_state.active_scenarios = bimodal | |
st.session_state.running_demo = True | |
st.session_state.demo_step = 0 | |
st.session_state.last_update_time = time.time() | |
apply_scenario(0) | |
def stop_demo(): | |
st.session_state.running_demo = False | |
# --- Demo State Advancement Logic --- | |
# This block handles advancing the demo. If it advances, it updates session state | |
# and then reruns. This ensures widgets are drawn with the new state in the next run. | |
if st.session_state.running_demo: | |
scenario = st.session_state.active_scenarios | |
current_time = time.time() | |
if current_time - st.session_state.last_update_time > DEMO_INTERVAL: | |
next_step = (st.session_state.demo_step + 1) % len(scenario) | |
st.session_state.demo_step = next_step | |
apply_scenario(next_step) # Update session state for the new scenario | |
st.session_state.last_update_time = time.time() # Reset timer | |
st.rerun() # Crucial: Rerun to reflect changes in widgets and charts | |
# --- UI Rendering --- | |
# This section renders the main UI. It executes after any potential rerun from the block above. | |
if st.session_state.running_demo: | |
st.info( | |
f"Showing scenario {st.session_state.demo_step + 1}" | |
f"/{len(st.session_state.active_scenarios)}: " | |
f"{st.session_state.active_scenarios[st.session_state.demo_step]['name']}" | |
) | |
if st.button("Stop Demo"): | |
st.session_state.running_demo = False | |
st.rerun() | |
else: | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
if st.button("Run: Dirac"): | |
start_dirac_demo() | |
st.rerun() | |
with col2: | |
if st.button("Run: Gauss"): | |
start_gauss_demo() | |
st.rerun() | |
with col3: | |
if st.button("Run: Bimodal"): | |
start_bimodal_demo() | |
st.rerun() | |
# Placeholder for charts and loss calculations that will be updated | |
# This section always reads the current st.session_state to generate its content. | |
current_prob_values_from_state = [ | |
st.session_state.get(f"slider_{j}", 1.0 / len(options)) for j in range(len(options)) | |
] | |
total_from_state = sum(current_prob_values_from_state) | |
probs_for_charts = ( | |
torch.ones(len(options)) / len(options) | |
if total_from_state == 0 | |
else torch.tensor([v / total_from_state for v in current_prob_values_from_state]) | |
) | |
gt_choice_for_charts = st.session_state.get("ground_truth", options[0]) | |
if gt_choice_for_charts == "Text": | |
gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10) | |
gt_numeric_for_charts = None | |
else: | |
gt_index_for_charts = int(gt_choice_for_charts) | |
gt_numeric_for_charts = gt_index_for_charts | |
gt = st.session_state["ground_truth"] | |
st.markdown(f"#### Predicted Probability Distribution — Ground truth token {gt}") | |
df_dist = pd.DataFrame( | |
{"token": options, "probability": probs_for_charts.numpy().round(2)} | |
) | |
df_dist["type"] = [ | |
"Ground Truth" if token == gt_choice_for_charts else "Prediction" | |
for token in options | |
] | |
bg = ( | |
alt.Chart(pd.DataFrame({"token": [gt]})) | |
.mark_bar(size=40, color="lightgray", opacity=0.4) | |
.encode( | |
x=alt.X("token:N", sort=options), | |
x2=alt.X2("token:N"), # pin the right edge to the same category | |
y=alt.value(0), # bottom at y=0 | |
y2=alt.value(1), # top at y=1 (full height) | |
) | |
) | |
bars = ( | |
alt.Chart(df_dist) | |
.mark_bar() | |
.encode( | |
x=alt.X( | |
"token:N", | |
title="Token", | |
sort=options, | |
axis=alt.Axis(labelAngle=0, labelFontSize=14, titleFontSize=16), | |
), | |
y=alt.Y( | |
"probability:Q", | |
title="Probability", | |
scale=alt.Scale(domain=[0, 1]), | |
axis=alt.Axis(format=".2f", labelFontSize=14, titleFontSize=16), | |
), | |
color=alt.Color( | |
"type:N", | |
scale=alt.Scale( | |
domain=["Ground Truth", "Prediction"], range=["green", "steelblue"] | |
), | |
legend=alt.Legend(title="Token Type", titleFontSize=16, labelFontSize=14), | |
), | |
tooltip=[ | |
alt.Tooltip("token:N", title="Token"), | |
alt.Tooltip("probability:Q", title="Probability", format=".2f"), | |
alt.Tooltip("type:N", title="Type"), | |
], | |
) | |
.properties(height=300) | |
) | |
annot1 = ( | |
alt.Chart(pd.DataFrame({"token": [gt]})) | |
.mark_text( | |
text="⬇ Ground", | |
dy=-25, # 10px above the top of the bar | |
dx=25, | |
fontSize=14, | |
fontWeight="bold", | |
color="green", | |
) | |
.encode(x=alt.X("token:N", sort=options), y=alt.value(1)) | |
) | |
# second line: “truth=4” | |
annot2 = ( | |
alt.Chart(pd.DataFrame({"token": [gt]})) | |
.mark_text( | |
text=f"truth={gt}", | |
dy=-10, # 25px above the top, so it sits above line 1 | |
dx=35, | |
fontSize=14, | |
fontWeight="bold", | |
color="green", | |
) | |
.encode(x=alt.X("token:N", sort=options), y=alt.value(1)) | |
) | |
# 4) Layer them in order: background, bars, annotation | |
final_chart = (bg + bars + annot1 + annot2).properties(height=300) | |
st.altair_chart(final_chart, use_container_width=True) | |
ce_loss = -torch.log(torch.clamp(probs_for_charts[gt_index_for_charts], min=1e-9)) | |
if gt_numeric_for_charts is None: # Text token | |
ntl_mse_loss = torch.tensor(float("nan")) # MSE not applicable for text | |
ntl_was_loss = torch.tensor(float("nan")) # WAS not applicable for text | |
else: # Numeric token | |
numeric_probs_for_loss = probs_for_charts[:10] # Probabilities for 0-9 | |
# Ensure numeric_probs_for_loss sums to 1 for NTL calculations if it's a subset | |
numeric_probs_sum = torch.sum(numeric_probs_for_loss) | |
if numeric_probs_sum > 1e-6: # Avoid division by zero | |
normalized_numeric_probs = numeric_probs_for_loss / numeric_probs_sum | |
else: | |
normalized_numeric_probs = torch.zeros_like(numeric_probs_for_loss) | |
loss_values_tensor = torch.arange(0, 10, dtype=torch.float32) | |
# Use normalized probabilities for NTL if only considering numeric tokens | |
if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6: | |
pred_value = torch.sum( | |
(probs_for_charts[:10] / torch.sum(probs_for_charts[:10])) | |
* loss_values_tensor | |
) | |
elif ( | |
gt_choice_for_charts != "Text" | |
): # if sum is zero, pred_value is ill-defined or 0 | |
pred_value = torch.tensor(0.0) | |
else: # Should not happen if gt_numeric_for_charts is not None | |
pred_value = torch.tensor(float("nan")) | |
if not torch.isnan(pred_value): | |
ntl_mse_loss = ntl_mse_loss = ( | |
NTL_MSE_SCALING * (pred_value - float(gt_numeric_for_charts)) ** 2 | |
) | |
abs_diff = torch.abs(loss_values_tensor - float(gt_numeric_for_charts)) | |
if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6: | |
ntl_was_loss = torch.sum( | |
(probs_for_charts[:10] / torch.sum(probs_for_charts[:10])) * abs_diff | |
) | |
elif gt_choice_for_charts != "Text": | |
ntl_was_loss = torch.tensor(0.0) | |
else: | |
ntl_was_loss = torch.tensor(float("nan")) | |
else: | |
ntl_mse_loss = torch.tensor(float("nan")) | |
ntl_was_loss = torch.tensor(float("nan")) | |
ce_val = round(ce_loss.item(), 3) | |
mse_val = round(ntl_mse_loss.item(), 3) if not torch.isnan(ntl_mse_loss) else "N/A" | |
was_val = round(ntl_was_loss.item(), 3) if not torch.isnan(ntl_was_loss) else "N/A" | |
if len(st.session_state.loss_history) < st.session_state.demo_step + 1: | |
st.session_state.loss_history.append( | |
{ | |
"token_index": np.argmax( | |
st.session_state.active_scenarios[st.session_state["demo_step"]][ | |
"values" | |
] | |
), | |
# int(np.argmax(st.session_state['values'])) | |
# int(), | |
"CE": ce_val, | |
"NTL-MSE": mse_val if mse_val != "N/A" else None, | |
"NTL-WAS": was_val if was_val != "N/A" else None, | |
} | |
) | |
last_step = st.session_state.demo_step | |
if st.session_state.loss_history: | |
loss_plot_data = [] | |
for entry in st.session_state.loss_history: | |
for loss_type in ["CE", "NTL-MSE", "NTL-WAS"]: | |
if entry[loss_type] is not None: | |
loss_plot_data.append( | |
{ | |
"Token Index": entry["token_index"], | |
"Loss Type": loss_type, | |
"Loss Value": entry[loss_type], # TODO: clip to MAX_LOSS_PLOT? | |
} | |
) | |
df_loss_plot = pd.DataFrame(loss_plot_data) | |
loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]} | |
if was_val != "N/A": | |
loss_data["Loss"].append("NTL-WAS") | |
loss_data["Value"].append(was_val) | |
if mse_val != "N/A": | |
loss_data["Loss"].append("NTL-MSE") | |
loss_data["Value"].append(mse_val) | |
loss_df = pd.DataFrame(loss_data) | |
# ============== Chart Display ============== | |
st.subheader("Loss Evolution Over Scenarios") | |
x_domain = list(range(10)) | |
grouped_chart = ( | |
alt.Chart(df_loss_plot) | |
.mark_bar() | |
.encode( | |
x=alt.X( | |
"Token Index:O", | |
title="Predicted Token Index", | |
axis=alt.Axis(labelAngle=0), | |
scale=alt.Scale(domain=x_domain), | |
), | |
y=alt.Y( | |
"Loss Value:Q", title="Loss", scale=alt.Scale(domain=[0, MAX_LOSS_PLOT]) | |
), | |
color=alt.Color("Loss Type:N", legend=alt.Legend(title="Loss")), | |
xOffset="Loss Type:N", # <== this causes the grouping instead of stacking | |
) | |
.properties(height=300) | |
) | |
st.altair_chart(grouped_chart, use_container_width=True) | |
# Create a single chart for loss visualization | |
st.subheader("Loss Comparison") | |
st.markdown(""" | |
Adjust the sliders to set a predicted probability for each token (0-9 and "Text"). | |
The sliders are vertical and compact. The app normalizes the slider values | |
to form a valid probability distribution, visualizes it, and computes the corresponding | |
Cross Entropy, NTL-MSE, and NTL-WAS losses. | |
""") | |
# Create an Altair chart that will look good and redraw cleanly | |
chart = ( | |
alt.Chart(loss_df) | |
.mark_bar() | |
.encode( | |
x=alt.X("Loss:N", sort=loss_df["Loss"].tolist()), | |
y=alt.Y( | |
"Value:Q", | |
scale=alt.Scale( | |
domain=[ | |
0, | |
max( | |
loss_df["Value"].max() * 1.2, | |
20 if st.session_state.running_demo else 0.5, | |
), | |
] | |
), | |
), | |
color=alt.Color( | |
"Loss:N", | |
scale=alt.Scale( | |
domain=["Cross Entropy", "NTL-WAS", "NTL-MSE"], | |
range=["steelblue", "red", "forestgreen"], | |
), | |
), | |
tooltip=["Loss", "Value"], | |
) | |
.properties(height=300) | |
) | |
# Sliders and Ground Truth Selector | |
# These widgets will read their initial values from st.session_state. | |
# User interactions will update st.session_state directly due to their keys. | |
if not st.session_state.running_demo: | |
st.markdown("#### Predicted Token Probabilities") | |
cols = st.columns(len(options)) | |
for i, col in enumerate(cols): | |
label = options[i] # Use token name directly for label | |
with col: | |
svs.vertical_slider( | |
label=label, | |
min_value=0.0, | |
max_value=1.0, | |
step=0.01, | |
height=50, | |
key=f"slider_{i}", # This key links the widget to st.session_state[f"slider_{i}"] | |
slider_color="green", | |
track_color="lightgray", | |
thumb_color="black", | |
) | |
# Add value labels on top of bars | |
text = chart.mark_text(align="center", baseline="bottom", dy=-5, fontSize=14).encode( | |
text=alt.Text("Value:Q", format=".3f") | |
) | |
# Combine chart and text | |
final_chart = chart + text | |
# Display chart with the full container width | |
st.altair_chart(final_chart, use_container_width=True) | |
# --- Polling Rerun for Demo Mode --- | |
# If the demo is running and we haven't just advanced (which would have caused a rerun), | |
# then we do a short sleep and rerun to keep the polling loop alive. | |
if st.session_state.running_demo: | |
# This check is implicitly: if we are here and demo is running, it means | |
# the time-based advance condition was NOT met in the block at the top. | |
time.sleep(0.1) | |
st.rerun() | |
# Add explanation of the demonstration | |
st.markdown(""" | |
### What Does This Demo Show? | |
- **Cross Entropy Loss**: Only cares if the prediction is exactly right or wrong - it doesn't consider how "close" a numerical prediction is. | |
- **Number Token Loss (NTL)**: Considers numerical proximity - predicting "7" when the true value is "8" is better than predicting "2". | |
""") | |
# References / resources section with links (common to both modes) | |
st.markdown("### Resources") | |
st.markdown(""" | |
- [Paper: Number Token Loss (ArXiv)](https://arxiv.org/abs/2411.02083) | |
- [GitHub: Number Token Loss](https://github.com/tum-ai/number-token-loss) | |
""") | |