Spaces:
Sleeping
Sleeping
# app.py | |
import random | |
import numpy as np | |
import streamlit as st | |
import plotly.graph_objects as go | |
from sklearn.decomposition import PCA | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
st.set_page_config(page_title="Embedding Visualizer", layout="wide") | |
# ----------------------------- | |
# Base datasets (dataset names stay lowercase) | |
# ----------------------------- | |
BASE_SETS = { | |
"countries": [ | |
"Germany","France","Italy","Spain","Portugal","Poland","Netherlands","Belgium","Austria","Switzerland", | |
"Greece","Norway","Sweden","Finland","Denmark","Ireland","Hungary","Czechia","Slovakia","Slovenia", | |
"Romania","Bulgaria","Croatia","Estonia","Latvia" | |
], | |
"animals": [ | |
"cat","dog","lion","tiger","bear","wolf","fox","eagle","shark","whale", | |
"zebra","giraffe","elephant","hippopotamus","rhinoceros","kangaroo","panda","otter","seal","dolphin", | |
"chimpanzee","gorilla","leopard","cheetah","lynx" | |
], | |
"furniture": [ | |
"armchair","sofa","dining table","coffee table","bookshelf","bed","wardrobe","desk","office chair","dresser", | |
"nightstand","side table","tv stand","loveseat","chaise lounge","bench","hutch","kitchen island","futon","recliner", | |
"ottoman","console table","vanity","buffet","sectional sofa" | |
], | |
"actors": [ | |
"Brad Pitt","Angelina Jolie","Meryl Streep","Leonardo DiCaprio","Tom Hanks","Scarlett Johansson","Robert De Niro", | |
"Natalie Portman","Matt Damon","Cate Blanchett","Johnny Depp","Keanu Reeves","Hugh Jackman","Emma Stone","Ryan Gosling", | |
"Jennifer Lawrence","Christian Bale","Charlize Theron","Will Smith","Anne Hathaway","Denzel Washington","Morgan Freeman", | |
"Julia Roberts","George Clooney","Kate Winslet" | |
], | |
"rock groups": [ | |
"The Beatles","Rolling Stones","Pink Floyd","Queen","Led Zeppelin","U2","AC/DC","Nirvana","Radiohead","Metallica", | |
"Guns N' Roses","Red Hot Chili Peppers","Coldplay","Pearl Jam","The Police","Aerosmith","Green Day","Foo Fighters", | |
"The Doors","Bon Jovi","Deep Purple","The Who","The Kinks","Fleetwood Mac","The Beach Boys" | |
], | |
"sports": [ | |
"soccer","basketball","tennis","baseball","golf","swimming","cycling","running","volleyball","rugby", | |
"boxing","skiing","snowboarding","surfing","skateboarding","karate","judo","fencing","rowing","badminton", | |
"cricket","table tennis","gymnastics","hockey","climbing" | |
], | |
} | |
# ----------------------------- | |
# Build datasets once per session (base + 3 random mixed) | |
# ----------------------------- | |
def make_random_mixed_sets(base: dict, n: int = 3) -> dict: | |
keys = list(base.keys()) | |
out = {} | |
for _ in range(n): | |
src = random.sample(keys, 3) | |
items = [] | |
for s in src: | |
take = min(7, len(base[s])) | |
items.extend(random.sample(base[s], take)) | |
out["/".join(src)] = items[:21] | |
return out | |
if "datasets" not in st.session_state: | |
mixed = make_random_mixed_sets(BASE_SETS, 3) | |
st.session_state.datasets = {**BASE_SETS, **mixed} | |
DATASETS = st.session_state.datasets # shorthand | |
# ----------------------------- | |
# Models (transformers) | |
# ----------------------------- | |
MODELS = { | |
"all-MiniLM-L6-v2 (384d)": "sentence-transformers/all-MiniLM-L6-v2", | |
"all-mpnet-base-v2 (768d)": "sentence-transformers/all-mpnet-base-v2", | |
"all-roberta-large-v1 (1024d)": "sentence-transformers/all-roberta-large-v1", | |
} | |
def load_model(model_name: str): | |
tok = AutoTokenizer.from_pretrained(model_name) | |
mdl = AutoModel.from_pretrained(model_name) | |
mdl.eval() | |
return tok, mdl | |
def embed_texts(model_name: str, texts_tuple: tuple): | |
tokenizer, model = load_model(model_name) | |
texts = list(texts_tuple) | |
with torch.no_grad(): | |
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") | |
outputs = model(**inputs) | |
token_embeddings = outputs.last_hidden_state | |
mask = inputs["attention_mask"].unsqueeze(-1).type_as(token_embeddings) | |
summed = (token_embeddings * mask).sum(dim=1) | |
counts = mask.sum(dim=1).clamp(min=1e-9) | |
embeddings = summed / counts # mean pooling | |
return embeddings.cpu().numpy() | |
# ----------------------------- | |
# Info page (local) via st.query_params | |
# ----------------------------- | |
def goto(page: str): | |
st.query_params["page"] = page | |
st.rerun() | |
page = st.query_params.get("page", "demo") | |
if page == "info": | |
st.title("ℹ about this demo") | |
st.write(""" | |
**embeddings** turn words (or longer text) into numerical vectors. | |
in this vector space, **semantically related** items end up **near** each other. | |
use cases: | |
- semantic search & retrieval | |
- clustering & topic discovery | |
- recommendations & deduplication | |
- measuring similarity and analogies | |
this demo embeds single words with a selectable model, reduces to 2d/3d with pca, | |
and shows how related words appear near each other in the projected space. | |
""".strip()) | |
if st.button("⬅ back to demo"): | |
goto("demo") | |
st.stop() | |
# ----------------------------- | |
# Top compact bar | |
# ----------------------------- | |
c1, c2, c3, c4 = st.columns([2, 2, 1, 1]) | |
with c1: | |
if "dataset_name" not in st.session_state: | |
st.session_state.dataset_name = "furniture" if "furniture" in DATASETS else list(DATASETS.keys())[0] | |
dataset_name = st.selectbox("dataset", list(DATASETS.keys()), | |
index=list(DATASETS.keys()).index(st.session_state.dataset_name), | |
key="dataset_name") | |
with c2: | |
if "model_name" not in st.session_state: | |
st.session_state.model_name = list(MODELS.values())[0] | |
labels = list(MODELS.keys()) | |
rev = {v: k for k, v in MODELS.items()} | |
current_label = rev.get(st.session_state.model_name, labels[0]) | |
chosen_label = st.selectbox("embedding model", labels, index=labels.index(current_label)) | |
st.session_state.model_name = MODELS[chosen_label] | |
with c3: | |
# Default to 3D on first render; single-click thereafter | |
radio_kwargs = dict(options=["2D", "3D"], horizontal=True, key="proj_mode") | |
if "proj_mode" not in st.session_state: | |
radio_kwargs["index"] = 1 # 3D default | |
st.radio("projection", **radio_kwargs) | |
with c4: | |
if st.button("ℹ info"): | |
goto("info") | |
# ----------------------------- | |
# Two-column layout (left = textarea, right = plot) | |
# ----------------------------- | |
left, right = st.columns([1, 2], gap="large") | |
# Keep textarea synced with dataset selection | |
if "dataset_text" not in st.session_state: | |
st.session_state.dataset_text = "\n".join(DATASETS[st.session_state.dataset_name]) | |
if "prev_dataset_name" not in st.session_state: | |
st.session_state.prev_dataset_name = st.session_state.dataset_name | |
if st.session_state.dataset_name != st.session_state.prev_dataset_name: | |
st.session_state.dataset_text = "\n".join(DATASETS[st.session_state.dataset_name]) | |
st.session_state.prev_dataset_name = st.session_state.dataset_name | |
with left: | |
st.text_area( | |
label="", | |
key="dataset_text", | |
height=420, | |
help="edit words (one per line). changing dataset above refreshes this box." | |
) | |
words = [w.strip() for w in st.session_state.dataset_text.split("\n") if w.strip()] | |
with right: | |
if len(words) < 3: | |
st.info("enter at least three lines to project.") | |
st.stop() | |
X = embed_texts(st.session_state.model_name, tuple(words)) | |
# Capitalized dataset name for the chart title (dataset keys remain lowercase in the UI) | |
chart_title = st.session_state.dataset_name.title() | |
if st.session_state.proj_mode == "2D": | |
coords = PCA(n_components=2).fit_transform(X) | |
fig = go.Figure( | |
data=[go.Scatter( | |
x=coords[:, 0], y=coords[:, 1], | |
mode="markers+text", | |
text=words, textposition="top center", | |
marker=dict(size=9), | |
)], | |
layout=go.Layout( | |
xaxis=dict(title="PC1"), | |
yaxis=dict(title="PC2", scaleanchor="x", scaleratio=1), | |
margin=dict(l=0, r=0, b=0, t=40), | |
), | |
) | |
fig.update_layout( | |
title=dict( | |
text=chart_title, | |
x=0.5, xanchor='center', yanchor='top', | |
font=dict(size=20) | |
) | |
) | |
else: | |
coords = PCA(n_components=3).fit_transform(X) | |
fig = go.Figure( | |
data=[go.Scatter3d( | |
x=coords[:, 0], y=coords[:, 1], z=coords[:, 2], | |
mode="markers+text", | |
text=words, textposition="top center", | |
marker=dict(size=6), | |
)], | |
layout=go.Layout( | |
scene=dict( | |
xaxis=dict(showbackground=True, backgroundcolor="rgba(255, 230, 230, 1)"), | |
yaxis=dict(showbackground=True, backgroundcolor="rgba(230, 255, 230, 1)"), | |
zaxis=dict(showbackground=True, backgroundcolor="rgba(230, 230, 255, 1)"), | |
), | |
margin=dict(l=0, r=0, b=0, t=40), | |
), | |
) | |
fig.update_layout( | |
title=dict( | |
text=chart_title, | |
x=0.5, xanchor='center', yanchor='top', | |
font=dict(size=20) | |
) | |
) | |
# Simple Plotly rotation: frames + Rotate/Stop buttons | |
frames = [] | |
radius = 1.7 | |
z_eye = 1.0 | |
for ang in range(0, 360, 4): | |
rad = np.deg2rad(ang) | |
frames.append(go.Frame(layout=dict( | |
scene_camera=dict(eye=dict(x=radius*np.cos(rad), y=radius*np.sin(rad), z=z_eye), | |
projection=dict(type="perspective")) | |
))) | |
fig.frames = frames | |
fig.update_layout( | |
updatemenus=[dict( | |
type="buttons", showactive=False, x=0.02, y=0.98, | |
buttons=[ | |
dict( | |
label="▶ Rotate", | |
method="animate", | |
args=[None, dict(frame=dict(duration=40, redraw=True), | |
transition=dict(duration=0), | |
fromcurrent=True, mode="immediate")] | |
), | |
dict( | |
label="⏹ Stop", | |
method="animate", | |
args=[[None], dict(frame=dict(duration=0, redraw=False), | |
transition=dict(duration=0), | |
mode="immediate")] | |
) | |
] | |
)] | |
) | |
st.plotly_chart(fig, use_container_width=True) | |