import gradio as gr import pandas as pd import numpy as np import os from utils import ( plot_distances_tsne, plot_distances_umap, cluster_languages_hdbscan, cluster_languages_kmeans, plot_mst, cluster_languages_by_families, cluster_languages_by_subfamilies, filter_languages_by_families, ) from functools import partial import datasets dataset = datasets.load_dataset( "mshamrai/language-metric-data", split="train", trust_remote_code=True ) languages = dataset["languages_list"][0] average_distances_matrix = np.array(dataset["average_distances_matrix"][0]) DATASETS = dataset["distances_matrices"][0]["dataset_name"] MODELS = dataset["distances_matrices"][0]["models"][0]["model_name"] distance_matrices = { DATASETS[i]: { MODELS[j]: np.array(dataset["distances_matrices"][0]["models"][i]["matrix"][j]) for j in range(len(MODELS)) } for i in range(len(DATASETS)) } def filter_languages_nan(model, dataset, use_average): if use_average: matrix = average_distances_matrix else: matrix = distance_matrices[dataset][model] vector = matrix[0] updated_languages = np.array(languages)[~np.isnan(vector)] updated_matrix = matrix[~np.isnan(vector), :][:, ~np.isnan(vector)] return updated_matrix, updated_languages def get_similar_languages(model, dataset, selected_language, use_average, n): """ Retrieves the distances for the selected language from the chosen model and dataset, sorts them by similarity (lowest distance first), and returns a DataFrame. """ if use_average: matrix = average_distances_matrix else: matrix = distance_matrices[dataset][model] selected_language_index = languages.index(selected_language) distances = matrix[selected_language_index] df = pd.DataFrame({"Language": languages, "Distance": distances}) sorted_distances = df.sort_values(by="Distance") sorted_distances.drop(index=selected_language_index, inplace=True) sorted_distances.reset_index(drop=True, inplace=True) sorted_distances.reset_index(inplace=True) sorted_distances["Distance"] = sorted_distances["Distance"].round(4) return sorted_distances.head(n) def update_languages(model, dataset): """ Returns the language list based on the given model and dataset. """ matrix = distance_matrices[dataset][model] vector = matrix[0] updated_languages = np.array(languages)[~np.isnan(vector)] return list(updated_languages) def update_language_options(model, dataset, language, use_average): if use_average: updated_languages = languages else: updated_languages = update_languages(model, dataset) if language not in updated_languages: language = updated_languages[0] return gr.Dropdown(label="Language", choices=updated_languages, value=language) def toggle_inputs(use_average): if use_average: return gr.update(interactive=False, visible=False), gr.update( interactive=False, visible=False ) else: return gr.update(interactive=True, visible=True), gr.update( interactive=True, visible=True ) plot_path = "plots/last_plot.pdf" os.makedirs("plots", exist_ok=True) def plot_distances( model, dataset, use_average, cluster_method, cluster_method_param, figsize_h, figsize_w, plot_fn, ): """ Plots all languages from the distances matrix using t-SNE. """ updated_matrix, updated_languages = filter_languages_nan( model, dataset, use_average ) if cluster_method == "HDBSCAN": filtered_matrix, filtered_languages, clusters = cluster_languages_hdbscan( updated_matrix, updated_languages, min_cluster_size=cluster_method_param ) legends = None elif cluster_method == "KMeans": filtered_matrix, filtered_languages, clusters = cluster_languages_kmeans( updated_matrix, updated_languages, n_clusters=cluster_method_param ) legends = None elif cluster_method == "Family": clusters, legends = cluster_languages_by_families(updated_languages) filtered_matrix = updated_matrix filtered_languages = updated_languages elif cluster_method == "Subfamily": clusters, legends = cluster_languages_by_subfamilies(updated_languages) filtered_matrix = updated_matrix filtered_languages = updated_languages else: raise ValueError("Invalid cluster method") fig = plot_fn( filtered_matrix, filtered_languages, clusters, legends, fig_size=(figsize_w, figsize_h), ) fig.tight_layout() fig.savefig(plot_path, format="pdf") return fig, gr.DownloadButton(label="Download Plot", value=plot_path) def plot_families_subfamilies( families, model, dataset, use_average, figsize_h, figsize_w ): updated_matrix, updated_languages = filter_languages_nan( model, dataset, use_average ) updated_matrix, updated_languages = filter_languages_by_families( updated_matrix, updated_languages, families ) clusters, legends = cluster_languages_by_subfamilies(updated_languages) fig = plot_mst( updated_matrix, updated_languages, clusters, legends, fig_size=(figsize_w, figsize_h), ) fig.tight_layout() fig.savefig(plot_path, format="pdf") return fig, gr.DownloadButton(label="Download Plot", value=plot_path) with gr.Blocks() as demo: gr.Markdown("## Language Distance Explorer") average_checkbox = gr.Checkbox(label="Use Average Distances", value=False) with gr.Row(): model_input = gr.Dropdown(label="Model", choices=MODELS, value=MODELS[0]) dataset_input = gr.Dropdown( label="Dataset", choices=DATASETS, value=DATASETS[0] ) with gr.Tab(label="Closest Languages Table"): with gr.Row(): language_input = gr.Dropdown( label="Language", choices=languages, value=languages[0] ) top_n_input = gr.Slider( label="Top N", minimum=1, maximum=30, step=1, value=10 ) output_table = gr.Dataframe(label="Similar Languages") model_input.change( fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input, ) dataset_input.change( fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input, ) language_input.change( fn=get_similar_languages, inputs=[ model_input, dataset_input, language_input, average_checkbox, top_n_input, ], outputs=output_table, ) model_input.change( fn=get_similar_languages, inputs=[ model_input, dataset_input, language_input, average_checkbox, top_n_input, ], outputs=output_table, ) dataset_input.change( fn=get_similar_languages, inputs=[ model_input, dataset_input, language_input, average_checkbox, top_n_input, ], outputs=output_table, ) top_n_input.change( fn=get_similar_languages, inputs=[ model_input, dataset_input, language_input, average_checkbox, top_n_input, ], outputs=output_table, ) average_checkbox.change( fn=toggle_inputs, inputs=[average_checkbox], outputs=[model_input, dataset_input], ) average_checkbox.change( fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input, ) average_checkbox.change( fn=get_similar_languages, inputs=[ model_input, dataset_input, language_input, average_checkbox, top_n_input, ], outputs=output_table, ) with gr.Tab(label="Distance Plot"): with gr.Row(): cluster_method_input = gr.Dropdown( label="Cluster Method", choices=["HDBSCAN", "KMeans", "Family", "Subfamily"], value="HDBSCAN", ) clusters_input = gr.Slider( label="Minimum Elements in a Cluster", minimum=2, maximum=10, step=1, value=2, ) def update_clusters_input_option(cluster_method): if cluster_method == "HDBSCAN": return gr.Slider( label="Minimum Elements in a Cluster", minimum=2, maximum=10, step=1, value=2, visible=True, interactive=True, ) elif cluster_method == "KMeans": return gr.Slider( label="Number of Clusters", minimum=2, maximum=20, step=1, value=2, visible=True, interactive=True, ) else: return gr.update(interactive=False, visible=False) cluster_method_input.change( fn=update_clusters_input_option, inputs=[cluster_method_input], outputs=clusters_input, ) with gr.Row(): plot_tsne_button = gr.Button("Plot t-SNE") plot_umap_button = gr.Button("Plot UMAP") plot_mst_button = gr.Button("Plot MST") with gr.Row(): plot_figsize_dist_h_input = gr.Slider( label="Figure Height", minimum=5, maximum=30, step=1, value=15 ) plot_figsize_dist_w_input = gr.Slider( label="Figure Width", minimum=5, maximum=30, step=1, value=15 ) with gr.Row(): download_plot_button = gr.DownloadButton("Download Plot") with gr.Row(): plot_output = gr.Plot(label="Distance Plot") plot_tsne_button.click( fn=partial(plot_distances, plot_fn=plot_distances_tsne), inputs=[ model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input, plot_figsize_dist_h_input, plot_figsize_dist_w_input, ], outputs=[plot_output, download_plot_button], ) plot_umap_button.click( fn=partial(plot_distances, plot_fn=plot_distances_umap), inputs=[ model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input, plot_figsize_dist_h_input, plot_figsize_dist_w_input, ], outputs=[plot_output, download_plot_button], ) plot_mst_button.click( fn=partial(plot_distances, plot_fn=plot_mst), inputs=[ model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input, plot_figsize_dist_h_input, plot_figsize_dist_w_input, ], outputs=[plot_output, download_plot_button], ) with gr.Tab(label="Language Families Subplot"): checked_families_input = gr.CheckboxGroup( label="Language Families", choices=[ "Afroasiatic", "Austroasiatic", "Austronesian", "Constructed", "Creole", "Dravidian", "Germanic", "Indo-European", "Japonic", "Kartvelian", "Koreanic", "Language Isolate", "Niger-Congo", "Northeast Caucasian", "Romance", "Sino-Tibetan", "Turkic", "Uralic", ], value=["Indo-European"], ) with gr.Row(): plot_family_button = gr.Button("Plot Families") plot_figsize_h_input = gr.Slider( label="Figure Height", minimum=5, maximum=30, step=1, value=15 ) plot_figsize_w_input = gr.Slider( label="Figure Width", minimum=5, maximum=30, step=1, value=15 ) with gr.Row(): download_families_plot_button = gr.DownloadButton( "Download Plot", value=plot_path ) plot_family_output = gr.Plot(label="Families Plot") plot_family_button.click( fn=plot_families_subfamilies, inputs=[ checked_families_input, model_input, dataset_input, average_checkbox, plot_figsize_h_input, plot_figsize_w_input, ], outputs=[plot_family_output, download_families_plot_button], ) demo.launch(share=True)