Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import os | |
from utils import ( | |
plot_distances_tsne, | |
plot_distances_umap, | |
cluster_languages_hdbscan, | |
cluster_languages_kmeans, | |
plot_mst, | |
cluster_languages_by_families, | |
cluster_languages_by_subfamilies, | |
filter_languages_by_families, | |
) | |
from functools import partial | |
import datasets | |
dataset = datasets.load_dataset( | |
"mshamrai/language-metric-data", split="train", trust_remote_code=True | |
) | |
languages = dataset["languages_list"][0] | |
average_distances_matrix = np.array(dataset["average_distances_matrix"][0]) | |
DATASETS = dataset["distances_matrices"][0]["dataset_name"] | |
MODELS = dataset["distances_matrices"][0]["models"][0]["model_name"] | |
distance_matrices = { | |
DATASETS[i]: { | |
MODELS[j]: np.array(dataset["distances_matrices"][0]["models"][i]["matrix"][j]) | |
for j in range(len(MODELS)) | |
} | |
for i in range(len(DATASETS)) | |
} | |
def filter_languages_nan(model, dataset, use_average): | |
if use_average: | |
matrix = average_distances_matrix | |
else: | |
matrix = distance_matrices[dataset][model] | |
vector = matrix[0] | |
updated_languages = np.array(languages)[~np.isnan(vector)] | |
updated_matrix = matrix[~np.isnan(vector), :][:, ~np.isnan(vector)] | |
return updated_matrix, updated_languages | |
def get_similar_languages(model, dataset, selected_language, use_average, n): | |
""" | |
Retrieves the distances for the selected language from the chosen model and dataset, | |
sorts them by similarity (lowest distance first), and returns a DataFrame. | |
""" | |
if use_average: | |
matrix = average_distances_matrix | |
else: | |
matrix = distance_matrices[dataset][model] | |
selected_language_index = languages.index(selected_language) | |
distances = matrix[selected_language_index] | |
df = pd.DataFrame({"Language": languages, "Distance": distances}) | |
sorted_distances = df.sort_values(by="Distance") | |
sorted_distances.drop(index=selected_language_index, inplace=True) | |
sorted_distances.reset_index(drop=True, inplace=True) | |
sorted_distances.reset_index(inplace=True) | |
sorted_distances["Distance"] = sorted_distances["Distance"].round(4) | |
return sorted_distances.head(n) | |
def update_languages(model, dataset): | |
""" | |
Returns the language list based on the given model and dataset. | |
""" | |
matrix = distance_matrices[dataset][model] | |
vector = matrix[0] | |
updated_languages = np.array(languages)[~np.isnan(vector)] | |
return list(updated_languages) | |
def update_language_options(model, dataset, language, use_average): | |
if use_average: | |
updated_languages = languages | |
else: | |
updated_languages = update_languages(model, dataset) | |
if language not in updated_languages: | |
language = updated_languages[0] | |
return gr.Dropdown(label="Language", choices=updated_languages, value=language) | |
def toggle_inputs(use_average): | |
if use_average: | |
return gr.update(interactive=False, visible=False), gr.update( | |
interactive=False, visible=False | |
) | |
else: | |
return gr.update(interactive=True, visible=True), gr.update( | |
interactive=True, visible=True | |
) | |
plot_path = "plots/last_plot.pdf" | |
os.makedirs("plots", exist_ok=True) | |
def plot_distances( | |
model, | |
dataset, | |
use_average, | |
cluster_method, | |
cluster_method_param, | |
figsize_h, | |
figsize_w, | |
plot_fn, | |
): | |
""" | |
Plots all languages from the distances matrix using t-SNE. | |
""" | |
updated_matrix, updated_languages = filter_languages_nan( | |
model, dataset, use_average | |
) | |
if cluster_method == "HDBSCAN": | |
filtered_matrix, filtered_languages, clusters = cluster_languages_hdbscan( | |
updated_matrix, updated_languages, min_cluster_size=cluster_method_param | |
) | |
legends = None | |
elif cluster_method == "KMeans": | |
filtered_matrix, filtered_languages, clusters = cluster_languages_kmeans( | |
updated_matrix, updated_languages, n_clusters=cluster_method_param | |
) | |
legends = None | |
elif cluster_method == "Family": | |
clusters, legends = cluster_languages_by_families(updated_languages) | |
filtered_matrix = updated_matrix | |
filtered_languages = updated_languages | |
elif cluster_method == "Subfamily": | |
clusters, legends = cluster_languages_by_subfamilies(updated_languages) | |
filtered_matrix = updated_matrix | |
filtered_languages = updated_languages | |
else: | |
raise ValueError("Invalid cluster method") | |
fig = plot_fn( | |
filtered_matrix, | |
filtered_languages, | |
clusters, | |
legends, | |
fig_size=(figsize_w, figsize_h), | |
) | |
fig.tight_layout() | |
fig.savefig(plot_path, format="pdf") | |
return fig, gr.DownloadButton(label="Download Plot", value=plot_path) | |
def plot_families_subfamilies( | |
families, model, dataset, use_average, figsize_h, figsize_w | |
): | |
updated_matrix, updated_languages = filter_languages_nan( | |
model, dataset, use_average | |
) | |
updated_matrix, updated_languages = filter_languages_by_families( | |
updated_matrix, updated_languages, families | |
) | |
clusters, legends = cluster_languages_by_subfamilies(updated_languages) | |
fig = plot_mst( | |
updated_matrix, | |
updated_languages, | |
clusters, | |
legends, | |
fig_size=(figsize_w, figsize_h), | |
) | |
fig.tight_layout() | |
fig.savefig(plot_path, format="pdf") | |
return fig, gr.DownloadButton(label="Download Plot", value=plot_path) | |
with gr.Blocks() as demo: | |
gr.Markdown("## Language Distance Explorer") | |
average_checkbox = gr.Checkbox(label="Use Average Distances", value=False) | |
with gr.Row(): | |
model_input = gr.Dropdown(label="Model", choices=MODELS, value=MODELS[0]) | |
dataset_input = gr.Dropdown( | |
label="Dataset", choices=DATASETS, value=DATASETS[0] | |
) | |
with gr.Tab(label="Closest Languages Table"): | |
with gr.Row(): | |
language_input = gr.Dropdown( | |
label="Language", choices=languages, value=languages[0] | |
) | |
top_n_input = gr.Slider( | |
label="Top N", minimum=1, maximum=30, step=1, value=10 | |
) | |
output_table = gr.Dataframe(label="Similar Languages") | |
model_input.change( | |
fn=update_language_options, | |
inputs=[model_input, dataset_input, language_input, average_checkbox], | |
outputs=language_input, | |
) | |
dataset_input.change( | |
fn=update_language_options, | |
inputs=[model_input, dataset_input, language_input, average_checkbox], | |
outputs=language_input, | |
) | |
language_input.change( | |
fn=get_similar_languages, | |
inputs=[ | |
model_input, | |
dataset_input, | |
language_input, | |
average_checkbox, | |
top_n_input, | |
], | |
outputs=output_table, | |
) | |
model_input.change( | |
fn=get_similar_languages, | |
inputs=[ | |
model_input, | |
dataset_input, | |
language_input, | |
average_checkbox, | |
top_n_input, | |
], | |
outputs=output_table, | |
) | |
dataset_input.change( | |
fn=get_similar_languages, | |
inputs=[ | |
model_input, | |
dataset_input, | |
language_input, | |
average_checkbox, | |
top_n_input, | |
], | |
outputs=output_table, | |
) | |
top_n_input.change( | |
fn=get_similar_languages, | |
inputs=[ | |
model_input, | |
dataset_input, | |
language_input, | |
average_checkbox, | |
top_n_input, | |
], | |
outputs=output_table, | |
) | |
average_checkbox.change( | |
fn=toggle_inputs, | |
inputs=[average_checkbox], | |
outputs=[model_input, dataset_input], | |
) | |
average_checkbox.change( | |
fn=update_language_options, | |
inputs=[model_input, dataset_input, language_input, average_checkbox], | |
outputs=language_input, | |
) | |
average_checkbox.change( | |
fn=get_similar_languages, | |
inputs=[ | |
model_input, | |
dataset_input, | |
language_input, | |
average_checkbox, | |
top_n_input, | |
], | |
outputs=output_table, | |
) | |
with gr.Tab(label="Distance Plot"): | |
with gr.Row(): | |
cluster_method_input = gr.Dropdown( | |
label="Cluster Method", | |
choices=["HDBSCAN", "KMeans", "Family", "Subfamily"], | |
value="HDBSCAN", | |
) | |
clusters_input = gr.Slider( | |
label="Minimum Elements in a Cluster", | |
minimum=2, | |
maximum=10, | |
step=1, | |
value=2, | |
) | |
def update_clusters_input_option(cluster_method): | |
if cluster_method == "HDBSCAN": | |
return gr.Slider( | |
label="Minimum Elements in a Cluster", | |
minimum=2, | |
maximum=10, | |
step=1, | |
value=2, | |
visible=True, | |
interactive=True, | |
) | |
elif cluster_method == "KMeans": | |
return gr.Slider( | |
label="Number of Clusters", | |
minimum=2, | |
maximum=20, | |
step=1, | |
value=2, | |
visible=True, | |
interactive=True, | |
) | |
else: | |
return gr.update(interactive=False, visible=False) | |
cluster_method_input.change( | |
fn=update_clusters_input_option, | |
inputs=[cluster_method_input], | |
outputs=clusters_input, | |
) | |
with gr.Row(): | |
plot_tsne_button = gr.Button("Plot t-SNE") | |
plot_umap_button = gr.Button("Plot UMAP") | |
plot_mst_button = gr.Button("Plot MST") | |
with gr.Row(): | |
plot_figsize_dist_h_input = gr.Slider( | |
label="Figure Height", minimum=5, maximum=30, step=1, value=15 | |
) | |
plot_figsize_dist_w_input = gr.Slider( | |
label="Figure Width", minimum=5, maximum=30, step=1, value=15 | |
) | |
with gr.Row(): | |
download_plot_button = gr.DownloadButton("Download Plot") | |
with gr.Row(): | |
plot_output = gr.Plot(label="Distance Plot") | |
plot_tsne_button.click( | |
fn=partial(plot_distances, plot_fn=plot_distances_tsne), | |
inputs=[ | |
model_input, | |
dataset_input, | |
average_checkbox, | |
cluster_method_input, | |
clusters_input, | |
plot_figsize_dist_h_input, | |
plot_figsize_dist_w_input, | |
], | |
outputs=[plot_output, download_plot_button], | |
) | |
plot_umap_button.click( | |
fn=partial(plot_distances, plot_fn=plot_distances_umap), | |
inputs=[ | |
model_input, | |
dataset_input, | |
average_checkbox, | |
cluster_method_input, | |
clusters_input, | |
plot_figsize_dist_h_input, | |
plot_figsize_dist_w_input, | |
], | |
outputs=[plot_output, download_plot_button], | |
) | |
plot_mst_button.click( | |
fn=partial(plot_distances, plot_fn=plot_mst), | |
inputs=[ | |
model_input, | |
dataset_input, | |
average_checkbox, | |
cluster_method_input, | |
clusters_input, | |
plot_figsize_dist_h_input, | |
plot_figsize_dist_w_input, | |
], | |
outputs=[plot_output, download_plot_button], | |
) | |
with gr.Tab(label="Language Families Subplot"): | |
checked_families_input = gr.CheckboxGroup( | |
label="Language Families", | |
choices=[ | |
"Afroasiatic", | |
"Austroasiatic", | |
"Austronesian", | |
"Constructed", | |
"Creole", | |
"Dravidian", | |
"Germanic", | |
"Indo-European", | |
"Japonic", | |
"Kartvelian", | |
"Koreanic", | |
"Language Isolate", | |
"Niger-Congo", | |
"Northeast Caucasian", | |
"Romance", | |
"Sino-Tibetan", | |
"Turkic", | |
"Uralic", | |
], | |
value=["Indo-European"], | |
) | |
with gr.Row(): | |
plot_family_button = gr.Button("Plot Families") | |
plot_figsize_h_input = gr.Slider( | |
label="Figure Height", minimum=5, maximum=30, step=1, value=15 | |
) | |
plot_figsize_w_input = gr.Slider( | |
label="Figure Width", minimum=5, maximum=30, step=1, value=15 | |
) | |
with gr.Row(): | |
download_families_plot_button = gr.DownloadButton( | |
"Download Plot", value=plot_path | |
) | |
plot_family_output = gr.Plot(label="Families Plot") | |
plot_family_button.click( | |
fn=plot_families_subfamilies, | |
inputs=[ | |
checked_families_input, | |
model_input, | |
dataset_input, | |
average_checkbox, | |
plot_figsize_h_input, | |
plot_figsize_w_input, | |
], | |
outputs=[plot_family_output, download_families_plot_button], | |
) | |
demo.launch(share=True) | |