berndf's picture
removed diagram ref
0b13d14 verified
# app.py
import random
import numpy as np
import streamlit as st
import plotly.graph_objects as go
from sklearn.decomposition import PCA
import torch
from transformers import AutoTokenizer, AutoModel
st.set_page_config(page_title="Embedding Visualizer", layout="wide")
# -----------------------------
# Base datasets (dataset names stay lowercase)
# -----------------------------
BASE_SETS = {
"countries": [
"Germany","France","Italy","Spain","Portugal","Poland","Netherlands","Belgium","Austria","Switzerland",
"Greece","Norway","Sweden","Finland","Denmark","Ireland","Hungary","Czechia","Slovakia","Slovenia",
"Romania","Bulgaria","Croatia","Estonia","Latvia"
],
"animals": [
"cat","dog","lion","tiger","bear","wolf","fox","eagle","shark","whale",
"zebra","giraffe","elephant","hippopotamus","rhinoceros","kangaroo","panda","otter","seal","dolphin",
"chimpanzee","gorilla","leopard","cheetah","lynx"
],
"furniture": [
"armchair","sofa","dining table","coffee table","bookshelf","bed","wardrobe","desk","office chair","dresser",
"nightstand","side table","tv stand","loveseat","chaise lounge","bench","hutch","kitchen island","futon","recliner",
"ottoman","console table","vanity","buffet","sectional sofa"
],
"actors": [
"Brad Pitt","Angelina Jolie","Meryl Streep","Leonardo DiCaprio","Tom Hanks","Scarlett Johansson","Robert De Niro",
"Natalie Portman","Matt Damon","Cate Blanchett","Johnny Depp","Keanu Reeves","Hugh Jackman","Emma Stone","Ryan Gosling",
"Jennifer Lawrence","Christian Bale","Charlize Theron","Will Smith","Anne Hathaway","Denzel Washington","Morgan Freeman",
"Julia Roberts","George Clooney","Kate Winslet"
],
"rock groups": [
"The Beatles","Rolling Stones","Pink Floyd","Queen","Led Zeppelin","U2","AC/DC","Nirvana","Radiohead","Metallica",
"Guns N' Roses","Red Hot Chili Peppers","Coldplay","Pearl Jam","The Police","Aerosmith","Green Day","Foo Fighters",
"The Doors","Bon Jovi","Deep Purple","The Who","The Kinks","Fleetwood Mac","The Beach Boys"
],
"sports": [
"soccer","basketball","tennis","baseball","golf","swimming","cycling","running","volleyball","rugby",
"boxing","skiing","snowboarding","surfing","skateboarding","karate","judo","fencing","rowing","badminton",
"cricket","table tennis","gymnastics","hockey","climbing"
],
}
# -----------------------------
# Build datasets once per session (base + 3 random mixed)
# -----------------------------
def make_random_mixed_sets(base: dict, n: int = 3) -> dict:
keys = list(base.keys())
out = {}
for _ in range(n):
src = random.sample(keys, 3)
items = []
for s in src:
take = min(7, len(base[s]))
items.extend(random.sample(base[s], take))
out["/".join(src)] = items[:21]
return out
if "datasets" not in st.session_state:
mixed = make_random_mixed_sets(BASE_SETS, 3)
st.session_state.datasets = {**BASE_SETS, **mixed}
DATASETS = st.session_state.datasets # shorthand
# -----------------------------
# Models (transformers)
# -----------------------------
MODELS = {
"all-MiniLM-L6-v2 (384d)": "sentence-transformers/all-MiniLM-L6-v2",
"all-mpnet-base-v2 (768d)": "sentence-transformers/all-mpnet-base-v2",
"all-roberta-large-v1 (1024d)": "sentence-transformers/all-roberta-large-v1",
}
@st.cache_resource(show_spinner=False)
def load_model(model_name: str):
tok = AutoTokenizer.from_pretrained(model_name)
mdl = AutoModel.from_pretrained(model_name)
mdl.eval()
return tok, mdl
@st.cache_data(show_spinner=False)
def embed_texts(model_name: str, texts_tuple: tuple):
tokenizer, model = load_model(model_name)
texts = list(texts_tuple)
with torch.no_grad():
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
outputs = model(**inputs)
token_embeddings = outputs.last_hidden_state
mask = inputs["attention_mask"].unsqueeze(-1).type_as(token_embeddings)
summed = (token_embeddings * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1e-9)
embeddings = summed / counts # mean pooling
return embeddings.cpu().numpy()
# -----------------------------
# Info page (local) via st.query_params
# -----------------------------
def goto(page: str):
st.query_params["page"] = page
st.rerun()
page = st.query_params.get("page", "demo")
if page == "info":
st.write("""
# 🧠 Embedding Visualizer – About
This demo shows how **vector embeddings** can capture the meaning of words and place them in a **numerical space** where related items appear close together.
You can:
- Choose from predefined or mixed datasets (e.g., countries, animals, actors, sports)
- Select different embedding models to compare results
- Switch between 2D and 3D visualizations
- Edit the list of words directly and see the updated projection instantly
---
## 📌 What are Vector Embeddings?
A **vector embedding** is a way of representing text (words, sentences, or documents) as a list of numbers — a point in a high-dimensional space.
These numbers are produced by a trained **language model** that captures semantic meaning.
In this space:
- Words with **similar meanings** end up **near each other**
- Dissimilar words are placed **far apart**
- The model can detect relationships and groupings that aren’t obvious from spelling or grammar alone
Example:
`"cat"` and `"dog"` will likely be closer to each other than to `"table"`, because the model “knows” they are both animals.
---
## 🔍 How the Demo Works
1. **Embedding step** – Each word is converted into a high-dimensional vector (e.g., 384, 768, or 1024 dimensions depending on the model).
2. **Dimensionality reduction** – Since humans can’t visualize hundreds of dimensions, the vectors are projected to 2D or 3D using **PCA** (Principal Component Analysis).
3. **Visualization** – The projected points are plotted, with labels showing the original words.
You can rotate the 3D view to explore groupings.
---
## 💡 Typical Applications of Embeddings
- **Semantic search** – Find relevant results even if exact keywords don’t match
- **Clustering & topic discovery** – Group related items automatically
- **Recommendations** – Suggest similar products, movies, or articles
- **Deduplication** – Detect near-duplicate content
- **Analogies** – Explore relationships like *"king" – "man" + "woman" ≈ "queen"*
---
## 🚀 Try it Yourself
- Pick a dataset or create your own by editing the list
- Switch models to compare how the embedding space changes
- Toggle between 2D and 3D to explore patterns
""".strip())
if st.button("⬅ back to demo"):
goto("demo")
st.stop()
# -----------------------------
# Top compact bar
# -----------------------------
c1, c2, c3, c4 = st.columns([2, 2, 1, 1])
with c1:
if "dataset_name" not in st.session_state:
st.session_state.dataset_name = "actors" if "actors" in DATASETS else list(DATASETS.keys())[0]
dataset_name = st.selectbox("dataset", list(DATASETS.keys()),
index=list(DATASETS.keys()).index(st.session_state.dataset_name),
key="dataset_name")
with c2:
if "model_name" not in st.session_state:
st.session_state.model_name = list(MODELS.values())[1]
labels = list(MODELS.keys())
rev = {v: k for k, v in MODELS.items()}
current_label = rev.get(st.session_state.model_name, labels[0])
chosen_label = st.selectbox("embedding model", labels, index=labels.index(current_label))
st.session_state.model_name = MODELS[chosen_label]
with c3:
# Default to 3D on first render; single-click thereafter
radio_kwargs = dict(options=["2D", "3D"], horizontal=True, key="proj_mode")
if "proj_mode" not in st.session_state:
radio_kwargs["index"] = 1 # 3D default
st.radio("projection", **radio_kwargs)
with c4:
if st.button("ℹ info"):
goto("info")
# -----------------------------
# Two-column layout (left = textarea, right = plot)
# -----------------------------
left, right = st.columns([1, 2], gap="large")
# Keep textarea synced with dataset selection
if "dataset_text" not in st.session_state:
st.session_state.dataset_text = "\n".join(DATASETS[st.session_state.dataset_name])
if "prev_dataset_name" not in st.session_state:
st.session_state.prev_dataset_name = st.session_state.dataset_name
if st.session_state.dataset_name != st.session_state.prev_dataset_name:
st.session_state.dataset_text = "\n".join(DATASETS[st.session_state.dataset_name])
st.session_state.prev_dataset_name = st.session_state.dataset_name
with left:
st.text_area(
label="",
key="dataset_text",
height=420,
help="edit words (one per line). changing dataset above refreshes this box."
)
words = [w.strip() for w in st.session_state.dataset_text.split("\n") if w.strip()]
with right:
if len(words) < 3:
st.info("enter at least three lines to project.")
st.stop()
X = embed_texts(st.session_state.model_name, tuple(words))
# Capitalized dataset name for the chart title (dataset keys remain lowercase in the UI)
chart_title = st.session_state.dataset_name.title()
if st.session_state.proj_mode == "2D":
coords = PCA(n_components=2).fit_transform(X)
fig = go.Figure(
data=[go.Scatter(
x=coords[:, 0], y=coords[:, 1],
mode="markers+text",
text=words, textposition="top center",
marker=dict(size=9),
)],
layout=go.Layout(
xaxis=dict(title="PC1"),
yaxis=dict(title="PC2", scaleanchor="x", scaleratio=1),
margin=dict(l=0, r=0, b=0, t=40),
),
)
fig.update_layout(
title=dict(
text=chart_title,
x=0.5, xanchor='center', yanchor='top',
font=dict(size=20)
)
)
else:
coords = PCA(n_components=3).fit_transform(X)
fig = go.Figure(
data=[go.Scatter3d(
x=coords[:, 0], y=coords[:, 1], z=coords[:, 2],
mode="markers+text",
text=words, textposition="top center",
marker=dict(size=6),
)],
layout=go.Layout(
scene=dict(
xaxis=dict(showbackground=True, backgroundcolor="rgba(255, 230, 230, 1)"),
yaxis=dict(showbackground=True, backgroundcolor="rgba(230, 255, 230, 1)"),
zaxis=dict(showbackground=True, backgroundcolor="rgba(230, 230, 255, 1)"),
),
margin=dict(l=0, r=0, b=0, t=40),
),
)
fig.update_layout(
title=dict(
text=chart_title,
x=0.5, xanchor='center', yanchor='top',
font=dict(size=20)
)
)
# Simple Plotly rotation: frames + Rotate/Stop buttons
frames = []
radius = 1.7
z_eye = 1.0
for ang in range(0, 360, 4):
rad = np.deg2rad(ang)
frames.append(go.Frame(layout=dict(
scene_camera=dict(eye=dict(x=radius*np.cos(rad), y=radius*np.sin(rad), z=z_eye),
projection=dict(type="perspective"))
)))
fig.frames = frames
fig.update_layout(
updatemenus=[dict(
type="buttons", showactive=False, x=0.02, y=0.98,
buttons=[
dict(
label="▶ Rotate",
method="animate",
args=[None, dict(frame=dict(duration=40, redraw=True),
transition=dict(duration=0),
fromcurrent=True, mode="immediate")]
),
dict(
label="⏹ Stop",
method="animate",
args=[[None], dict(frame=dict(duration=0, redraw=False),
transition=dict(duration=0),
mode="immediate")]
)
]
)]
)
st.plotly_chart(fig, use_container_width=True)