Spaces:
Running
Running
File size: 8,015 Bytes
f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a b9bfbae f63050a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
import marimo
__generated_with = "0.12.8"
app = marimo.App()
@app.cell(hide_code=True)
def _(mo):
mo.md(
r"""
## Face Embeddings of World Leaders
This notebook explores face embeddings using a subset of the **Labeled Faces in the Wild** dataset, focused on public figures. We'll use standard Python and scikit-learn libraries to load the data, embed images, reduce dimensionality, and visualize clustering behavior.
This example builds on a demo from the Marimo gallery using the MNIST dataset. Here, we adapt it to work with a facial recognition dataset of public figures. While facial recognition has limited responsible use cases, this curated subset includes only world leaders β a group I feel comfortable experimenting with in a technical context.
We'll start with our imports:
"""
)
return
@app.cell
def _():
from time import time
import matplotlib.pyplot as plt
from scipy.stats import loguniform
from sklearn.datasets import fetch_lfw_people
from sklearn.decomposition import PCA
from sklearn.metrics import ConfusionMatrixDisplay, classification_report
from sklearn.model_selection import RandomizedSearchCV, train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
return (
ConfusionMatrixDisplay,
PCA,
RandomizedSearchCV,
SVC,
StandardScaler,
classification_report,
fetch_lfw_people,
loguniform,
plt,
time,
train_test_split,
)
@app.cell(hide_code=True)
def _(mo):
mo.md(r"""We're using `fetch_lfw_people` from `sklearn.datasets` to load a curated subset of the LFW dataset β restricted to individuals with at least 70 images, resulting in 7 distinct people and just over 1,200 samples. These happen to be mostly world leaders, which makes the demo both manageable and fun to explore.""")
return
@app.cell
def _(fetch_lfw_people):
lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4)
# introspect the images arrays to find the shapes (for plotting)
n_samples, h, w = lfw_people.images.shape
# for machine learning we use the 2 data directly (as relative pixel
# positions info is ignored by this model)
X = lfw_people.data
n_features = X.shape[1]
# the label to predict is the id of the person
Y = lfw_people.target
target_names = lfw_people.target_names
n_classes = target_names.shape[0]
print("Total dataset size:")
print("n_samples: %d" % n_samples)
print("n_features: %d" % n_features)
print("n_classes: %d" % n_classes)
return (
X,
Y,
h,
lfw_people,
n_classes,
n_features,
n_samples,
target_names,
w,
)
@app.cell(hide_code=True)
def _(mo):
mo.md(r"""Next, we embed each face image using a pre-trained FaceNet model (`InceptionResnetV1` trained on `vggface2`). This converts each image into a 512-dimensional vector. Since the original data is grayscale and flattened, we reshape, normalize, and convert it to RGB before feeding it through the model.""")
return
@app.cell
def _(X, h, w):
from facenet_pytorch import InceptionResnetV1
from torchvision import transforms
from PIL import Image
import torch
import numpy as np
# Load FaceNet model
model = InceptionResnetV1(pretrained='vggface2').eval()
# Transform pipeline: grayscale β RGB β resize β normalize
transform = transforms.Compose([
transforms.Resize((160, 160)),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),
transforms.Normalize([0.5], [0.5])
])
# Embed a single flattened row from X
def embed_flat_row(flat):
img = flat.reshape(h, w)
img = (img * 255).astype(np.uint8)
pil = Image.fromarray(img).convert("L") # grayscale
tensor = transform(pil).unsqueeze(0)
with torch.no_grad():
return model(tensor).squeeze().numpy() # 512-dim
# Generate embeddings for all samples
embeddings = np.array([embed_flat_row(row) for row in X])
return (
Image,
InceptionResnetV1,
embed_flat_row,
embeddings,
model,
np,
torch,
transform,
transforms,
)
@app.cell
def _(mo):
mo.md(r"""Now that we have 512-dimensional embeddings, we reduce them to 2D for visualization. Both t-SNE and UMAP are available here β UMAP is active by default, but you can switch to t-SNE by uncommenting the alternate line. This step lets us inspect the structure of the embedding space:""")
return
@app.cell
def _(embeddings):
from sklearn.manifold import TSNE
import umap.umap_ as umap
# X_embedded = TSNE(n_components=2, perplexity=30, random_state=42).fit_transform(embeddings)
X_embedded = umap.UMAP(n_components=2, random_state=42).fit_transform(embeddings)
return TSNE, X_embedded, umap
@app.cell
def _(mo):
mo.md(r"""We wrap the 2D embeddings into a Pandas DataFrame for easier manipulation and plotting. Each row includes x/y coordinates and the associated person ID, which we map to names. We then define a simple Altair scatterplot function to visualize the clustered embeddings by identity.""")
return
@app.cell
def _(X_embedded, Y, target_names):
import pandas as pd
embedding_df = pd.DataFrame({
"x": X_embedded[:, 0],
"y": X_embedded[:, 1],
"person": Y
}).reset_index()
embedding_df["name"] = embedding_df["person"].map(lambda i: target_names[i])
return embedding_df, pd
@app.cell
def _():
import altair as alt
def scatter(df):
return (alt.Chart(df)
.mark_circle()
.encode(
x=alt.X("x:Q"),
y=alt.Y("y:Q"),
color=alt.Color("name:N"),
).properties(width=500, height=300))
return alt, scatter
@app.cell(hide_code=True)
def _(mo):
mo.md(r"""Here's our 2D embedding space of world leader faces! Each point is a facial embedding projected with UMAP and colored by identity. Try selecting a cluster β the notebook will automatically reveal the associated images so you can explore what the model βthinksβ belongs together.""")
return
@app.cell
def _(embedding_df, scatter):
import marimo as mo
chart = mo.ui.altair_chart(scatter(embedding_df))
return chart, mo
@app.cell(hide_code=True)
def _(mo):
mo.md(r"""When you select points in the scatterplot, Marimo automatically passes those indices into this cell. Here, we render a preview of the corresponding face images using `matplotlib`, along with a table of all selected metadata β making it easy to inspect clustering quality or outliers at a glance.""")
return
@app.cell
def _(chart, mo):
table = mo.ui.table(chart.value)
return (table,)
@app.cell
def _(X, chart, h, mo, table, w):
def show_images(indices, max_images=6):
import matplotlib.pyplot as plt
indices = indices[:max_images]
images = X.reshape((-1, h, w))[indices]
fig, axes = plt.subplots(1, len(indices))
fig.set_size_inches(12.5, 1.5)
if len(indices) > 1:
for im, ax in zip(images, axes.flat):
ax.imshow(im, cmap="gray")
ax.set_yticks([])
ax.set_xticks([])
else:
axes.imshow(images[0], cmap="gray")
axes.set_yticks([])
axes.set_xticks([])
plt.tight_layout()
return fig
def show_selected():
return (
show_images(list(chart.value["index"]))
if not len(table.value)
else show_images(list(table.value["index"]))
)
mo.hstack([chart, show_selected() if len(chart.value) else ""])
return show_images, show_selected
@app.cell
def _():
return
if __name__ == "__main__":
app.run()
|