mshamrai commited on
Commit
8b0cc53
·
1 Parent(s): e5394a1

chore: black

Browse files
Files changed (3) hide show
  1. app.py +253 -83
  2. constants.py +2 -2
  3. utils.py +91 -31
app.py CHANGED
@@ -2,19 +2,23 @@ import gradio as gr
2
  import pandas as pd
3
  import numpy as np
4
  import os
5
- from utils import (plot_distances_tsne,
6
- plot_distances_umap,
7
- cluster_languages_hdbscan,
8
- cluster_languages_kmeans,
9
- plot_mst,
10
- cluster_languages_by_families,
11
- cluster_languages_by_subfamilies,
12
- filter_languages_by_families)
 
 
13
  from functools import partial
14
  import datasets
15
 
16
 
17
- dataset = datasets.load_dataset("mshamrai/language-metric-data", split="train", trust_remote_code=True)
 
 
18
 
19
  languages = dataset["languages_list"][0]
20
  average_distances_matrix = np.array(dataset["average_distances_matrix"][0])
@@ -27,7 +31,7 @@ distance_matrices = {
27
  MODELS[j]: np.array(dataset["distances_matrices"][0]["models"][i]["matrix"][j])
28
  for j in range(len(MODELS))
29
  }
30
- for i in range(len(DATASETS))
31
  }
32
 
33
 
@@ -63,6 +67,7 @@ def get_similar_languages(model, dataset, selected_language, use_average, n):
63
  sorted_distances["Distance"] = sorted_distances["Distance"].round(4)
64
  return sorted_distances.head(n)
65
 
 
66
  def update_languages(model, dataset):
67
  """
68
  Returns the language list based on the given model and dataset.
@@ -85,21 +90,29 @@ def update_language_options(model, dataset, language, use_average):
85
 
86
  def toggle_inputs(use_average):
87
  if use_average:
88
- return gr.update(interactive=False, visible=False), gr.update(interactive=False, visible=False)
 
 
89
  else:
90
- return gr.update(interactive=True, visible=True), gr.update(interactive=True, visible=True)
 
 
91
 
92
 
93
  plot_path = "plots/last_plot.pdf"
94
  os.makedirs("plots", exist_ok=True)
95
 
96
 
97
- def plot_distances(model, dataset, use_average, cluster_method, cluster_method_param, plot_fn):
 
 
98
  """
99
  Plots all languages from the distances matrix using t-SNE.
