import time import altair as alt import numpy as np import pandas as pd import streamlit as st import streamlit_vertical_slider as svs import torch from scenarios import bimodal, dirac, gauss DEMO_INTERVAL = 1.5 NTL_MSE_SCALING = 0.5 MAX_LOSS_PLOT = 15 LAST_STEP = -1 # """TODO: # - Remove flickering of loss evolution scenario plot (lower ylim?) # - Move manual part down (predicted token probabilities) # - Allow to set GT token for each demo # - Add text token to loss evolution barplot # - pick good default (4?) # """ # Define options globally as it's used in initialization and UI options = [str(i) for i in range(10)] + ["Text"] # --- Session State Initialization --- # Ensure all session state variables are initialized before first use, especially by widgets. if "running_demo" not in st.session_state: st.session_state.running_demo = False if "demo_step" not in st.session_state: st.session_state.demo_step = 0 if "last_update_time" not in st.session_state: st.session_state.last_update_time = 0 if "loss_container" not in st.session_state: st.session_state.loss_container = None if "previous_chart_html" not in st.session_state: st.session_state.previous_chart_html = "" if "active_scenarios" not in st.session_state: # default if you want one to load on first show st.session_state.active_scenarios = dirac if "loss_history" not in st.session_state: st.session_state.loss_history = [] # Initialize states for sliders and ground_truth selector # Using len(options) to correctly size for 0-9 + "Text" for i in range(len(options)): if f"slider_{i}" not in st.session_state: st.session_state[f"slider_{i}"] = 1.0 / len(options) if "ground_truth" not in st.session_state: st.session_state["ground_truth"] = options[0] # Default to "0" st.title("Number Token Loss - Demo") st.markdown( """ **Instructions** 1. **Pick a ground truth token (0–9).** 2. **Select one of the three automated demos:** - **Dirac**: a one-hot (Dirac) distribution whose single 1.0 mass moves from token 0 all the way to “Text.” - **Gaussian**: a peaked Gaussian (0.6 mass at center, 0.4 spread) that slides its center from token 0 to “Text.” - **Bimodal**: two equal peaks (0.5 each) that start at (0,8) and then move symmetrically away from the GT token. """ ) if "ground_truth" not in st.session_state: st.session_state["ground_truth"] = "4" gt = st.selectbox( "Ground Truth Token", options=options, index=options.index(st.session_state["ground_truth"]), key="ground_truth", ) def apply_scenario(step_idx): scenario = st.session_state.active_scenarios[step_idx] for i, val in enumerate(scenario["values"]): st.session_state[f"slider_{i}"] = val def start_dirac_demo(): st.session_state.active_scenarios = dirac st.session_state.running_demo = True st.session_state.demo_step = 0 st.session_state.last_update_time = time.time() apply_scenario(0) def start_gauss_demo(): st.session_state.active_scenarios = gauss st.session_state.running_demo = True st.session_state.demo_step = 0 st.session_state.last_update_time = time.time() apply_scenario(0) def start_bimodal_demo(): st.session_state.active_scenarios = bimodal st.session_state.running_demo = True st.session_state.demo_step = 0 st.session_state.last_update_time = time.time() apply_scenario(0) def stop_demo(): st.session_state.running_demo = False # --- Demo State Advancement Logic --- # This block handles advancing the demo. If it advances, it updates session state # and then reruns. This ensures widgets are drawn with the new state in the next run. if st.session_state.running_demo: scenario = st.session_state.active_scenarios current_time = time.time() if current_time - st.session_state.last_update_time > DEMO_INTERVAL: next_step = (st.session_state.demo_step + 1) % len(scenario) st.session_state.demo_step = next_step apply_scenario(next_step) # Update session state for the new scenario st.session_state.last_update_time = time.time() # Reset timer st.rerun() # Crucial: Rerun to reflect changes in widgets and charts # --- UI Rendering --- # This section renders the main UI. It executes after any potential rerun from the block above. if st.session_state.running_demo: st.info( f"Showing scenario {st.session_state.demo_step + 1}" f"/{len(st.session_state.active_scenarios)}: " f"{st.session_state.active_scenarios[st.session_state.demo_step]['name']}" ) if st.button("Stop Demo"): st.session_state.running_demo = False st.rerun() else: col1, col2, col3 = st.columns(3) with col1: if st.button("Run: Dirac"): start_dirac_demo() st.rerun() with col2: if st.button("Run: Gauss"): start_gauss_demo() st.rerun() with col3: if st.button("Run: Bimodal"): start_bimodal_demo() st.rerun() # Placeholder for charts and loss calculations that will be updated # This section always reads the current st.session_state to generate its content. current_prob_values_from_state = [ st.session_state.get(f"slider_{j}", 1.0 / len(options)) for j in range(len(options)) ] total_from_state = sum(current_prob_values_from_state) probs_for_charts = ( torch.ones(len(options)) / len(options) if total_from_state == 0 else torch.tensor([v / total_from_state for v in current_prob_values_from_state]) ) gt_choice_for_charts = st.session_state.get("ground_truth", options[0]) if gt_choice_for_charts == "Text": gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10) gt_numeric_for_charts = None else: gt_index_for_charts = int(gt_choice_for_charts) gt_numeric_for_charts = gt_index_for_charts gt = st.session_state["ground_truth"] st.markdown(f"#### Predicted Probability Distribution — Ground truth token {gt}") df_dist = pd.DataFrame( {"token": options, "probability": probs_for_charts.numpy().round(2)} ) df_dist["type"] = [ "Ground Truth" if token == gt_choice_for_charts else "Prediction" for token in options ] bg = ( alt.Chart(pd.DataFrame({"token": [gt]})) .mark_bar(size=40, color="lightgray", opacity=0.4) .encode( x=alt.X("token:N", sort=options), x2=alt.X2("token:N"), # pin the right edge to the same category y=alt.value(0), # bottom at y=0 y2=alt.value(1), # top at y=1 (full height) ) ) bars = ( alt.Chart(df_dist) .mark_bar() .encode( x=alt.X( "token:N", title="Token", sort=options, axis=alt.Axis(labelAngle=0, labelFontSize=14, titleFontSize=16), ), y=alt.Y( "probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1]), axis=alt.Axis(format=".2f", labelFontSize=14, titleFontSize=16), ), color=alt.Color( "type:N", scale=alt.Scale( domain=["Ground Truth", "Prediction"], range=["green", "steelblue"] ), legend=alt.Legend(title="Token Type", titleFontSize=16, labelFontSize=14), ), tooltip=[ alt.Tooltip("token:N", title="Token"), alt.Tooltip("probability:Q", title="Probability", format=".2f"), alt.Tooltip("type:N", title="Type"), ], ) .properties(height=300) ) annot1 = ( alt.Chart(pd.DataFrame({"token": [gt]})) .mark_text( text="⬇ Ground", dy=-25, # 10px above the top of the bar dx=25, fontSize=14, fontWeight="bold", color="green", ) .encode(x=alt.X("token:N", sort=options), y=alt.value(1)) ) # second line: “truth=4” annot2 = ( alt.Chart(pd.DataFrame({"token": [gt]})) .mark_text( text=f"truth={gt}", dy=-10, # 25px above the top, so it sits above line 1 dx=35, fontSize=14, fontWeight="bold", color="green", ) .encode(x=alt.X("token:N", sort=options), y=alt.value(1)) ) # 4) Layer them in order: background, bars, annotation final_chart = (bg + bars + annot1 + annot2).properties(height=300) st.altair_chart(final_chart, use_container_width=True) ce_loss = -torch.log(torch.clamp(probs_for_charts[gt_index_for_charts], min=1e-9)) if gt_numeric_for_charts is None: # Text token ntl_mse_loss = torch.tensor(float("nan")) # MSE not applicable for text ntl_was_loss = torch.tensor(float("nan")) # WAS not applicable for text else: # Numeric token numeric_probs_for_loss = probs_for_charts[:10] # Probabilities for 0-9 # Ensure numeric_probs_for_loss sums to 1 for NTL calculations if it's a subset numeric_probs_sum = torch.sum(numeric_probs_for_loss) if numeric_probs_sum > 1e-6: # Avoid division by zero normalized_numeric_probs = numeric_probs_for_loss / numeric_probs_sum else: normalized_numeric_probs = torch.zeros_like(numeric_probs_for_loss) loss_values_tensor = torch.arange(0, 10, dtype=torch.float32) # Use normalized probabilities for NTL if only considering numeric tokens if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6: pred_value = torch.sum( (probs_for_charts[:10] / torch.sum(probs_for_charts[:10])) * loss_values_tensor ) elif ( gt_choice_for_charts != "Text" ): # if sum is zero, pred_value is ill-defined or 0 pred_value = torch.tensor(0.0) else: # Should not happen if gt_numeric_for_charts is not None pred_value = torch.tensor(float("nan")) if not torch.isnan(pred_value): ntl_mse_loss = ntl_mse_loss = ( NTL_MSE_SCALING * (pred_value - float(gt_numeric_for_charts)) ** 2 ) abs_diff = torch.abs(loss_values_tensor - float(gt_numeric_for_charts)) if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6: ntl_was_loss = torch.sum( (probs_for_charts[:10] / torch.sum(probs_for_charts[:10])) * abs_diff ) elif gt_choice_for_charts != "Text": ntl_was_loss = torch.tensor(0.0) else: ntl_was_loss = torch.tensor(float("nan")) else: ntl_mse_loss = torch.tensor(float("nan")) ntl_was_loss = torch.tensor(float("nan")) ce_val = round(ce_loss.item(), 3) mse_val = round(ntl_mse_loss.item(), 3) if not torch.isnan(ntl_mse_loss) else "N/A" was_val = round(ntl_was_loss.item(), 3) if not torch.isnan(ntl_was_loss) else "N/A" if len(st.session_state.loss_history) < st.session_state.demo_step + 1: st.session_state.loss_history.append( { "token_index": np.argmax( st.session_state.active_scenarios[st.session_state["demo_step"]][ "values" ] ), # int(np.argmax(st.session_state['values'])) # int(), "CE": ce_val, "NTL-MSE": mse_val if mse_val != "N/A" else None, "NTL-WAS": was_val if was_val != "N/A" else None, } ) last_step = st.session_state.demo_step if st.session_state.loss_history: loss_plot_data = [] for entry in st.session_state.loss_history: for loss_type in ["CE", "NTL-MSE", "NTL-WAS"]: if entry[loss_type] is not None: loss_plot_data.append( { "Token Index": entry["token_index"], "Loss Type": loss_type, "Loss Value": entry[loss_type], # TODO: clip to MAX_LOSS_PLOT? } ) df_loss_plot = pd.DataFrame(loss_plot_data) loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]} if was_val != "N/A": loss_data["Loss"].append("NTL-WAS") loss_data["Value"].append(was_val) if mse_val != "N/A": loss_data["Loss"].append("NTL-MSE") loss_data["Value"].append(mse_val) loss_df = pd.DataFrame(loss_data) # ============== Chart Display ============== st.subheader("Loss Evolution Over Scenarios") x_domain = list(range(10)) grouped_chart = ( alt.Chart(df_loss_plot) .mark_bar() .encode( x=alt.X( "Token Index:O", title="Predicted Token Index", axis=alt.Axis(labelAngle=0), scale=alt.Scale(domain=x_domain), ), y=alt.Y( "Loss Value:Q", title="Loss", scale=alt.Scale(domain=[0, MAX_LOSS_PLOT]) ), color=alt.Color("Loss Type:N", legend=alt.Legend(title="Loss")), xOffset="Loss Type:N", # <== this causes the grouping instead of stacking ) .properties(height=300) ) st.altair_chart(grouped_chart, use_container_width=True) # Create a single chart for loss visualization st.subheader("Loss Comparison") st.markdown(""" Adjust the sliders to set a predicted probability for each token (0-9 and "Text"). The sliders are vertical and compact. The app normalizes the slider values to form a valid probability distribution, visualizes it, and computes the corresponding Cross Entropy, NTL-MSE, and NTL-WAS losses. """) # Create an Altair chart that will look good and redraw cleanly chart = ( alt.Chart(loss_df) .mark_bar() .encode( x=alt.X("Loss:N", sort=loss_df["Loss"].tolist()), y=alt.Y( "Value:Q", scale=alt.Scale( domain=[ 0, max( loss_df["Value"].max() * 1.2, 20 if st.session_state.running_demo else 0.5, ), ] ), ), color=alt.Color( "Loss:N", scale=alt.Scale( domain=["Cross Entropy", "NTL-WAS", "NTL-MSE"], range=["steelblue", "red", "forestgreen"], ), ), tooltip=["Loss", "Value"], ) .properties(height=300) ) # Sliders and Ground Truth Selector # These widgets will read their initial values from st.session_state. # User interactions will update st.session_state directly due to their keys. if not st.session_state.running_demo: st.markdown("#### Predicted Token Probabilities") cols = st.columns(len(options)) for i, col in enumerate(cols): label = options[i] # Use token name directly for label with col: svs.vertical_slider( label=label, min_value=0.0, max_value=1.0, step=0.01, height=50, key=f"slider_{i}", # This key links the widget to st.session_state[f"slider_{i}"] slider_color="green", track_color="lightgray", thumb_color="black", ) # Add value labels on top of bars text = chart.mark_text(align="center", baseline="bottom", dy=-5, fontSize=14).encode( text=alt.Text("Value:Q", format=".3f") ) # Combine chart and text final_chart = chart + text # Display chart with the full container width st.altair_chart(final_chart, use_container_width=True) # --- Polling Rerun for Demo Mode --- # If the demo is running and we haven't just advanced (which would have caused a rerun), # then we do a short sleep and rerun to keep the polling loop alive. if st.session_state.running_demo: # This check is implicitly: if we are here and demo is running, it means # the time-based advance condition was NOT met in the block at the top. time.sleep(0.1) st.rerun() # Add explanation of the demonstration st.markdown(""" ### What Does This Demo Show? - **Cross Entropy Loss**: Only cares if the prediction is exactly right or wrong - it doesn't consider how "close" a numerical prediction is. - **Number Token Loss (NTL)**: Considers numerical proximity - predicting "7" when the true value is "8" is better than predicting "2". """) # References / resources section with links (common to both modes) st.markdown("### Resources") st.markdown(""" - [Paper: Number Token Loss (ArXiv)](https://arxiv.org/abs/2411.02083) - [GitHub: Number Token Loss](https://github.com/tum-ai/number-token-loss) """)