berndf commited on
Commit
c4fa76e
·
verified ·
1 Parent(s): 7f68d76
Files changed (1) hide show
  1. app.py +314 -0
app.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import random
3
+ import numpy as np
4
+ import streamlit as st
5
+ import plotly.graph_objects as go
6
+ from sklearn.decomposition import PCA
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModel
9
+
10
+ st.set_page_config(page_title="Embedding Visualizer", layout="wide")
11
+
12
+ # -----------------------------
13
+ # Base datasets (dataset names stay lowercase)
14
+ # -----------------------------
15
+ BASE_SETS = {
16
+ "countries": [
17
+ "Germany","France","Italy","Spain","Portugal","Poland","Netherlands","Belgium","Austria","Switzerland",
18
+ "Greece","Norway","Sweden","Finland","Denmark","Ireland","Hungary","Czechia","Slovakia","Slovenia",
19
+ "Romania","Bulgaria","Croatia","Estonia","Latvia"
20
+ ],
21
+ "animals": [
22
+ "cat","dog","lion","tiger","bear","wolf","fox","eagle","shark","whale",
23
+ "zebra","giraffe","elephant","hippopotamus","rhinoceros","kangaroo","panda","otter","seal","dolphin",
24
+ "chimpanzee","gorilla","leopard","cheetah","lynx"
25
+ ],
26
+ "furniture": [
27
+ "armchair","sofa","dining table","coffee table","bookshelf","bed","wardrobe","desk","office chair","dresser",
28
+ "nightstand","side table","tv stand","loveseat","chaise lounge","bench","hutch","kitchen island","futon","recliner",
29
+ "ottoman","console table","vanity","buffet","sectional sofa"
30
+ ],
31
+ "actors": [
32
+ "Brad Pitt","Angelina Jolie","Meryl Streep","Leonardo DiCaprio","Tom Hanks","Scarlett Johansson","Robert De Niro",
33
+ "Natalie Portman","Matt Damon","Cate Blanchett","Johnny Depp","Keanu Reeves","Hugh Jackman","Emma Stone","Ryan Gosling",
34
+ "Jennifer Lawrence","Christian Bale","Charlize Theron","Will Smith","Anne Hathaway","Denzel Washington","Morgan Freeman",
35
+ "Julia Roberts","George Clooney","Kate Winslet"
36
+ ],
37
+ "rock groups": [
38
+ "The Beatles","Rolling Stones","Pink Floyd","Queen","Led Zeppelin","U2","AC/DC","Nirvana","Radiohead","Metallica",
39
+ "Guns N' Roses","Red Hot Chili Peppers","Coldplay","Pearl Jam","The Police","Aerosmith","Green Day","Foo Fighters",
40
+ "The Doors","Bon Jovi","Deep Purple","The Who","The Kinks","Fleetwood Mac","The Beach Boys"
41
+ ],
42
+ "sports": [
43
+ "soccer","basketball","tennis","baseball","golf","swimming","cycling","running","volleyball","rugby",
44
+ "boxing","skiing","snowboarding","surfing","skateboarding","karate","judo","fencing","rowing","badminton",
45
+ "cricket","table tennis","gymnastics","hockey","climbing"
46
+ ],
47
+ }
48
+
49
+ # -----------------------------
50
+ # Build datasets once per session (base + 3 random mixed)
51
+ # -----------------------------
52
+ def make_random_mixed_sets(base: dict, n: int = 3) -> dict:
53
+ keys = list(base.keys())
54
+ out = {}
55
+ for _ in range(n):
56
+ src = random.sample(keys, 3)
57
+ items = []
58
+ for s in src:
59
+ take = min(7, len(base[s]))
60
+ items.extend(random.sample(base[s], take))
61
+ out["/".join(src)] = items[:21]
62
+ return out
63
+
64
+ if "datasets" not in st.session_state:
65
+ mixed = make_random_mixed_sets(BASE_SETS, 3)
66
+ st.session_state.datasets = {**BASE_SETS, **mixed}
67
+
68
+ DATASETS = st.session_state.datasets # shorthand
69
+
70
+ # -----------------------------
71
+ # Models (transformers)
72
+ # -----------------------------
73
+ MODELS = {
74
+ "all-MiniLM-L6-v2 (384d)": "sentence-transformers/all-MiniLM-L6-v2",
75
+ "all-mpnet-base-v2 (768d)": "sentence-transformers/all-mpnet-base-v2",
76
+ "all-roberta-large-v1 (1024d)": "sentence-transformers/all-roberta-large-v1",
77
+ }
78
+
79
+ @st.cache_resource(show_spinner=False)
80
+ def load_model(model_name: str):
81
+ tok = AutoTokenizer.from_pretrained(model_name)
82
+ mdl = AutoModel.from_pretrained(model_name)
83
+ mdl.eval()
84
+ return tok, mdl
85
+
86
+ @st.cache_data(show_spinner=False)
87
+ def embed_texts(model_name: str, texts_tuple: tuple):
88
+ tokenizer, model = load_model(model_name)
89
+ texts = list(texts_tuple)
90
+ with torch.no_grad():
91
+ inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
92
+ outputs = model(**inputs)
93
+ token_embeddings = outputs.last_hidden_state
94
+ mask = inputs["attention_mask"].unsqueeze(-1).type_as(token_embeddings)
95
+ summed = (token_embeddings * mask).sum(dim=1)
96
+ counts = mask.sum(dim=1).clamp(min=1e-9)
97
+ embeddings = summed / counts # mean pooling
98
+ return embeddings.cpu().numpy()
99
+
100
+ # -----------------------------
101
+ # Info page (local) via st.query_params
102
+ # -----------------------------
103
+ def goto(page: str):
104
+ st.query_params["page"] = page
105
+ st.rerun()
106
+
107
+ page = st.query_params.get("page", "demo")
108
+
109
+ if page == "info":
110
+ st.title("about this demo")
111
+ st.write("""
112
+ # 🧠 Embedding Visualizer – About
113
+
114
+ This demo shows how **vector embeddings** can capture the meaning of words and place them in a **numerical space** where related items appear close together.
115
+
116
+ You can:
117
+ - Choose from predefined or mixed datasets (e.g., countries, animals, actors, sports)
118
+ - Select different embedding models to compare results
119
+ - Switch between 2D and 3D visualizations
120
+ - Edit the list of words directly and see the updated projection instantly
121
+
122
+ ---
123
+
124
+ ## 📌 What are Vector Embeddings?
125
+ A **vector embedding** is a way of representing text (words, sentences, or documents) as a list of numbers — a point in a high-dimensional space.
126
+ These numbers are produced by a trained **language model** that captures semantic meaning.
127
+
128
+ In this space:
129
+ - Words with **similar meanings** end up **near each other**
130
+ - Dissimilar words are placed **far apart**
131
+ - The model can detect relationships and groupings that aren’t obvious from spelling or grammar alone
132
+
133
+ Example:
134
+ `"cat"` and `"dog"` will likely be closer to each other than to `"table"`, because the model “knows” they are both animals.
135
+
136
+ ---
137
+
138
+ ## 🔍 How the Demo Works
139
+ 1. **Embedding step** – Each word is converted into a high-dimensional vector (e.g., 384, 768, or 1024 dimensions depending on the model).
140
+ 2. **Dimensionality reduction** – Since humans can’t visualize hundreds of dimensions, the vectors are projected to 2D or 3D using **PCA** (Principal Component Analysis).
141
+ 3. **Visualization** – The projected points are plotted, with labels showing the original words.
142
+ You can rotate the 3D view to explore groupings.
143
+
144
+ ---
145
+
146
+ ## 💡 Typical Applications of Embeddings
147
+ - **Semantic search** – Find relevant results even if exact keywords don’t match
148
+ - **Clustering & topic discovery** – Group related items automatically
149
+ - **Recommendations** – Suggest similar products, movies, or articles
150
+ - **Deduplication** – Detect near-duplicate content
151
+ - **Analogies** – Explore relationships like *"king" – "man" + "woman" ≈ "queen"*
152
+
153
+ ---
154
+
155
+ ## 🚀 Try it Yourself
156
+ - Pick a dataset or create your own by editing the list
157
+ - Switch models to compare how the embedding space changes
158
+ - Toggle between 2D and 3D to explore patterns
159
+
160
+ """.strip())
161
+ if st.button("⬅ back to demo"):
162
+ goto("demo")
163
+ st.stop()
164
+
165
+ # -----------------------------
166
+ # Top compact bar
167
+ # -----------------------------
168
+ c1, c2, c3, c4 = st.columns([2, 2, 1, 1])
169
+
170
+ with c1:
171
+ if "dataset_name" not in st.session_state:
172
+ st.session_state.dataset_name = "furniture" if "furniture" in DATASETS else list(DATASETS.keys())[0]
173
+ dataset_name = st.selectbox("dataset", list(DATASETS.keys()),
174
+ index=list(DATASETS.keys()).index(st.session_state.dataset_name),
175
+ key="dataset_name")
176
+
177
+ with c2:
178
+ if "model_name" not in st.session_state:
179
+ st.session_state.model_name = list(MODELS.values())[0]
180
+ labels = list(MODELS.keys())
181
+ rev = {v: k for k, v in MODELS.items()}
182
+ current_label = rev.get(st.session_state.model_name, labels[0])
183
+ chosen_label = st.selectbox("embedding model", labels, index=labels.index(current_label))
184
+ st.session_state.model_name = MODELS[chosen_label]
185
+
186
+ with c3:
187
+ # Single-click fix: stable key and only set index on first render
188
+ radio_kwargs = dict(options=["2D", "3D"], horizontal=True, key="proj_mode")
189
+ if "proj_mode" not in st.session_state:
190
+ radio_kwargs["index"] = 1 # default to 3D initially
191
+ st.radio("projection", **radio_kwargs)
192
+
193
+ with c4:
194
+ if st.button("ℹ info"):
195
+ goto("info")
196
+
197
+ # -----------------------------
198
+ # Two-column layout (left = textarea, right = plot)
199
+ # -----------------------------
200
+ left, right = st.columns([1, 2], gap="large")
201
+
202
+ # Keep textarea synced with dataset selection
203
+ if "dataset_text" not in st.session_state:
204
+ st.session_state.dataset_text = "\n".join(DATASETS[st.session_state.dataset_name])
205
+
206
+ if "prev_dataset_name" not in st.session_state:
207
+ st.session_state.prev_dataset_name = st.session_state.dataset_name
208
+
209
+ if st.session_state.dataset_name != st.session_state.prev_dataset_name:
210
+ st.session_state.dataset_text = "\n".join(DATASETS[st.session_state.dataset_name])
211
+ st.session_state.prev_dataset_name = st.session_state.dataset_name
212
+
213
+ with left:
214
+ st.text_area(
215
+ label="",
216
+ key="dataset_text",
217
+ height=420,
218
+ help="edit words (one per line). changing dataset above refreshes this box."
219
+ )
220
+ words = [w.strip() for w in st.session_state.dataset_text.split("\n") if w.strip()]
221
+
222
+ with right:
223
+ if len(words) < 3:
224
+ st.info("enter at least three lines to project.")
225
+ st.stop()
226
+
227
+ X = embed_texts(st.session_state.model_name, tuple(words))
228
+
229
+ # Capitalized dataset name for the chart title (dataset keys remain lowercase in the UI)
230
+ chart_title = st.session_state.dataset_name.title()
231
+
232
+ if st.session_state.proj_mode == "2D":
233
+ coords = PCA(n_components=2).fit_transform(X)
234
+ fig = go.Figure(
235
+ data=[go.Scatter(
236
+ x=coords[:, 0], y=coords[:, 1],
237
+ mode="markers+text",
238
+ text=words, textposition="top center",
239
+ marker=dict(size=9),
240
+ )],
241
+ layout=go.Layout(
242
+ xaxis=dict(title="PC1"),
243
+ yaxis=dict(title="PC2", scaleanchor="x", scaleratio=1),
244
+ margin=dict(l=0, r=0, b=0, t=40),
245
+ ),
246
+ )
247
+ fig.update_layout(
248
+ title=dict(
249
+ text=chart_title,
250
+ x=0.5, xanchor='center', yanchor='top',
251
+ font=dict(size=20)
252
+ )
253
+ )
254
+ else:
255
+ coords = PCA(n_components=3).fit_transform(X)
256
+ fig = go.Figure(
257
+ data=[go.Scatter3d(
258
+ x=coords[:, 0], y=coords[:, 1], z=coords[:, 2],
259
+ mode="markers+text",
260
+ text=words, textposition="top center",
261
+ marker=dict(size=6),
262
+ )],
263
+ layout=go.Layout(
264
+ scene=dict(
265
+ xaxis=dict(showbackground=True, backgroundcolor="rgba(255, 230, 230, 1)"),
266
+ yaxis=dict(showbackground=True, backgroundcolor="rgba(230, 255, 230, 1)"),
267
+ zaxis=dict(showbackground=True, backgroundcolor="rgba(230, 230, 255, 1)"),
268
+ ),
269
+ margin=dict(l=0, r=0, b=0, t=40),
270
+ ),
271
+ )
272
+ fig.update_layout(
273
+ title=dict(
274
+ text=chart_title,
275
+ x=0.5, xanchor='center', yanchor='top',
276
+ font=dict(size=20)
277
+ )
278
+ )
279
+
280
+ # Simple Plotly rotation: frames + Rotate/Stop buttons
281
+ frames = []
282
+ radius = 1.7
283
+ z_eye = 1.0
284
+ for ang in range(0, 360, 4):
285
+ rad = np.deg2rad(ang)
286
+ frames.append(go.Frame(layout=dict(
287
+ scene_camera=dict(eye=dict(x=radius*np.cos(rad), y=radius*np.sin(rad), z=z_eye),
288
+ projection=dict(type="perspective"))
289
+ )))
290
+ fig.frames = frames
291
+
292
+ fig.update_layout(
293
+ updatemenus=[dict(
294
+ type="buttons", showactive=False, x=0.02, y=0.98,
295
+ buttons=[
296
+ dict(
297
+ label="▶ Rotate",
298
+ method="animate",
299
+ args=[None, dict(frame=dict(duration=40, redraw=True),
300
+ transition=dict(duration=0),
301
+ fromcurrent=True, mode="immediate")]
302
+ ),
303
+ dict(
304
+ label="⏹ Stop",
305
+ method="animate",
306
+ args=[[None], dict(frame=dict(duration=0, redraw=False),
307
+ transition=dict(duration=0),
308
+ mode="immediate")]
309
+ )
310
+ ]
311
+ )]
312
+ )
313
+
314
+ st.plotly_chart(fig, use_container_width=True)