Embeddings / app.py
de-Rodrigo's picture
Multiple Dataset Versions
6b1f66d
raw
history blame
14.9 kB
import streamlit as st
import pandas as pd
import numpy as np
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button
from bokeh.layouts import column
from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import io
import ot
TOOLTIPS = """
<div>
<div>
<img src="@img{safe}" style="width:128px; height:auto; float: left; margin: 0px 15px 15px 0px;" alt="@img" border="2"></img>
</div>
<div>
<span style="font-size: 17px; font-weight: bold;">@label</span>
</div>
</div>
"""
def config_style():
st.markdown("""
<style>
.main-title { font-size: 50px; color: #4CAF50; text-align: center; }
.sub-title { font-size: 30px; color: #555; }
.custom-text { font-size: 18px; line-height: 1.5; }
</style>
""", unsafe_allow_html=True)
st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True)
# Carga los datos y asigna versiones de forma uniforme
def load_embeddings(model):
if model == "Donut":
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
df_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
df_line = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
df_real["version"] = "real"
df_seq["version"] = "synthetic"
df_line["version"] = "synthetic"
# Usamos un identificador en la columna 'source' para diferenciarlos
df_seq["source"] = "es-digital-seq"
df_line["source"] = "es-digital-line-degradation-seq"
return {"real": df_real, "synthetic": pd.concat([df_seq, df_line], ignore_index=True)}
elif model == "Idefics2":
df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
df_real["version"] = "real"
df_seq["version"] = "synthetic"
df_seq["source"] = "es-digital-seq"
return {"real": df_real, "synthetic": df_seq}
else:
st.error("Modelo no reconocido")
return None
# Selección de reducción dimensional
def reducer_selector(df_combined, embedding_cols):
reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"])
all_embeddings = df_combined[embedding_cols].values
if reduction_method == "PCA":
reducer = PCA(n_components=2)
else:
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
return reducer.fit_transform(all_embeddings)
# Función genérica para agregar datos al gráfico
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label):
renderers = {}
for label in selected_labels:
subset = df[df['label'] == label]
if subset.empty:
continue
source = ColumnDataSource(data=dict(
x=subset['x'],
y=subset['y'],
label=subset['label'],
img=subset.get('img', "")
))
color = color_mapping[label]
# Se añade el identificador de la fuente en la leyenda
legend_label = f"{label} ({group_label})"
if marker == "circle":
r = fig.circle('x', 'y', size=10, source=source,
fill_color=color, line_color=color,
legend_label=legend_label)
elif marker == "square":
r = fig.square('x', 'y', size=10, source=source,
fill_color=color, line_color=color,
legend_label=legend_label)
elif marker == "triangle":
r = fig.triangle('x', 'y', size=12, source=source,
fill_color=color, line_color=color,
legend_label=legend_label)
renderers[label + f" ({group_label})"] = r
return renderers
# Asigna paletas de colores de forma genérica para cada grupo (real y para cada fuente sintética)
def get_color_maps(unique_subsets):
color_map = {}
# Real
num_real = len(unique_subsets["real"])
red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real]
color_map["real"] = {label: red_palette[i] for i, label in enumerate(sorted(unique_subsets["real"]))}
# Synthetic: vamos a separar por fuente (source) basándonos en la lista completa de etiquetas
# Suponemos que en la columna "source" se encuentran los identificadores
synthetic_labels = sorted(unique_subsets["synthetic"])
# Aquí usamos una sola paleta para todos, pero se podría distinguir según la fuente si se quiere
blue_palette = Blues9[:len(synthetic_labels)] if len(synthetic_labels) <= 9 else (Blues9 * ((len(synthetic_labels) // 9) + 1))[:len(synthetic_labels)]
color_map["synthetic"] = {label: blue_palette[i] for i, label in enumerate(synthetic_labels)}
return color_map
# Separa los datos reducidos en "real" y "synthetic" y extrae los subsets (clusters)
def split_versions(df_combined, reduced):
df_combined['x'] = reduced[:, 0]
df_combined['y'] = reduced[:, 1]
df_real = df_combined[df_combined["version"] == "real"].copy()
df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
# Extraemos los clusters (subset) usando la columna 'label'
unique_real = sorted(df_real['label'].unique().tolist())
unique_synth = sorted(df_synth['label'].unique().tolist())
df_dict = {"real": df_real, "synthetic": df_synth}
unique_subsets = {"real": unique_real, "synthetic": unique_synth}
return df_dict, unique_subsets
# Crea el gráfico; se tratan de forma uniforme ambos conjuntos sintéticos
def create_figure(dfs, unique_subsets, color_maps):
fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="")
real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
marker="circle", color_mapping=color_maps["real"],
group_label="Real")
# Aquí separamos los puntos sintéticos según su fuente para asignar diferentes marcadores
synth_df = dfs["synthetic"]
# Dividimos por 'source'
df_seq = synth_df[synth_df["source"] == "es-digital-seq"]
df_line = synth_df[synth_df["source"] == "es-digital-line-degradation-seq"]
# Extraemos los clusters para cada fuente (si existen)
unique_seq = sorted(df_seq['label'].unique().tolist())
unique_line = sorted(df_line['label'].unique().tolist())
seq_renderers = add_dataset_to_fig(fig, df_seq, unique_seq,
marker="square", color_mapping=color_maps["synthetic"],
group_label="es-digital-seq")
line_renderers = add_dataset_to_fig(fig, df_line, unique_line,
marker="triangle", color_mapping=color_maps["synthetic"],
group_label="es-digital-line-degradation-seq")
# Combina ambos renderers sintéticos
synthetic_renderers = {**seq_renderers, **line_renderers}
fig.legend.location = "top_right"
fig.legend.click_policy = "hide"
return fig, real_renderers, synthetic_renderers
# Calcula los centros de cada cluster (por grupo)
def calculate_cluster_centers(df, labels):
centers = {}
for label in labels:
subset = df[df['label'] == label]
if not subset.empty:
centers[label] = (subset['x'].mean(), subset['y'].mean())
return centers
# Calcula la distancia Wasserstein de cada subset sintético respecto a cada cluster real (por cluster y global)
def compute_wasserstein_distances_all_synthetics(df_synth, df_real, labels_real):
distances = {}
# Para cada cluster en el conjunto sintético (la tabla mostrará todas las etiquetas)
synth_labels = sorted(df_synth['label'].unique().tolist())
for label in synth_labels:
key = f"{label}"
distances[key] = {}
cluster = df_synth[df_synth['label'] == label][['x','y']].values
n = cluster.shape[0]
weights = np.ones(n) / n
for real_label in labels_real:
cluster_real = df_real[df_real['label'] == real_label][['x','y']].values
m = cluster_real.shape[0]
weights_real = np.ones(m) / m
M = ot.dist(cluster, cluster_real, metric='euclidean')
distances[key][real_label] = ot.emd2(weights, weights_real, M)
# Distancia global del conjunto sintético a cada cluster real
key = "Global synthetic"
distances[key] = {}
global_synth = df_synth[['x','y']].values
n_global = global_synth.shape[0]
weights_global = np.ones(n_global) / n_global
for real_label in labels_real:
cluster_real = df_real[df_real['label'] == real_label][['x','y']].values
m = cluster_real.shape[0]
weights_real = np.ones(m) / m
M = ot.dist(global_synth, cluster_real, metric='euclidean')
distances[key][real_label] = ot.emd2(weights_global, weights_real, M)
return pd.DataFrame(distances).T
def create_table(df_distances):
df_table = df_distances.copy()
df_table.reset_index(inplace=True)
df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
min_row = {"Synthetic": "Min."}
mean_row = {"Synthetic": "Mean"}
max_row = {"Synthetic": "Max."}
for col in df_table.columns:
if col != "Synthetic":
min_row[col] = df_table[col].min()
mean_row[col] = df_table[col].mean()
max_row[col] = df_table[col].max()
df_table = pd.concat([df_table, pd.DataFrame([min_row, mean_row, max_row])], ignore_index=True)
source_table = ColumnDataSource(df_table)
columns = [TableColumn(field='Synthetic', title='Synthetic')]
for col in df_table.columns:
if col != 'Synthetic':
columns.append(TableColumn(field=col, title=col))
total_height = 30 + len(df_table)*28
data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height)
return data_table, df_table, source_table
def run_model(model_name):
embeddings = load_embeddings(model_name)
if embeddings is None:
return
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
# Combina todos los DataFrames
df_combined = pd.concat(list(embeddings.values()), ignore_index=True)
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=model_name)
if reduction_method == "PCA":
reducer = PCA(n_components=2)
else:
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
reduced = reducer.fit_transform(df_combined[embedding_cols].values)
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
# Se espera que unique_subsets tenga claves "real" y "synthetic"
color_maps = get_color_maps(unique_subsets)
fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, unique_subsets, color_maps)
centers_real = calculate_cluster_centers(dfs_reduced["real"], unique_subsets["real"])
df_distances = compute_wasserstein_distances_all_synthetics(dfs_reduced["synthetic"],
dfs_reduced["real"],
unique_subsets["real"])
data_table, df_table, source_table = create_table(df_distances)
real_subset_names = list(df_table.columns[1:])
real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
reset_button = Button(label="Reset Colors", button_type="primary")
line_source = ColumnDataSource(data={'x': [], 'y': []})
fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
# Preparar centros para callback (para trazar líneas entre centros)
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
# Se podría preparar también los centros sintéticos si se requiere
synthetic_centers = {}
synth_labels = sorted(dfs_reduced["synthetic"]['label'].unique().tolist())
for label in synth_labels:
subset = dfs_reduced["synthetic"][dfs_reduced["synthetic"]['label'] == label]
synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()]
callback = CustomJS(args=dict(source=source_table, line_source=line_source,
synthetic_centers=synthetic_centers,
real_centers=real_centers_js,
real_select=real_select),
code="""
var selected = source.selected.indices;
if (selected.length > 0) {
var idx = selected[0];
var data = source.data;
var synth_label = data['Synthetic'][idx];
var real_label = real_select.value;
var syn_coords = synthetic_centers[synth_label];
var real_coords = real_centers[real_label];
line_source.data = {'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]]};
line_source.change.emit();
} else {
line_source.data = {'x': [], 'y': []};
line_source.change.emit();
}
""")
source_table.selected.js_on_change('indices', callback)
real_select.js_on_change('value', callback)
reset_callback = CustomJS(args=dict(line_source=line_source),
code="""
line_source.data = {'x': [], 'y': []};
line_source.change.emit();
""")
reset_button.js_on_event("button_click", reset_callback)
buffer = io.BytesIO()
df_table.to_excel(buffer, index=False)
buffer.seek(0)
layout = column(fig, column(real_select, reset_button, data_table))
st.bokeh_chart(layout, use_container_width=True)
st.download_button(
label="Export Table",
data=buffer,
file_name=f"cluster_distances_{model_name}.xlsx",
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
key=f"download_button_excel_{model_name}"
)
def main():
config_style()
tabs = st.tabs(["Donut", "Idefics2"])
with tabs[0]:
st.markdown('<h2 class="sub-title">Donut 🤗</h2>', unsafe_allow_html=True)
run_model("Donut")
with tabs[1]:
st.markdown('<h2 class="sub-title">Idefics2 🤗</h2>', unsafe_allow_html=True)
run_model("Idefics2")
if __name__ == "__main__":
main()