awacke1 commited on
Commit
f06395c
·
1 Parent(s): 2332bdf

Create new file

Browse files
Files changed (1) hide show
  1. app.py +292 -0
app.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio demo for different clustering techiniques
2
+ Derived from https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html
3
+ """
4
+
5
+ import math
6
+ from functools import partial
7
+
8
+ import gradio as gr
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ from sklearn.cluster import (
12
+ AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth
13
+ )
14
+ from sklearn.datasets import make_blobs, make_circles, make_moons
15
+ from sklearn.mixture import GaussianMixture
16
+ from sklearn.neighbors import kneighbors_graph
17
+ from sklearn.preprocessing import StandardScaler
18
+
19
+
20
+ plt.style.use('seaborn')
21
+
22
+
23
+ SEED = 0
24
+ MAX_CLUSTERS = 10
25
+ N_SAMPLES = 1000
26
+ N_COLS = 3
27
+ FIGSIZE = 7, 7 # does not affect size in webpage
28
+ COLORS = [
29
+ 'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'
30
+ ]
31
+ assert len(COLORS) >= MAX_CLUSTERS, "Not enough different colors for all clusters"
32
+ np.random.seed(SEED)
33
+
34
+
35
+ def normalize(X):
36
+ return StandardScaler().fit_transform(X)
37
+
38
+
39
+ def get_regular(n_clusters):
40
+ # spiral pattern
41
+ centers = [
42
+ [0, 0],
43
+ [1, 0],
44
+ [1, 1],
45
+ [0, 1],
46
+ [-1, 1],
47
+ [-1, 0],
48
+ [-1, -1],
49
+ [0, -1],
50
+ [1, -1],
51
+ [2, -1],
52
+ ][:n_clusters]
53
+ assert len(centers) == n_clusters
54
+ X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.25, random_state=SEED)
55
+ return normalize(X), labels
56
+
57
+
58
+ def get_circles(n_clusters):
59
+ X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED)
60
+ return normalize(X), labels
61
+
62
+
63
+ def get_moons(n_clusters):
64
+ X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED)
65
+ return normalize(X), labels
66
+
67
+
68
+ def get_noise(n_clusters):
69
+ np.random.seed(SEED)
70
+ X, labels = np.random.rand(N_SAMPLES, 2), np.random.randint(0, n_clusters, size=(N_SAMPLES,))
71
+ return normalize(X), labels
72
+
73
+
74
+ def get_anisotropic(n_clusters):
75
+ X, labels = make_blobs(n_samples=N_SAMPLES, centers=n_clusters, random_state=170)
76
+ transformation = [[0.6, -0.6], [-0.4, 0.8]]
77
+ X = np.dot(X, transformation)
78
+ return X, labels
79
+
80
+
81
+ def get_varied(n_clusters):
82
+ cluster_std = [1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0][:n_clusters]
83
+ assert len(cluster_std) == n_clusters
84
+ X, labels = make_blobs(
85
+ n_samples=N_SAMPLES, centers=n_clusters, cluster_std=cluster_std, random_state=SEED
86
+ )
87
+ return normalize(X), labels
88
+
89
+
90
+ def get_spiral(n_clusters):
91
+ # from https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_clustering.html
92
+ np.random.seed(SEED)
93
+ t = 1.5 * np.pi * (1 + 3 * np.random.rand(1, N_SAMPLES))
94
+ x = t * np.cos(t)
95
+ y = t * np.sin(t)
96
+ X = np.concatenate((x, y))
97
+ X += 0.7 * np.random.randn(2, N_SAMPLES)
98
+ X = np.ascontiguousarray(X.T)
99
+
100
+ labels = np.zeros(N_SAMPLES, dtype=int)
101
+ return normalize(X), labels
102
+
103
+
104
+ DATA_MAPPING = {
105
+ 'regular': get_regular,
106
+ 'circles': get_circles,
107
+ 'moons': get_moons,
108
+ 'spiral': get_spiral,
109
+ 'noise': get_noise,
110
+ 'anisotropic': get_anisotropic,
111
+ 'varied': get_varied,
112
+ }
113
+
114
+
115
+ def get_groundtruth_model(X, labels, n_clusters, **kwargs):
116
+ # dummy model to show true label distribution
117
+ class Dummy:
118
+ def __init__(self, y):
119
+ self.labels_ = labels
120
+
121
+ return Dummy(labels)
122
+
123
+
124
+ def get_kmeans(X, labels, n_clusters, **kwargs):
125
+ model = KMeans(init="k-means++", n_clusters=n_clusters, n_init=10, random_state=SEED)
126
+ model.set_params(**kwargs)
127
+ return model.fit(X)
128
+
129
+
130
+ def get_dbscan(X, labels, n_clusters, **kwargs):
131
+ model = DBSCAN(eps=0.3)
132
+ model.set_params(**kwargs)
133
+ return model.fit(X)
134
+
135
+
136
+ def get_agglomerative(X, labels, n_clusters, **kwargs):
137
+ connectivity = kneighbors_graph(
138
+ X, n_neighbors=n_clusters, include_self=False
139
+ )
140
+ # make connectivity symmetric
141
+ connectivity = 0.5 * (connectivity + connectivity.T)
142
+ model = AgglomerativeClustering(
143
+ n_clusters=n_clusters, linkage="ward", connectivity=connectivity
144
+ )
145
+ model.set_params(**kwargs)
146
+ return model.fit(X)
147
+
148
+
149
+ def get_meanshift(X, labels, n_clusters, **kwargs):
150
+ bandwidth = estimate_bandwidth(X, quantile=0.25)
151
+ model = MeanShift(bandwidth=bandwidth, bin_seeding=True)
152
+ model.set_params(**kwargs)
153
+ return model.fit(X)
154
+
155
+
156
+ def get_spectral(X, labels, n_clusters, **kwargs):
157
+ model = SpectralClustering(
158
+ n_clusters=n_clusters,
159
+ eigen_solver="arpack",
160
+ affinity="nearest_neighbors",
161
+ )
162
+ model.set_params(**kwargs)
163
+ return model.fit(X)
164
+
165
+
166
+ def get_optics(X, labels, n_clusters, **kwargs):
167
+ model = OPTICS(
168
+ min_samples=7,
169
+ xi=0.05,
170
+ min_cluster_size=0.1,
171
+ )
172
+ model.set_params(**kwargs)
173
+ return model.fit(X)
174
+
175
+
176
+ def get_birch(X, labels, n_clusters, **kwargs):
177
+ model = Birch(n_clusters=n_clusters)
178
+ model.set_params(**kwargs)
179
+ return model.fit(X)
180
+
181
+
182
+ def get_gaussianmixture(X, labels, n_clusters, **kwargs):
183
+ model = GaussianMixture(
184
+ n_components=n_clusters, covariance_type="full", random_state=SEED,
185
+ )
186
+ model.set_params(**kwargs)
187
+ return model.fit(X)
188
+
189
+
190
+ MODEL_MAPPING = {
191
+ 'True labels': get_groundtruth_model,
192
+ 'KMeans': get_kmeans,
193
+ 'DBSCAN': get_dbscan,
194
+ 'MeanShift': get_meanshift,
195
+ 'SpectralClustering': get_spectral,
196
+ 'OPTICS': get_optics,
197
+ 'Birch': get_birch,
198
+ 'GaussianMixture': get_gaussianmixture,
199
+ 'AgglomerativeClustering': get_agglomerative,
200
+ }
201
+
202
+
203
+ def plot_clusters(ax, X, labels):
204
+ set_clusters = set(labels)
205
+ set_clusters.discard(-1) # -1 signifiies outliers, which we plot separately
206
+ for label, color in zip(sorted(set_clusters), COLORS):
207
+ idx = labels == label
208
+ if not sum(idx):
209
+ continue
210
+ ax.scatter(X[idx, 0], X[idx, 1], color=color)
211
+
212
+ # show outliers (if any)
213
+ idx = labels == -1
214
+ if sum(idx):
215
+ ax.scatter(X[idx, 0], X[idx, 1], c='k', marker='x')
216
+
217
+ ax.grid(None)
218
+ ax.set_xticks([])
219
+ ax.set_yticks([])
220
+ return ax
221
+
222
+
223
+ def cluster(dataset: str, n_clusters: int, clustering_algorithm: str):
224
+ if isinstance(n_clusters, dict):
225
+ n_clusters = n_clusters['value']
226
+ else:
227
+ n_clusters = int(n_clusters)
228
+
229
+ X, labels = DATA_MAPPING[dataset](n_clusters)
230
+ model = MODEL_MAPPING[clustering_algorithm](X, labels, n_clusters=n_clusters)
231
+ if hasattr(model, "labels_"):
232
+ y_pred = model.labels_.astype(int)
233
+ else:
234
+ y_pred = model.predict(X)
235
+
236
+ fig, ax = plt.subplots(figsize=FIGSIZE)
237
+
238
+ plot_clusters(ax, X, y_pred)
239
+ ax.set_title(clustering_algorithm, fontsize=16)
240
+
241
+ return fig
242
+
243
+
244
+ title = "Clustering with Scikit-learn"
245
+ description = (
246
+ "This example shows how different clustering algorithms work. Simply pick "
247
+ "the dataset and the number of clusters to see how the clustering algorithms work. "
248
+ "Colored cirles are (predicted) labels and black x are outliers."
249
+ )
250
+
251
+
252
+ def iter_grid(n_rows, n_cols):
253
+ # create a grid using gradio Block
254
+ for _ in range(n_rows):
255
+ with gr.Row():
256
+ for _ in range(n_cols):
257
+ with gr.Column():
258
+ yield
259
+
260
+
261
+ with gr.Blocks(title=title) as demo:
262
+ gr.HTML(f"<b>{title}</b>")
263
+ gr.Markdown(description)
264
+
265
+ input_models = list(MODEL_MAPPING)
266
+ input_data = gr.Radio(
267
+ list(DATA_MAPPING),
268
+ value="regular",
269
+ label="dataset"
270
+ )
271
+ input_n_clusters = gr.Slider(
272
+ minimum=1,
273
+ maximum=MAX_CLUSTERS,
274
+ value=4,
275
+ step=1,
276
+ label='Number of clusters'
277
+ )
278
+ n_rows = int(math.ceil(len(input_models) / N_COLS))
279
+ counter = 0
280
+ for _ in iter_grid(n_rows, N_COLS):
281
+ if counter >= len(input_models):
282
+ break
283
+
284
+ input_model = input_models[counter]
285
+ plot = gr.Plot(label=input_model)
286
+ fn = partial(cluster, clustering_algorithm=input_model)
287
+ input_data.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
288
+ input_n_clusters.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
289
+ counter += 1
290
+
291
+
292
+ demo.launch()