100
  """
101
 
102
- updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average)
 
 
103
 
104
  if cluster_method == "HDBSCAN":
105
  filtered_matrix, filtered_languages, clusters = cluster_languages_hdbscan(
@@ -122,18 +135,41 @@ def plot_distances(model, dataset, use_average, cluster_method, cluster_method_p
122
  else:
123
  raise ValueError("Invalid cluster method")
124
 
125
- fig = plot_fn(model, dataset, use_average, filtered_matrix, filtered_languages, clusters, legends)
 
 
 
 
 
 
 
 
126
  fig.tight_layout()
127
  fig.savefig(plot_path, format="pdf")
128
  return fig, gr.DownloadButton(label="Download Plot", value=plot_path)
129
 
130
 
131
- def plot_families_subfamilies(families, model, dataset, use_average, figsize_h, figsize_w):
132
- updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average)
133
- updated_matrix, updated_languages = filter_languages_by_families(updated_matrix, updated_languages, families)
 
 
 
 
 
 
134
 
135
  clusters, legends = cluster_languages_by_subfamilies(updated_languages)
136
- fig = plot_mst(model, dataset, use_average, updated_matrix, updated_languages, clusters, legends, fig_size=(figsize_w, figsize_h))
 
 
 
 
 
 
 
 
 
137
  fig.tight_layout()
138
  fig.savefig(plot_path, format="pdf")
139
  return fig, gr.DownloadButton(label="Download Plot", value=plot_path)
@@ -145,108 +181,242 @@ with gr.Blocks() as demo:
145
  with gr.Row():
146
  model_input = gr.Dropdown(label="Model", choices=MODELS, value=MODELS[0])
147
  dataset_input = gr.Dropdown(
148
- label="Dataset",
149
- choices=DATASETS,
150
- value=DATASETS[0]
151
  )
152
 
153
  with gr.Tab(label="Closest Languages Table"):
154
  with gr.Row():
155
- language_input = gr.Dropdown(label="Language", choices=languages, value=languages[0])
156
- top_n_input = gr.Slider(label="Top N", minimum=1, maximum=30, step=1, value=10)
157
-
 
 
 
 
158
  output_table = gr.Dataframe(label="Similar Languages")
159
-
160
- model_input.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input)
161
- dataset_input.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input)
162
- language_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
163
- model_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
164
- dataset_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
165
- top_n_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  average_checkbox.change(
168
  fn=toggle_inputs,
169
  inputs=[average_checkbox],
170
- outputs=[model_input, dataset_input]
171
  )
172
 
173
- average_checkbox.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input)
174
- average_checkbox.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
175
-
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  with gr.Tab(label="Distance Plot"):
178
  with gr.Row():
179
- cluster_method_input = gr.Dropdown(label="Cluster Method", choices=["HDBSCAN", "KMeans", "Family", "Subfamily"], value="HDBSCAN")
180
- clusters_input = gr.Slider(label="Minimum Elements in a Cluster", minimum=2, maximum=10, step=1, value=2)
 
 
 
 
 
 
 
 
 
 
181
 
182
  def update_clusters_input_option(cluster_method):
183
  if cluster_method == "HDBSCAN":
184
- return gr.Slider(label="Minimum Elements in a Cluster", minimum=2, maximum=10, step=1, value=2, visible=True, interactive=True)
 
 
 
 
 
 
 
 
185
  elif cluster_method == "KMeans":
186
- return gr.Slider(label="Number of Clusters", minimum=2, maximum=20, step=1, value=2, visible=True, interactive=True)
 
 
 
 
 
 
 
 
187
  else:
188
  return gr.update(interactive=False, visible=False)
189
 
190
- cluster_method_input.change(fn=update_clusters_input_option, inputs=[cluster_method_input], outputs=clusters_input)
191
-
 
 
 
 
192
  with gr.Row():
193
  plot_tsne_button = gr.Button("Plot t-SNE")
194
  plot_umap_button = gr.Button("Plot UMAP")
195
  plot_mst_button = gr.Button("Plot MST")
196
-
197
  with gr.Row():
198
  download_plot_button = gr.DownloadButton("Download Plot")
199
 
200
  with gr.Row():
201
  plot_output = gr.Plot(label="Distance Plot")
202
 
203
- plot_tsne_button.click(fn=partial(plot_distances, plot_fn=plot_distances_tsne),
204
- inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
205
- outputs=[plot_output, download_plot_button])
206
- plot_umap_button.click(fn=partial(plot_distances, plot_fn=plot_distances_umap),
207
- inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
208
- outputs=[plot_output, download_plot_button])
209
- plot_mst_button.click(fn=partial(plot_distances, plot_fn=plot_mst),
210
- inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
211
- outputs=[plot_output, download_plot_button])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  with gr.Tab(label="Language Families Subplot"):
214
-
215
- checked_families_input = gr.CheckboxGroup(label="Language Families",
216
- choices=[
217
- 'Afroasiatic',
218
- 'Austroasiatic',
219
- 'Austronesian',
220
- 'Constructed',
221
- 'Creole',
222
- 'Dravidian',
223
- 'Germanic',
224
- 'Indo-European',
225
- 'Japonic',
226
- 'Kartvelian',
227
- 'Koreanic',
228
- 'Language Isolate',
229
- 'Niger-Congo',
230
- 'Northeast Caucasian',
231
- 'Romance',
232
- 'Sino-Tibetan',
233
- 'Turkic',
234
- 'Uralic'
235
- ],
236
- value=["Indo-European"])
 
 
237
  with gr.Row():
238
  plot_family_button = gr.Button("Plot Families")
239
- plot_figsize_h_input = gr.Slider(label="Figure Height", minimum=5, maximum=30, step=1, value=15)
240
- plot_figsize_w_input = gr.Slider(label="Figure Width", minimum=5, maximum=30, step=1, value=15)
 
 
 
 
241
 
242
  with gr.Row():
243
- download_families_plot_button = gr.DownloadButton("Download Plot", value=plot_path)
 
 
244
 
245
  plot_family_output = gr.Plot(label="Families Plot")
246
-
247
- plot_family_button.click(fn=plot_families_subfamilies,
248
- inputs=[checked_families_input, model_input, dataset_input, average_checkbox, plot_figsize_h_input, plot_figsize_w_input],
249
- outputs=[plot_family_output, download_families_plot_button])
250
-
 
 
 
 
 
 
 
 
 
251
 
252
  demo.launch(share=True)
 
2
  import pandas as pd
3
  import numpy as np
4
  import os
5
+ from utils import (
6
+ plot_distances_tsne,
7
+ plot_distances_umap,
8
+ cluster_languages_hdbscan,
9
+ cluster_languages_kmeans,
10
+ plot_mst,
11
+ cluster_languages_by_families,
12
+ cluster_languages_by_subfamilies,
13
+ filter_languages_by_families,
14
+ )
15
  from functools import partial
16
  import datasets
17
 
18
 
19
+ dataset = datasets.load_dataset(
20
+ "mshamrai/language-metric-data", split="train", trust_remote_code=True
21
+ )
22
 
23
  languages = dataset["languages_list"][0]
24
  average_distances_matrix = np.array(dataset["average_distances_matrix"][0])
 
31
  MODELS[j]: np.array(dataset["distances_matrices"][0]["models"][i]["matrix"][j])
32
  for j in range(len(MODELS))
33
  }
34
+ for i in range(len(DATASETS))
35
  }
36
 
37
 
 
67
  sorted_distances["Distance"] = sorted_distances["Distance"].round(4)
68
  return sorted_distances.head(n)
69
 
70
+
71
  def update_languages(model, dataset):
72
  """
