Spaces:
Sleeping
Sleeping
| import logging | |
| import time | |
| import altair as alt | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import streamlit_vertical_slider as svs | |
| import torch | |
| from scenarios import dirac, gauss, make_bimodal_scenarios | |
| logging.getLogger("streamlit.watcher.local_sources_watcher").setLevel(logging.ERROR) | |
| DEMO_INTERVAL = 0.75 | |
| CE_SCALING = 0.25 | |
| MAX_LOSS_PLOT = 6 | |
| LAST_STEP = -1 | |
| # Define options globally as it's used in initialization and UI | |
| options = [str(i) for i in range(10)] + ["Text"] | |
| def compute_losses(probs: torch.Tensor, gt_token: str) -> tuple[float, float, float]: | |
| """Compute CE, NTL-MAE, NTL-WAS losses for the given probability vector and ground truth token.""" | |
| ce_loss = CE_SCALING * -torch.log( | |
| torch.clamp(probs[options.index(gt_token)], min=1e-9) | |
| ) | |
| numeric_mass = probs[:10].sum() | |
| if gt_token == "Text" or numeric_mass < 1e-6: | |
| return ce_loss.item(), 0.0, 0.0 | |
| gt_numeric = int(gt_token) | |
| token_vals = torch.arange(10, dtype=torch.float32) | |
| mae = numeric_mass * abs(torch.dot(token_vals, probs[:10]) - gt_numeric) | |
| was = numeric_mass * torch.dot(probs[:10], torch.abs(token_vals - gt_numeric)) | |
| return round(ce_loss.item(), 3), round(mae.item(), 3), round(was.item(), 3) | |
| # --- Session State Initialization --- | |
| # Ensure all session state variables are initialized before first use, especially by widgets. | |
| if "running_demo" not in st.session_state: | |
| st.session_state.running_demo = False | |
| if "demo_step" not in st.session_state: | |
| st.session_state.demo_step = 0 | |
| if "last_update_time" not in st.session_state: | |
| st.session_state.last_update_time = 0 | |
| if "loss_container" not in st.session_state: | |
| st.session_state.loss_container = None | |
| if "previous_chart_html" not in st.session_state: | |
| st.session_state.previous_chart_html = "" | |
| if "active_scenarios" not in st.session_state: | |
| # default if you want one to load on first show | |
| st.session_state.active_scenarios = dirac | |
| if "loss_history" not in st.session_state: | |
| st.session_state.loss_history = [] | |
| if "df_loss_plot" not in st.session_state: | |
| # Initialize an empty DataFrame for loss history | |
| st.session_state.df_loss_plot = pd.DataFrame( | |
| columns=["step", "x_val", "Loss Type", "Loss Value"] | |
| ) | |
| # Initialize states for sliders and ground_truth selector | |
| # Using len(options) to correctly size for 0-9 + "Text" | |
| for i in range(len(options)): | |
| if f"slider_{i}" not in st.session_state: | |
| st.session_state[f"slider_{i}"] = 0 | |
| if "ground_truth" not in st.session_state: | |
| st.session_state["ground_truth"] = options[5] | |
| if "manual_ground_truth" not in st.session_state: | |
| st.session_state["manual_ground_truth"] = options[5] | |
| if "demo_name" not in st.session_state: | |
| st.session_state["demo_name"] = "Dirac" | |
| st.title("NTL -- The Number Token Loss ๐") | |
| st.markdown( | |
| """This is the interactive demo for our [ICML 2025](https://arxiv.org/abs/2411.02083) paper!๐ | |
| โก๏ธ NTL augments cross-entropy to help LMs reason better with numbers ๐ง | |
| """ | |
| ) | |
| st.subheader("Demo 1 โ NTL vs. Cross Entropy in 3 Scenarios") | |
| st.markdown(""" | |
| 1๏ธโฃ Pick a ground truth token: a digit (0โ9) or "Text" ๐ (simulates generic text tokens). | |
| 2๏ธโฃ Choose a demo: | |
| - **Dirac** โก: All probability mass on one token. | |
| - **Gaussian** ๐: Soft bell-curve around the true number. | |
| - **Bimodal** ๐ฏ: Two peaks moving away from the target. | |
| Watch how losses evolve as predictions get worse โ and see how NTL shines compared to CE! ๐ | |
| """) | |
| if "ground_truth" not in st.session_state: | |
| st.session_state["ground_truth"] = "4" | |
| gt = st.selectbox("Ground Truth Token", options=options, key="ground_truth") | |
| def apply_scenario(step_idx): | |
| scenario = st.session_state.active_scenarios[step_idx] | |
| for i, val in enumerate(scenario["values"]): | |
| st.session_state[f"slider_{i}"] = val | |
| def start_dirac_demo(): | |
| st.session_state.loss_history = [] | |
| st.session_state.active_scenarios = dirac | |
| st.session_state.demo_name = "Dirac" | |
| st.session_state.running_demo = True | |
| st.session_state.demo_step = 0 | |
| st.session_state.last_update_time = time.time() | |
| apply_scenario(0) | |
| def start_gauss_demo(): | |
| st.session_state.loss_history = [] | |
| st.session_state.active_scenarios = gauss | |
| st.session_state.demo_name = "Gauss" | |
| st.session_state.running_demo = True | |
| st.session_state.demo_step = 0 | |
| st.session_state.last_update_time = time.time() | |
| apply_scenario(0) | |
| def start_bimodal_demo(): | |
| st.session_state.loss_history = [] | |
| gt = st.session_state["ground_truth"] | |
| st.session_state.active_scenarios = make_bimodal_scenarios(gt, options) | |
| st.session_state.demo_name = f"Bimodal (GT={gt})" | |
| st.session_state.running_demo = True | |
| st.session_state.demo_step = 0 | |
| st.session_state.last_update_time = time.time() | |
| apply_scenario(0) | |
| def stop_demo(): | |
| st.session_state.running_demo = False | |
| # --- Demo State Advancement Logic --- | |
| # This block handles advancing the demo. If it advances, it updates session state | |
| # and then reruns. This ensures widgets are drawn with the new state in the next run. | |
| if st.session_state.running_demo: | |
| scenario = st.session_state.active_scenarios | |
| current_time = time.time() | |
| if current_time - st.session_state.last_update_time > DEMO_INTERVAL: | |
| # if we havenโt yet shown the last scenario, advance | |
| if st.session_state.demo_step < len(scenario) - 1: | |
| st.session_state.demo_step += 1 | |
| apply_scenario(st.session_state.demo_step) | |
| st.session_state.last_update_time = current_time | |
| # st.rerun() # not needed, leading to too many reruns | |
| else: | |
| # we just displayed the final case โ stop | |
| st.session_state.running_demo = False | |
| # --- UI Rendering --- | |
| # This section renders the main UI. It executes after any potential rerun from the block above. | |
| if st.session_state.running_demo: | |
| st.info( | |
| f"Showing scenario {st.session_state.demo_step + 1}" | |
| f"/{len(st.session_state.active_scenarios)}: " | |
| f"{st.session_state.active_scenarios[st.session_state.demo_step]['name']}" | |
| ) | |
| if st.button("Stop Demo"): | |
| st.session_state.running_demo = False | |
| st.rerun() | |
| else: | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| if st.button("Run: Dirac"): | |
| start_dirac_demo() | |
| st.rerun() | |
| with col2: | |
| if st.button("Run: Gauss"): | |
| start_gauss_demo() | |
| st.rerun() | |
| with col3: | |
| if st.button("Run: Bimodal"): | |
| start_bimodal_demo() | |
| st.rerun() | |
| current_prob_values_from_state = [ | |
| st.session_state.get(f"slider_{j}", 0) | |
| for j in range(len(options)) # 1.0 / len(options)) for j in range(len(options)) | |
| ] | |
| total_from_state = sum(current_prob_values_from_state) | |
| probs_for_charts = ( | |
| torch.ones(len(options)) / len(options) | |
| if total_from_state == 0 | |
| else torch.tensor([v / total_from_state for v in current_prob_values_from_state]) | |
| ) | |
| # Use manual GT token when not in running demo | |
| gt_choice_for_charts = ( | |
| st.session_state["manual_ground_truth"] | |
| if not st.session_state.running_demo | |
| else st.session_state["ground_truth"] | |
| ) | |
| if gt_choice_for_charts == "Text": | |
| gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10) | |
| gt_numeric_for_charts = None | |
| else: | |
| gt_index_for_charts = int(gt_choice_for_charts) | |
| gt_numeric_for_charts = gt_index_for_charts | |
| gt = st.session_state["ground_truth"] | |
| demo_name = st.session_state["demo_name"] | |
| st.markdown(f'#### Predicted distribution (<span style="color:darkgreen;">ground truth: {gt}</span>)', unsafe_allow_html=True) | |
| df_dist = pd.DataFrame( | |
| {"token": options, "probability": probs_for_charts.numpy().round(2)} | |
| ) | |
| df_dist["is_gt"] = df_dist["token"] == gt | |
| bars = ( | |
| alt.Chart(df_dist) | |
| .mark_bar(color="dodgerblue", size=40) | |
| .encode( | |
| x=alt.X( | |
| "token:N", | |
| title="Token", | |
| sort=options, | |
| axis=alt.Axis( | |
| labelAngle=0, | |
| labelFontSize=14, | |
| titleFontSize=16, | |
| labelAlign="center", | |
| labelFlush=False, | |
| ), | |
| ), | |
| color=alt.condition( | |
| "datum.is_gt", | |
| alt.value("darkgreen"), # color for ground truth | |
| alt.value("dodgerblue") # color for others | |
| ), | |
| y=alt.Y( | |
| "probability:Q", | |
| title="Probability", | |
| scale=alt.Scale(domain=[0, 1]), | |
| axis=alt.Axis(format=".2f", labelFontSize=14, titleFontSize=16), | |
| ), | |
| tooltip=[ | |
| alt.Tooltip("token:N", title="Token"), | |
| alt.Tooltip("probability:Q", title="Predicted Prob.", format=".2f"), | |
| alt.Tooltip("is_gt:N", title="Ground Truth") | |
| ] | |
| ) | |
| ) | |
| st.altair_chart(bars.properties(height=200), use_container_width=True, theme="streamlit") | |
| ce_val, mae_val, was_val = compute_losses(probs_for_charts, gt_choice_for_charts) | |
| if ( | |
| st.session_state.running_demo | |
| and len(st.session_state.loss_history) < st.session_state.demo_step + 1 | |
| ): | |
| step = st.session_state.demo_step | |
| scenario = st.session_state.active_scenarios[step] | |
| ce, mae, was = compute_losses(probs_for_charts, gt_choice_for_charts) | |
| # pick x_val differently for bimodal vs others | |
| if st.session_state.demo_name.startswith("Bimodal"): | |
| x_val = scenario["name"] # e.g. "(4,4)", "(3,5)", โฆ | |
| else: | |
| # exactly like before: | |
| best_idx = np.argmax(scenario["values"]) | |
| x_val = options[best_idx] # "0", "1", โฆ, or "Text" | |
| st.session_state.loss_history.append( | |
| { | |
| "step": step, | |
| "x_val": x_val, | |
| "Cross Entropy": ce, | |
| "NTL-MAE": mae, | |
| "NTL-WAS": was, | |
| } | |
| ) | |
| st.session_state.df_loss_plot = pd.DataFrame(st.session_state.loss_history).melt(id_vars=["step", "x_val"], | |
| value_vars=["Cross Entropy", "NTL-MAE", "NTL-WAS"], | |
| var_name="Loss Type", | |
| value_name="Loss Value") | |
| loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]} | |
| if was_val != "N/A": | |
| loss_data["Loss"].append("NTL-WAS") | |
| loss_data["Value"].append(was_val) | |
| if mae_val != "N/A": | |
| loss_data["Loss"].append("NTL-MAE") | |
| loss_data["Value"].append(mae_val) | |
| loss_df = pd.DataFrame(loss_data) | |
| if st.session_state.demo_name.startswith("Bimodal"): | |
| domain = [sc["name"] for sc in st.session_state.active_scenarios] | |
| x_title = f"Offset from GT {st.session_state['ground_truth']}" | |
| else: | |
| domain = options | |
| x_title = f"Maximum of predicted {st.session_state['demo_name']} distribution" | |
| # ============== Chart Display ============== | |
| st.markdown("#### Loss as a function of predicted distribution") | |
| grouped_chart = ( | |
| alt.Chart(st.session_state.df_loss_plot) | |
| .mark_bar() | |
| .encode( | |
| x=alt.X( | |
| "x_val:N", | |
| title=x_title, | |
| sort=domain, | |
| scale=alt.Scale(domain=domain), | |
| axis=alt.Axis(labelAngle=0, labelFontSize=14, titleFontSize=16), | |
| ), | |
| y=alt.Y( | |
| "Loss Value:Q", | |
| title="Loss Value", | |
| scale=alt.Scale(domain=[0, MAX_LOSS_PLOT], nice=False, clamp=True), | |
| axis=alt.Axis(labelFontSize=14, titleFontSize=16), | |
| ), | |
| color=alt.Color( | |
| "Loss Type:N", | |
| scale=alt.Scale( | |
| domain=["Cross Entropy", "NTL-WAS", "NTL-MAE"], | |
| range=["red", "limegreen", "blueviolet"], | |
| ), | |
| legend=alt.Legend( | |
| title="", | |
| orient="top", | |
| direction="horizontal", | |
| columns=3, | |
| ), | |
| ), | |
| xOffset="Loss Type:N", # grouped bars | |
| tooltip=[ | |
| alt.Tooltip("x_val:N", title="Scenario"), | |
| alt.Tooltip("Loss Type:N", title="Loss Type"), | |
| alt.Tooltip("Loss Value:Q", title="Value", format=".3f"), | |
| ], | |
| ) | |
| .properties(height=250) | |
| ) | |
| st.altair_chart(grouped_chart, use_container_width=True, theme="streamlit") | |
| # Create a single chart for loss visualization | |
| if not st.session_state.running_demo: | |
| for i in range(len(options)): | |
| st.session_state[f"slider_{i}"] = 0.0 | |
| st.session_state.demo_step = 0 | |
| st.subheader("Demo 2 -- Manual loss comparison") | |
| st.subheader("๐งช Demo 2 โ Craft your own distribution") | |
| st.markdown(""" | |
| This demo gives you more control but is harder to interpret. See it as a playground! ๐จ | |
| Manually adjust the sliders to change the predicted probabilities for each token. | |
| The demo normalizes the values to form a valid probability distribution and calculates the losses. | |
| ๐ฃ **Steps:** | |
| - Use the **vertical sliders** to allocate probability to each token. | |
| - Choose the correct **Ground Truth Token** (0โ9 or "Text" ๐). | |
| - Observe how each loss function reacts. | |
| ๐ก **Tip:** Want to trick the loss? Try putting all mass on the wrong token or spread it wildly. See how NTL handles it! ๐ | |
| """) | |
| manual_gt = st.selectbox( | |
| "Ground Truth Token", | |
| options=options, | |
| key="manual_ground_truth", | |
| ) | |
| loss_df = pd.DataFrame( | |
| { | |
| "Loss": ["Cross Entropy", "NTL-MAE", "NTL-WAS"], | |
| "Value": [ce_val, mae_val, was_val], | |
| } | |
| ) | |
| # Sliders and Ground Truth Selector | |
| # These widgets will read their initial values from st.session_state. | |
| # User interactions will update st.session_state directly due to their keys. | |
| st.markdown("#### Adjust the predicted token probability") | |
| cols = st.columns(len(options)) | |
| for i, col in enumerate(cols): | |
| label = options[i] # Use token name directly for label | |
| with col: | |
| svs.vertical_slider( | |
| label=label, | |
| min_value=0.0, | |
| max_value=1.0, | |
| step=0.01, | |
| height=50, | |
| key=f"slider_{i}", | |
| slider_color="green", | |
| track_color="lightgray", | |
| thumb_color="black", | |
| ) | |
| chart = ( | |
| alt.Chart(loss_df) | |
| .mark_bar() | |
| .encode( | |
| x=alt.X("Loss:N", sort=loss_df["Loss"].tolist()), | |
| y=alt.Y( | |
| "Value:Q", | |
| scale=alt.Scale( | |
| domain=[ | |
| 0, | |
| max( | |
| loss_df["Value"].max() * 1.2, | |
| 20 if st.session_state.running_demo else 0.5, | |
| ), | |
| ] | |
| ), | |
| ), | |
| color=alt.Color( | |
| "Loss:N", | |
| scale=alt.Scale( | |
| domain=["Cross Entropy", "NTL-WAS", "NTL-MAE"], | |
| range=["orangered", "limegreen", "blueviolet"], | |
| ), | |
| ), | |
| tooltip=["Loss", "Value"], | |
| ) | |
| .properties(height=300) | |
| ) | |
| text = chart.mark_text( | |
| align="center", baseline="bottom", dy=-5, fontSize=14 | |
| ).encode(text=alt.Text("Value:Q", format=".3f")) | |
| final_chart = chart + text | |
| st.altair_chart(final_chart, use_container_width=True) | |
| # # Add value labels on top of bars | |
| # text = chart.mark_text(align="center", baseline="bottom", dy=-5, fontSize=14).encode( | |
| # text=alt.Text("Value:Q", format=".3f") | |
| # ) | |
| # # Combine chart and text | |
| # final_chart = chart + text | |
| # Display chart with the full container width | |
| # st.altair_chart(final_chart, use_container_width=True) | |
| # --- Polling Rerun for Demo Mode --- | |
| # If the demo is running and we haven't just advanced (which would have caused a rerun), | |
| # then we do a short sleep and rerun to keep the polling loop alive. | |
| if st.session_state.running_demo: | |
| # This check is implicitly: if we are here and demo is running, it means | |
| # the time-based advance condition was NOT met in the block at the top. | |
| time.sleep(DEMO_INTERVAL) | |
| st.rerun() | |
| st.markdown(""" | |
| ### ๐ค TL;DR โ Why NTL? | |
| Cross Entropy only cares if the prediction is exactly right or wrong โโ โ it doesnโt care *how close* a guess is! | |
| Thatโs bad for LLMs doing math and numeric reasoning ๐งฎ. | |
| ๐ฅ NTL fixes that: it behaves like a regression loss on the token head, rewarding predictions that are numerically close. | |
| """) | |
| st.markdown("#### ๐ Further Resources") | |
| st.markdown(""" | |
| - ๐ [ICML 2025 Paper](https://arxiv.org/abs/2411.02083) | |
| - ๐ [NTL Landing Page](https://tum-ai.github.io/number-token-loss/) | |
| - ๐ป [GitHub Code](https://github.com/tum-ai/number-token-loss) | |
| """) | |