mshamrai commited on
Commit
69dd8d4
·
1 Parent(s): 8b0cc53

chore: add fisize

Browse files
Files changed (2) hide show
  1. app.py +19 -1
  2. utils.py +18 -4
app.py CHANGED
@@ -104,7 +104,14 @@ os.makedirs("plots", exist_ok=True)
104
 
105
 
106
  def plot_distances(
107
- model, dataset, use_average, cluster_method, cluster_method_param, plot_fn
 
 
 
 
 
 
 
108
  ):
109
  """
110
  Plots all languages from the distances matrix using t-SNE.
@@ -143,6 +150,7 @@ def plot_distances(
143
  filtered_languages,
144
  clusters,
145
  legends,
 
146
  )
147
  fig.tight_layout()
148
  fig.savefig(plot_path, format="pdf")
@@ -323,6 +331,14 @@ with gr.Blocks() as demo:
323
  plot_umap_button = gr.Button("Plot UMAP")
324
  plot_mst_button = gr.Button("Plot MST")
325
 
 
 
 
 
 
 
 
 
326
  with gr.Row():
327
  download_plot_button = gr.DownloadButton("Download Plot")
328
 
@@ -359,6 +375,8 @@ with gr.Blocks() as demo:
359
  average_checkbox,
360
  cluster_method_input,
361
  clusters_input,
 
 
362
  ],
363
  outputs=[plot_output, download_plot_button],
364
  )
 
104
 
105
 
106
  def plot_distances(
107
+ model,
108
+ dataset,
109
+ use_average,
110
+ cluster_method,
111
+ cluster_method_param,
112
+ plot_fn,
113
+ figsize_h,
114
+ figsize_w,
115
  ):
116
  """
117
  Plots all languages from the distances matrix using t-SNE.
 
150
  filtered_languages,
151
  clusters,
152
  legends,
153
+ fig_size=(figsize_w, figsize_h),
154
  )
155
  fig.tight_layout()
156
  fig.savefig(plot_path, format="pdf")
 
331
  plot_umap_button = gr.Button("Plot UMAP")
332
  plot_mst_button = gr.Button("Plot MST")
333
 
334
+ with gr.Row():
335
+ plot_figsize_dist_h_input = gr.Slider(
336
+ label="Figure Height", minimum=5, maximum=30, step=1, value=15
337
+ )
338
+ plot_figsize_dist_w_input = gr.Slider(
339
+ label="Figure Width", minimum=5, maximum=30, step=1, value=15
340
+ )
341
+
342
  with gr.Row():
343
  download_plot_button = gr.DownloadButton("Download Plot")
344
 
 
375
  average_checkbox,
376
  cluster_method_input,
377
  clusters_input,
378
+ plot_figsize_dist_h_input,
379
+ plot_figsize_dist_w_input,
380
  ],
381
  outputs=[plot_output, download_plot_button],
382
  )
utils.py CHANGED
@@ -212,7 +212,14 @@ def cluster_languages_hdbscan(dist_matrix, languages, min_cluster_size=2):
212
 
213
 
214
  def plot_distances_tsne(
215
- model, dataset, use_average, matrix, languages, clusters, legend=None
 
 
 
 
 
 
 
216
  ):
217
  """
218
  Plots all languages from the distances matrix using t-SNE and colors them by clusters.
@@ -225,7 +232,7 @@ def plot_distances_tsne(
225
  cmap = get_dynamic_color_map(len(unique_clusters))
226
  cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
227
 
228
- fig, ax = plt.subplots(figsize=(16, 12))
229
  scatter = ax.scatter(
230
  tsne_results[:, 0],
231
  tsne_results[:, 1],
@@ -272,7 +279,14 @@ def plot_distances_tsne(
272
 
273
 
274
  def plot_distances_umap(
275
- model, dataset, use_average, matrix, languages, clusters, legend=None
 
 
 
 
 
 
 
276
  ):
277
  """
278
  Plots all languages from the distances matrix using UMAP and colors them by clusters.
@@ -286,7 +300,7 @@ def plot_distances_umap(
286
  cmap = get_dynamic_color_map(len(unique_clusters))
287
  cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
288
 
289
- fig, ax = plt.subplots(figsize=(16, 12))
290
  scatter = ax.scatter(
291
  umap_results[:, 0],
292
  umap_results[:, 1],
 
212
 
213
 
214
  def plot_distances_tsne(
215
+ model,
216
+ dataset,
217
+ use_average,
218
+ matrix,
219
+ languages,
220
+ clusters,
221
+ legend=None,
222
+ fig_size=(16, 12),
223
  ):
224
  """
225
  Plots all languages from the distances matrix using t-SNE and colors them by clusters.
 
232
  cmap = get_dynamic_color_map(len(unique_clusters))
233
  cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
234
 
235
+ fig, ax = plt.subplots(figsize=fig_size)
236
  scatter = ax.scatter(
237
  tsne_results[:, 0],
238
  tsne_results[:, 1],
 
279
 
280
 
281
  def plot_distances_umap(
282
+ model,
283
+ dataset,
284
+ use_average,
285
+ matrix,
286
+ languages,
287
+ clusters,
288
+ legend=None,
289
+ fig_size=(16, 12),
290
  ):
291
  """
292
  Plots all languages from the distances matrix using UMAP and colors them by clusters.
 
300
  cmap = get_dynamic_color_map(len(unique_clusters))
301
  cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
302
 
303
+ fig, ax = plt.subplots(figsize=fig_size)
304
  scatter = ax.scatter(
305
  umap_results[:, 0],
306
  umap_results[:, 1],