import altair as alt import pandas as pd import streamlit_vertical_slider as svs import torch # from streamlit_vertical_slider import vertical_slider # Not directly used, svs.vertical_slider is import streamlit as st import time import plotly.graph_objects as go # Add Plotly import # Define options globally as it's used in initialization and UI options = [str(i) for i in range(10)] + ["Text"] # --- Session State Initialization --- # Ensure all session state variables are initialized before first use, especially by widgets. if 'running_demo' not in st.session_state: st.session_state.running_demo = False if 'demo_step' not in st.session_state: st.session_state.demo_step = 0 if 'last_update_time' not in st.session_state: st.session_state.last_update_time = 0 if 'loss_container' not in st.session_state: st.session_state.loss_container = None if 'previous_chart_html' not in st.session_state: st.session_state.previous_chart_html = "" # Initialize states for sliders and ground_truth selector # Using len(options) to correctly size for 0-9 + "Text" for i in range(len(options)): if f"slider_{i}" not in st.session_state: st.session_state[f"slider_{i}"] = 1.0 / len(options) if 'ground_truth' not in st.session_state: st.session_state['ground_truth'] = options[0] # Default to "0" st.title("Number Token Loss - Demo") st.markdown(""" Adjust the sliders to set a predicted probability for each token (0-9 and "Text"). The sliders are vertical and compact. The app normalizes the slider values to form a valid probability distribution, visualizes it, and computes the corresponding Cross Entropy, NTL-MSE, and NTL-WAS losses. """) # --- Scenario Definitions --- scenarios = [ { "name": "Probability mass at 0", "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values "ground_truth": "0", "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." }, { "name": "Probability mass at 0", "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values "ground_truth": "1", "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." }, { "name": "Probability mass at 0", "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values "ground_truth": "2", "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." }, { "name": "Probability mass at 0", "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values "ground_truth": "3", "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." }, { "name": "Probability mass at 0", "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values "ground_truth": "4", "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." }, { "name": "Probability mass at 0", "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values "ground_truth": "5", "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." }, { "name": "Probability mass at 0", "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values "ground_truth": "6", "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." }, { "name": "Probability mass at 0", "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values "ground_truth": "7", "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." }, { "name": "Probability mass at 0", "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values "ground_truth": "8", "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." }, { "name": "Probability mass at 0", "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values "ground_truth": "9", "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." }, { "name": "Probability mass around 5", "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values "ground_truth": "0", "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." }, { "name": "Probability mass around 5", "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values "ground_truth": "1", "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." }, { "name": "Probability mass around 5", "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values "ground_truth": "2", "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." }, { "name": "Probability mass around 5", "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values "ground_truth": "3", "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." }, { "name": "Probability mass around 5", "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values "ground_truth": "4", "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth." }, { "name": "Probability mass around ground truth (5)", "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values "ground_truth": "5", "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth." }, { "name": "Probability mass around 5", "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values "ground_truth": "6", "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth." }, { "name": "Probability mass around 5", "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values "ground_truth": "7", "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth." }, { "name": "Probability mass around 5", "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values "ground_truth": "8", "explanation": "Cross Entropy is high, NTL is higher but still penalizes less than CE because distribution knows it's a number." }, { "name": "Probability mass around 5", "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values "ground_truth": "9", "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth." }, { "name": "Probability mass concentrated on 5", "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values "ground_truth": "0", "explanation": "Both CE and NTL are high because the prediction is far from correct." }, { "name": "Probability mass concentrated on 5", "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values "ground_truth": "1", "explanation": "Both CE and NTL are high because the prediction is far from correct." }, { "name": "Probability mass concentrated on 5", "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values "ground_truth": "2", "explanation": "Both CE and NTL are high because the prediction is far from correct." }, { "name": "Probability mass concentrated on 5", "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values "ground_truth": "3", "explanation": "Both CE and NTL are high because the prediction is far from correct." }, { "name": "Probability mass concentrated on 5", "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values "ground_truth": "4", "explanation": "Both CE and NTL are high because the prediction is far from correct." }, { "name": "Probability mass concentrated on 5", "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values "ground_truth": "5", "explanation": "Both CE and NTL are high because the prediction is far from correct." }, { "name": "Probability mass concentrated on 5", "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values "ground_truth": "6", "explanation": "Both CE and NTL are high because the prediction is far from correct." }, { "name": "Probability mass concentrated on 5", "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values "ground_truth": "7", "explanation": "Both CE and NTL are high because the prediction is far from correct." }, { "name": "Probability mass concentrated on 5", "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values "ground_truth": "8", "explanation": "Both CE and NTL are high because the prediction is far from correct." }, { "name": "Probability mass concentrated on 5", "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values "ground_truth": "9", "explanation": "Both CE and NTL are high because the prediction is far from correct." }, { "name": "Probability mass concentrated on 1", "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values "ground_truth": "0", "explanation": "Both losses are low because the prediction is correct." }, { "name": "Probability mass concentrated on 1", "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values "ground_truth": "1", "explanation": "Both losses are low because the prediction is correct." }, { "name": "Probability mass concentrated on 1", "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values "ground_truth": "2", "explanation": "Both losses are low because the prediction is correct." }, { "name": "Probability mass concentrated on 1", "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values "ground_truth": "3", "explanation": "Both losses are low because the prediction is correct." }, { "name": "Probability mass concentrated on 1", "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values "ground_truth": "4", "explanation": "Both losses are low because the prediction is correct." }, { "name": "Probability mass concentrated on 1", "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values "ground_truth": "5", "explanation": "Both losses are low because the prediction is correct." }, { "name": "Probability mass concentrated on 1", "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values "ground_truth": "6", "explanation": "Both losses are low because the prediction is correct." }, { "name": "Probability mass concentrated on 1", "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values "ground_truth": "7", "explanation": "Both losses are low because the prediction is correct." }, { "name": "Probability mass concentrated on 1", "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values "ground_truth": "8", "explanation": "Both losses are low because the prediction is correct." }, { "name": "Probability mass concentrated on 1", "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values "ground_truth": "9", "explanation": "Both losses are low because the prediction is correct." }, { "name": "Almost correct (1 vs 2)", "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values "ground_truth": "0", "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close." }, { "name": "Almost correct (1 vs 2)", "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values "ground_truth": "1", "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close." }, { "name": "Almost correct (1 vs 2)", "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values "ground_truth": "2", "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close." }, { "name": "Almost correct (1 vs 2)", "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values "ground_truth": "3", "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close." } ] # --- Helper Functions --- def apply_scenario(step_idx): scenario = scenarios[step_idx] # These assignments modify session state. They must be done *before* the widgets # are rendered in the script run that should display these new values. for i, val in enumerate(scenario["values"]): st.session_state[f"slider_{i}"] = val st.session_state['ground_truth'] = scenario["ground_truth"] def start_demo(): st.session_state.running_demo = True st.session_state.demo_step = 0 st.session_state.last_update_time = time.time() apply_scenario(0) # Apply the first scenario's state # The button click that calls start_demo() will itself cause a rerun. def stop_demo(): st.session_state.running_demo = False # --- Demo State Advancement Logic --- # This block handles advancing the demo. If it advances, it updates session state # and then reruns. This ensures widgets are drawn with the new state in the next run. if st.session_state.running_demo: current_time = time.time() if current_time - st.session_state.last_update_time > 3.0: # 3 seconds per scenario next_step = (st.session_state.demo_step + 1) % len(scenarios) st.session_state.demo_step = next_step apply_scenario(next_step) # Update session state for the new scenario st.session_state.last_update_time = time.time() # Reset timer st.rerun() # Crucial: Rerun to reflect changes in widgets and charts # --- UI Rendering --- # This section renders the main UI. It executes after any potential rerun from the block above. if st.session_state.running_demo: st.info(f"Showing scenario {st.session_state.demo_step + 1}/{len(scenarios)}: {scenarios[st.session_state.demo_step]['name']}") st.markdown(f"**Explanation:** {scenarios[st.session_state.demo_step]['explanation']}") if st.button("Stop Demo"): stop_demo() st.rerun() else: # Not st.session_state.running_demo if st.button("Start Automated Demo"): start_demo() # This calls apply_scenario(0) st.rerun() # Rerun to enter demo mode and draw scenario 0 correctly # Sliders and Ground Truth Selector # These widgets will read their initial values from st.session_state. # User interactions will update st.session_state directly due to their keys. if not st.session_state.running_demo: st.markdown("#### Predicted Token Probabilities") cols = st.columns(len(options)) for i, col in enumerate(cols): label = options[i] # Use token name directly for label with col: svs.vertical_slider( label=label, min_value=0.0, max_value=1.0, step=0.01, height=50, key=f"slider_{i}", # This key links the widget to st.session_state[f"slider_{i}"] slider_color="green", track_color="lightgray", thumb_color="black" ) # Ground truth selectbox st.selectbox( "Ground Truth Token", options=options, index=options.index(st.session_state['ground_truth']), # Display value from session state key='ground_truth' # Links widget to st.session_state['ground_truth'] ) # Placeholder for charts and loss calculations that will be updated # This section always reads the current st.session_state to generate its content. current_prob_values_from_state = [st.session_state.get(f"slider_{j}", 1.0/len(options)) for j in range(len(options))] total_from_state = sum(current_prob_values_from_state) probs_for_charts = ( torch.ones(len(options)) / len(options) if total_from_state == 0 else torch.tensor([v / total_from_state for v in current_prob_values_from_state]) ) gt_choice_for_charts = st.session_state.get('ground_truth', options[0]) if gt_choice_for_charts == "Text": gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10) gt_numeric_for_charts = None else: gt_index_for_charts = int(gt_choice_for_charts) gt_numeric_for_charts = gt_index_for_charts st.markdown("#### Input Probability Distribution") df_dist = pd.DataFrame({"token": options, "probability": probs_for_charts.numpy()}) df_dist["type"] = ["Ground Truth" if token == gt_choice_for_charts else "Prediction" for token in options] chart = ( alt.Chart(df_dist).mark_bar().encode( x=alt.X("token:N", title="Token", sort=options), # Ensure consistent sort order y=alt.Y("probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1])), color=alt.Color("type:N", scale=alt.Scale(domain=["Ground Truth", "Prediction"], range=["green", "steelblue"]), legend=alt.Legend(title="Token Type")) ).properties(height=300) ) st.altair_chart(chart, use_container_width=True) ce_loss = -torch.log(torch.clamp(probs_for_charts[gt_index_for_charts], min=1e-9)) if gt_numeric_for_charts is None: # Text token ntl_mse_loss = torch.tensor(float('nan')) # MSE not applicable for text ntl_was_loss = torch.tensor(float('nan')) # WAS not applicable for text else: # Numeric token numeric_probs_for_loss = probs_for_charts[:10] # Probabilities for 0-9 # Ensure numeric_probs_for_loss sums to 1 for NTL calculations if it's a subset numeric_probs_sum = torch.sum(numeric_probs_for_loss) if numeric_probs_sum > 1e-6 : # Avoid division by zero normalized_numeric_probs = numeric_probs_for_loss / numeric_probs_sum else: normalized_numeric_probs = torch.zeros_like(numeric_probs_for_loss) loss_values_tensor = torch.arange(0, 10, dtype=torch.float32) # Use normalized probabilities for NTL if only considering numeric tokens if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6 : pred_value = torch.sum( (probs_for_charts[:10]/torch.sum(probs_for_charts[:10])) * loss_values_tensor) elif gt_choice_for_charts != "Text": # if sum is zero, pred_value is ill-defined or 0 pred_value = torch.tensor(0.0) else: # Should not happen if gt_numeric_for_charts is not None pred_value = torch.tensor(float('nan')) if not torch.isnan(pred_value): ntl_mse_loss = (pred_value - float(gt_numeric_for_charts)) ** 2 abs_diff = torch.abs(loss_values_tensor - float(gt_numeric_for_charts)) if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6: ntl_was_loss = torch.sum((probs_for_charts[:10]/torch.sum(probs_for_charts[:10])) * abs_diff) elif gt_choice_for_charts != "Text": ntl_was_loss = torch.tensor(0.0) # Or some other default if all numeric probs are zero else: ntl_was_loss = torch.tensor(float('nan')) else: ntl_mse_loss = torch.tensor(float('nan')) ntl_was_loss = torch.tensor(float('nan')) ce_val = round(ce_loss.item(), 3) mse_val = round(ntl_mse_loss.item(), 3) if not torch.isnan(ntl_mse_loss) else "N/A" was_val = round(ntl_was_loss.item(), 3) if not torch.isnan(ntl_was_loss) else "N/A" loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]} if was_val != "N/A": loss_data["Loss"].append("NTL-WAS") loss_data["Value"].append(was_val) if mse_val != "N/A": loss_data["Loss"].append("NTL-MSE") loss_data["Value"].append(mse_val) loss_df = pd.DataFrame(loss_data) # ============== Chart Display ============== # Create a single chart for loss visualization st.subheader("Loss Comparison") # Create an Altair chart that will look good and redraw cleanly chart = alt.Chart(loss_df).mark_bar().encode( x=alt.X('Loss:N', sort=loss_df["Loss"].tolist()), y=alt.Y('Value:Q', scale=alt.Scale(domain=[0, max(loss_df["Value"].max() * 1.2, 20 if st.session_state.running_demo else 0.5)])), color=alt.Color('Loss:N', scale=alt.Scale( domain=['Cross Entropy', 'NTL-WAS', 'NTL-MSE'], range=['steelblue', 'red', 'forestgreen'] )), tooltip=['Loss', 'Value'] ).properties( height=300 ) # Add value labels on top of bars text = chart.mark_text( align='center', baseline='bottom', dy=-5, fontSize=14 ).encode( text=alt.Text('Value:Q', format='.3f') ) # Combine chart and text final_chart = (chart + text) # Display chart with the full container width st.altair_chart(final_chart, use_container_width=True) # --- Polling Rerun for Demo Mode --- # If the demo is running and we haven't just advanced (which would have caused a rerun), # then we do a short sleep and rerun to keep the polling loop alive. if st.session_state.running_demo: # This check is implicitly: if we are here and demo is running, it means # the time-based advance condition was NOT met in the block at the top. time.sleep(0.1) # Adjusted from 0.2 to 0.5 (or try 1.0) st.rerun() # Add explanation of the demonstration st.markdown(""" ### What Does This Demo Show? - **Cross Entropy Loss**: Only cares if the prediction is exactly right or wrong - it doesn't consider how "close" a numerical prediction is. - **Number Token Loss (NTL)**: Considers numerical proximity - predicting "7" when the true value is "8" is better than predicting "2". """) # References / resources section with links (common to both modes) st.markdown("### Resources") st.markdown(""" - [Paper: Number Token Loss (ArXiv)](https://arxiv.org/abs/2411.02083) - [GitHub: Number Token Loss](https://github.com/tum-ai/number-token-loss) """)