import streamlit as st import pandas as pd import numpy as np from bokeh.plotting import figure from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button from bokeh.layouts import column from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9 from sklearn.decomposition import PCA from sklearn.manifold import TSNE import io import ot TOOLTIPS = """
@img
@label
""" def config_style(): st.markdown(""" """, unsafe_allow_html=True) st.markdown('

Merit Embeddings 🎒📃🏆

', unsafe_allow_html=True) # Carga los datos y asigna versiones de forma uniforme def load_embeddings(model): if model == "Donut": df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv") df_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv") df_line = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv") df_real["version"] = "real" df_seq["version"] = "synthetic" df_line["version"] = "synthetic" # Usamos un identificador en la columna 'source' para diferenciarlos df_seq["source"] = "es-digital-seq" df_line["source"] = "es-digital-line-degradation-seq" return {"real": df_real, "synthetic": pd.concat([df_seq, df_line], ignore_index=True)} elif model == "Idefics2": df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv") df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv") df_real["version"] = "real" df_seq["version"] = "synthetic" df_seq["source"] = "es-digital-seq" return {"real": df_real, "synthetic": df_seq} else: st.error("Modelo no reconocido") return None # Selección de reducción dimensional def reducer_selector(df_combined, embedding_cols): reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"]) all_embeddings = df_combined[embedding_cols].values if reduction_method == "PCA": reducer = PCA(n_components=2) else: reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200) return reducer.fit_transform(all_embeddings) # Función genérica para agregar datos al gráfico def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label): renderers = {} for label in selected_labels: subset = df[df['label'] == label] if subset.empty: continue source = ColumnDataSource(data=dict( x=subset['x'], y=subset['y'], label=subset['label'], img=subset.get('img', "") )) color = color_mapping[label] # Se añade el identificador de la fuente en la leyenda legend_label = f"{label} ({group_label})" if marker == "circle": r = fig.circle('x', 'y', size=10, source=source, fill_color=color, line_color=color, legend_label=legend_label) elif marker == "square": r = fig.square('x', 'y', size=10, source=source, fill_color=color, line_color=color, legend_label=legend_label) elif marker == "triangle": r = fig.triangle('x', 'y', size=12, source=source, fill_color=color, line_color=color, legend_label=legend_label) renderers[label + f" ({group_label})"] = r return renderers # Asigna paletas de colores de forma genérica para cada grupo (real y para cada fuente sintética) def get_color_maps(unique_subsets): color_map = {} # Real num_real = len(unique_subsets["real"]) red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real] color_map["real"] = {label: red_palette[i] for i, label in enumerate(sorted(unique_subsets["real"]))} # Synthetic: vamos a separar por fuente (source) basándonos en la lista completa de etiquetas # Suponemos que en la columna "source" se encuentran los identificadores synthetic_labels = sorted(unique_subsets["synthetic"]) # Aquí usamos una sola paleta para todos, pero se podría distinguir según la fuente si se quiere blue_palette = Blues9[:len(synthetic_labels)] if len(synthetic_labels) <= 9 else (Blues9 * ((len(synthetic_labels) // 9) + 1))[:len(synthetic_labels)] color_map["synthetic"] = {label: blue_palette[i] for i, label in enumerate(synthetic_labels)} return color_map # Separa los datos reducidos en "real" y "synthetic" y extrae los subsets (clusters) def split_versions(df_combined, reduced): df_combined['x'] = reduced[:, 0] df_combined['y'] = reduced[:, 1] df_real = df_combined[df_combined["version"] == "real"].copy() df_synth = df_combined[df_combined["version"] == "synthetic"].copy() # Extraemos los clusters (subset) usando la columna 'label' unique_real = sorted(df_real['label'].unique().tolist()) unique_synth = sorted(df_synth['label'].unique().tolist()) df_dict = {"real": df_real, "synthetic": df_synth} unique_subsets = {"real": unique_real, "synthetic": unique_synth} return df_dict, unique_subsets # Crea el gráfico; se tratan de forma uniforme ambos conjuntos sintéticos def create_figure(dfs, unique_subsets, color_maps): fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="") real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"], marker="circle", color_mapping=color_maps["real"], group_label="Real") # Aquí separamos los puntos sintéticos según su fuente para asignar diferentes marcadores synth_df = dfs["synthetic"] # Dividimos por 'source' df_seq = synth_df[synth_df["source"] == "es-digital-seq"] df_line = synth_df[synth_df["source"] == "es-digital-line-degradation-seq"] # Extraemos los clusters para cada fuente (si existen) unique_seq = sorted(df_seq['label'].unique().tolist()) unique_line = sorted(df_line['label'].unique().tolist()) seq_renderers = add_dataset_to_fig(fig, df_seq, unique_seq, marker="square", color_mapping=color_maps["synthetic"], group_label="es-digital-seq") line_renderers = add_dataset_to_fig(fig, df_line, unique_line, marker="triangle", color_mapping=color_maps["synthetic"], group_label="es-digital-line-degradation-seq") # Combina ambos renderers sintéticos synthetic_renderers = {**seq_renderers, **line_renderers} fig.legend.location = "top_right" fig.legend.click_policy = "hide" return fig, real_renderers, synthetic_renderers # Calcula los centros de cada cluster (por grupo) def calculate_cluster_centers(df, labels): centers = {} for label in labels: subset = df[df['label'] == label] if not subset.empty: centers[label] = (subset['x'].mean(), subset['y'].mean()) return centers # Calcula la distancia Wasserstein de cada subset sintético respecto a cada cluster real (por cluster y global) def compute_wasserstein_distances_all_synthetics(df_synth, df_real, labels_real): distances = {} # Para cada cluster en el conjunto sintético (la tabla mostrará todas las etiquetas) synth_labels = sorted(df_synth['label'].unique().tolist()) for label in synth_labels: key = f"{label}" distances[key] = {} cluster = df_synth[df_synth['label'] == label][['x','y']].values n = cluster.shape[0] weights = np.ones(n) / n for real_label in labels_real: cluster_real = df_real[df_real['label'] == real_label][['x','y']].values m = cluster_real.shape[0] weights_real = np.ones(m) / m M = ot.dist(cluster, cluster_real, metric='euclidean') distances[key][real_label] = ot.emd2(weights, weights_real, M) # Distancia global del conjunto sintético a cada cluster real key = "Global synthetic" distances[key] = {} global_synth = df_synth[['x','y']].values n_global = global_synth.shape[0] weights_global = np.ones(n_global) / n_global for real_label in labels_real: cluster_real = df_real[df_real['label'] == real_label][['x','y']].values m = cluster_real.shape[0] weights_real = np.ones(m) / m M = ot.dist(global_synth, cluster_real, metric='euclidean') distances[key][real_label] = ot.emd2(weights_global, weights_real, M) return pd.DataFrame(distances).T def create_table(df_distances): df_table = df_distances.copy() df_table.reset_index(inplace=True) df_table.rename(columns={'index': 'Synthetic'}, inplace=True) min_row = {"Synthetic": "Min."} mean_row = {"Synthetic": "Mean"} max_row = {"Synthetic": "Max."} for col in df_table.columns: if col != "Synthetic": min_row[col] = df_table[col].min() mean_row[col] = df_table[col].mean() max_row[col] = df_table[col].max() df_table = pd.concat([df_table, pd.DataFrame([min_row, mean_row, max_row])], ignore_index=True) source_table = ColumnDataSource(df_table) columns = [TableColumn(field='Synthetic', title='Synthetic')] for col in df_table.columns: if col != 'Synthetic': columns.append(TableColumn(field=col, title=col)) total_height = 30 + len(df_table)*28 data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height) return data_table, df_table, source_table def run_model(model_name): embeddings = load_embeddings(model_name) if embeddings is None: return embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")] # Combina todos los DataFrames df_combined = pd.concat(list(embeddings.values()), ignore_index=True) st.markdown('
Select Dimensionality Reduction Method
', unsafe_allow_html=True) reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=model_name) if reduction_method == "PCA": reducer = PCA(n_components=2) else: reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200) reduced = reducer.fit_transform(df_combined[embedding_cols].values) dfs_reduced, unique_subsets = split_versions(df_combined, reduced) # Se espera que unique_subsets tenga claves "real" y "synthetic" color_maps = get_color_maps(unique_subsets) fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, unique_subsets, color_maps) centers_real = calculate_cluster_centers(dfs_reduced["real"], unique_subsets["real"]) df_distances = compute_wasserstein_distances_all_synthetics(dfs_reduced["synthetic"], dfs_reduced["real"], unique_subsets["real"]) data_table, df_table, source_table = create_table(df_distances) real_subset_names = list(df_table.columns[1:]) real_select = Select(title="", value=real_subset_names[0], options=real_subset_names) reset_button = Button(label="Reset Colors", button_type="primary") line_source = ColumnDataSource(data={'x': [], 'y': []}) fig.line('x', 'y', source=line_source, line_width=2, line_color='black') # Preparar centros para callback (para trazar líneas entre centros) real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()} # Se podría preparar también los centros sintéticos si se requiere synthetic_centers = {} synth_labels = sorted(dfs_reduced["synthetic"]['label'].unique().tolist()) for label in synth_labels: subset = dfs_reduced["synthetic"][dfs_reduced["synthetic"]['label'] == label] synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()] callback = CustomJS(args=dict(source=source_table, line_source=line_source, synthetic_centers=synthetic_centers, real_centers=real_centers_js, real_select=real_select), code=""" var selected = source.selected.indices; if (selected.length > 0) { var idx = selected[0]; var data = source.data; var synth_label = data['Synthetic'][idx]; var real_label = real_select.value; var syn_coords = synthetic_centers[synth_label]; var real_coords = real_centers[real_label]; line_source.data = {'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]]}; line_source.change.emit(); } else { line_source.data = {'x': [], 'y': []}; line_source.change.emit(); } """) source_table.selected.js_on_change('indices', callback) real_select.js_on_change('value', callback) reset_callback = CustomJS(args=dict(line_source=line_source), code=""" line_source.data = {'x': [], 'y': []}; line_source.change.emit(); """) reset_button.js_on_event("button_click", reset_callback) buffer = io.BytesIO() df_table.to_excel(buffer, index=False) buffer.seek(0) layout = column(fig, column(real_select, reset_button, data_table)) st.bokeh_chart(layout, use_container_width=True) st.download_button( label="Export Table", data=buffer, file_name=f"cluster_distances_{model_name}.xlsx", mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", key=f"download_button_excel_{model_name}" ) def main(): config_style() tabs = st.tabs(["Donut", "Idefics2"]) with tabs[0]: st.markdown('

Donut 🤗

', unsafe_allow_html=True) run_model("Donut") with tabs[1]: st.markdown('

Idefics2 🤗

', unsafe_allow_html=True) run_model("Idefics2") if __name__ == "__main__": main()