schoginitoys commited on
Commit
a53f381
·
verified ·
1 Parent(s): c325b73

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +141 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,143 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import os
2
+ # turn off Streamlit’s automatic file-watching
3
+ os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false"
4
+
5
+ import sys
6
+ import types
7
+ import torch # now safe to import
8
  import streamlit as st
9
+ import numpy as np
10
+
11
+ # Prevent Streamlit from trying to walk torch.classes' non-standard __path__
12
+ if isinstance(getattr(sys.modules.get("torch"), "classes", None), types.ModuleType):
13
+ torch.classes.__path__ = []
14
+
15
+ # pip install tiktoken transformers
16
+ import tiktoken
17
+ from transformers import GPT2TokenizerFast
18
+
19
+ st.set_page_config(page_title="Embedding Dimension Visualizer", layout="wide")
20
+ st.title("🔍 Embedding Dimension Visualizer")
21
+
22
+ # ---- THEORY EXPANDER ----
23
+ with st.expander("📖 Theory: Tokenization, BPE & Positional Encoding"):
24
+ st.markdown("""
25
+ **1️⃣ Tokenization**
26
+ Splits raw text into atomic units (“tokens”).
27
+
28
+ **2️⃣ Byte-Pair Encoding (BPE)**
29
+ Iteratively merges the most frequent pair of symbols to build a subword vocabulary.
30
+ E.g. "embedding" → ["em", "bed", "ding"]
31
+
32
+ **3️⃣ Positional Encoding**
33
+ We add a deterministic sinusoidal vector to each token embedding so the model knows position.
34
+ """)
35
+ st.markdown("For embedding dimension \(d\), position \(pos\) and channel index \(i\):")
36
+ st.latex(r"""\mathrm{PE}_{(pos,\,2i)} = \sin\!\Bigl(\frac{pos}{10000^{2i/d}}\Bigr)""")
37
+ st.latex(r"""\mathrm{PE}_{(pos,\,2i+1)} = \cos\!\Bigl(\frac{pos}{10000^{2i/d}}\Bigr)""")
38
+ st.markdown("""
39
+ - \(pos\) starts at 0 for the first token
40
+ - Even channels use \(\sin\), odd channels use \(\cos\)
41
+ - This injects unique, smoothly varying positional signals into each embedding
42
+ """)
43
+
44
+
45
+ # ---- Sidebar ----
46
+ with st.sidebar:
47
+ st.header("Settings")
48
+ input_text = st.text_input("Enter text to embed", value="Hello world!")
49
+ dim = st.number_input(
50
+ "Embedding dimensions",
51
+ min_value=2,
52
+ max_value=1536,
53
+ value=3,
54
+ step=1,
55
+ help="Choose 2, 3, 512, 768, 1536, etc."
56
+ )
57
+ tokenizer_choice = st.selectbox(
58
+ "Choose tokenizer",
59
+ ["tiktoken", "openai", "huggingface"],
60
+ help="Which tokenization scheme to demo."
61
+ )
62
+ generate = st.button("Generate / Reset Embedding")
63
+
64
+ if not generate:
65
+ st.info("Adjust the settings in the sidebar and click **Generate / Reset Embedding** to see the tokens and sliders.")
66
+ st.stop()
67
+
68
+ # ---- Tokenize ----
69
+ if tokenizer_choice in ("tiktoken", "openai"):
70
+ model_name = "gpt2" if tokenizer_choice=="tiktoken" else "gpt-3.5-turbo"
71
+ enc = tiktoken.encoding_for_model(model_name)
72
+ token_ids = enc.encode(input_text)
73
+ token_strs = [enc.decode([tid]) for tid in token_ids]
74
+ else:
75
+ hf_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
76
+ token_ids = hf_tokenizer.encode(input_text)
77
+ token_strs = hf_tokenizer.convert_ids_to_tokens(token_ids)
78
+
79
+ st.subheader("🪶 Tokens and IDs")
80
+ for i, (tok, tid) in enumerate(zip(token_strs, token_ids), start=1):
81
+ st.write(f"**{i}.** `{tok}` → ID **{tid}**")
82
+
83
+ st.write("---")
84
+ st.subheader("📊 Embedding + Positional Encoding per Token")
85
+ st.write(f"Input: `{input_text}` | Tokenizer: **{tokenizer_choice}** | Dims per token: **{dim}**")
86
+ if dim > 20:
87
+ st.warning("Showing >20 sliders per block may be unwieldy; consider smaller dims for teaching.")
88
+
89
+ # helper for sinusoidal positional encoding
90
+ def get_positional_encoding(position: int, d_model: int) -> np.ndarray:
91
+ pe = np.zeros(d_model, dtype=float)
92
+ for i in range(d_model):
93
+ angle = position / np.power(10000, (2 * (i // 2)) / d_model)
94
+ pe[i] = np.sin(angle) if (i % 2 == 0) else np.cos(angle)
95
+ return pe
96
+
97
+ # ---- For each token, three slider‐blocks ----
98
+ for t_idx, tok in enumerate(token_strs, start=1):
99
+ emb = np.random.uniform(-1.0, 1.0, size=dim)
100
+ pe = get_positional_encoding(t_idx - 1, dim)
101
+ combined = emb + pe
102
+
103
+ with st.expander(f"Token {t_idx}: `{tok}`"):
104
+ st.markdown("**1️⃣ Embedding**")
105
+ for d in range(dim):
106
+ st.slider(
107
+ label=f"Emb Dim {d+1}",
108
+ min_value=-1.0, max_value=1.0,
109
+ value=float(emb[d]),
110
+ key=f"t{t_idx}_emb{d+1}",
111
+ disabled=True
112
+ )
113
+
114
+ st.markdown("**2️⃣ Positional Encoding (sin / cos)**")
115
+ for d in range(dim):
116
+ st.slider(
117
+ label=f"PE Dim {d+1}",
118
+ min_value=-1.0, max_value=1.0,
119
+ value=float(pe[d]),
120
+ key=f"t{t_idx}_pe{d+1}",
121
+ disabled=True
122
+ )
123
+
124
+ st.markdown("**3️⃣ Embedding + Positional Encoding**")
125
+ for d in range(dim):
126
+ st.slider(
127
+ label=f"Sum Dim {d+1}",
128
+ min_value=-2.0, max_value=2.0,
129
+ value=float(combined[d]),
130
+ key=f"t{t_idx}_sum{d+1}",
131
+ disabled=True
132
+ )
133
+
134
+ # ---- NEW FINAL SECTION ----
135
+ st.write("---")
136
+ st.subheader("Final Input Embedding Plus Positional Encoding Ready to Send to ATtention Heads")
137
 
138
+ for t_idx, tid in enumerate(token_ids, start=1):
139
+ with st.expander(f"Token ID {tid}"):
140
+ for d in range(1, dim+1):
141
+ # pull the “sum” value out of session state
142
+ val = st.session_state.get(f"t{t_idx}_sum{d}", None)
143
+ st.write(f"Dim {d}: {val:.4f}" if val is not None else f"Dim {d}: N/A")