import streamlit as st import pandas as pd import numpy as np from bokeh.plotting import figure from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button from bokeh.layouts import row, column from bokeh.palettes import Reds9, Blues9 from sklearn.decomposition import PCA from sklearn.manifold import TSNE import io TOOLTIPS = """
@img
@label
""" def config_style(): st.markdown(""" """, unsafe_allow_html=True) st.markdown('

Merit Embeddings 馃帓馃搩馃弳

', unsafe_allow_html=True) # Modificamos load_embeddings para aceptar el modelo a cargar def load_embeddings(model): if model == "Donut": df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv") df_es_digital_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv") elif model == "Idefics2": df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv") df_es_digital_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv") else: st.error("Modelo no reconocido") return None return {"real": df_real, "es-digital-seq": df_es_digital_seq} # Funciones auxiliares (id茅nticas a las de tu c贸digo) def reducer_selector(df_combined, embedding_cols): reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"]) all_embeddings = df_combined[embedding_cols].values if reduction_method == "PCA": reducer = PCA(n_components=2) else: reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200) return reducer.fit_transform(all_embeddings) def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping): renderers = {} for label in selected_labels: subset = df[df['label'] == label] if subset.empty: continue source = ColumnDataSource(data=dict( x=subset['x'], y=subset['y'], label=subset['label'], img=subset['img'] )) color = color_mapping[label] if marker == "circle": r = fig.circle('x', 'y', size=10, source=source, fill_color=color, line_color=color, legend_label=f"{label} (Real)") elif marker == "square": r = fig.square('x', 'y', size=6, source=source, fill_color=color, line_color=color, legend_label=f"{label} (Synthetic)") renderers[label] = r return renderers def get_color_maps(selected_subsets: dict): num_real = len(selected_subsets["real"]) red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real] color_mapping_real = {label: red_palette[i] for i, label in enumerate(sorted(selected_subsets["real"]))} num_es = len(selected_subsets["es-digital-seq"]) blue_palette = Blues9[:num_es] if num_es <= 9 else (Blues9 * ((num_es // 9) + 1))[:num_es] color_mapping_es = {label: blue_palette[i] for i, label in enumerate(sorted(selected_subsets["es-digital-seq"]))} return {"real": color_mapping_real, "es-digital-seq": color_mapping_es} def split_versions(df_combined, reduced): df_combined['x'] = reduced[:, 0] df_combined['y'] = reduced[:, 1] df_real = df_combined[df_combined["version"] == "real"].copy() df_es = df_combined[df_combined["version"] == "es_digital_seq"].copy() unique_real = sorted(df_real['label'].unique().tolist()) unique_es = sorted(df_es['label'].unique().tolist()) return {"real": df_real, "es-digital-seq": df_es}, {"real": unique_real, "es-digital-seq": unique_es} def create_figure(dfs_reduced, selected_subsets: dict, color_maps: dict): fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="") real_renderers = add_dataset_to_fig(fig, dfs_reduced["real"], selected_subsets["real"], marker="circle", color_mapping=color_maps["real"]) synthetic_renderers = add_dataset_to_fig(fig, dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"], marker="square", color_mapping=color_maps["es-digital-seq"]) fig.legend.location = "top_right" fig.legend.click_policy = "hide" return fig, real_renderers, synthetic_renderers def calculate_cluster_centers(df: pd.DataFrame, selected_labels: list) -> dict: centers = {} for label in selected_labels: subset = df[df['label'] == label] if not subset.empty: centers[label] = (subset['x'].mean(), subset['y'].mean()) return centers def compute_distances(centers_es: dict, centers_real: dict) -> pd.DataFrame: distances = {} for es_label, (x_es, y_es) in centers_es.items(): distances[es_label] = {} for real_label, (x_real, y_real) in centers_real.items(): distances[es_label][real_label] = np.sqrt((x_es - x_real)**2 + (y_es - y_real)**2) return pd.DataFrame(distances).T def create_table(df_distances): df_table = df_distances.copy() df_table.reset_index(inplace=True) df_table.rename(columns={'index': 'Synthetic'}, inplace=True) source_table = ColumnDataSource(df_table) columns = [TableColumn(field='Synthetic', title='Synthetic')] for col in df_table.columns: if col != 'Synthetic': columns.append(TableColumn(field=col, title=col)) row_height = 28 header_height = 30 total_height = header_height + len(df_table) * row_height data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height) return data_table, df_table, source_table # Funci贸n que ejecuta todo el proceso para un modelo determinado def run_model(model_name): embeddings = load_embeddings(model_name) if embeddings is None: return # Asignamos la versi贸n para distinguir en el split embeddings["real"]["version"] = "real" embeddings["es-digital-seq"]["version"] = "es_digital_seq" embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")] df_combined = pd.concat([embeddings["real"], embeddings["es-digital-seq"]], ignore_index=True) st.markdown('
Select Dimensionality Reduction Method
', unsafe_allow_html=True) reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=model_name) if reduction_method == "PCA": reducer = PCA(n_components=2) else: reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200) reduced = reducer.fit_transform(df_combined[embedding_cols].values) dfs_reduced, unique_subsets = split_versions(df_combined, reduced) selected_subsets = {"real": unique_subsets["real"], "es-digital-seq": unique_subsets["es-digital-seq"]} color_maps = get_color_maps(selected_subsets) fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, selected_subsets, color_maps) centers_real = calculate_cluster_centers(dfs_reduced["real"], selected_subsets["real"]) centers_es = calculate_cluster_centers(dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"]) df_distances = compute_distances(centers_es, centers_real) data_table, df_table, source_table = create_table(df_distances) real_subset_names = list(df_table.columns[1:]) real_select = Select(title="", value=real_subset_names[0], options=real_subset_names) reset_button = Button(label="Reset Colors", button_type="primary") line_source = ColumnDataSource(data={'x': [], 'y': []}) fig.line('x', 'y', source=line_source, line_width=2, line_color='black') synthetic_centers_js = {k: [v[0], v[1]] for k, v in centers_es.items()} real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()} # Callback para actualizar el gr谩fico callback = CustomJS(args=dict(source=source_table, line_source=line_source, synthetic_centers=synthetic_centers_js, real_centers=real_centers_js, synthetic_renderers=synthetic_renderers, real_renderers=real_renderers, synthetic_colors=color_maps["es-digital-seq"], real_colors=color_maps["real"], real_select=real_select), code=""" var selected = source.selected.indices; if (selected.length > 0) { var row = selected[0]; var data = source.data; var synthetic_label = data['Synthetic'][row]; var real_label = real_select.value; var syn_coords = synthetic_centers[synthetic_label]; var real_coords = real_centers[real_label]; line_source.data = { 'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]] }; line_source.change.emit(); for (var key in synthetic_renderers) { if (synthetic_renderers.hasOwnProperty(key)) { var renderer = synthetic_renderers[key]; if (key === synthetic_label) { renderer.glyph.fill_color = synthetic_colors[key]; renderer.glyph.line_color = synthetic_colors[key]; } else { renderer.glyph.fill_color = "lightgray"; renderer.glyph.line_color = "lightgray"; } } } for (var key in real_renderers) { if (real_renderers.hasOwnProperty(key)) { var renderer = real_renderers[key]; if (key === real_label) { renderer.glyph.fill_color = real_colors[key]; renderer.glyph.line_color = real_colors[key]; } else { renderer.glyph.fill_color = "lightgray"; renderer.glyph.line_color = "lightgray"; } } } } else { line_source.data = { 'x': [], 'y': [] }; line_source.change.emit(); for (var key in synthetic_renderers) { if (synthetic_renderers.hasOwnProperty(key)) { var renderer = synthetic_renderers[key]; renderer.glyph.fill_color = synthetic_colors[key]; renderer.glyph.line_color = synthetic_colors[key]; } } for (var key in real_renderers) { if (real_renderers.hasOwnProperty(key)) { var renderer = real_renderers[key]; renderer.glyph.fill_color = real_colors[key]; renderer.glyph.line_color = real_colors[key]; } } } """) source_table.selected.js_on_change('indices', callback) real_select.js_on_change('value', callback) reset_callback = CustomJS(args=dict(line_source=line_source, synthetic_renderers=synthetic_renderers, real_renderers=real_renderers, synthetic_colors=color_maps["es-digital-seq"], real_colors=color_maps["real"]), code=""" line_source.data = { 'x': [], 'y': [] }; line_source.change.emit(); for (var key in synthetic_renderers) { if (synthetic_renderers.hasOwnProperty(key)) { var renderer = synthetic_renderers[key]; renderer.glyph.fill_color = synthetic_colors[key]; renderer.glyph.line_color = synthetic_colors[key]; } } for (var key in real_renderers) { if (real_renderers.hasOwnProperty(key)) { var renderer = real_renderers[key]; renderer.glyph.fill_color = real_colors[key]; renderer.glyph.line_color = real_colors[key]; } } """) reset_button.js_on_event("button_click", reset_callback) buffer = io.BytesIO() df_table.to_excel(buffer, index=False) buffer.seek(0) # Agregar un bot贸n de descarga en Streamlit st.download_button( label="Exportar tabla a Excel", data=buffer, file_name="tabla.xlsx", mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" ) layout = column(fig, column(real_select, reset_button, data_table)) st.bokeh_chart(layout, use_container_width=True) # Funci贸n principal con tabs para cambiar de modelo def main(): config_style() tabs = st.tabs(["Donut", "Idefics2"]) with tabs[0]: st.markdown('

Modelo Donut 馃

', unsafe_allow_html=True) run_model("Donut") with tabs[1]: st.markdown('

Modelo Idefics2 馃

', unsafe_allow_html=True) run_model("Idefics2") if __name__ == "__main__": main()