File size: 8,015 Bytes
f63050a
 
b9bfbae
f63050a
 
 
 
b9bfbae
f63050a
b9bfbae
 
f63050a
b9bfbae
f63050a
b9bfbae
f63050a
b9bfbae
f63050a
 
 
 
 
 
b9bfbae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f63050a
 
 
 
b9bfbae
 
f63050a
 
 
b9bfbae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f63050a
 
 
 
b9bfbae
 
f63050a
 
 
b9bfbae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f63050a
 
 
 
b9bfbae
 
f63050a
 
 
 
b9bfbae
 
 
f63050a
b9bfbae
 
 
f63050a
 
 
b9bfbae
 
f63050a
 
 
 
b9bfbae
 
f63050a
b9bfbae
 
 
 
 
 
 
f63050a
 
b9bfbae
 
 
 
 
 
 
 
 
 
 
 
f63050a
 
 
b9bfbae
 
f63050a
 
 
b9bfbae
 
 
 
 
f63050a
 
 
b9bfbae
 
f63050a
 
 
 
b9bfbae
 
 
f63050a
 
b9bfbae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f63050a
b9bfbae
 
f63050a
 
b9bfbae
 
 
f63050a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import marimo

__generated_with = "0.12.8"
app = marimo.App()


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        ## Face Embeddings of World Leaders

        This notebook explores face embeddings using a subset of the **Labeled Faces in the Wild** dataset, focused on public figures. We'll use standard Python and scikit-learn libraries to load the data, embed images, reduce dimensionality, and visualize clustering behavior.

        This example builds on a demo from the Marimo gallery using the MNIST dataset. Here, we adapt it to work with a facial recognition dataset of public figures. While facial recognition has limited responsible use cases, this curated subset includes only world leaders β€” a group I feel comfortable experimenting with in a technical context.

        We'll start with our imports:
        """
    )
    return


@app.cell
def _():
    from time import time

    import matplotlib.pyplot as plt
    from scipy.stats import loguniform

    from sklearn.datasets import fetch_lfw_people
    from sklearn.decomposition import PCA
    from sklearn.metrics import ConfusionMatrixDisplay, classification_report
    from sklearn.model_selection import RandomizedSearchCV, train_test_split
    from sklearn.preprocessing import StandardScaler
    from sklearn.svm import SVC
    return (
        ConfusionMatrixDisplay,
        PCA,
        RandomizedSearchCV,
        SVC,
        StandardScaler,
        classification_report,
        fetch_lfw_people,
        loguniform,
        plt,
        time,
        train_test_split,
    )


@app.cell(hide_code=True)
def _(mo):
    mo.md(r"""We're using `fetch_lfw_people` from `sklearn.datasets` to load a curated subset of the LFW dataset β€” restricted to individuals with at least 70 images, resulting in 7 distinct people and just over 1,200 samples. These happen to be mostly world leaders, which makes the demo both manageable and fun to explore.""")
    return


@app.cell
def _(fetch_lfw_people):
    lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4)

    # introspect the images arrays to find the shapes (for plotting)
    n_samples, h, w = lfw_people.images.shape

    # for machine learning we use the 2 data directly (as relative pixel
    # positions info is ignored by this model)
    X = lfw_people.data
    n_features = X.shape[1]

    # the label to predict is the id of the person
    Y = lfw_people.target
    target_names = lfw_people.target_names
    n_classes = target_names.shape[0]

    print("Total dataset size:")
    print("n_samples: %d" % n_samples)
    print("n_features: %d" % n_features)
    print("n_classes: %d" % n_classes)
    return (
        X,
        Y,
        h,
        lfw_people,
        n_classes,
        n_features,
        n_samples,
        target_names,
        w,
    )


@app.cell(hide_code=True)
def _(mo):
    mo.md(r"""Next, we embed each face image using a pre-trained FaceNet model (`InceptionResnetV1` trained on `vggface2`). This converts each image into a 512-dimensional vector. Since the original data is grayscale and flattened, we reshape, normalize, and convert it to RGB before feeding it through the model.""")
    return


