berndf's picture
fix
b27e307 verified
raw
history blame
10.8 kB
# app.py
import random
import numpy as np
import streamlit as st
import plotly.graph_objects as go
from sklearn.decomposition import PCA
import torch
from transformers import AutoTokenizer, AutoModel
st.set_page_config(page_title="Embedding Visualizer", layout="wide")
# -----------------------------
# Base datasets (dataset names stay lowercase)
# -----------------------------
BASE_SETS = {
"countries": [
"Germany","France","Italy","Spain","Portugal","Poland","Netherlands","Belgium","Austria","Switzerland",
"Greece","Norway","Sweden","Finland","Denmark","Ireland","Hungary","Czechia","Slovakia","Slovenia",
"Romania","Bulgaria","Croatia","Estonia","Latvia"
],
"animals": [
"cat","dog","lion","tiger","bear","wolf","fox","eagle","shark","whale",
"zebra","giraffe","elephant","hippopotamus","rhinoceros","kangaroo","panda","otter","seal","dolphin",
"chimpanzee","gorilla","leopard","cheetah","lynx"
],
"furniture": [
"armchair","sofa","dining table","coffee table","bookshelf","bed","wardrobe","desk","office chair","dresser",
"nightstand","side table","tv stand","loveseat","chaise lounge","bench","hutch","kitchen island","futon","recliner",
"ottoman","console table","vanity","buffet","sectional sofa"
],
"actors": [
"Brad Pitt","Angelina Jolie","Meryl Streep","Leonardo DiCaprio","Tom Hanks","Scarlett Johansson","Robert De Niro",
"Natalie Portman","Matt Damon","Cate Blanchett","Johnny Depp","Keanu Reeves","Hugh Jackman","Emma Stone","Ryan Gosling",
"Jennifer Lawrence","Christian Bale","Charlize Theron","Will Smith","Anne Hathaway","Denzel Washington","Morgan Freeman",
"Julia Roberts","George Clooney","Kate Winslet"
],
"rock groups": [
"The Beatles","Rolling Stones","Pink Floyd","Queen","Led Zeppelin","U2","AC/DC","Nirvana","Radiohead","Metallica",
"Guns N' Roses","Red Hot Chili Peppers","Coldplay","Pearl Jam","The Police","Aerosmith","Green Day","Foo Fighters",
"The Doors","Bon Jovi","Deep Purple","The Who","The Kinks","Fleetwood Mac","The Beach Boys"
],
"sports": [
"soccer","basketball","tennis","baseball","golf","swimming","cycling","running","volleyball","rugby",
"boxing","skiing","snowboarding","surfing","skateboarding","karate","judo","fencing","rowing","badminton",
"cricket","table tennis","gymnastics","hockey","climbing"
],
}
# -----------------------------
# Build datasets once per session (base + 3 random mixed)
# -----------------------------
def make_random_mixed_sets(base: dict, n: int = 3) -> dict:
keys = list(base.keys())
out = {}
for _ in range(n):
src = random.sample(keys, 3)
items = []
for s in src:
take = min(7, len(base[s]))
items.extend(random.sample(base[s], take))
out["/".join(src)] = items[:21]
return out
if "datasets" not in st.session_state:
mixed = make_random_mixed_sets(BASE_SETS, 3)
st.session_state.datasets = {**BASE_SETS, **mixed}
DATASETS = st.session_state.datasets # shorthand
# -----------------------------
# Models (transformers)
# -----------------------------
MODELS = {
"all-MiniLM-L6-v2 (384d)": "sentence-transformers/all-MiniLM-L6-v2",
"all-mpnet-base-v2 (768d)": "sentence-transformers/all-mpnet-base-v2",
"all-roberta-large-v1 (1024d)": "sentence-transformers/all-roberta-large-v1",
}
@st.cache_resource(show_spinner=False)
def load_model(model_name: str):
tok = AutoTokenizer.from_pretrained(model_name)
mdl = AutoModel.from_pretrained(model_name)
mdl.eval()
return tok, mdl
@st.cache_data(show_spinner=False)
def embed_texts(model_name: str, texts_tuple: tuple):
tokenizer, model = load_model(model_name)
texts = list(texts_tuple)
with torch.no_grad():
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
outputs = model(**inputs)
token_embeddings = outputs.last_hidden_state
mask = inputs["attention_mask"].unsqueeze(-1).type_as(token_embeddings)
summed = (token_embeddings * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1e-9)
embeddings = summed / counts # mean pooling
return embeddings.cpu().numpy()
# -----------------------------
# Info page (local) via st.query_params
# -----------------------------
def goto(page: str):
st.query_params["page"] = page
st.rerun()
page = st.query_params.get("page", "demo")
if page == "info":
st.title("ℹ about this demo")
st.write("""
**embeddings** turn words (or longer text) into numerical vectors.
in this vector space, **semantically related** items end up **near** each other.
use cases:
- semantic search & retrieval
- clustering & topic discovery
- recommendations & deduplication
- measuring similarity and analogies
this demo embeds single words with a selectable model, reduces to 2d/3d with pca,
and shows how related words appear near each other in the projected space.
""".strip())
if st.button("⬅ back to demo"):
goto("demo")
st.stop()
# -----------------------------
# Top compact bar
# -----------------------------
c1, c2, c3, c4 = st.columns([2, 2, 1, 1])
with c1:
if "dataset_name" not in st.session_state:
st.session_state.dataset_name = "furniture" if "furniture" in DATASETS else list(DATASETS.keys())[0]
dataset_name = st.selectbox("dataset", list(DATASETS.keys()),
index=list(DATASETS.keys()).index(st.session_state.dataset_name),
key="dataset_name")
with c2:
if "model_name" not in st.session_state:
st.session_state.model_name = list(MODELS.values())[0]
labels = list(MODELS.keys())
rev = {v: k for k, v in MODELS.items()}
current_label = rev.get(st.session_state.model_name, labels[0])
chosen_label = st.selectbox("embedding model", labels, index=labels.index(current_label))
st.session_state.model_name = MODELS[chosen_label]
with c3:
# Default to 3D on first render; single-click thereafter
radio_kwargs = dict(options=["2D", "3D"], horizontal=True, key="proj_mode")
if "proj_mode" not in st.session_state:
radio_kwargs["index"] = 1 # 3D default
st.radio("projection", **radio_kwargs)
with c4:
if st.button("ℹ info"):
goto("info")
# -----------------------------
# Two-column layout (left = textarea, right = plot)
# -----------------------------
left, right = st.columns([1, 2], gap="large")
# Keep textarea synced with dataset selection
if "dataset_text" not in st.session_state:
st.session_state.dataset_text = "\n".join(DATASETS[st.session_state.dataset_name])
if "prev_dataset_name" not in st.session_state:
st.session_state.prev_dataset_name = st.session_state.dataset_name
if st.session_state.dataset_name != st.session_state.prev_dataset_name:
st.session_state.dataset_text = "\n".join(DATASETS[st.session_state.dataset_name])
st.session_state.prev_dataset_name = st.session_state.dataset_name
with left:
st.text_area(
label="",
key="dataset_text",
height=420,
help="edit words (one per line). changing dataset above refreshes this box."
)
words = [w.strip() for w in st.session_state.dataset_text.split("\n") if w.strip()]
with right:
if len(words) < 3:
st.info("enter at least three lines to project.")
st.stop()
X = embed_texts(st.session_state.model_name, tuple(words))
# Capitalized dataset name for the chart title (dataset keys remain lowercase in the UI)
chart_title = st.session_state.dataset_name.title()
if st.session_state.proj_mode == "2D":
coords = PCA(n_components=2).fit_transform(X)
fig = go.Figure(
data=[go.Scatter(
x=coords[:, 0], y=coords[:, 1],
mode="markers+text",
text=words, textposition="top center",
marker=dict(size=9),
)],
layout=go.Layout(
xaxis=dict(title="PC1"),
yaxis=dict(title="PC2", scaleanchor="x", scaleratio=1),
margin=dict(l=0, r=0, b=0, t=40),
),
)
fig.update_layout(
title=dict(
text=chart_title,
x=0.5, xanchor='center', yanchor='top',
font=dict(size=20)
)
)
else:
coords = PCA(n_components=3).fit_transform(X)
fig = go.Figure(
data=[go.Scatter3d(
x=coords[:, 0], y=coords[:, 1], z=coords[:, 2],
mode="markers+text",
text=words, textposition="top center",
marker=dict(size=6),
)],
layout=go.Layout(
scene=dict(
xaxis=dict(showbackground=True, backgroundcolor="rgba(255, 230, 230, 1)"),
yaxis=dict(showbackground=True, backgroundcolor="rgba(230, 255, 230, 1)"),
zaxis=dict(showbackground=True, backgroundcolor="rgba(230, 230, 255, 1)"),
),
margin=dict(l=0, r=0, b=0, t=40),
),
)
fig.update_layout(
title=dict(
text=chart_title,
x=0.5, xanchor='center', yanchor='top',
font=dict(size=20)
)
)
# Simple Plotly rotation: frames + Rotate/Stop buttons
frames = []
radius = 1.7
z_eye = 1.0
for ang in range(0, 360, 4):
rad = np.deg2rad(ang)
frames.append(go.Frame(layout=dict(
scene_camera=dict(eye=dict(x=radius*np.cos(rad), y=radius*np.sin(rad), z=z_eye),
projection=dict(type="perspective"))
)))
fig.frames = frames
fig.update_layout(
updatemenus=[dict(
type="buttons", showactive=False, x=0.02, y=0.98,
buttons=[
dict(
label="▶ Rotate",
method="animate",
args=[None, dict(frame=dict(duration=40, redraw=True),
transition=dict(duration=0),
fromcurrent=True, mode="immediate")]
),
dict(
label="⏹ Stop",
method="animate",
args=[[None], dict(frame=dict(duration=0, redraw=False),
transition=dict(duration=0),
mode="immediate")]
)
]
)]
)
st.plotly_chart(fig, use_container_width=True)