import gradio as gr import pandas as pd import numpy as np import pickle import os from sklearn.manifold import TSNE import matplotlib.pyplot as plt from utils import (plot_distances_tsne, plot_distances_umap, cluster_languages_hdbscan, cluster_languages_kmeans, plot_mst, cluster_languages_by_families, cluster_languages_by_subfamilies, filter_languages_by_families) from functools import partial import datasets dataset = datasets.load_dataset("mshamrai/language-metric-data", split="train", trust_remote_code=True) languages = dataset["languages_list"][0] average_distances_matrix = np.array(dataset["average_distances_matrix"][0]) DATASETS = dataset["distances_matrices"][0]["dataset_name"] MODELS = dataset["distances_matrices"][0]["models"][0]["model_name"] distance_matrices = { DATASETS[i]: { MODELS[j]: np.array(dataset["distances_matrices"][0]["models"][i]["matrix"][j]) for j in range(len(MODELS)) } for i in range(len(DATASETS)) } def filter_languages_nan(model, dataset, use_average): if use_average: matrix = average_distances_matrix else: matrix = distance_matrices[dataset][model] vector = matrix[0] updated_languages = np.array(languages)[~np.isnan(vector)] updated_matrix = matrix[~np.isnan(vector), :][:, ~np.isnan(vector)] return updated_matrix, updated_languages def get_similar_languages(model, dataset, selected_language, use_average, n): """ Retrieves the distances for the selected language from the chosen model and dataset, sorts them by similarity (lowest distance first), and returns a DataFrame. """ if use_average: matrix = average_distances_matrix else: matrix = distance_matrices[dataset][model] selected_language_index = languages.index(selected_language) distances = matrix[selected_language_index] df = pd.DataFrame({"Language": languages, "Distance": distances}) sorted_distances = df.sort_values(by="Distance") sorted_distances.drop(index=selected_language_index, inplace=True) sorted_distances.reset_index(drop=True, inplace=True) sorted_distances.reset_index(inplace=True) sorted_distances["Distance"] = sorted_distances["Distance"].round(4) return sorted_distances.head(n) def update_languages(model, dataset): """ Returns the language list based on the given model and dataset. """ matrix = distance_matrices[dataset][model] vector = matrix[0] updated_languages = np.array(languages)[~np.isnan(vector)] return list(updated_languages) def update_language_options(model, dataset, language, use_average): if use_average: updated_languages = languages else: updated_languages = update_languages(model, dataset) if language not in updated_languages: language = updated_languages[0] return gr.Dropdown(label="Language", choices=updated_languages, value=language) def toggle_inputs(use_average): if use_average: return gr.update(interactive=False, visible=False), gr.update(interactive=False, visible=False) else: return gr.update(interactive=True, visible=True), gr.update(interactive=True, visible=True) plot_path = "plots/last_plot.pdf" def plot_distances(model, dataset, use_average, cluster_method, cluster_method_param, plot_fn): """ Plots all languages from the distances matrix using t-SNE. """ updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average) if cluster_method == "HDBSCAN": filtered_matrix, filtered_languages, clusters = cluster_languages_hdbscan( updated_matrix, updated_languages, min_cluster_size=cluster_method_param ) legends = None elif cluster_method == "KMeans": filtered_matrix, filtered_languages, clusters = cluster_languages_kmeans( updated_matrix, updated_languages, n_clusters=cluster_method_param ) legends = None elif cluster_method == "Family": clusters, legends = cluster_languages_by_families(updated_languages) filtered_matrix = updated_matrix filtered_languages = updated_languages elif cluster_method == "Subfamily": clusters, legends = cluster_languages_by_subfamilies(updated_languages) filtered_matrix = updated_matrix filtered_languages = updated_languages else: raise ValueError("Invalid cluster method") fig = plot_fn(model, dataset, use_average, filtered_matrix, filtered_languages, clusters, legends) fig.tight_layout() fig.savefig(plot_path, format="pdf") return fig, gr.DownloadButton(label="Download Plot", value=plot_path) def plot_families_subfamilies(families, model, dataset, use_average, figsize_h, figsize_w): updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average) updated_matrix, updated_languages = filter_languages_by_families(updated_matrix, updated_languages, families) clusters, legends = cluster_languages_by_subfamilies(updated_languages) fig = plot_mst(model, dataset, use_average, updated_matrix, updated_languages, clusters, legends, fig_size=(figsize_w, figsize_h)) fig.tight_layout() fig.savefig(plot_path, format="pdf") return fig, gr.DownloadButton(label="Download Plot", value=plot_path) with gr.Blocks() as demo: gr.Markdown("## Language Distance Explorer") average_checkbox = gr.Checkbox(label="Use Average Distances", value=False) with gr.Row(): model_input = gr.Dropdown(label="Model", choices=MODELS, value=MODELS[0]) dataset_input = gr.Dropdown( label="Dataset", choices=DATASETS, value=DATASETS[0] ) with gr.Tab(label="Closest Languages Table"): with gr.Row(): language_input = gr.Dropdown(label="Language", choices=languages, value=languages[0]) top_n_input = gr.Slider(label="Top N", minimum=1, maximum=30, step=1, value=10) output_table = gr.Dataframe(label="Similar Languages") model_input.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input) dataset_input.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input) language_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table) model_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table) dataset_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table) top_n_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table) average_checkbox.change( fn=toggle_inputs, inputs=[average_checkbox], outputs=[model_input, dataset_input] ) average_checkbox.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input) average_checkbox.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table) with gr.Tab(label="Distance Plot"): with gr.Row(): cluster_method_input = gr.Dropdown(label="Cluster Method", choices=["HDBSCAN", "KMeans", "Family", "Subfamily"], value="HDBSCAN") clusters_input = gr.Slider(label="Minimum Elements in a Cluster", minimum=2, maximum=10, step=1, value=2) def update_clusters_input_option(cluster_method): if cluster_method == "HDBSCAN": return gr.Slider(label="Minimum Elements in a Cluster", minimum=2, maximum=10, step=1, value=2, visible=True, interactive=True) elif cluster_method == "KMeans": return gr.Slider(label="Number of Clusters", minimum=2, maximum=20, step=1, value=2, visible=True, interactive=True) else: return gr.update(interactive=False, visible=False) cluster_method_input.change(fn=update_clusters_input_option, inputs=[cluster_method_input], outputs=clusters_input) with gr.Row(): plot_tsne_button = gr.Button("Plot t-SNE") plot_umap_button = gr.Button("Plot UMAP") plot_mst_button = gr.Button("Plot MST") with gr.Row(): download_plot_button = gr.DownloadButton("Download Plot") with gr.Row(): plot_output = gr.Plot(label="Distance Plot") plot_tsne_button.click(fn=partial(plot_distances, plot_fn=plot_distances_tsne), inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input], outputs=[plot_output, download_plot_button]) plot_umap_button.click(fn=partial(plot_distances, plot_fn=plot_distances_umap), inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input], outputs=[plot_output, download_plot_button]) plot_mst_button.click(fn=partial(plot_distances, plot_fn=plot_mst), inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input], outputs=[plot_output, download_plot_button]) with gr.Tab(label="Language Families Subplot"): checked_families_input = gr.CheckboxGroup(label="Language Families", choices=[ 'Afroasiatic', 'Austroasiatic', 'Austronesian', 'Constructed', 'Creole', 'Dravidian', 'Germanic', 'Indo-European', 'Japonic', 'Kartvelian', 'Koreanic', 'Language Isolate', 'Niger-Congo', 'Northeast Caucasian', 'Romance', 'Sino-Tibetan', 'Turkic', 'Uralic' ], value=["Indo-European"]) with gr.Row(): plot_family_button = gr.Button("Plot Families") plot_figsize_h_input = gr.Slider(label="Figure Height", minimum=5, maximum=30, step=1, value=15) plot_figsize_w_input = gr.Slider(label="Figure Width", minimum=5, maximum=30, step=1, value=15) with gr.Row(): download_families_plot_button = gr.DownloadButton("Download Plot", value=plot_path) plot_family_output = gr.Plot(label="Families Plot") plot_family_button.click(fn=plot_families_subfamilies, inputs=[checked_families_input, model_input, dataset_input, average_checkbox, plot_figsize_h_input, plot_figsize_w_input], outputs=[plot_family_output, download_families_plot_button]) demo.launch(share=True)