@app.cell
def _(X, h, w):
    from facenet_pytorch import InceptionResnetV1
    from torchvision import transforms
    from PIL import Image
    import torch
    import numpy as np

    # Load FaceNet model
    model = InceptionResnetV1(pretrained='vggface2').eval()

    # Transform pipeline: grayscale β†’ RGB β†’ resize β†’ normalize
    transform = transforms.Compose([
        transforms.Resize((160, 160)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),
        transforms.Normalize([0.5], [0.5])
    ])

    # Embed a single flattened row from X
    def embed_flat_row(flat):
        img = flat.reshape(h, w)
        img = (img * 255).astype(np.uint8)
        pil = Image.fromarray(img).convert("L")  # grayscale
        tensor = transform(pil).unsqueeze(0)
        with torch.no_grad():
            return model(tensor).squeeze().numpy()  # 512-dim

    # Generate embeddings for all samples
    embeddings = np.array([embed_flat_row(row) for row in X])
    return (
        Image,
        InceptionResnetV1,
        embed_flat_row,
        embeddings,
        model,
        np,
        torch,
        transform,
        transforms,
    )


@app.cell
def _(mo):
    mo.md(r"""Now that we have 512-dimensional embeddings, we reduce them to 2D for visualization. Both t-SNE and UMAP are available here β€” UMAP is active by default, but you can switch to t-SNE by uncommenting the alternate line. This step lets us inspect the structure of the embedding space:""")
    return


@app.cell
def _(embeddings):
    from sklearn.manifold import TSNE
    import umap.umap_ as umap

    # X_embedded = TSNE(n_components=2, perplexity=30, random_state=42).fit_transform(embeddings)
    X_embedded = umap.UMAP(n_components=2, random_state=42).fit_transform(embeddings)
    return TSNE, X_embedded, umap


@app.cell
def _(mo):
    mo.md(r"""We wrap the 2D embeddings into a Pandas DataFrame for easier manipulation and plotting. Each row includes x/y coordinates and the associated person ID, which we map to names. We then define a simple Altair scatterplot function to visualize the clustered embeddings by identity.""")
    return


@app.cell
def _(X_embedded, Y, target_names):
    import pandas as pd

    embedding_df = pd.DataFrame({
        "x": X_embedded[:, 0],
        "y": X_embedded[:, 1],
        "person": Y
    }).reset_index()
    embedding_df["name"] = embedding_df["person"].map(lambda i: target_names[i])
    return embedding_df, pd


@app.cell
def _():
    import altair as alt
    def scatter(df):
        return (alt.Chart(df)
        .mark_circle()
        .encode(
            x=alt.X("x:Q"),
            y=alt.Y("y:Q"),
            color=alt.Color("name:N"),
        ).properties(width=500, height=300))
    return alt, scatter


@app.cell(hide_code=True)
def _(mo):
    mo.md(r"""Here's our 2D embedding space of world leader faces! Each point is a facial embedding projected with UMAP and colored by identity. Try selecting a cluster β€” the notebook will automatically reveal the associated images so you can explore what the model β€œthinks” belongs together.""")
    return


@app.cell
def _(embedding_df, scatter):
    import marimo as mo
    chart = mo.ui.altair_chart(scatter(embedding_df))
    return chart, mo


@app.cell(hide_code=True)
def _(mo):
    mo.md(r"""When you select points in the scatterplot, Marimo automatically passes those indices into this cell. Here, we render a preview of the corresponding face images using `matplotlib`, along with a table of all selected metadata β€” making it easy to inspect clustering quality or outliers at a glance.""")
    return


@app.cell
def _(chart, mo):
    table = mo.ui.table(chart.value)
    return (table,)


@app.cell
def _(X, chart, h, mo, table, w):
    def show_images(indices, max_images=6):
        import matplotlib.pyplot as plt

        indices = indices[:max_images]
        images = X.reshape((-1, h, w))[indices]
        fig, axes = plt.subplots(1, len(indices))
        fig.set_size_inches(12.5, 1.5)
        if len(indices) > 1:
            for im, ax in zip(images, axes.flat):
                ax.imshow(im, cmap="gray")
                ax.set_yticks([])
                ax.set_xticks([])
        else:
            axes.imshow(images[0], cmap="gray")
            axes.set_yticks([])
            axes.set_xticks([])
        plt.tight_layout()
        return fig

    def show_selected():
        return (
            show_images(list(chart.value["index"]))
            if not len(table.value)
            else show_images(list(table.value["index"]))
        )

    mo.hstack([chart, show_selected() if len(chart.value) else ""])
    return show_images, show_selected


@app.cell
def _():
    return


if __name__ == "__main__":
    app.run()