from time import time import gradio as gr import numpy as np import matplotlib.pyplot as plt import plotly.graph_objects as go from sklearn import manifold, datasets from sklearn.cluster import AgglomerativeClustering SEED = 0 digits = datasets.load_digits() X, y = digits.data, digits.target n_samples, n_features = X.shape np.random.seed(SEED) import matplotlib matplotlib.use('Agg') def plot_clustering(linkage, dim): if dim == '3D': X_red = manifold.SpectralEmbedding(n_components=3).fit_transform(X) else: X_red = manifold.SpectralEmbedding(n_components=2).fit_transform(X) clustering = AgglomerativeClustering(linkage=linkage, n_clusters=10) t0 = time() clustering.fit(X_red) print("%s :\t%.2fs" % (linkage, time() - t0)) labels = clustering.labels_ x_min, x_max = np.min(X_red, axis=0), np.max(X_red, axis=0) X_red = (X_red - x_min) / (x_max - x_min) fig = go.Figure() for digit in digits.target_names: subset = X_red[y==digit] rgbas = plt.cm.nipy_spectral(labels[y == digit]/10) color = [f'rgba({rgba[0]}, {rgba[1]}, {rgba[2]}, 0.8)' for rgba in rgbas] if dim == '2D': fig.add_trace(go.Scatter(x=subset[:,0], y=subset[:,1], mode='text', text=str(digit), textfont={'size': 16, 'color': color})) elif dim == '3D': fig.add_trace(go.Scatter3d(x=subset[:,0], y=subset[:,1], z=subset[:,2], mode='text', text=str(digit), textfont={'size': 16, 'color': color})) fig.update_traces(showlegend=False) return fig title = '# Agglomerative Clustering on MNIST' description = """ An illustration of various linkage option for [agglomerative clustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) on the digits dataset. """ author = ''' Created by [@Hnabil](https://huggingface.co/Hnabil) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/cluster/plot_digits_linkage.html) ''' with gr.Blocks(analytics_enabled=False, title=title) as demo: gr.Markdown(title) gr.Markdown(description) gr.Markdown(author) with gr.Row(): with gr.Column(): linkage = gr.Radio(["ward", "average", "complete", "single"], value="average", interactive=True, label="Linkage Method") dim = gr.Radio(['2D', '3D'], label='Embedding Dimensionality', value='2D') btn = gr.Button('Submit') with gr.Column(): plot = gr.Plot(label='MNIST Embeddings') btn.click(plot_clustering, inputs=[linkage, dim], outputs=[plot]) demo.load(plot_clustering, inputs=[linkage, dim], outputs=[plot]) demo.launch()