# app.py import random import numpy as np import streamlit as st import plotly.graph_objects as go from sklearn.decomposition import PCA import torch from transformers import AutoTokenizer, AutoModel st.set_page_config(page_title="Embedding Visualizer", layout="wide") # ----------------------------- # Base datasets (dataset names stay lowercase) # ----------------------------- BASE_SETS = { "countries": [ "Germany","France","Italy","Spain","Portugal","Poland","Netherlands","Belgium","Austria","Switzerland", "Greece","Norway","Sweden","Finland","Denmark","Ireland","Hungary","Czechia","Slovakia","Slovenia", "Romania","Bulgaria","Croatia","Estonia","Latvia" ], "animals": [ "cat","dog","lion","tiger","bear","wolf","fox","eagle","shark","whale", "zebra","giraffe","elephant","hippopotamus","rhinoceros","kangaroo","panda","otter","seal","dolphin", "chimpanzee","gorilla","leopard","cheetah","lynx" ], "furniture": [ "armchair","sofa","dining table","coffee table","bookshelf","bed","wardrobe","desk","office chair","dresser", "nightstand","side table","tv stand","loveseat","chaise lounge","bench","hutch","kitchen island","futon","recliner", "ottoman","console table","vanity","buffet","sectional sofa" ], "actors": [ "Brad Pitt","Angelina Jolie","Meryl Streep","Leonardo DiCaprio","Tom Hanks","Scarlett Johansson","Robert De Niro", "Natalie Portman","Matt Damon","Cate Blanchett","Johnny Depp","Keanu Reeves","Hugh Jackman","Emma Stone","Ryan Gosling", "Jennifer Lawrence","Christian Bale","Charlize Theron","Will Smith","Anne Hathaway","Denzel Washington","Morgan Freeman", "Julia Roberts","George Clooney","Kate Winslet" ], "rock groups": [ "The Beatles","Rolling Stones","Pink Floyd","Queen","Led Zeppelin","U2","AC/DC","Nirvana","Radiohead","Metallica", "Guns N' Roses","Red Hot Chili Peppers","Coldplay","Pearl Jam","The Police","Aerosmith","Green Day","Foo Fighters", "The Doors","Bon Jovi","Deep Purple","The Who","The Kinks","Fleetwood Mac","The Beach Boys" ], "sports": [ "soccer","basketball","tennis","baseball","golf","swimming","cycling","running","volleyball","rugby", "boxing","skiing","snowboarding","surfing","skateboarding","karate","judo","fencing","rowing","badminton", "cricket","table tennis","gymnastics","hockey","climbing" ], } # ----------------------------- # Build datasets once per session (base + 3 random mixed) # ----------------------------- def make_random_mixed_sets(base: dict, n: int = 3) -> dict: keys = list(base.keys()) out = {} for _ in range(n): src = random.sample(keys, 3) items = [] for s in src: take = min(7, len(base[s])) items.extend(random.sample(base[s], take)) out["/".join(src)] = items[:21] return out if "datasets" not in st.session_state: mixed = make_random_mixed_sets(BASE_SETS, 3) st.session_state.datasets = {**BASE_SETS, **mixed} DATASETS = st.session_state.datasets # shorthand # ----------------------------- # Models (transformers) # ----------------------------- MODELS = { "all-MiniLM-L6-v2 (384d)": "sentence-transformers/all-MiniLM-L6-v2", "all-mpnet-base-v2 (768d)": "sentence-transformers/all-mpnet-base-v2", "all-roberta-large-v1 (1024d)": "sentence-transformers/all-roberta-large-v1", } @st.cache_resource(show_spinner=False) def load_model(model_name: str): tok = AutoTokenizer.from_pretrained(model_name) mdl = AutoModel.from_pretrained(model_name) mdl.eval() return tok, mdl @st.cache_data(show_spinner=False) def embed_texts(model_name: str, texts_tuple: tuple): tokenizer, model = load_model(model_name) texts = list(texts_tuple) with torch.no_grad(): inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") outputs = model(**inputs) token_embeddings = outputs.last_hidden_state mask = inputs["attention_mask"].unsqueeze(-1).type_as(token_embeddings) summed = (token_embeddings * mask).sum(dim=1) counts = mask.sum(dim=1).clamp(min=1e-9) embeddings = summed / counts # mean pooling return embeddings.cpu().numpy() # ----------------------------- # Info page (local) via st.query_params # ----------------------------- def goto(page: str): st.query_params["page"] = page st.rerun() page = st.query_params.get("page", "demo") if page == "info": st.write(""" # 🧠 Embedding Visualizer – About This demo shows how **vector embeddings** can capture the meaning of words and place them in a **numerical space** where related items appear close together. You can: - Choose from predefined or mixed datasets (e.g., countries, animals, actors, sports) - Select different embedding models to compare results - Switch between 2D and 3D visualizations - Edit the list of words directly and see the updated projection instantly --- ## πŸ“Œ What are Vector Embeddings? A **vector embedding** is a way of representing text (words, sentences, or documents) as a list of numbers β€” a point in a high-dimensional space. These numbers are produced by a trained **language model** that captures semantic meaning. In this space: - Words with **similar meanings** end up **near each other** - Dissimilar words are placed **far apart** - The model can detect relationships and groupings that aren’t obvious from spelling or grammar alone Example: `"cat"` and `"dog"` will likely be closer to each other than to `"table"`, because the model β€œknows” they are both animals. --- ## πŸ” How the Demo Works 1. **Embedding step** – Each word is converted into a high-dimensional vector (e.g., 384, 768, or 1024 dimensions depending on the model). 2. **Dimensionality reduction** – Since humans can’t visualize hundreds of dimensions, the vectors are projected to 2D or 3D using **PCA** (Principal Component Analysis). 3. **Visualization** – The projected points are plotted, with labels showing the original words. You can rotate the 3D view to explore groupings. --- ## πŸ’‘ Typical Applications of Embeddings - **Semantic search** – Find relevant results even if exact keywords don’t match - **Clustering & topic discovery** – Group related items automatically - **Recommendations** – Suggest similar products, movies, or articles - **Deduplication** – Detect near-duplicate content - **Analogies** – Explore relationships like *"king" – "man" + "woman" β‰ˆ "queen"* --- ## πŸš€ Try it Yourself - Pick a dataset or create your own by editing the list - Switch models to compare how the embedding space changes - Toggle between 2D and 3D to explore patterns """.strip()) if st.button("β¬… back to demo"): goto("demo") st.stop() # ----------------------------- # Top compact bar # ----------------------------- c1, c2, c3, c4 = st.columns([2, 2, 1, 1]) with c1: if "dataset_name" not in st.session_state: st.session_state.dataset_name = "actors" if "actors" in DATASETS else list(DATASETS.keys())[0] dataset_name = st.selectbox("dataset", list(DATASETS.keys()), index=list(DATASETS.keys()).index(st.session_state.dataset_name), key="dataset_name") with c2: if "model_name" not in st.session_state: st.session_state.model_name = list(MODELS.values())[1] labels = list(MODELS.keys()) rev = {v: k for k, v in MODELS.items()} current_label = rev.get(st.session_state.model_name, labels[0]) chosen_label = st.selectbox("embedding model", labels, index=labels.index(current_label)) st.session_state.model_name = MODELS[chosen_label] with c3: # Default to 3D on first render; single-click thereafter radio_kwargs = dict(options=["2D", "3D"], horizontal=True, key="proj_mode") if "proj_mode" not in st.session_state: radio_kwargs["index"] = 1 # 3D default st.radio("projection", **radio_kwargs) with c4: if st.button("β„Ή info"): goto("info") # ----------------------------- # Two-column layout (left = textarea, right = plot) # ----------------------------- left, right = st.columns([1, 2], gap="large") # Keep textarea synced with dataset selection if "dataset_text" not in st.session_state: st.session_state.dataset_text = "\n".join(DATASETS[st.session_state.dataset_name]) if "prev_dataset_name" not in st.session_state: st.session_state.prev_dataset_name = st.session_state.dataset_name if st.session_state.dataset_name != st.session_state.prev_dataset_name: st.session_state.dataset_text = "\n".join(DATASETS[st.session_state.dataset_name]) st.session_state.prev_dataset_name = st.session_state.dataset_name with left: st.text_area( label="", key="dataset_text", height=420, help="edit words (one per line). changing dataset above refreshes this box." ) words = [w.strip() for w in st.session_state.dataset_text.split("\n") if w.strip()] with right: if len(words) < 3: st.info("enter at least three lines to project.") st.stop() X = embed_texts(st.session_state.model_name, tuple(words)) # Capitalized dataset name for the chart title (dataset keys remain lowercase in the UI) chart_title = st.session_state.dataset_name.title() if st.session_state.proj_mode == "2D": coords = PCA(n_components=2).fit_transform(X) fig = go.Figure( data=[go.Scatter( x=coords[:, 0], y=coords[:, 1], mode="markers+text", text=words, textposition="top center", marker=dict(size=9), )], layout=go.Layout( xaxis=dict(title="PC1"), yaxis=dict(title="PC2", scaleanchor="x", scaleratio=1), margin=dict(l=0, r=0, b=0, t=40), ), ) fig.update_layout( title=dict( text=chart_title, x=0.5, xanchor='center', yanchor='top', font=dict(size=20) ) ) else: coords = PCA(n_components=3).fit_transform(X) fig = go.Figure( data=[go.Scatter3d( x=coords[:, 0], y=coords[:, 1], z=coords[:, 2], mode="markers+text", text=words, textposition="top center", marker=dict(size=6), )], layout=go.Layout( scene=dict( xaxis=dict(showbackground=True, backgroundcolor="rgba(255, 230, 230, 1)"), yaxis=dict(showbackground=True, backgroundcolor="rgba(230, 255, 230, 1)"), zaxis=dict(showbackground=True, backgroundcolor="rgba(230, 230, 255, 1)"), ), margin=dict(l=0, r=0, b=0, t=40), ), ) fig.update_layout( title=dict( text=chart_title, x=0.5, xanchor='center', yanchor='top', font=dict(size=20) ) ) # Simple Plotly rotation: frames + Rotate/Stop buttons frames = [] radius = 1.7 z_eye = 1.0 for ang in range(0, 360, 4): rad = np.deg2rad(ang) frames.append(go.Frame(layout=dict( scene_camera=dict(eye=dict(x=radius*np.cos(rad), y=radius*np.sin(rad), z=z_eye), projection=dict(type="perspective")) ))) fig.frames = frames fig.update_layout( updatemenus=[dict( type="buttons", showactive=False, x=0.02, y=0.98, buttons=[ dict( label="β–Ά Rotate", method="animate", args=[None, dict(frame=dict(duration=40, redraw=True), transition=dict(duration=0), fromcurrent=True, mode="immediate")] ), dict( label="⏹ Stop", method="animate", args=[[None], dict(frame=dict(duration=0, redraw=False), transition=dict(duration=0), mode="immediate")] ) ] )] ) st.plotly_chart(fig, use_container_width=True)