mshamrai commited on
Commit
f129974
·
1 Parent(s): 4a784da

chore: add download plot button

Browse files
Files changed (1) hide show
  1. app.py +27 -22
app.py CHANGED
@@ -92,15 +92,15 @@ def toggle_inputs(use_average):
92
  else:
93
  return gr.update(interactive=True, visible=True), gr.update(interactive=True, visible=True)
94
 
95
- i = 0
 
 
96
 
97
  def plot_distances(model, dataset, use_average, cluster_method, cluster_method_param, plot_fn):
98
  """
99
  Plots all languages from the distances matrix using t-SNE.
100
  """
101
 
102
- global i
103
-
104
  updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average)
105
 
106
  if cluster_method == "HDBSCAN":
@@ -126,9 +126,19 @@ def plot_distances(model, dataset, use_average, cluster_method, cluster_method_p
126
 
127
  fig = plot_fn(model, dataset, use_average, filtered_matrix, filtered_languages, clusters, legends)
128
  fig.tight_layout()
129
- fig.savefig(f"plots/plot_{i}.pdf", format="pdf")
130
- i += 1
131
- return fig
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  with gr.Blocks() as demo:
@@ -185,19 +195,22 @@ with gr.Blocks() as demo:
185
  plot_tsne_button = gr.Button("Plot t-SNE")
186
  plot_umap_button = gr.Button("Plot UMAP")
187
  plot_mst_button = gr.Button("Plot MST")
 
 
 
188
 
189
  with gr.Row():
190
  plot_output = gr.Plot(label="Distance Plot")
191
 
192
  plot_tsne_button.click(fn=partial(plot_distances, plot_fn=plot_distances_tsne),
193
  inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
194
- outputs=plot_output)
195
  plot_umap_button.click(fn=partial(plot_distances, plot_fn=plot_distances_umap),
196
  inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
197
- outputs=plot_output)
198
  plot_mst_button.click(fn=partial(plot_distances, plot_fn=plot_mst),
199
  inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
200
- outputs=plot_output)
201
 
202
  with gr.Tab(label="Language Families Subplot"):
203
 
@@ -227,23 +240,15 @@ with gr.Blocks() as demo:
227
  plot_family_button = gr.Button("Plot Families")
228
  plot_figsize_h_input = gr.Slider(label="Figure Height", minimum=5, maximum=30, step=1, value=15)
229
  plot_figsize_w_input = gr.Slider(label="Figure Width", minimum=5, maximum=30, step=1, value=15)
 
 
 
 
230
  plot_family_output = gr.Plot(label="Families Plot")
231
- def plot_families_subfamilies(families, model, dataset, use_average, figsize_h, figsize_w):
232
- global i
233
-
234
- updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average)
235
- updated_matrix, updated_languages = filter_languages_by_families(updated_matrix, updated_languages, families)
236
-
237
- clusters, legends = cluster_languages_by_subfamilies(updated_languages)
238
- fig = plot_mst(model, dataset, use_average, updated_matrix, updated_languages, clusters, legends, fig_size=(figsize_w, figsize_h))
239
- fig.tight_layout()
240
- fig.savefig(f"plots/plot_{i}.pdf", format="pdf")
241
- i += 1
242
- return fig
243
 
244
  plot_family_button.click(fn=plot_families_subfamilies,
245
  inputs=[checked_families_input, model_input, dataset_input, average_checkbox, plot_figsize_h_input, plot_figsize_w_input],
246
- outputs=plot_family_output)
247
 
248
 
249
  demo.launch(share=True)
 
92
  else:
93
  return gr.update(interactive=True, visible=True), gr.update(interactive=True, visible=True)
94
 
95
+
96
+ plot_path = "plots/last_plot.pdf"
97
+
98
 
99
  def plot_distances(model, dataset, use_average, cluster_method, cluster_method_param, plot_fn):
100
  """
101
  Plots all languages from the distances matrix using t-SNE.
102
  """
103
 
 
 
104
  updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average)
105
 
106
  if cluster_method == "HDBSCAN":
 
126
 
127
  fig = plot_fn(model, dataset, use_average, filtered_matrix, filtered_languages, clusters, legends)
128
  fig.tight_layout()
129
+ fig.savefig(plot_path, format="pdf")
130
+ return fig, gr.DownloadButton(label="Download Plot", value=plot_path)
131
+
132
+
133
+ def plot_families_subfamilies(families, model, dataset, use_average, figsize_h, figsize_w):
134
+ updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average)
135
+ updated_matrix, updated_languages = filter_languages_by_families(updated_matrix, updated_languages, families)
136
+
137
+ clusters, legends = cluster_languages_by_subfamilies(updated_languages)
138
+ fig = plot_mst(model, dataset, use_average, updated_matrix, updated_languages, clusters, legends, fig_size=(figsize_w, figsize_h))
139
+ fig.tight_layout()
140
+ fig.savefig(plot_path, format="pdf")
141
+ return fig, gr.DownloadButton(label="Download Plot", value=plot_path)
142
 
143
 
144
  with gr.Blocks() as demo:
 
195
  plot_tsne_button = gr.Button("Plot t-SNE")
196
  plot_umap_button = gr.Button("Plot UMAP")
197
  plot_mst_button = gr.Button("Plot MST")
198
+
199
+ with gr.Row():
200
+ download_plot_button = gr.DownloadButton("Download Plot")
201
 
202
  with gr.Row():
203
  plot_output = gr.Plot(label="Distance Plot")
204
 
205
  plot_tsne_button.click(fn=partial(plot_distances, plot_fn=plot_distances_tsne),
206
  inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
207
+ outputs=[plot_output, download_plot_button])
208
  plot_umap_button.click(fn=partial(plot_distances, plot_fn=plot_distances_umap),
209
  inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
210
+ outputs=[plot_output, download_plot_button])
211
  plot_mst_button.click(fn=partial(plot_distances, plot_fn=plot_mst),
212
  inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
213
+ outputs=[plot_output, download_plot_button])
214
 
215
  with gr.Tab(label="Language Families Subplot"):
216
 
 
240
  plot_family_button = gr.Button("Plot Families")
241
  plot_figsize_h_input = gr.Slider(label="Figure Height", minimum=5, maximum=30, step=1, value=15)
242
  plot_figsize_w_input = gr.Slider(label="Figure Width", minimum=5, maximum=30, step=1, value=15)
243
+
244
+ with gr.Row():
245
+ download_families_plot_button = gr.DownloadButton("Download Plot", value=plot_path)
246
+
247
  plot_family_output = gr.Plot(label="Families Plot")
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
  plot_family_button.click(fn=plot_families_subfamilies,
250
  inputs=[checked_families_input, model_input, dataset_input, average_checkbox, plot_figsize_h_input, plot_figsize_w_input],
251
+ outputs=[plot_family_output, download_families_plot_button])
252
 
253
 
254
  demo.launch(share=True)