import logging import time import altair as alt import numpy as np import pandas as pd import streamlit as st import streamlit_vertical_slider as svs import torch from scenarios import dirac, gauss, make_bimodal_scenarios logging.getLogger("streamlit.watcher.local_sources_watcher").setLevel(logging.ERROR) DEMO_INTERVAL = 1.5 CE_SCALING = 0.25 MAX_LOSS_PLOT = 6 LAST_STEP = -1 # Define options globally as it's used in initialization and UI options = [str(i) for i in range(10)] + ["Text"] def compute_losses(probs: torch.Tensor, gt_token: str) -> tuple[float, float, float]: """Compute CE, NTL-MAE, NTL-WAS losses for the given probability vector and ground truth token.""" ce_loss = CE_SCALING * -torch.log( torch.clamp(probs[options.index(gt_token)], min=1e-9) ) numeric_mass = probs[:10].sum() if gt_token == "Text" or numeric_mass < 1e-6: return ce_loss.item(), 0.0, 0.0 gt_numeric = int(gt_token) token_vals = torch.arange(10, dtype=torch.float32) mae = numeric_mass * abs(torch.dot(token_vals, probs[:10]) - gt_numeric) was = numeric_mass * torch.dot(probs[:10], torch.abs(token_vals - gt_numeric)) return round(ce_loss.item(), 3), round(mae.item(), 3), round(was.item(), 3) # --- Session State Initialization --- # Ensure all session state variables are initialized before first use, especially by widgets. if "running_demo" not in st.session_state: st.session_state.running_demo = False if "demo_step" not in st.session_state: st.session_state.demo_step = 0 if "last_update_time" not in st.session_state: st.session_state.last_update_time = 0 if "loss_container" not in st.session_state: st.session_state.loss_container = None if "previous_chart_html" not in st.session_state: st.session_state.previous_chart_html = "" if "active_scenarios" not in st.session_state: # default if you want one to load on first show st.session_state.active_scenarios = dirac if "loss_history" not in st.session_state: st.session_state.loss_history = [] # Initialize states for sliders and ground_truth selector # Using len(options) to correctly size for 0-9 + "Text" for i in range(len(options)): if f"slider_{i}" not in st.session_state: st.session_state[f"slider_{i}"] = 0 if "ground_truth" not in st.session_state: st.session_state["ground_truth"] = options[5] if "manual_ground_truth" not in st.session_state: st.session_state["manual_ground_truth"] = options[5] if "demo_name" not in st.session_state: st.session_state["demo_name"] = "Dirac" st.title("NTL -- The Number Token Loss ๐Ÿš€") st.markdown( """This is the interactive demo for our [ICML 2025](https://arxiv.org/abs/2411.02083) paper!๐ŸŽ‰ โžก๏ธ NTL augments cross-entropy to help LMs reason better with numbers ๐Ÿง  """ ) st.subheader("Demo 1 โ€” NTL vs. Cross Entropy in 3 Scenarios") st.markdown(""" 1๏ธโƒฃ Pick a ground truth token: a digit (0โ€“9) or "Text" ๐Ÿ“ (simulates generic text tokens). 2๏ธโƒฃ Choose a demo: - **Dirac** โšก: All probability mass on one token. - **Gaussian** ๐ŸŒŠ: Soft bell-curve around the true number. - **Bimodal** ๐ŸŽฏ: Two peaks moving away from the target. Watch how losses evolve as predictions get worse โ€” and see how NTL shines compared to CE! ๐ŸŒŸ """) if "ground_truth" not in st.session_state: st.session_state["ground_truth"] = "4" gt = st.selectbox("Ground Truth Token", options=options, key="ground_truth") def apply_scenario(step_idx): scenario = st.session_state.active_scenarios[step_idx] for i, val in enumerate(scenario["values"]): st.session_state[f"slider_{i}"] = val def start_dirac_demo(): st.session_state.loss_history = [] st.session_state.active_scenarios = dirac st.session_state.demo_name = "Dirac" st.session_state.running_demo = True st.session_state.demo_step = 0 st.session_state.last_update_time = time.time() apply_scenario(0) def start_gauss_demo(): st.session_state.loss_history = [] st.session_state.active_scenarios = gauss st.session_state.demo_name = "Gauss" st.session_state.running_demo = True st.session_state.demo_step = 0 st.session_state.last_update_time = time.time() apply_scenario(0) def start_bimodal_demo(): st.session_state.loss_history = [] gt = st.session_state["ground_truth"] st.session_state.active_scenarios = make_bimodal_scenarios(gt, options) st.session_state.demo_name = f"Bimodal (GT={gt})" st.session_state.running_demo = True st.session_state.demo_step = 0 st.session_state.last_update_time = time.time() apply_scenario(0) def stop_demo(): st.session_state.running_demo = False # --- Demo State Advancement Logic --- # This block handles advancing the demo. If it advances, it updates session state # and then reruns. This ensures widgets are drawn with the new state in the next run. if st.session_state.running_demo: scenario = st.session_state.active_scenarios current_time = time.time() if current_time - st.session_state.last_update_time > DEMO_INTERVAL: # if we havenโ€™t yet shown the last scenario, advance if st.session_state.demo_step < len(scenario) - 1: st.session_state.demo_step += 1 apply_scenario(st.session_state.demo_step) st.session_state.last_update_time = current_time st.rerun() else: # we just displayed the final case โ†’ stop st.session_state.running_demo = False # --- UI Rendering --- # This section renders the main UI. It executes after any potential rerun from the block above. if st.session_state.running_demo: st.info( f"Showing scenario {st.session_state.demo_step + 1}" f"/{len(st.session_state.active_scenarios)}: " f"{st.session_state.active_scenarios[st.session_state.demo_step]['name']}" ) if st.button("Stop Demo"): st.session_state.running_demo = False st.rerun() else: col1, col2, col3 = st.columns(3) with col1: if st.button("Run: Dirac"): start_dirac_demo() st.rerun() with col2: if st.button("Run: Gauss"): start_gauss_demo() st.rerun() with col3: if st.button("Run: Bimodal"): start_bimodal_demo() st.rerun() current_prob_values_from_state = [ st.session_state.get(f"slider_{j}", 0) for j in range(len(options)) # 1.0 / len(options)) for j in range(len(options)) ] total_from_state = sum(current_prob_values_from_state) probs_for_charts = ( torch.ones(len(options)) / len(options) if total_from_state == 0 else torch.tensor([v / total_from_state for v in current_prob_values_from_state]) ) # Use manual GT token when not in running demo gt_choice_for_charts = ( st.session_state["manual_ground_truth"] if not st.session_state.running_demo else st.session_state["ground_truth"] ) if gt_choice_for_charts == "Text": gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10) gt_numeric_for_charts = None else: gt_index_for_charts = int(gt_choice_for_charts) gt_numeric_for_charts = gt_index_for_charts gt = st.session_state["ground_truth"] demo_name = st.session_state["demo_name"] st.markdown(f"#### Predicted distribution โ€” ground truth: {gt}") df_dist = pd.DataFrame( {"token": options, "probability": probs_for_charts.numpy().round(2)} ) df_dist["type"] = [ "Ground Truth" if token == gt_choice_for_charts else "Prediction" for token in options ] bars = ( alt.Chart(df_dist) .mark_bar(color="dodgerblue", size=40) .encode( x=alt.X( "token:N", title="Token", sort=options, axis=alt.Axis( labelAngle=0, labelFontSize=14, titleFontSize=16, labelAlign="center", labelFlush=False, ), ), y=alt.Y( "probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1]), axis=alt.Axis(format=".2f", labelFontSize=14, titleFontSize=16), ), tooltip=[ alt.Tooltip("token:N", title="Token"), alt.Tooltip("probability:Q", title="Predicted Prob.", format=".2f"), ], ) ) bg_bar = pd.DataFrame({"token": [gt], "height": [1.0]}) gt_bar = ( alt.Chart(bg_bar) .mark_bar( color="darkgreen", size=20, opacity=0.3, stroke="gray", strokeWidth=2, strokeDash=[4, 4], ) .encode( x=alt.X("token:N", sort=options), y=alt.Y("height:Q", scale=alt.Scale(domain=[0, 1])), tooltip=[ alt.Tooltip("token:N", title="Ground Truth"), alt.Tooltip("height:Q", title="Desired mass", format=".2f"), ], ) ) annot1 = ( alt.Chart(pd.DataFrame({"token": [gt]})) .mark_text( text="โฌ‡ Ground", dy=-25, # 10px above the top of the bar dx=25, fontSize=14, fontWeight="bold", color="darkgreen", ) .encode(x=alt.X("token:N", sort=options), y=alt.value(1)) ) annot2 = ( alt.Chart(pd.DataFrame({"token": [gt]})) .mark_text( text=f"truth={gt}", dy=-10, # 25px above the top, so it sits above line 1 dx=35, fontSize=14, fontWeight="bold", color="darkgreen", ) .encode(x=alt.X("token:N", sort=options), y=alt.value(1)) ) # 4) Layer them in order: background, bars, annotation final_chart = (gt_bar + bars + annot1 + annot2).properties(height=200) st.altair_chart(final_chart, use_container_width=True) ce_val, mae_val, was_val = compute_losses(probs_for_charts, gt_choice_for_charts) if ( st.session_state.running_demo and len(st.session_state.loss_history) < st.session_state.demo_step + 1 ): step = st.session_state.demo_step scenario = st.session_state.active_scenarios[step] ce, mae, was = compute_losses(probs_for_charts, gt_choice_for_charts) # pick x_val differently for bimodal vs others if st.session_state.demo_name.startswith("Bimodal"): x_val = scenario["name"] # e.g. "(4,4)", "(3,5)", โ€ฆ else: # exactly like before: best_idx = np.argmax(scenario["values"]) x_val = options[best_idx] # "0", "1", โ€ฆ, or "Text" st.session_state.loss_history.append( { "step": step, "x_val": x_val, "Cross Entropy": ce, "NTL-MAE": mae, "NTL-WAS": was, } ) # 1) build a raw DF from histories df = pd.DataFrame(st.session_state.loss_history) if df.empty: # define an empty "melted" DataFrame with the right columns df_loss_plot = pd.DataFrame(columns=["step", "x_val", "Loss Type", "Loss Value"]) else: # now it's safe to melt df_loss_plot = df.melt( id_vars=["step", "x_val"], value_vars=["Cross Entropy", "NTL-MAE", "NTL-WAS"], var_name="Loss Type", value_name="Loss Value", ) loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]} if was_val != "N/A": loss_data["Loss"].append("NTL-WAS") loss_data["Value"].append(was_val) if mae_val != "N/A": loss_data["Loss"].append("NTL-MAE") loss_data["Value"].append(mae_val) loss_df = pd.DataFrame(loss_data) if st.session_state.demo_name.startswith("Bimodal"): domain = [sc["name"] for sc in st.session_state.active_scenarios] x_title = f"Offset from GT {st.session_state['ground_truth']}" else: domain = options x_title = f"Maximum of predicted {st.session_state['demo_name']} distribution" # ============== Chart Display ============== st.markdown("#### Loss as a function of predicted distribution") grouped_chart = ( alt.Chart(df_loss_plot) .mark_bar() .encode( x=alt.X( "x_val:N", title=x_title, sort=domain, scale=alt.Scale(domain=domain), axis=alt.Axis(labelAngle=0, labelFontSize=14, titleFontSize=16), ), y=alt.Y( "Loss Value:Q", title="Loss Value", scale=alt.Scale(domain=[0, MAX_LOSS_PLOT], nice=False, clamp=True), axis=alt.Axis(labelFontSize=14, titleFontSize=16), ), color=alt.Color( "Loss Type:N", scale=alt.Scale( domain=["Cross Entropy", "NTL-WAS", "NTL-MAE"], range=["red", "limegreen", "blueviolet"], ), legend=alt.Legend( title="", orient="top", direction="horizontal", columns=3, ), ), xOffset="Loss Type:N", # grouped bars tooltip=[ alt.Tooltip("x_val:N", title="Scenario"), alt.Tooltip("Loss Type:N", title="Loss Type"), alt.Tooltip("Loss Value:Q", title="Value", format=".3f"), ], ) .properties(height=250) ) st.altair_chart(grouped_chart, use_container_width=True) # Create a single chart for loss visualization if not st.session_state.running_demo: for i in range(len(options)): st.session_state[f"slider_{i}"] = 0.0 st.session_state.demo_step = 0 st.subheader("Demo 2 -- Manual loss comparison") st.subheader("๐Ÿงช Demo 2 โ€” Craft your own distribution") st.markdown(""" This demo gives you more control but is harder to interpret. See it as a playground! ๐ŸŽจ Manually adjust the sliders to change the predicted probabilities for each token. The demo normalizes the values to form a valid probability distribution and calculates the losses. ๐Ÿ‘ฃ **Steps:** - Use the **vertical sliders** to allocate probability to each token. - Choose the correct **Ground Truth Token** (0โ€“9 or "Text" ๐Ÿ“œ). - Observe how each loss function reacts. ๐Ÿ’ก **Tip:** Want to trick the loss? Try putting all mass on the wrong token or spread it wildly. See how NTL handles it! ๐Ÿ˜ˆ """) manual_gt = st.selectbox( "Ground Truth Token", options=options, key="manual_ground_truth", ) loss_df = pd.DataFrame( { "Loss": ["Cross Entropy", "NTL-MAE", "NTL-WAS"], "Value": [ce_val, mae_val, was_val], } ) # Sliders and Ground Truth Selector # These widgets will read their initial values from st.session_state. # User interactions will update st.session_state directly due to their keys. st.markdown("#### Adjust the predicted token probability") cols = st.columns(len(options)) for i, col in enumerate(cols): label = options[i] # Use token name directly for label with col: svs.vertical_slider( label=label, min_value=0.0, max_value=1.0, step=0.01, height=50, key=f"slider_{i}", slider_color="green", track_color="lightgray", thumb_color="black", ) chart = ( alt.Chart(loss_df) .mark_bar() .encode( x=alt.X("Loss:N", sort=loss_df["Loss"].tolist()), y=alt.Y( "Value:Q", scale=alt.Scale( domain=[ 0, max( loss_df["Value"].max() * 1.2, 20 if st.session_state.running_demo else 0.5, ), ] ), ), color=alt.Color( "Loss:N", scale=alt.Scale( domain=["Cross Entropy", "NTL-WAS", "NTL-MAE"], range=["orangered", "limegreen", "blueviolet"], ), ), tooltip=["Loss", "Value"], ) .properties(height=300) ) text = chart.mark_text( align="center", baseline="bottom", dy=-5, fontSize=14 ).encode(text=alt.Text("Value:Q", format=".3f")) final_chart = chart + text st.altair_chart(final_chart, use_container_width=True) # # Add value labels on top of bars # text = chart.mark_text(align="center", baseline="bottom", dy=-5, fontSize=14).encode( # text=alt.Text("Value:Q", format=".3f") # ) # # Combine chart and text # final_chart = chart + text # Display chart with the full container width # st.altair_chart(final_chart, use_container_width=True) # --- Polling Rerun for Demo Mode --- # If the demo is running and we haven't just advanced (which would have caused a rerun), # then we do a short sleep and rerun to keep the polling loop alive. if st.session_state.running_demo: # This check is implicitly: if we are here and demo is running, it means # the time-based advance condition was NOT met in the block at the top. time.sleep(0.1) st.rerun() st.markdown(""" ### ๐Ÿค” TL;DR โ€” Why NTL? Cross Entropy only cares if the prediction is exactly right or wrong โŒโœ… โ€” it doesnโ€™t care *how close* a guess is! Thatโ€™s bad for LLMs doing math and numeric reasoning ๐Ÿงฎ. ๐Ÿ’ฅ NTL fixes that: it behaves like a regression loss on the token head, rewarding predictions that are numerically close. """) st.markdown("#### ๐Ÿ“š Further Resources") st.markdown(""" - ๐Ÿ“„ [ICML 2025 Paper](https://arxiv.org/abs/2411.02083) - ๐ŸŒ [NTL Landing Page](https://tum-ai.github.io/number-token-loss/) - ๐Ÿ’ป [GitHub Code](https://github.com/tum-ai/number-token-loss) """)