Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from bokeh.plotting import figure | |
from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button | |
from bokeh.layouts import column | |
from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9 | |
from sklearn.decomposition import PCA | |
from sklearn.manifold import TSNE | |
import io | |
import ot | |
TOOLTIPS = """ | |
<div> | |
<div> | |
<img src="@img{safe}" style="width:128px; height:auto; float: left; margin: 0px 15px 15px 0px;" alt="@img" border="2"></img> | |
</div> | |
<div> | |
<span style="font-size: 17px; font-weight: bold;">@label</span> | |
</div> | |
</div> | |
""" | |
def config_style(): | |
st.markdown(""" | |
<style> | |
.main-title { font-size: 50px; color: #4CAF50; text-align: center; } | |
.sub-title { font-size: 30px; color: #555; } | |
.custom-text { font-size: 18px; line-height: 1.5; } | |
</style> | |
""", unsafe_allow_html=True) | |
st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True) | |
# Carga los datos y asigna versiones de forma uniforme | |
def load_embeddings(model): | |
if model == "Donut": | |
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv") | |
df_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv") | |
df_line = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv") | |
df_real["version"] = "real" | |
df_seq["version"] = "synthetic" | |
df_line["version"] = "synthetic" | |
# Usamos un identificador en la columna 'source' para diferenciarlos | |
df_seq["source"] = "es-digital-seq" | |
df_line["source"] = "es-digital-line-degradation-seq" | |
return {"real": df_real, "synthetic": pd.concat([df_seq, df_line], ignore_index=True)} | |
elif model == "Idefics2": | |
df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv") | |
df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv") | |
df_real["version"] = "real" | |
df_seq["version"] = "synthetic" | |
df_seq["source"] = "es-digital-seq" | |
return {"real": df_real, "synthetic": df_seq} | |
else: | |
st.error("Modelo no reconocido") | |
return None | |
# Selección de reducción dimensional | |
def reducer_selector(df_combined, embedding_cols): | |
reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"]) | |
all_embeddings = df_combined[embedding_cols].values | |
if reduction_method == "PCA": | |
reducer = PCA(n_components=2) | |
else: | |
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200) | |
return reducer.fit_transform(all_embeddings) | |
# Función genérica para agregar datos al gráfico | |
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label): | |
renderers = {} | |
for label in selected_labels: | |
subset = df[df['label'] == label] | |
if subset.empty: | |
continue | |
source = ColumnDataSource(data=dict( | |
x=subset['x'], | |
y=subset['y'], | |
label=subset['label'], | |
img=subset.get('img', "") | |
)) | |
color = color_mapping[label] | |
# Se añade el identificador de la fuente en la leyenda | |
legend_label = f"{label} ({group_label})" | |
if marker == "circle": | |
r = fig.circle('x', 'y', size=10, source=source, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "square": | |
r = fig.square('x', 'y', size=10, source=source, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "triangle": | |
r = fig.triangle('x', 'y', size=12, source=source, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
renderers[label + f" ({group_label})"] = r | |
return renderers | |
# Asigna paletas de colores de forma genérica para cada grupo (real y para cada fuente sintética) | |
def get_color_maps(unique_subsets): | |
color_map = {} | |
# Real | |
num_real = len(unique_subsets["real"]) | |
red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real] | |
color_map["real"] = {label: red_palette[i] for i, label in enumerate(sorted(unique_subsets["real"]))} | |
# Synthetic: vamos a separar por fuente (source) basándonos en la lista completa de etiquetas | |
# Suponemos que en la columna "source" se encuentran los identificadores | |
synthetic_labels = sorted(unique_subsets["synthetic"]) | |
# Aquí usamos una sola paleta para todos, pero se podría distinguir según la fuente si se quiere | |
blue_palette = Blues9[:len(synthetic_labels)] if len(synthetic_labels) <= 9 else (Blues9 * ((len(synthetic_labels) // 9) + 1))[:len(synthetic_labels)] | |
color_map["synthetic"] = {label: blue_palette[i] for i, label in enumerate(synthetic_labels)} | |
return color_map | |
# Separa los datos reducidos en "real" y "synthetic" y extrae los subsets (clusters) | |
def split_versions(df_combined, reduced): | |
df_combined['x'] = reduced[:, 0] | |
df_combined['y'] = reduced[:, 1] | |
df_real = df_combined[df_combined["version"] == "real"].copy() | |
df_synth = df_combined[df_combined["version"] == "synthetic"].copy() | |
# Extraemos los clusters (subset) usando la columna 'label' | |
unique_real = sorted(df_real['label'].unique().tolist()) | |
unique_synth = sorted(df_synth['label'].unique().tolist()) | |
df_dict = {"real": df_real, "synthetic": df_synth} | |
unique_subsets = {"real": unique_real, "synthetic": unique_synth} | |
return df_dict, unique_subsets | |
# Crea el gráfico; se tratan de forma uniforme ambos conjuntos sintéticos | |
def create_figure(dfs, unique_subsets, color_maps): | |
fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="") | |
real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"], | |
marker="circle", color_mapping=color_maps["real"], | |
group_label="Real") | |
# Aquí separamos los puntos sintéticos según su fuente para asignar diferentes marcadores | |
synth_df = dfs["synthetic"] | |
# Dividimos por 'source' | |
df_seq = synth_df[synth_df["source"] == "es-digital-seq"] | |
df_line = synth_df[synth_df["source"] == "es-digital-line-degradation-seq"] | |
# Extraemos los clusters para cada fuente (si existen) | |
unique_seq = sorted(df_seq['label'].unique().tolist()) | |
unique_line = sorted(df_line['label'].unique().tolist()) | |
seq_renderers = add_dataset_to_fig(fig, df_seq, unique_seq, | |
marker="square", color_mapping=color_maps["synthetic"], | |
group_label="es-digital-seq") | |
line_renderers = add_dataset_to_fig(fig, df_line, unique_line, | |
marker="triangle", color_mapping=color_maps["synthetic"], | |
group_label="es-digital-line-degradation-seq") | |
# Combina ambos renderers sintéticos | |
synthetic_renderers = {**seq_renderers, **line_renderers} | |
fig.legend.location = "top_right" | |
fig.legend.click_policy = "hide" | |
return fig, real_renderers, synthetic_renderers | |
# Calcula los centros de cada cluster (por grupo) | |
def calculate_cluster_centers(df, labels): | |
centers = {} | |
for label in labels: | |
subset = df[df['label'] == label] | |
if not subset.empty: | |
centers[label] = (subset['x'].mean(), subset['y'].mean()) | |
return centers | |
# Calcula la distancia Wasserstein de cada subset sintético respecto a cada cluster real (por cluster y global) | |
def compute_wasserstein_distances_all_synthetics(df_synth, df_real, labels_real): | |
distances = {} | |
# Para cada cluster en el conjunto sintético (la tabla mostrará todas las etiquetas) | |
synth_labels = sorted(df_synth['label'].unique().tolist()) | |
for label in synth_labels: | |
key = f"{label}" | |
distances[key] = {} | |
cluster = df_synth[df_synth['label'] == label][['x','y']].values | |
n = cluster.shape[0] | |
weights = np.ones(n) / n | |
for real_label in labels_real: | |
cluster_real = df_real[df_real['label'] == real_label][['x','y']].values | |
m = cluster_real.shape[0] | |
weights_real = np.ones(m) / m | |
M = ot.dist(cluster, cluster_real, metric='euclidean') | |
distances[key][real_label] = ot.emd2(weights, weights_real, M) | |
# Distancia global del conjunto sintético a cada cluster real | |
key = "Global synthetic" | |
distances[key] = {} | |
global_synth = df_synth[['x','y']].values | |
n_global = global_synth.shape[0] | |
weights_global = np.ones(n_global) / n_global | |
for real_label in labels_real: | |
cluster_real = df_real[df_real['label'] == real_label][['x','y']].values | |
m = cluster_real.shape[0] | |
weights_real = np.ones(m) / m | |
M = ot.dist(global_synth, cluster_real, metric='euclidean') | |
distances[key][real_label] = ot.emd2(weights_global, weights_real, M) | |
return pd.DataFrame(distances).T | |
def create_table(df_distances): | |
df_table = df_distances.copy() | |
df_table.reset_index(inplace=True) | |
df_table.rename(columns={'index': 'Synthetic'}, inplace=True) | |
min_row = {"Synthetic": "Min."} | |
mean_row = {"Synthetic": "Mean"} | |
max_row = {"Synthetic": "Max."} | |
for col in df_table.columns: | |
if col != "Synthetic": | |
min_row[col] = df_table[col].min() | |
mean_row[col] = df_table[col].mean() | |
max_row[col] = df_table[col].max() | |
df_table = pd.concat([df_table, pd.DataFrame([min_row, mean_row, max_row])], ignore_index=True) | |
source_table = ColumnDataSource(df_table) | |
columns = [TableColumn(field='Synthetic', title='Synthetic')] | |
for col in df_table.columns: | |
if col != 'Synthetic': | |
columns.append(TableColumn(field=col, title=col)) | |
total_height = 30 + len(df_table)*28 | |
data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height) | |
return data_table, df_table, source_table | |
def run_model(model_name): | |
embeddings = load_embeddings(model_name) | |
if embeddings is None: | |
return | |
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")] | |
# Combina todos los DataFrames | |
df_combined = pd.concat(list(embeddings.values()), ignore_index=True) | |
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True) | |
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=model_name) | |
if reduction_method == "PCA": | |
reducer = PCA(n_components=2) | |
else: | |
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200) | |
reduced = reducer.fit_transform(df_combined[embedding_cols].values) | |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced) | |
# Se espera que unique_subsets tenga claves "real" y "synthetic" | |
color_maps = get_color_maps(unique_subsets) | |
fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, unique_subsets, color_maps) | |
centers_real = calculate_cluster_centers(dfs_reduced["real"], unique_subsets["real"]) | |
df_distances = compute_wasserstein_distances_all_synthetics(dfs_reduced["synthetic"], | |
dfs_reduced["real"], | |
unique_subsets["real"]) | |
data_table, df_table, source_table = create_table(df_distances) | |
real_subset_names = list(df_table.columns[1:]) | |
real_select = Select(title="", value=real_subset_names[0], options=real_subset_names) | |
reset_button = Button(label="Reset Colors", button_type="primary") | |
line_source = ColumnDataSource(data={'x': [], 'y': []}) | |
fig.line('x', 'y', source=line_source, line_width=2, line_color='black') | |
# Preparar centros para callback (para trazar líneas entre centros) | |
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()} | |
# Se podría preparar también los centros sintéticos si se requiere | |
synthetic_centers = {} | |
synth_labels = sorted(dfs_reduced["synthetic"]['label'].unique().tolist()) | |
for label in synth_labels: | |
subset = dfs_reduced["synthetic"][dfs_reduced["synthetic"]['label'] == label] | |
synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()] | |
callback = CustomJS(args=dict(source=source_table, line_source=line_source, | |
synthetic_centers=synthetic_centers, | |
real_centers=real_centers_js, | |
real_select=real_select), | |
code=""" | |
var selected = source.selected.indices; | |
if (selected.length > 0) { | |
var idx = selected[0]; | |
var data = source.data; | |
var synth_label = data['Synthetic'][idx]; | |
var real_label = real_select.value; | |
var syn_coords = synthetic_centers[synth_label]; | |
var real_coords = real_centers[real_label]; | |
line_source.data = {'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]]}; | |
line_source.change.emit(); | |
} else { | |
line_source.data = {'x': [], 'y': []}; | |
line_source.change.emit(); | |
} | |
""") | |
source_table.selected.js_on_change('indices', callback) | |
real_select.js_on_change('value', callback) | |
reset_callback = CustomJS(args=dict(line_source=line_source), | |
code=""" | |
line_source.data = {'x': [], 'y': []}; | |
line_source.change.emit(); | |
""") | |
reset_button.js_on_event("button_click", reset_callback) | |
buffer = io.BytesIO() | |
df_table.to_excel(buffer, index=False) | |
buffer.seek(0) | |
layout = column(fig, column(real_select, reset_button, data_table)) | |
st.bokeh_chart(layout, use_container_width=True) | |
st.download_button( | |
label="Export Table", | |
data=buffer, | |
file_name=f"cluster_distances_{model_name}.xlsx", | |
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", | |
key=f"download_button_excel_{model_name}" | |
) | |
def main(): | |
config_style() | |
tabs = st.tabs(["Donut", "Idefics2"]) | |
with tabs[0]: | |
st.markdown('<h2 class="sub-title">Donut 🤗</h2>', unsafe_allow_html=True) | |
run_model("Donut") | |
with tabs[1]: | |
st.markdown('<h2 class="sub-title">Idefics2 🤗</h2>', unsafe_allow_html=True) | |
run_model("Idefics2") | |
if __name__ == "__main__": | |
main() | |