Spaces:
Running
Running
import altair as alt | |
import pandas as pd | |
import streamlit_vertical_slider as svs | |
import torch | |
# from streamlit_vertical_slider import vertical_slider # Not directly used, svs.vertical_slider is | |
import streamlit as st | |
import time | |
import plotly.graph_objects as go # Add Plotly import | |
# Define options globally as it's used in initialization and UI | |
options = [str(i) for i in range(10)] + ["Text"] | |
# --- Session State Initialization --- | |
# Ensure all session state variables are initialized before first use, especially by widgets. | |
if 'running_demo' not in st.session_state: | |
st.session_state.running_demo = False | |
if 'demo_step' not in st.session_state: | |
st.session_state.demo_step = 0 | |
if 'last_update_time' not in st.session_state: | |
st.session_state.last_update_time = 0 | |
if 'loss_container' not in st.session_state: | |
st.session_state.loss_container = None | |
if 'previous_chart_html' not in st.session_state: | |
st.session_state.previous_chart_html = "" | |
# Initialize states for sliders and ground_truth selector | |
# Using len(options) to correctly size for 0-9 + "Text" | |
for i in range(len(options)): | |
if f"slider_{i}" not in st.session_state: | |
st.session_state[f"slider_{i}"] = 1.0 / len(options) | |
if 'ground_truth' not in st.session_state: | |
st.session_state['ground_truth'] = options[0] # Default to "0" | |
st.title("Number Token Loss - Demo") | |
st.markdown(""" | |
Adjust the sliders to set a predicted probability for each token (0-9 and "Text"). | |
The sliders are vertical and compact. The app normalizes the slider values | |
to form a valid probability distribution, visualizes it, and computes the corresponding | |
Cross Entropy, NTL-MSE, and NTL-WAS losses. | |
""") | |
# --- Scenario Definitions --- | |
scenarios = [ | |
{ | |
"name": "Probability mass at 0", | |
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "0", | |
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." | |
}, | |
{ | |
"name": "Probability mass at 0", | |
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "1", | |
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." | |
}, | |
{ | |
"name": "Probability mass at 0", | |
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "2", | |
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." | |
}, | |
{ | |
"name": "Probability mass at 0", | |
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "3", | |
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." | |
}, | |
{ | |
"name": "Probability mass at 0", | |
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "4", | |
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." | |
}, | |
{ | |
"name": "Probability mass at 0", | |
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "5", | |
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." | |
}, | |
{ | |
"name": "Probability mass at 0", | |
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "6", | |
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." | |
}, | |
{ | |
"name": "Probability mass at 0", | |
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "7", | |
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." | |
}, | |
{ | |
"name": "Probability mass at 0", | |
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "8", | |
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." | |
}, | |
{ | |
"name": "Probability mass at 0", | |
"values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "9", | |
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." | |
}, | |
{ | |
"name": "Probability mass around 5", | |
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values | |
"ground_truth": "0", | |
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." | |
}, | |
{ | |
"name": "Probability mass around 5", | |
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values | |
"ground_truth": "1", | |
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." | |
}, | |
{ | |
"name": "Probability mass around 5", | |
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values | |
"ground_truth": "2", | |
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." | |
}, | |
{ | |
"name": "Probability mass around 5", | |
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values | |
"ground_truth": "3", | |
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." | |
}, | |
{ | |
"name": "Probability mass around 5", | |
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values | |
"ground_truth": "4", | |
"explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." | |
}, | |
{ | |
"name": "Probability mass around ground truth (5)", | |
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values | |
"ground_truth": "5", | |
"explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth." | |
}, | |
{ | |
"name": "Probability mass around 5", | |
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values | |
"ground_truth": "6", | |
"explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth." | |
}, | |
{ | |
"name": "Probability mass around 5", | |
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values | |
"ground_truth": "7", | |
"explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth." | |
}, | |
{ | |
"name": "Probability mass around 5", | |
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values | |
"ground_truth": "8", | |
"explanation": "Cross Entropy is high, NTL is higher but still penalizes less than CE because distribution knows it's a number." | |
}, | |
{ | |
"name": "Probability mass around 5", | |
"values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values | |
"ground_truth": "9", | |
"explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth." | |
}, | |
{ | |
"name": "Probability mass concentrated on 5", | |
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "0", | |
"explanation": "Both CE and NTL are high because the prediction is far from correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 5", | |
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "1", | |
"explanation": "Both CE and NTL are high because the prediction is far from correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 5", | |
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "2", | |
"explanation": "Both CE and NTL are high because the prediction is far from correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 5", | |
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "3", | |
"explanation": "Both CE and NTL are high because the prediction is far from correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 5", | |
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "4", | |
"explanation": "Both CE and NTL are high because the prediction is far from correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 5", | |
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "5", | |
"explanation": "Both CE and NTL are high because the prediction is far from correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 5", | |
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "6", | |
"explanation": "Both CE and NTL are high because the prediction is far from correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 5", | |
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "7", | |
"explanation": "Both CE and NTL are high because the prediction is far from correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 5", | |
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "8", | |
"explanation": "Both CE and NTL are high because the prediction is far from correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 5", | |
"values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values | |
"ground_truth": "9", | |
"explanation": "Both CE and NTL are high because the prediction is far from correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 1", | |
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values | |
"ground_truth": "0", | |
"explanation": "Both losses are low because the prediction is correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 1", | |
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values | |
"ground_truth": "1", | |
"explanation": "Both losses are low because the prediction is correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 1", | |
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values | |
"ground_truth": "2", | |
"explanation": "Both losses are low because the prediction is correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 1", | |
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values | |
"ground_truth": "3", | |
"explanation": "Both losses are low because the prediction is correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 1", | |
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values | |
"ground_truth": "4", | |
"explanation": "Both losses are low because the prediction is correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 1", | |
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values | |
"ground_truth": "5", | |
"explanation": "Both losses are low because the prediction is correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 1", | |
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values | |
"ground_truth": "6", | |
"explanation": "Both losses are low because the prediction is correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 1", | |
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values | |
"ground_truth": "7", | |
"explanation": "Both losses are low because the prediction is correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 1", | |
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values | |
"ground_truth": "8", | |
"explanation": "Both losses are low because the prediction is correct." | |
}, | |
{ | |
"name": "Probability mass concentrated on 1", | |
"values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values | |
"ground_truth": "9", | |
"explanation": "Both losses are low because the prediction is correct." | |
}, | |
{ | |
"name": "Almost correct (1 vs 2)", | |
"values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values | |
"ground_truth": "0", | |
"explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close." | |
}, | |
{ | |
"name": "Almost correct (1 vs 2)", | |
"values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values | |
"ground_truth": "1", | |
"explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close." | |
}, | |
{ | |
"name": "Almost correct (1 vs 2)", | |
"values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values | |
"ground_truth": "2", | |
"explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close." | |
}, | |
{ | |
"name": "Almost correct (1 vs 2)", | |
"values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values | |
"ground_truth": "3", | |
"explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close." | |
} | |
] | |
# --- Helper Functions --- | |
def apply_scenario(step_idx): | |
scenario = scenarios[step_idx] | |
# These assignments modify session state. They must be done *before* the widgets | |
# are rendered in the script run that should display these new values. | |
for i, val in enumerate(scenario["values"]): | |
st.session_state[f"slider_{i}"] = val | |
st.session_state['ground_truth'] = scenario["ground_truth"] | |
def start_demo(): | |
st.session_state.running_demo = True | |
st.session_state.demo_step = 0 | |
st.session_state.last_update_time = time.time() | |
apply_scenario(0) # Apply the first scenario's state | |
# The button click that calls start_demo() will itself cause a rerun. | |
def stop_demo(): | |
st.session_state.running_demo = False | |
# --- Demo State Advancement Logic --- | |
# This block handles advancing the demo. If it advances, it updates session state | |
# and then reruns. This ensures widgets are drawn with the new state in the next run. | |
if st.session_state.running_demo: | |
current_time = time.time() | |
if current_time - st.session_state.last_update_time > 3.0: # 3 seconds per scenario | |
next_step = (st.session_state.demo_step + 1) % len(scenarios) | |
st.session_state.demo_step = next_step | |
apply_scenario(next_step) # Update session state for the new scenario | |
st.session_state.last_update_time = time.time() # Reset timer | |
st.rerun() # Crucial: Rerun to reflect changes in widgets and charts | |
# --- UI Rendering --- | |
# This section renders the main UI. It executes after any potential rerun from the block above. | |
if st.session_state.running_demo: | |
st.info(f"Showing scenario {st.session_state.demo_step + 1}/{len(scenarios)}: {scenarios[st.session_state.demo_step]['name']}") | |
st.markdown(f"**Explanation:** {scenarios[st.session_state.demo_step]['explanation']}") | |
if st.button("Stop Demo"): | |
stop_demo() | |
st.rerun() | |
else: # Not st.session_state.running_demo | |
if st.button("Start Automated Demo"): | |
start_demo() # This calls apply_scenario(0) | |
st.rerun() # Rerun to enter demo mode and draw scenario 0 correctly | |
# Sliders and Ground Truth Selector | |
# These widgets will read their initial values from st.session_state. | |
# User interactions will update st.session_state directly due to their keys. | |
if not st.session_state.running_demo: | |
st.markdown("#### Predicted Token Probabilities") | |
cols = st.columns(len(options)) | |
for i, col in enumerate(cols): | |
label = options[i] # Use token name directly for label | |
with col: | |
svs.vertical_slider( | |
label=label, min_value=0.0, max_value=1.0, step=0.01, height=50, | |
key=f"slider_{i}", # This key links the widget to st.session_state[f"slider_{i}"] | |
slider_color="green", track_color="lightgray", thumb_color="black" | |
) | |
# Ground truth selectbox | |
st.selectbox( | |
"Ground Truth Token", options=options, | |
index=options.index(st.session_state['ground_truth']), # Display value from session state | |
key='ground_truth' # Links widget to st.session_state['ground_truth'] | |
) | |
# Placeholder for charts and loss calculations that will be updated | |
# This section always reads the current st.session_state to generate its content. | |
current_prob_values_from_state = [st.session_state.get(f"slider_{j}", 1.0/len(options)) for j in range(len(options))] | |
total_from_state = sum(current_prob_values_from_state) | |
probs_for_charts = ( | |
torch.ones(len(options)) / len(options) | |
if total_from_state == 0 | |
else torch.tensor([v / total_from_state for v in current_prob_values_from_state]) | |
) | |
gt_choice_for_charts = st.session_state.get('ground_truth', options[0]) | |
if gt_choice_for_charts == "Text": | |
gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10) | |
gt_numeric_for_charts = None | |
else: | |
gt_index_for_charts = int(gt_choice_for_charts) | |
gt_numeric_for_charts = gt_index_for_charts | |
st.markdown("#### Input Probability Distribution") | |
df_dist = pd.DataFrame({"token": options, "probability": probs_for_charts.numpy()}) | |
df_dist["type"] = ["Ground Truth" if token == gt_choice_for_charts else "Prediction" for token in options] | |
chart = ( | |
alt.Chart(df_dist).mark_bar().encode( | |
x=alt.X("token:N", title="Token", sort=options), # Ensure consistent sort order | |
y=alt.Y("probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1])), | |
color=alt.Color("type:N", scale=alt.Scale(domain=["Ground Truth", "Prediction"], range=["green", "steelblue"]), legend=alt.Legend(title="Token Type")) | |
).properties(height=300) | |
) | |
st.altair_chart(chart, use_container_width=True) | |
ce_loss = -torch.log(torch.clamp(probs_for_charts[gt_index_for_charts], min=1e-9)) | |
if gt_numeric_for_charts is None: # Text token | |
ntl_mse_loss = torch.tensor(float('nan')) # MSE not applicable for text | |
ntl_was_loss = torch.tensor(float('nan')) # WAS not applicable for text | |
else: # Numeric token | |
numeric_probs_for_loss = probs_for_charts[:10] # Probabilities for 0-9 | |
# Ensure numeric_probs_for_loss sums to 1 for NTL calculations if it's a subset | |
numeric_probs_sum = torch.sum(numeric_probs_for_loss) | |
if numeric_probs_sum > 1e-6 : # Avoid division by zero | |
normalized_numeric_probs = numeric_probs_for_loss / numeric_probs_sum | |
else: | |
normalized_numeric_probs = torch.zeros_like(numeric_probs_for_loss) | |
loss_values_tensor = torch.arange(0, 10, dtype=torch.float32) | |
# Use normalized probabilities for NTL if only considering numeric tokens | |
if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6 : | |
pred_value = torch.sum( (probs_for_charts[:10]/torch.sum(probs_for_charts[:10])) * loss_values_tensor) | |
elif gt_choice_for_charts != "Text": # if sum is zero, pred_value is ill-defined or 0 | |
pred_value = torch.tensor(0.0) | |
else: # Should not happen if gt_numeric_for_charts is not None | |
pred_value = torch.tensor(float('nan')) | |
if not torch.isnan(pred_value): | |
ntl_mse_loss = (pred_value - float(gt_numeric_for_charts)) ** 2 | |
abs_diff = torch.abs(loss_values_tensor - float(gt_numeric_for_charts)) | |
if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6: | |
ntl_was_loss = torch.sum((probs_for_charts[:10]/torch.sum(probs_for_charts[:10])) * abs_diff) | |
elif gt_choice_for_charts != "Text": | |
ntl_was_loss = torch.tensor(0.0) # Or some other default if all numeric probs are zero | |
else: | |
ntl_was_loss = torch.tensor(float('nan')) | |
else: | |
ntl_mse_loss = torch.tensor(float('nan')) | |
ntl_was_loss = torch.tensor(float('nan')) | |
ce_val = round(ce_loss.item(), 3) | |
mse_val = round(ntl_mse_loss.item(), 3) if not torch.isnan(ntl_mse_loss) else "N/A" | |
was_val = round(ntl_was_loss.item(), 3) if not torch.isnan(ntl_was_loss) else "N/A" | |
loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]} | |
if was_val != "N/A": | |
loss_data["Loss"].append("NTL-WAS") | |
loss_data["Value"].append(was_val) | |
if mse_val != "N/A": | |
loss_data["Loss"].append("NTL-MSE") | |
loss_data["Value"].append(mse_val) | |
loss_df = pd.DataFrame(loss_data) | |
# ============== Chart Display ============== | |
# Create a single chart for loss visualization | |
st.subheader("Loss Comparison") | |
# Create an Altair chart that will look good and redraw cleanly | |
chart = alt.Chart(loss_df).mark_bar().encode( | |
x=alt.X('Loss:N', sort=loss_df["Loss"].tolist()), | |
y=alt.Y('Value:Q', scale=alt.Scale(domain=[0, max(loss_df["Value"].max() * 1.2, 20 if st.session_state.running_demo else 0.5)])), | |
color=alt.Color('Loss:N', scale=alt.Scale( | |
domain=['Cross Entropy', 'NTL-WAS', 'NTL-MSE'], | |
range=['steelblue', 'red', 'forestgreen'] | |
)), | |
tooltip=['Loss', 'Value'] | |
).properties( | |
height=300 | |
) | |
# Add value labels on top of bars | |
text = chart.mark_text( | |
align='center', | |
baseline='bottom', | |
dy=-5, | |
fontSize=14 | |
).encode( | |
text=alt.Text('Value:Q', format='.3f') | |
) | |
# Combine chart and text | |
final_chart = (chart + text) | |
# Display chart with the full container width | |
st.altair_chart(final_chart, use_container_width=True) | |
# --- Polling Rerun for Demo Mode --- | |
# If the demo is running and we haven't just advanced (which would have caused a rerun), | |
# then we do a short sleep and rerun to keep the polling loop alive. | |
if st.session_state.running_demo: | |
# This check is implicitly: if we are here and demo is running, it means | |
# the time-based advance condition was NOT met in the block at the top. | |
time.sleep(0.1) # Adjusted from 0.2 to 0.5 (or try 1.0) | |
st.rerun() | |
# Add explanation of the demonstration | |
st.markdown(""" | |
### What Does This Demo Show? | |
- **Cross Entropy Loss**: Only cares if the prediction is exactly right or wrong - it doesn't consider how "close" a numerical prediction is. | |
- **Number Token Loss (NTL)**: Considers numerical proximity - predicting "7" when the true value is "8" is better than predicting "2". | |
""") | |
# References / resources section with links (common to both modes) | |
st.markdown("### Resources") | |
st.markdown(""" | |
- [Paper: Number Token Loss (ArXiv)](https://arxiv.org/abs/2411.02083) | |
- [GitHub: Number Token Loss](https://github.com/tum-ai/number-token-loss) | |
""") | |