jannisborn commited on
Commit
d830963
·
unverified ·
1 Parent(s): ba0cb12

feat: Initial commit

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -1
  2. src/streamlit_app.py +111 -29
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  altair
2
  pandas
3
- streamlit
 
 
1
  altair
2
  pandas
3
+ streamlit
4
+ streamlit_vertical_slider
src/streamlit_app.py CHANGED
@@ -1,40 +1,122 @@
1
  import altair as alt
2
- import numpy as np
3
  import pandas as pd
4
  import streamlit as st
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
 
 
 
 
 
18
 
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
 
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
 
 
 
 
 
 
25
 
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
  .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import altair as alt
 
2
  import pandas as pd
3
  import streamlit as st
4
+ import streamlit_vertical_slider as svs
5
+ import torch
6
+ from streamlit_vertical_slider import vertical_slider
7
 
8
+ st.title("Number Token Loss - Demo")
 
9
 
10
+ st.markdown("""
11
+ Adjust the sliders to set a predicted probability for each token (0-9 and "Text").
12
+ The sliders are vertical and compact. The app normalizes the slider values
13
+ to form a valid probability distribution, visualizes it, and computes the corresponding
14
+ Cross Entropy, NTL-MSE, and NTL-WAS losses.
15
+ """)
16
 
17
+ # Vertical sliders for predicted probabilities of tokens 0-9 and "Text"
18
+ st.markdown("#### Predicted Token Probabilities")
19
+ cols = st.columns(11)
20
+ prob_values = []
21
+ for i, col in enumerate(cols):
22
+ label = f"Token {i}" if i < 10 else "Text"
23
+ with col:
24
+ val = svs.vertical_slider(
25
+ label=label,
26
+ min_value=0.0,
27
+ max_value=1.0,
28
+ step=0.1,
29
+ height=50,
30
+ key=f"slider_{i}",
31
+ slider_color="green",
32
+ track_color="lightgray",
33
+ thumb_color="black",
34
+ )
35
+ prob_values.append(val)
36
 
37
+ # Normalize the probabilities to sum to 1
38
+ total = sum(prob_values)
39
+ probs = (
40
+ torch.ones(11) / 11.0
41
+ if total == 0
42
+ else torch.tensor([v / total for v in prob_values])
43
+ )
44
 
45
+ # Token labels
46
+ options = [str(i) for i in range(10)] + ["Text"]
 
47
 
48
+ # Ground truth token selection
49
+ gt_choice = st.selectbox("Ground Truth Token", options=options, index=0)
50
+ if gt_choice == "Text":
51
+ gt_index = 10
52
+ gt_numeric = None
53
+ else:
54
+ gt_index = int(gt_choice)
55
+ gt_numeric = gt_index
56
 
57
+ # Visualize the input distribution with highlighted ground truth bar
58
+ st.markdown("#### Input Probability Distribution")
59
+ df_dist = pd.DataFrame({"token": options, "probability": probs.numpy()})
60
+ chart = (
61
+ alt.Chart(df_dist)
62
+ .mark_bar()
 
 
 
63
  .encode(
64
+ x=alt.X("token:N", title="Token"),
65
+ y=alt.Y("probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1])),
66
+ color=alt.condition(
67
+ alt.datum.token == gt_choice,
68
+ alt.value("green"), # Highlight ground truth token
69
+ alt.value("steelblue"), # Other tokens
70
+ ),
71
+ )
72
+ .properties(height=300)
73
+ )
74
+ st.altair_chart(chart, use_container_width=True)
75
+
76
+ # Compute Cross Entropy loss: -log(predicted probability of the ground truth)
77
+ ce_loss = -torch.log(torch.clamp(probs[gt_index], min=1e-9))
78
+
79
+ # Compute NTL-MSE loss
80
+ if gt_numeric is None:
81
+ ntl_mse_loss = torch.tensor(0.0)
82
+ else:
83
+ numeric_probs = probs[:10]
84
+ values = torch.arange(0, 10, dtype=torch.float32)
85
+ pred_value = torch.sum(numeric_probs * values)
86
+ ntl_mse_loss = (pred_value - float(gt_numeric)) ** 2
87
+
88
+ # Compute NTL-WAS loss
89
+ if gt_numeric is None:
90
+ ntl_was_loss = torch.tensor(0.0)
91
+ else:
92
+ numeric_probs = probs[:10]
93
+ values = torch.arange(0, 10, dtype=torch.float32)
94
+ abs_diff = torch.abs(values - float(gt_numeric))
95
+ ntl_was_loss = torch.sum(numeric_probs * abs_diff)
96
+
97
+ # Convert losses to Python floats and round to 3 decimals
98
+ ce_val = round(ce_loss.item(), 3)
99
+ mse_val = round(ntl_mse_loss.item(), 3)
100
+ was_val = round(ntl_was_loss.item(), 3)
101
+
102
+ # Display numeric values of the losses
103
+ st.subheader("Loss Values")
104
+ st.write(f"**Cross Entropy:** {ce_val:.3f}")
105
+ st.write(f"**NTL-MSE:** {mse_val:.3f}")
106
+ st.write(f"**NTL-WAS:** {was_val:.3f}")
107
+
108
+ # Bar chart comparison of the three losses
109
+ st.subheader("Loss Comparison Chart")
110
+ loss_df = pd.DataFrame(
111
+ {
112
+ "Loss": ["Cross Entropy", "NTL-MSE", "NTL-WAS"],
113
+ "Value": [ce_val, mse_val, was_val],
114
+ }
115
+ ).set_index("Loss")
116
+ st.bar_chart(loss_df)
117
+
118
+ # References / resources section with links
119
+ st.markdown("### Resources")
120
+ st.markdown(
121
+ "- **Paper:** [Regress, Don't Guess – A Regression-like Loss on Number Tokens for Language Models](https://arxiv.org/abs/2411.02083) \n- **Code:** [tum-ai/number-token-loss (GitHub)](https://github.com/tum-ai/number-token-loss)"
122
+ )