Spaces:
Running
Running
# app.py | |
import random | |
import numpy as np | |
import streamlit as st | |
import plotly.graph_objects as go | |
from sklearn.decomposition import PCA | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
st.set_page_config(page_title="Embedding Visualizer", layout="wide") | |
# ----------------------------- | |
# Base datasets (dataset names stay lowercase) | |
# ----------------------------- | |
BASE_SETS = { | |
"countries": [ | |
"Germany","France","Italy","Spain","Portugal","Poland","Netherlands","Belgium","Austria","Switzerland", | |
"Greece","Norway","Sweden","Finland","Denmark","Ireland","Hungary","Czechia","Slovakia","Slovenia", | |
"Romania","Bulgaria","Croatia","Estonia","Latvia" | |
], | |
"animals": [ | |
"cat","dog","lion","tiger","bear","wolf","fox","eagle","shark","whale", | |
"zebra","giraffe","elephant","hippopotamus","rhinoceros","kangaroo","panda","otter","seal","dolphin", | |
"chimpanzee","gorilla","leopard","cheetah","lynx" | |
], | |
"furniture": [ | |
"armchair","sofa","dining table","coffee table","bookshelf","bed","wardrobe","desk","office chair","dresser", | |
"nightstand","side table","tv stand","loveseat","chaise lounge","bench","hutch","kitchen island","futon","recliner", | |
"ottoman","console table","vanity","buffet","sectional sofa" | |
], | |
"actors": [ | |
"Brad Pitt","Angelina Jolie","Meryl Streep","Leonardo DiCaprio","Tom Hanks","Scarlett Johansson","Robert De Niro", | |
"Natalie Portman","Matt Damon","Cate Blanchett","Johnny Depp","Keanu Reeves","Hugh Jackman","Emma Stone","Ryan Gosling", | |
"Jennifer Lawrence","Christian Bale","Charlize Theron","Will Smith","Anne Hathaway","Denzel Washington","Morgan Freeman", | |
"Julia Roberts","George Clooney","Kate Winslet" | |
], | |
"rock groups": [ | |
"The Beatles","Rolling Stones","Pink Floyd","Queen","Led Zeppelin","U2","AC/DC","Nirvana","Radiohead","Metallica", | |
"Guns N' Roses","Red Hot Chili Peppers","Coldplay","Pearl Jam","The Police","Aerosmith","Green Day","Foo Fighters", | |
"The Doors","Bon Jovi","Deep Purple","The Who","The Kinks","Fleetwood Mac","The Beach Boys" | |
], | |
"sports": [ | |
"soccer","basketball","tennis","baseball","golf","swimming","cycling","running","volleyball","rugby", | |
"boxing","skiing","snowboarding","surfing","skateboarding","karate","judo","fencing","rowing","badminton", | |
"cricket","table tennis","gymnastics","hockey","climbing" | |
], | |
} | |
# ----------------------------- | |
# Build datasets once per session (base + 3 random mixed) | |
# ----------------------------- | |
def make_random_mixed_sets(base: dict, n: int = 3) -> dict: | |
keys = list(base.keys()) | |
out = {} | |
for _ in range(n): | |
src = random.sample(keys, 3) | |
items = [] | |
for s in src: | |
take = min(7, len(base[s])) | |
items.extend(random.sample(base[s], take)) | |
out["/".join(src)] = items[:21] | |
return out | |
if "datasets" not in st.session_state: | |
mixed = make_random_mixed_sets(BASE_SETS, 3) | |
st.session_state.datasets = {**BASE_SETS, **mixed} | |
DATASETS = st.session_state.datasets # shorthand | |
# ----------------------------- | |
# Models (transformers) | |
# ----------------------------- | |
MODELS = { | |
"all-MiniLM-L6-v2 (384d)": "sentence-transformers/all-MiniLM-L6-v2", | |
"all-mpnet-base-v2 (768d)": "sentence-transformers/all-mpnet-base-v2", | |
"all-roberta-large-v1 (1024d)": "sentence-transformers/all-roberta-large-v1", | |
} | |
def load_model(model_name: str): | |
tok = AutoTokenizer.from_pretrained(model_name) | |
mdl = AutoModel.from_pretrained(model_name) | |
mdl.eval() | |
return tok, mdl | |
def embed_texts(model_name: str, texts_tuple: tuple): | |
tokenizer, model = load_model(model_name) | |
texts = list(texts_tuple) | |
with torch.no_grad(): | |
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") | |
outputs = model(**inputs) | |
token_embeddings = outputs.last_hidden_state | |
mask = inputs["attention_mask"].unsqueeze(-1).type_as(token_embeddings) | |
summed = (token_embeddings * mask).sum(dim=1) | |
counts = mask.sum(dim=1).clamp(min=1e-9) | |
embeddings = summed / counts # mean pooling | |
return embeddings.cpu().numpy() | |
# ----------------------------- | |
# Info page (local) via st.query_params | |
# ----------------------------- | |
def goto(page: str): | |
st.query_params["page"] = page | |
st.rerun() | |
page = st.query_params.get("page", "demo") | |
if page == "info": | |
st.write(""" | |
# 🧠 Embedding Visualizer – About | |
This demo shows how **vector embeddings** can capture the meaning of words and place them in a **numerical space** where related items appear close together. | |
You can: | |
- Choose from predefined or mixed datasets (e.g., countries, animals, actors, sports) | |
- Select different embedding models to compare results | |
- Switch between 2D and 3D visualizations | |
- Edit the list of words directly and see the updated projection instantly | |
--- | |
## 📌 What are Vector Embeddings? | |
A **vector embedding** is a way of representing text (words, sentences, or documents) as a list of numbers — a point in a high-dimensional space. | |
These numbers are produced by a trained **language model** that captures semantic meaning. | |
In this space: | |
- Words with **similar meanings** end up **near each other** | |
- Dissimilar words are placed **far apart** | |
- The model can detect relationships and groupings that aren’t obvious from spelling or grammar alone | |
Example: | |
`"cat"` and `"dog"` will likely be closer to each other than to `"table"`, because the model “knows” they are both animals. | |
--- | |
## 🔍 How the Demo Works | |
1. **Embedding step** – Each word is converted into a high-dimensional vector (e.g., 384, 768, or 1024 dimensions depending on the model). | |
2. **Dimensionality reduction** – Since humans can’t visualize hundreds of dimensions, the vectors are projected to 2D or 3D using **PCA** (Principal Component Analysis). | |
3. **Visualization** – The projected points are plotted, with labels showing the original words. | |
You can rotate the 3D view to explore groupings. | |
--- | |
## 💡 Typical Applications of Embeddings | |
- **Semantic search** – Find relevant results even if exact keywords don’t match | |
- **Clustering & topic discovery** – Group related items automatically | |
- **Recommendations** – Suggest similar products, movies, or articles | |
- **Deduplication** – Detect near-duplicate content | |
- **Analogies** – Explore relationships like *"king" – "man" + "woman" ≈ "queen"* | |
--- | |
## 🚀 Try it Yourself | |
- Pick a dataset or create your own by editing the list | |
- Switch models to compare how the embedding space changes | |
- Toggle between 2D and 3D to explore patterns | |
""".strip()) | |
if st.button("⬅ back to demo"): | |
goto("demo") | |
st.stop() | |
# ----------------------------- | |
# Top compact bar | |
# ----------------------------- | |
c1, c2, c3, c4 = st.columns([2, 2, 1, 1]) | |
with c1: | |
if "dataset_name" not in st.session_state: | |
st.session_state.dataset_name = "actors" if "actors" in DATASETS else list(DATASETS.keys())[0] | |
dataset_name = st.selectbox("dataset", list(DATASETS.keys()), | |
index=list(DATASETS.keys()).index(st.session_state.dataset_name), | |
key="dataset_name") | |
with c2: | |
if "model_name" not in st.session_state: | |
st.session_state.model_name = list(MODELS.values())[1] | |
labels = list(MODELS.keys()) | |
rev = {v: k for k, v in MODELS.items()} | |
current_label = rev.get(st.session_state.model_name, labels[0]) | |
chosen_label = st.selectbox("embedding model", labels, index=labels.index(current_label)) | |
st.session_state.model_name = MODELS[chosen_label] | |
with c3: | |
# Default to 3D on first render; single-click thereafter | |
radio_kwargs = dict(options=["2D", "3D"], horizontal=True, key="proj_mode") | |
if "proj_mode" not in st.session_state: | |
radio_kwargs["index"] = 1 # 3D default | |
st.radio("projection", **radio_kwargs) | |
with c4: | |
if st.button("ℹ info"): | |
goto("info") | |
# ----------------------------- | |
# Two-column layout (left = textarea, right = plot) | |
# ----------------------------- | |
left, right = st.columns([1, 2], gap="large") | |
# Keep textarea synced with dataset selection | |
if "dataset_text" not in st.session_state: | |
st.session_state.dataset_text = "\n".join(DATASETS[st.session_state.dataset_name]) | |
if "prev_dataset_name" not in st.session_state: | |
st.session_state.prev_dataset_name = st.session_state.dataset_name | |
if st.session_state.dataset_name != st.session_state.prev_dataset_name: | |
st.session_state.dataset_text = "\n".join(DATASETS[st.session_state.dataset_name]) | |
st.session_state.prev_dataset_name = st.session_state.dataset_name | |
with left: | |
st.text_area( | |
label="", | |
key="dataset_text", | |
height=420, | |
help="edit words (one per line). changing dataset above refreshes this box." | |
) | |
words = [w.strip() for w in st.session_state.dataset_text.split("\n") if w.strip()] | |
with right: | |
if len(words) < 3: | |
st.info("enter at least three lines to project.") | |
st.stop() | |
X = embed_texts(st.session_state.model_name, tuple(words)) | |
# Capitalized dataset name for the chart title (dataset keys remain lowercase in the UI) | |
chart_title = st.session_state.dataset_name.title() | |
if st.session_state.proj_mode == "2D": | |
coords = PCA(n_components=2).fit_transform(X) | |
fig = go.Figure( | |
data=[go.Scatter( | |
x=coords[:, 0], y=coords[:, 1], | |
mode="markers+text", | |
text=words, textposition="top center", | |
marker=dict(size=9), | |
)], | |
layout=go.Layout( | |
xaxis=dict(title="PC1"), | |
yaxis=dict(title="PC2", scaleanchor="x", scaleratio=1), | |
margin=dict(l=0, r=0, b=0, t=40), | |
), | |
) | |
fig.update_layout( | |
title=dict( | |
text=chart_title, | |
x=0.5, xanchor='center', yanchor='top', | |
font=dict(size=20) | |
) | |
) | |
else: | |
coords = PCA(n_components=3).fit_transform(X) | |
fig = go.Figure( | |
data=[go.Scatter3d( | |
x=coords[:, 0], y=coords[:, 1], z=coords[:, 2], | |
mode="markers+text", | |
text=words, textposition="top center", | |
marker=dict(size=6), | |
)], | |
layout=go.Layout( | |
scene=dict( | |
xaxis=dict(showbackground=True, backgroundcolor="rgba(255, 230, 230, 1)"), | |
yaxis=dict(showbackground=True, backgroundcolor="rgba(230, 255, 230, 1)"), | |
zaxis=dict(showbackground=True, backgroundcolor="rgba(230, 230, 255, 1)"), | |
), | |
margin=dict(l=0, r=0, b=0, t=40), | |
), | |
) | |
fig.update_layout( | |
title=dict( | |
text=chart_title, | |
x=0.5, xanchor='center', yanchor='top', | |
font=dict(size=20) | |
) | |
) | |
# Simple Plotly rotation: frames + Rotate/Stop buttons | |
frames = [] | |
radius = 1.7 | |
z_eye = 1.0 | |
for ang in range(0, 360, 4): | |
rad = np.deg2rad(ang) | |
frames.append(go.Frame(layout=dict( | |
scene_camera=dict(eye=dict(x=radius*np.cos(rad), y=radius*np.sin(rad), z=z_eye), | |
projection=dict(type="perspective")) | |
))) | |
fig.frames = frames | |
fig.update_layout( | |
updatemenus=[dict( | |
type="buttons", showactive=False, x=0.02, y=0.98, | |
buttons=[ | |
dict( | |
label="▶ Rotate", | |
method="animate", | |
args=[None, dict(frame=dict(duration=40, redraw=True), | |
transition=dict(duration=0), | |
fromcurrent=True, mode="immediate")] | |
), | |
dict( | |
label="⏹ Stop", | |
method="animate", | |
args=[[None], dict(frame=dict(duration=0, redraw=False), | |
transition=dict(duration=0), | |
mode="immediate")] | |
) | |
] | |
)] | |
) | |
st.plotly_chart(fig, use_container_width=True) | |