mshamrai's picture
chore: rm axis
661c42a
import gradio as gr
import pandas as pd
import numpy as np
import os
from utils import (
plot_distances_tsne,
plot_distances_umap,
cluster_languages_hdbscan,
cluster_languages_kmeans,
plot_mst,
cluster_languages_by_families,
cluster_languages_by_subfamilies,
filter_languages_by_families,
)
from functools import partial
import datasets
dataset = datasets.load_dataset(
"mshamrai/language-metric-data", split="train", trust_remote_code=True
)
languages = dataset["languages_list"][0]
average_distances_matrix = np.array(dataset["average_distances_matrix"][0])
DATASETS = dataset["distances_matrices"][0]["dataset_name"]
MODELS = dataset["distances_matrices"][0]["models"][0]["model_name"]
distance_matrices = {
DATASETS[i]: {
MODELS[j]: np.array(dataset["distances_matrices"][0]["models"][i]["matrix"][j])
for j in range(len(MODELS))
}
for i in range(len(DATASETS))
}
def filter_languages_nan(model, dataset, use_average):
if use_average:
matrix = average_distances_matrix
else:
matrix = distance_matrices[dataset][model]
vector = matrix[0]
updated_languages = np.array(languages)[~np.isnan(vector)]
updated_matrix = matrix[~np.isnan(vector), :][:, ~np.isnan(vector)]
return updated_matrix, updated_languages
def get_similar_languages(model, dataset, selected_language, use_average, n):
"""
Retrieves the distances for the selected language from the chosen model and dataset,
sorts them by similarity (lowest distance first), and returns a DataFrame.
"""
if use_average:
matrix = average_distances_matrix
else:
matrix = distance_matrices[dataset][model]
selected_language_index = languages.index(selected_language)
distances = matrix[selected_language_index]
df = pd.DataFrame({"Language": languages, "Distance": distances})
sorted_distances = df.sort_values(by="Distance")
sorted_distances.drop(index=selected_language_index, inplace=True)
sorted_distances.reset_index(drop=True, inplace=True)
sorted_distances.reset_index(inplace=True)
sorted_distances["Distance"] = sorted_distances["Distance"].round(4)
return sorted_distances.head(n)
def update_languages(model, dataset):
"""
Returns the language list based on the given model and dataset.
"""
matrix = distance_matrices[dataset][model]
vector = matrix[0]
updated_languages = np.array(languages)[~np.isnan(vector)]
return list(updated_languages)
def update_language_options(model, dataset, language, use_average):
if use_average:
updated_languages = languages
else:
updated_languages = update_languages(model, dataset)
if language not in updated_languages:
language = updated_languages[0]
return gr.Dropdown(label="Language", choices=updated_languages, value=language)
def toggle_inputs(use_average):
if use_average:
return gr.update(interactive=False, visible=False), gr.update(
interactive=False, visible=False
)
else:
return gr.update(interactive=True, visible=True), gr.update(
interactive=True, visible=True
)
plot_path = "plots/last_plot.pdf"
os.makedirs("plots", exist_ok=True)
def plot_distances(
model,
dataset,
use_average,
cluster_method,
cluster_method_param,
figsize_h,
figsize_w,
plot_fn,
):
"""
Plots all languages from the distances matrix using t-SNE.
"""
updated_matrix, updated_languages = filter_languages_nan(
model, dataset, use_average
)
if cluster_method == "HDBSCAN":
filtered_matrix, filtered_languages, clusters = cluster_languages_hdbscan(
updated_matrix, updated_languages, min_cluster_size=cluster_method_param
)
legends = None
elif cluster_method == "KMeans":
filtered_matrix, filtered_languages, clusters = cluster_languages_kmeans(
updated_matrix, updated_languages, n_clusters=cluster_method_param
)
legends = None
elif cluster_method == "Family":
clusters, legends = cluster_languages_by_families(updated_languages)
filtered_matrix = updated_matrix
filtered_languages = updated_languages
elif cluster_method == "Subfamily":
clusters, legends = cluster_languages_by_subfamilies(updated_languages)
filtered_matrix = updated_matrix
filtered_languages = updated_languages
else:
raise ValueError("Invalid cluster method")
fig = plot_fn(
filtered_matrix,
filtered_languages,
clusters,
legends,
fig_size=(figsize_w, figsize_h),
)
fig.tight_layout()
fig.savefig(plot_path, format="pdf")
return fig, gr.DownloadButton(label="Download Plot", value=plot_path)
def plot_families_subfamilies(
families, model, dataset, use_average, figsize_h, figsize_w
):
updated_matrix, updated_languages = filter_languages_nan(
model, dataset, use_average
)
updated_matrix, updated_languages = filter_languages_by_families(
updated_matrix, updated_languages, families
)
clusters, legends = cluster_languages_by_subfamilies(updated_languages)
fig = plot_mst(
updated_matrix,
updated_languages,
clusters,
legends,
fig_size=(figsize_w, figsize_h),
)
fig.tight_layout()
fig.savefig(plot_path, format="pdf")
return fig, gr.DownloadButton(label="Download Plot", value=plot_path)
with gr.Blocks() as demo:
gr.Markdown("## Language Distance Explorer")
average_checkbox = gr.Checkbox(label="Use Average Distances", value=False)
with gr.Row():
model_input = gr.Dropdown(label="Model", choices=MODELS, value=MODELS[0])
dataset_input = gr.Dropdown(
label="Dataset", choices=DATASETS, value=DATASETS[0]
)
with gr.Tab(label="Closest Languages Table"):
with gr.Row():
language_input = gr.Dropdown(
label="Language", choices=languages, value=languages[0]
)
top_n_input = gr.Slider(
label="Top N", minimum=1, maximum=30, step=1, value=10
)
output_table = gr.Dataframe(label="Similar Languages")
model_input.change(
fn=update_language_options,
inputs=[model_input, dataset_input, language_input, average_checkbox],
outputs=language_input,
)
dataset_input.change(
fn=update_language_options,
inputs=[model_input, dataset_input, language_input, average_checkbox],
outputs=language_input,
)
language_input.change(
fn=get_similar_languages,
inputs=[
model_input,
dataset_input,
language_input,
average_checkbox,
top_n_input,
],
outputs=output_table,
)
model_input.change(
fn=get_similar_languages,
inputs=[
model_input,
dataset_input,
language_input,
average_checkbox,
top_n_input,
],
outputs=output_table,
)
dataset_input.change(
fn=get_similar_languages,
inputs=[
model_input,
dataset_input,
language_input,
average_checkbox,
top_n_input,
],
outputs=output_table,
)
top_n_input.change(
fn=get_similar_languages,
inputs=[
model_input,
dataset_input,
language_input,
average_checkbox,
top_n_input,
],
outputs=output_table,
)
average_checkbox.change(
fn=toggle_inputs,
inputs=[average_checkbox],
outputs=[model_input, dataset_input],
)
average_checkbox.change(
fn=update_language_options,
inputs=[model_input, dataset_input, language_input, average_checkbox],
outputs=language_input,
)
average_checkbox.change(
fn=get_similar_languages,
inputs=[
model_input,
dataset_input,
language_input,
average_checkbox,
top_n_input,
],
outputs=output_table,
)
with gr.Tab(label="Distance Plot"):
with gr.Row():
cluster_method_input = gr.Dropdown(
label="Cluster Method",
choices=["HDBSCAN", "KMeans", "Family", "Subfamily"],
value="HDBSCAN",
)
clusters_input = gr.Slider(
label="Minimum Elements in a Cluster",
minimum=2,
maximum=10,
step=1,
value=2,
)
def update_clusters_input_option(cluster_method):
if cluster_method == "HDBSCAN":
return gr.Slider(
label="Minimum Elements in a Cluster",
minimum=2,
maximum=10,
step=1,
value=2,
visible=True,
interactive=True,
)
elif cluster_method == "KMeans":
return gr.Slider(
label="Number of Clusters",
minimum=2,
maximum=20,
step=1,
value=2,
visible=True,
interactive=True,
)
else:
return gr.update(interactive=False, visible=False)
cluster_method_input.change(
fn=update_clusters_input_option,
inputs=[cluster_method_input],
outputs=clusters_input,
)
with gr.Row():
plot_tsne_button = gr.Button("Plot t-SNE")
plot_umap_button = gr.Button("Plot UMAP")
plot_mst_button = gr.Button("Plot MST")
with gr.Row():
plot_figsize_dist_h_input = gr.Slider(
label="Figure Height", minimum=5, maximum=30, step=1, value=15
)
plot_figsize_dist_w_input = gr.Slider(
label="Figure Width", minimum=5, maximum=30, step=1, value=15
)
with gr.Row():
download_plot_button = gr.DownloadButton("Download Plot")
with gr.Row():
plot_output = gr.Plot(label="Distance Plot")
plot_tsne_button.click(
fn=partial(plot_distances, plot_fn=plot_distances_tsne),
inputs=[
model_input,
dataset_input,
average_checkbox,
cluster_method_input,
clusters_input,
plot_figsize_dist_h_input,
plot_figsize_dist_w_input,
],
outputs=[plot_output, download_plot_button],
)
plot_umap_button.click(
fn=partial(plot_distances, plot_fn=plot_distances_umap),
inputs=[
model_input,
dataset_input,
average_checkbox,
cluster_method_input,
clusters_input,
plot_figsize_dist_h_input,
plot_figsize_dist_w_input,
],
outputs=[plot_output, download_plot_button],
)
plot_mst_button.click(
fn=partial(plot_distances, plot_fn=plot_mst),
inputs=[
model_input,
dataset_input,
average_checkbox,
cluster_method_input,
clusters_input,
plot_figsize_dist_h_input,
plot_figsize_dist_w_input,
],
outputs=[plot_output, download_plot_button],
)
with gr.Tab(label="Language Families Subplot"):
checked_families_input = gr.CheckboxGroup(
label="Language Families",
choices=[
"Afroasiatic",
"Austroasiatic",
"Austronesian",
"Constructed",
"Creole",
"Dravidian",
"Germanic",
"Indo-European",
"Japonic",
"Kartvelian",
"Koreanic",
"Language Isolate",
"Niger-Congo",
"Northeast Caucasian",
"Romance",
"Sino-Tibetan",
"Turkic",
"Uralic",
],
value=["Indo-European"],
)
with gr.Row():
plot_family_button = gr.Button("Plot Families")
plot_figsize_h_input = gr.Slider(
label="Figure Height", minimum=5, maximum=30, step=1, value=15
)
plot_figsize_w_input = gr.Slider(
label="Figure Width", minimum=5, maximum=30, step=1, value=15
)
with gr.Row():
download_families_plot_button = gr.DownloadButton(
"Download Plot", value=plot_path
)
plot_family_output = gr.Plot(label="Families Plot")
plot_family_button.click(
fn=plot_families_subfamilies,
inputs=[
checked_families_input,
model_input,
dataset_input,
average_checkbox,
plot_figsize_h_input,
plot_figsize_w_input,
],
outputs=[plot_family_output, download_families_plot_button],
)
demo.launch(share=True)