73
  Returns the language list based on the given model and dataset.
 
90
 
91
  def toggle_inputs(use_average):
92
  if use_average:
93
+ return gr.update(interactive=False, visible=False), gr.update(
94
+ interactive=False, visible=False
95
+ )
96
  else:
97
+ return gr.update(interactive=True, visible=True), gr.update(
98
+ interactive=True, visible=True
99
+ )
100
 
101
 
102
  plot_path = "plots/last_plot.pdf"
103
  os.makedirs("plots", exist_ok=True)
104
 
105
 
106
+ def plot_distances(
107
+ model, dataset, use_average, cluster_method, cluster_method_param, plot_fn
108
+ ):
109
  """
110
  Plots all languages from the distances matrix using t-SNE.
111
  """
112
 
113
+ updated_matrix, updated_languages = filter_languages_nan(
114
+ model, dataset, use_average
115
+ )
116
 
117
  if cluster_method == "HDBSCAN":
118
  filtered_matrix, filtered_languages, clusters = cluster_languages_hdbscan(
 
135
  else:
136
  raise ValueError("Invalid cluster method")
137
 
138
+ fig = plot_fn(
139
+ model,
140
+ dataset,
141
+ use_average,
142
+ filtered_matrix,
143
+ filtered_languages,
144
+ clusters,
145
+ legends,
146
+ )
147
  fig.tight_layout()
148
  fig.savefig(plot_path, format="pdf")
149
  return fig, gr.DownloadButton(label="Download Plot", value=plot_path)
150
 
151
 
152
+ def plot_families_subfamilies(
153
+ families, model, dataset, use_average, figsize_h, figsize_w
154
+ ):
155
+ updated_matrix, updated_languages = filter_languages_nan(
156
+ model, dataset, use_average
157
+ )
158
+ updated_matrix, updated_languages = filter_languages_by_families(
159
+ updated_matrix, updated_languages, families
160
+ )
161
 
162
  clusters, legends = cluster_languages_by_subfamilies(updated_languages)
163
+ fig = plot_mst(
164
+ model,
165
+ dataset,
166
+ use_average,
167
+ updated_matrix,
168
+ updated_languages,
169
+ clusters,
170
+ legends,
171
+ fig_size=(figsize_w, figsize_h),
172
+ )
173
  fig.tight_layout()
174
  fig.savefig(plot_path, format="pdf")
175
  return fig, gr.DownloadButton(label="Download Plot", value=plot_path)
 
181
  with gr.Row():
182
  model_input = gr.Dropdown(label="Model", choices=MODELS, value=MODELS[0])
183
  dataset_input = gr.Dropdown(
184
+ label="Dataset", choices=DATASETS, value=DATASETS[0]
 
 
185
  )
186
 
187
  with gr.Tab(label="Closest Languages Table"):
188
  with gr.Row():
189
+ language_input = gr.Dropdown(
190
+ label="Language", choices=languages, value=languages[0]
191
+ )
192
+ top_n_input = gr.Slider(
193
+ label="Top N", minimum=1, maximum=30, step=1, value=10
194
+ )
195
+
196
  output_table = gr.Dataframe(label="Similar Languages")
197
+
198
+ model_input.change(
199
+ fn=update_language_options,
200
+ inputs=[model_input, dataset_input, language_input, average_checkbox],
201
+ outputs=language_input,
202
+ )
203
+ dataset_input.change(
204
+ fn=update_language_options,
205
+ inputs=[model_input, dataset_input, language_input, average_checkbox],
206
+ outputs=language_input,
207
+ )
208
+ language_input.change(
209
+ fn=get_similar_languages,
210
+ inputs=[
211
+ model_input,
212
+ dataset_input,
213
+ language_input,
214
+ average_checkbox,
215
+ top_n_input,
216
+ ],
217
+ outputs=output_table,
218
+ )
219
+ model_input.change(
220
+ fn=get_similar_languages,
221
+ inputs=[
222
+ model_input,
223
+ dataset_input,
224
+ language_input,
225
+ average_checkbox,
226
+ top_n_input,
227
+ ],
228
+ outputs=output_table,
229
+ )
230
+ dataset_input.change(
231
+ fn=get_similar_languages,
232
+ inputs=[
233
+ model_input,
234
+ dataset_input,
235
+ language_input,
236
+ average_checkbox,
237
+ top_n_input,
238
+ ],
239
+ outputs=output_table,
240
+ )
241
+ top_n_input.change(
242
+ fn=get_similar_languages,
243
+ inputs=[
244
+ model_input,
245
+ dataset_input,
246
+ language_input,
247
+ average_checkbox,
248
+ top_n_input,
249
+ ],
250
+ outputs=output_table,
251
+ )
252
 
253
  average_checkbox.change(
254
  fn=toggle_inputs,
255
  inputs=[average_checkbox],
256
+ outputs=[model_input, dataset_input],
257
  )
258
 
259
+ average_checkbox.change(
260
+ fn=update_language_options,
261
+ inputs=[model_input, dataset_input, language_input, average_checkbox],
262
+ outputs=language_input,
263
+ )
264
+ average_checkbox.change(
265
+ fn=get_similar_languages,
266
+ inputs=[
267
+ model_input,
268
+ dataset_input,
269
+ language_input,
270
+ average_checkbox,
271
+ top_n_input,
272
+ ],
273
+ outputs=output_table,
274
+ )
275
 
276
  with gr.Tab(label="Distance Plot"):
277
  with gr.Row():
278
+ cluster_method_input = gr.Dropdown(
279
+ label="Cluster Method",
280
+ choices=["HDBSCAN", "KMeans", "Family", "Subfamily"],
281
+ value="HDBSCAN",
282
+ )
283
+ clusters_input = gr.Slider(
284
+ label="Minimum Elements in a Cluster",
285
+ minimum=2,
286
+ maximum=10,
287
+ step=1,
288
+ value=2,
289
+ )
290
 
291
  def update_clusters_input_option(cluster_method):
292
  if cluster_method == "HDBSCAN":
293
+ return gr.Slider(
294
+ label="Minimum Elements in a Cluster",
295
+ minimum=2,
296
+ maximum=10,
297
+ step=1,
298
+ value=2,
299
+ visible=True,
300
+ interactive=True,
301
+ )
302
  elif cluster_method == "KMeans":
303
+ return gr.Slider(
304
+ label="Number of Clusters",
305
+ minimum=2,
306
+ maximum=20,
307
+ step=1,
308
+ value=2,
309
+ visible=True,
310
+ interactive=True,
311
+ )
312
  else:
313
  return gr.update(interactive=False, visible=False)
314
 
315
+ cluster_method_input.change(
316
+ fn=update_clusters_input_option,
317
+ inputs=[cluster_method_input],
318
+ outputs=clusters_input,
319
+ )
320
+
321
  with gr.Row():
322
  plot_tsne_button = gr.Button("Plot t-SNE")
323
  plot_umap_button = gr.Button("Plot UMAP")
324
  plot_mst_button = gr.Button("Plot MST")
325
+
326
  with gr.Row():
327
  download_plot_button = gr.DownloadButton("Download Plot")
328
 
329
  with gr.Row():
330
  plot_output = gr.Plot(label="Distance Plot")
331
 
332
+ plot_tsne_button.click(
333
+ fn=partial(plot_distances, plot_fn=plot_distances_tsne),
334
+ inputs=[
335
+ model_input,
336
+ dataset_input,
337
+ average_checkbox,
338
+ cluster_method_input,
339
+ clusters_input,
340
+ ],
341
+ outputs=[plot_output, download_plot_button],
342
+ )
343
+ plot_umap_button.click(
344
+ fn=partial(plot_distances, plot_fn=plot_distances_umap),
345
+ inputs=[
346
+ model_input,
347
+ dataset_input,
348
+ average_checkbox,
349
+ cluster_method_input,
350
+ clusters_input,
351
+ ],
352
+ outputs=[plot_output, download_plot_button],
353
+ )
354
+ plot_mst_button.click(
355
+ fn=partial(plot_distances, plot_fn=plot_mst),
356
+ inputs=[
357
+ model_input,
358
+ dataset_input,
359
+ average_checkbox,
360
+ cluster_method_input,
361
+ clusters_input,
362
+ ],
363
+ outputs=[plot_output, download_plot_button],
364
+ )
365
 
366
  with gr.Tab(label="Language Families Subplot"):
367
+
368
+ checked_families_input = gr.CheckboxGroup(
369
+ label="Language Families",
370
+ choices=[
371
+ "Afroasiatic",
372
+ "Austroasiatic",
373
+ "Austronesian",
374
+ "Constructed",
375
+ "Creole",
376
+ "Dravidian",
377
+ "Germanic",
378
+ "Indo-European",
379
+ "Japonic",
380
+ "Kartvelian",
381
+ "Koreanic",
382
+ "Language Isolate",
383
+ "Niger-Congo",
384
+ "Northeast Caucasian",
385
+ "Romance",
386
+ "Sino-Tibetan",
387
+ "Turkic",
388
+ "Uralic",
389
+ ],
390
+ value=["Indo-European"],
391
+ )
392
  with gr.Row():
393
  plot_family_button = gr.Button("Plot Families")
394
+ plot_figsize_h_input = gr.Slider(
395
+ label="Figure Height", minimum=5, maximum=30, step=1, value=15
396
+ )
397
+ plot_figsize_w_input = gr.Slider(
398
+ label="Figure Width", minimum=5, maximum=30, step=1, value=15
399
+ )
400
 
401
  with gr.Row():
402
+ download_families_plot_button = gr.DownloadButton(
403
+ "Download Plot", value=plot_path
404
+ )
405
 
406
  plot_family_output = gr.Plot(label="Families Plot")
407
+
408
+ plot_family_button.click(
409
+ fn=plot_families_subfamilies,
410
+ inputs=[
411
+ checked_families_input,
412
+ model_input,
413
+ dataset_input,
414
+ average_checkbox,
415
+ plot_figsize_h_input,
416
+ plot_figsize_w_input,
417
+ ],
418
+ outputs=[plot_family_output, download_families_plot_button],
419
+ )
420
+
421
 
422
  demo.launch(share=True)
constants.py CHANGED
@@ -104,7 +104,7 @@ language_subfamilies = {
104
  "Western Punjabi": "Punjabi",
105
  "Yoruba": "Yoruboid",
106
  "Esperanto": "Constructed",
107
- "Crimean Tatar": "Kypchak"
108
  }
109
 
110
  language_families = {
@@ -213,5 +213,5 @@ language_families = {
213
  "Western Punjabi": "Indo-European",
214
  "Yoruba": "Niger-Congo",
215
  "Esperanto": "Constructed",
216
- "Crimean Tatar": "Turkic"
217
  }
 
104
  "Western Punjabi": "Punjabi",
105
  "Yoruba": "Yoruboid",
106
  "Esperanto": "Constructed",
107
+ "Crimean Tatar": "Kypchak",
108
  }
109
 
110
  language_families = {
 
213
  "Western Punjabi": "Indo-European",
214
  "Yoruba": "Niger-Congo",
215
  "Esperanto": "Constructed",
216
+ "Crimean Tatar": "Turkic",
217
  }
utils.py CHANGED
@@ -21,7 +21,11 @@ def filter_languages_by_families(matrix, languages, families):
21
  Returns:
22
  - filtered_languages: list of languages that belong to the specified families.
23
  """
24
- filtered_languages = [(i, lang) for i, lang in enumerate(languages) if language_families[lang] in families]
 
 
 
 
25
  filtered_indices = [i for i, lang in filtered_languages]
26
  filtered_languages = [lang for i, lang in filtered_languages]
27
  filtered_matrix = matrix[np.ix_(filtered_indices, filtered_indices)]
@@ -51,13 +55,25 @@ def cluster_languages_by_families(languages):
51
 
52
 
53
  def cluster_languages_by_subfamilies(languages):
54
- labels = [language_families[lang] + f" ({language_subfamilies[lang]})" for lang in languages]
 
 
 
55
  legend = sorted(set(labels))
56
  clusters = [legend.index(family) for family in labels]
57
  return clusters, legend
58
 
59
 
60
- def plot_mst(model, dataset, use_average, matrix, languages, clusters, legend=None, fig_size=(20,20)):
 
 
 
 
 
 
 
 
 
61
  """
62
  Plots a Minimum Spanning Tree (MST) from a given distance matrix, node labels, and cluster assignments.
63
 
@@ -68,21 +84,21 @@ def plot_mst(model, dataset, use_average, matrix, languages, clusters, legend=No
68
  """
69
  # Create an empty undirected graph
70
  G = nx.Graph()
71
-
72
  # Number of nodes
73
  N = len(languages)
74
-
75
  # Add edges to the graph from the distance matrix.
76
  # Only iterate over the upper triangle of the matrix (i < j)
77
  for i in range(N):
78
  for j in range(i + 1, N):
79
  G.add_edge(i, j, weight=matrix[i, j])
80
-
81
  # Compute the Minimum Spanning Tree using NetworkX's built-in function.
82
  mst = nx.minimum_spanning_tree(G)
83
-
84
  # Choose a layout for the MST. Here we use Kamada-Kawai layout which considers edge weights.
85
- pos = nx.kamada_kawai_layout(mst, weight='weight')
86
 
87
  # Map each cluster to a color
88
  unique_clusters = sorted(set(clusters))
@@ -90,22 +106,24 @@ def plot_mst(model, dataset, use_average, matrix, languages, clusters, legend=No
90
  cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
91
 
92
  node_colors = [cluster_colors.get(cluster) for cluster in clusters]
93
-
94
  # Create a figure for plotting.
95
  fig, ax = plt.subplots(figsize=fig_size)
96
-
97
  # Draw the MST edges.
98
- nx.draw_networkx_edges(mst, pos, edge_color='gray', ax=ax)
99
-
100
  # Draw the nodes with colors corresponding to their clusters.
101
- nx.draw_networkx_nodes(mst, pos, node_color=node_colors, node_size=100, ax=ax, alpha=0.7)
 
 
102
 
103
  # Instead of directly drawing labels, we create text objects to adjust them later
104
  texts = []
105
  for i, label in enumerate(languages):
106
  x, y = pos[i]
107
  texts.append(ax.text(x, y, label, fontsize=10))
108
-
109
  # Adjust text labels to minimize overlap.
110
  # The arrowprops argument can draw arrows from labels to nodes if desired.
111
  adjust_text(texts, expand_text=(1.05, 1.2))
@@ -114,17 +132,27 @@ def plot_mst(model, dataset, use_average, matrix, languages, clusters, legend=No
114
  if legend is None:
115
  legend = {cluster: str(cluster) for cluster in unique_clusters}
116
  legend_handles = [
117
- plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cluster_colors[cluster], markersize=10, alpha=0.7, label=legend[cluster])
 
 
 
 
 
 
 
 
 
118
  for cluster in unique_clusters
119
  ]
120
  ax.legend(handles=legend_handles, title="Clusters", loc="best")
121
-
122
  # Remove axis for clarity.
123
- ax.axis('off')
124
  # ax.set_title(f"Minimum Spanning Tree of Languages ({'Average' if use_average else f'{model}, {dataset}'})")
125
 
126
  return fig
127
 
 
128
  def cluster_languages_kmeans(dist_matrix, languages, n_clusters=5):
129
  """
130
  Clusters languages using a distance matrix and KMeans.
@@ -172,9 +200,7 @@ def cluster_languages_hdbscan(dist_matrix, languages, min_cluster_size=2):
172
  - clusters: list of length N containing the cluster assignment (or ID) for each language.
173
  """
174
  # Perform clustering using HDBSCAN with the precomputed distance matrix
175
- clustering_model = HDBSCAN(
176
- metric='precomputed', min_cluster_size=min_cluster_size
177
- )
178
  clusters = clustering_model.fit_predict(dist_matrix)
179
 
180
  # Filter out points belonging to cluster -1 using NumPy
@@ -185,7 +211,9 @@ def cluster_languages_hdbscan(dist_matrix, languages, min_cluster_size=2):
185
  return filtered_matrix, filtered_languages, filtered_clusters
186
 
187
 
188
- def plot_distances_tsne(model, dataset, use_average, matrix, languages, clusters, legend=None):
 
 
189
  """
190
  Plots all languages from the distances matrix using t-SNE and colors them by clusters.
191
  """
@@ -198,7 +226,12 @@ def plot_distances_tsne(model, dataset, use_average, matrix, languages, clusters
198
  cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
199
 
200
  fig, ax = plt.subplots(figsize=(16, 12))
201
- scatter = ax.scatter(tsne_results[:, 0], tsne_results[:, 1], c=[cluster_colors[cluster] for cluster in clusters], alpha=0.7)
 
 
 
 
 
202
 
203
  # for i, lang in enumerate(languages):
204
  # ax.text(tsne_results[i, 0], tsne_results[i, 1], lang, fontsize=8, alpha=0.8)
@@ -208,7 +241,7 @@ def plot_distances_tsne(model, dataset, use_average, matrix, languages, clusters
208
  for i, label in enumerate(languages):
209
  x, y = tsne_results[i, 0], tsne_results[i, 1]
210
  texts.append(ax.text(x, y, label, fontsize=10))
211
-
212
  # Adjust text labels to minimize overlap.
213
  # The arrowprops argument can draw arrows from labels to nodes if desired.
214
  adjust_text(texts, expand_text=(1.05, 1.2))
@@ -217,18 +250,30 @@ def plot_distances_tsne(model, dataset, use_average, matrix, languages, clusters
217
  if legend is None:
218
  legend = {cluster: str(cluster) for cluster in unique_clusters}
219
  legend_handles = [
220
- plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cluster_colors[cluster], markersize=10, label=legend[cluster])
 
 
 
 
 
 
 
 
221
  for cluster in unique_clusters
222
  ]
223
  ax.legend(handles=legend_handles, title="Clusters", loc="best")
224
 
225
- ax.set_title(f"t-SNE Visualization of Language Distances ({'Average' if use_average else f'{model}, {dataset}'})")
 
 
226
  ax.set_xlabel("t-SNE Dimension 1")
227
  ax.set_ylabel("t-SNE Dimension 2")
228
  return fig
229
 
230
 
231
- def plot_distances_umap(model, dataset, use_average, matrix, languages, clusters, legend=None):
 
 
232
  """
233
  Plots all languages from the distances matrix using UMAP and colors them by clusters.
234
  """
@@ -242,7 +287,12 @@ def plot_distances_umap(model, dataset, use_average, matrix, languages, clusters
242
  cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
243
 
244
  fig, ax = plt.subplots(figsize=(16, 12))
245
- scatter = ax.scatter(umap_results[:, 0], umap_results[:, 1], c=[cluster_colors[cluster] for cluster in clusters], alpha=0.7)
 
 
 
 
 
246
 
247
  # for i, lang in enumerate(languages):
248
  # ax.text(umap_results[i, 0], umap_results[i, 1], lang, fontsize=8, alpha=0.8)
@@ -252,7 +302,7 @@ def plot_distances_umap(model, dataset, use_average, matrix, languages, clusters
252
  for i, label in enumerate(languages):
253
  x, y = umap_results[i, 0], umap_results[i, 1]
254
  texts.append(ax.text(x, y, label, fontsize=10))
255
-
256
  # Adjust text labels to minimize overlap.
257
  # The arrowprops argument can draw arrows from labels to nodes if desired.
258
  adjust_text(texts, expand_text=(1.05, 1.2))
@@ -261,12 +311,22 @@ def plot_distances_umap(model, dataset, use_average, matrix, languages, clusters
261
  if legend is None:
262
  legend = {cluster: str(cluster) for cluster in unique_clusters}
263
  legend_handles = [
264
- plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cluster_colors[cluster], markersize=10, label=legend[cluster])
 
 
 
 
 
 
 
 
265
  for cluster in unique_clusters
266
  ]
267
  ax.legend(handles=legend_handles, title="Clusters", loc="best")
268
 
269
- ax.set_title(f"UMAP Visualization of Language Distances ({'Average' if use_average else f'{model}, {dataset}'})")
 
 
270
  ax.set_xlabel("UMAP Dimension 1")
271
  ax.set_ylabel("UMAP Dimension 2")
272
- return fig
 
21
  Returns:
22
  - filtered_languages: list of languages that belong to the specified families.
23
  """
24
+ filtered_languages = [
25
+ (i, lang)
26
+ for i, lang in enumerate(languages)
27
+ if language_families[lang] in families
28
+ ]
29
  filtered_indices = [i for i, lang in filtered_languages]
30
  filtered_languages = [lang for i, lang in filtered_languages]
31
  filtered_matrix = matrix[np.ix_(filtered_indices, filtered_indices)]
 
55
 
56
 
57
  def cluster_languages_by_subfamilies(languages):
58
+ labels = [
59
+ language_families[lang] + f" ({language_subfamilies[lang]})"
60
+ for lang in languages
61
+ ]
62
  legend = sorted(set(labels))
63
  clusters = [legend.index(family) for family in labels]
64
  return clusters, legend
65
 
66
 
67
+ def plot_mst(
68
+ model,
69
+ dataset,
70
+ use_average,
71
+ matrix,
72
+ languages,
73
+ clusters,
74
+ legend=None,
75
+ fig_size=(20, 20),
76
+ ):
77
  """
78
  Plots a Minimum Spanning Tree (MST) from a given distance matrix, node labels, and cluster assignments.
79
 
 
84
  """
85
  # Create an empty undirected graph
86
  G = nx.Graph()
87
+
88
  # Number of nodes
89
  N = len(languages)
90
+
91
  # Add edges to the graph from the distance matrix.
92
  # Only iterate over the upper triangle of the matrix (i < j)
93
  for i in range(N):
94
  for j in range(i + 1, N):
95
  G.add_edge(i, j, weight=matrix[i, j])
96
+
97
  # Compute the Minimum Spanning Tree using NetworkX's built-in function.
98
  mst = nx.minimum_spanning_tree(G)
99
+
100
  # Choose a layout for the MST. Here we use Kamada-Kawai layout which considers edge weights.
101
+ pos = nx.kamada_kawai_layout(mst, weight="weight")
102
 
103
  # Map each cluster to a color
104
  unique_clusters = sorted(set(clusters))
 
106
  cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
107
 
108
  node_colors = [cluster_colors.get(cluster) for cluster in clusters]
109
+
110
  # Create a figure for plotting.
111
  fig, ax = plt.subplots(figsize=fig_size)
112
+
113
  # Draw the MST edges.
114
+ nx.draw_networkx_edges(mst, pos, edge_color="gray", ax=ax)
115
+
116
  # Draw the nodes with colors corresponding to their clusters.
117
+ nx.draw_networkx_nodes(
118
+ mst, pos, node_color=node_colors, node_size=100, ax=ax, alpha=0.7
119
+ )
120
 
121
  # Instead of directly drawing labels, we create text objects to adjust them later
122
  texts = []
123
  for i, label in enumerate(languages):
124
  x, y = pos[i]
125
  texts.append(ax.text(x, y, label, fontsize=10))
126
+
127
  # Adjust text labels to minimize overlap.
128
  # The arrowprops argument can draw arrows from labels to nodes if desired.
129
  adjust_text(texts, expand_text=(1.05, 1.2))
 
132
  if legend is None:
133
  legend = {cluster: str(cluster) for cluster in unique_clusters}
134
  legend_handles = [
135
+ plt.Line2D(
136
+ [0],
137
+ [0],
138
+ marker="o",
139
+ color="w",
140
+ markerfacecolor=cluster_colors[cluster],
141
+ markersize=10,
142
+ alpha=0.7,
143
+ label=legend[cluster],
144
+ )
145
  for cluster in unique_clusters
146
  ]
147
  ax.legend(handles=legend_handles, title="Clusters", loc="best")
148
+
149
  # Remove axis for clarity.
150
+ ax.axis("off")
151
  # ax.set_title(f"Minimum Spanning Tree of Languages ({'Average' if use_average else f'{model}, {dataset}'})")
152
 
153
  return fig
154
 
155
+
156
  def cluster_languages_kmeans(dist_matrix, languages, n_clusters=5):
157
  """
158
  Clusters languages using a distance matrix and KMeans.
 
200
  - clusters: list of length N containing the cluster assignment (or ID) for each language.
201
  """
202
  # Perform clustering using HDBSCAN with the precomputed distance matrix
203
+ clustering_model = HDBSCAN(metric="precomputed", min_cluster_size=min_cluster_size)
 
 
204
  clusters = clustering_model.fit_predict(dist_matrix)
205
 
206
  # Filter out points belonging to cluster -1 using NumPy
 
211
  return filtered_matrix, filtered_languages, filtered_clusters
212
 
213
 
214
+ def plot_distances_tsne(
215
+ model, dataset, use_average, matrix, languages, clusters, legend=None
216
+ ):
217
  """
218
  Plots all languages from the distances matrix using t-SNE and colors them by clusters.
219
  """
 
226
  cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
227
 
228
  fig, ax = plt.subplots(figsize=(16, 12))
229
+ scatter = ax.scatter(
230
+ tsne_results[:, 0],
231
+ tsne_results[:, 1],
232
+ c=[cluster_colors[cluster] for cluster in clusters],
233
+ alpha=0.7,
234
+ )
235
 
236
  # for i, lang in enumerate(languages):
237
  # ax.text(tsne_results[i, 0], tsne_results[i, 1], lang, fontsize=8, alpha=0.8)
 
241
  for i, label in enumerate(languages):
242
  x, y = tsne_results[i, 0], tsne_results[i, 1]
243
  texts.append(ax.text(x, y, label, fontsize=10))
244
+
245
  # Adjust text labels to minimize overlap.
246
  # The arrowprops argument can draw arrows from labels to nodes if desired.
247
  adjust_text(texts, expand_text=(1.05, 1.2))
 
250
  if legend is None:
251
  legend = {cluster: str(cluster) for cluster in unique_clusters}
252
  legend_handles = [
253
+ plt.Line2D(
254
+ [0],
255
+ [0],
256
+ marker="o",
257
+ color="w",
258
+ markerfacecolor=cluster_colors[cluster],
259
+ markersize=10,
260
+ label=legend[cluster],
261
+ )
262
  for cluster in unique_clusters
263
  ]
264
  ax.legend(handles=legend_handles, title="Clusters", loc="best")
265
 
266
+ ax.set_title(
267
+ f"t-SNE Visualization of Language Distances ({'Average' if use_average else f'{model}, {dataset}'})"
268
+ )
269
  ax.set_xlabel("t-SNE Dimension 1")
270
  ax.set_ylabel("t-SNE Dimension 2")
271
  return fig
272
 
273
 
274
+ def plot_distances_umap(
275
+ model, dataset, use_average, matrix, languages, clusters, legend=None
276
+ ):
277
  """
278
  Plots all languages from the distances matrix using UMAP and colors them by clusters.
279
  """
 
287
  cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
288
 
289
  fig, ax = plt.subplots(figsize=(16, 12))
290
+ scatter = ax.scatter(
291
+ umap_results[:, 0],
292
+ umap_results[:, 1],
293
+ c=[cluster_colors[cluster] for cluster in clusters],
294
+ alpha=0.7,
295
+ )
296
 
297
  # for i, lang in enumerate(languages):
298
  # ax.text(umap_results[i, 0], umap_results[i, 1], lang, fontsize=8, alpha=0.8)
 
302
  for i, label in enumerate(languages):
303
  x, y = umap_results[i, 0], umap_results[i, 1]
304
  texts.append(ax.text(x, y, label, fontsize=10))
305
+
306
  # Adjust text labels to minimize overlap.
307
  # The arrowprops argument can draw arrows from labels to nodes if desired.
308
  adjust_text(texts, expand_text=(1.05, 1.2))
 
311
  if legend is None:
312
  legend = {cluster: str(cluster) for cluster in unique_clusters}
313
  legend_handles = [
314
+ plt.Line2D(
315
+ [0],
316
+ [0],
317
+ marker="o",
318
+ color="w",
319
+ markerfacecolor=cluster_colors[cluster],
320
+ markersize=10,
321
+ label=legend[cluster],
322
+ )
323
  for cluster in unique_clusters
324
  ]
325
  ax.legend(handles=legend_handles, title="Clusters", loc="best")
326
 
327
+ ax.set_title(
328
+ f"UMAP Visualization of Language Distances ({'Average' if use_average else f'{model}, {dataset}'})"
329
+ )
330
  ax.set_xlabel("UMAP Dimension 1")
331
  ax.set_ylabel("UMAP Dimension 2")
332
+ return fig