Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from bokeh.plotting import figure | |
from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button | |
from bokeh.layouts import row, column | |
from bokeh.palettes import Reds9, Blues9 | |
from sklearn.decomposition import PCA | |
from sklearn.manifold import TSNE | |
import io | |
TOOLTIPS = """ | |
<div> | |
<div> | |
<img src="@img{safe}" style="width:128px; height:auto; float: left; margin: 0px 15px 15px 0px;" alt="@img" border="2"></img> | |
</div> | |
<div> | |
<span style="font-size: 17px; font-weight: bold;">@label</span> | |
</div> | |
</div> | |
""" | |
def config_style(): | |
st.markdown(""" | |
<style> | |
.main-title { font-size: 50px; color: #4CAF50; text-align: center; } | |
.sub-title { font-size: 30px; color: #555; } | |
.custom-text { font-size: 18px; line-height: 1.5; } | |
</style> | |
""", unsafe_allow_html=True) | |
st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True) | |
# Modificamos load_embeddings para aceptar el modelo a cargar | |
def load_embeddings(model): | |
if model == "Donut": | |
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv") | |
df_es_digital_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv") | |
elif model == "Idefics2": | |
df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv") | |
df_es_digital_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv") | |
else: | |
st.error("Modelo no reconocido") | |
return None | |
return {"real": df_real, "es-digital-seq": df_es_digital_seq} | |
# Funciones auxiliares (id茅nticas a las de tu c贸digo) | |
def reducer_selector(df_combined, embedding_cols): | |
reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"]) | |
all_embeddings = df_combined[embedding_cols].values | |
if reduction_method == "PCA": | |
reducer = PCA(n_components=2) | |
else: | |
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200) | |
return reducer.fit_transform(all_embeddings) | |
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping): | |
renderers = {} | |
for label in selected_labels: | |
subset = df[df['label'] == label] | |
if subset.empty: | |
continue | |
source = ColumnDataSource(data=dict( | |
x=subset['x'], | |
y=subset['y'], | |
label=subset['label'], | |
img=subset['img'] | |
)) | |
color = color_mapping[label] | |
if marker == "circle": | |
r = fig.circle('x', 'y', size=10, source=source, | |
fill_color=color, line_color=color, | |
legend_label=f"{label} (Real)") | |
elif marker == "square": | |
r = fig.square('x', 'y', size=6, source=source, | |
fill_color=color, line_color=color, | |
legend_label=f"{label} (Synthetic)") | |
renderers[label] = r | |
return renderers | |
def get_color_maps(selected_subsets: dict): | |
num_real = len(selected_subsets["real"]) | |
red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real] | |
color_mapping_real = {label: red_palette[i] for i, label in enumerate(sorted(selected_subsets["real"]))} | |
num_es = len(selected_subsets["es-digital-seq"]) | |
blue_palette = Blues9[:num_es] if num_es <= 9 else (Blues9 * ((num_es // 9) + 1))[:num_es] | |
color_mapping_es = {label: blue_palette[i] for i, label in enumerate(sorted(selected_subsets["es-digital-seq"]))} | |
return {"real": color_mapping_real, "es-digital-seq": color_mapping_es} | |
def split_versions(df_combined, reduced): | |
df_combined['x'] = reduced[:, 0] | |
df_combined['y'] = reduced[:, 1] | |
df_real = df_combined[df_combined["version"] == "real"].copy() | |
df_es = df_combined[df_combined["version"] == "es_digital_seq"].copy() | |
unique_real = sorted(df_real['label'].unique().tolist()) | |
unique_es = sorted(df_es['label'].unique().tolist()) | |
return {"real": df_real, "es-digital-seq": df_es}, {"real": unique_real, "es-digital-seq": unique_es} | |
def create_figure(dfs_reduced, selected_subsets: dict, color_maps: dict): | |
fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="") | |
real_renderers = add_dataset_to_fig(fig, dfs_reduced["real"], selected_subsets["real"], | |
marker="circle", color_mapping=color_maps["real"]) | |
synthetic_renderers = add_dataset_to_fig(fig, dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"], | |
marker="square", color_mapping=color_maps["es-digital-seq"]) | |
fig.legend.location = "top_right" | |
fig.legend.click_policy = "hide" | |
return fig, real_renderers, synthetic_renderers | |
def calculate_cluster_centers(df: pd.DataFrame, selected_labels: list) -> dict: | |
centers = {} | |
for label in selected_labels: | |
subset = df[df['label'] == label] | |
if not subset.empty: | |
centers[label] = (subset['x'].mean(), subset['y'].mean()) | |
return centers | |
def compute_distances(centers_es: dict, centers_real: dict) -> pd.DataFrame: | |
distances = {} | |
for es_label, (x_es, y_es) in centers_es.items(): | |
distances[es_label] = {} | |
for real_label, (x_real, y_real) in centers_real.items(): | |
distances[es_label][real_label] = np.sqrt((x_es - x_real)**2 + (y_es - y_real)**2) | |
return pd.DataFrame(distances).T | |
def create_table(df_distances): | |
df_table = df_distances.copy() | |
df_table.reset_index(inplace=True) | |
df_table.rename(columns={'index': 'Synthetic'}, inplace=True) | |
source_table = ColumnDataSource(df_table) | |
columns = [TableColumn(field='Synthetic', title='Synthetic')] | |
for col in df_table.columns: | |
if col != 'Synthetic': | |
columns.append(TableColumn(field=col, title=col)) | |
row_height = 28 | |
header_height = 30 | |
total_height = header_height + len(df_table) * row_height | |
data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height) | |
return data_table, df_table, source_table | |
# Funci贸n que ejecuta todo el proceso para un modelo determinado | |
def run_model(model_name): | |
embeddings = load_embeddings(model_name) | |
if embeddings is None: | |
return | |
# Asignamos la versi贸n para distinguir en el split | |
embeddings["real"]["version"] = "real" | |
embeddings["es-digital-seq"]["version"] = "es_digital_seq" | |
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")] | |
df_combined = pd.concat([embeddings["real"], embeddings["es-digital-seq"]], ignore_index=True) | |
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True) | |
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=model_name) | |
if reduction_method == "PCA": | |
reducer = PCA(n_components=2) | |
else: | |
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200) | |
reduced = reducer.fit_transform(df_combined[embedding_cols].values) | |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced) | |
selected_subsets = {"real": unique_subsets["real"], "es-digital-seq": unique_subsets["es-digital-seq"]} | |
color_maps = get_color_maps(selected_subsets) | |
fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, selected_subsets, color_maps) | |
centers_real = calculate_cluster_centers(dfs_reduced["real"], selected_subsets["real"]) | |
centers_es = calculate_cluster_centers(dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"]) | |
df_distances = compute_distances(centers_es, centers_real) | |
data_table, df_table, source_table = create_table(df_distances) | |
real_subset_names = list(df_table.columns[1:]) | |
real_select = Select(title="", value=real_subset_names[0], options=real_subset_names) | |
reset_button = Button(label="Reset Colors", button_type="primary") | |
line_source = ColumnDataSource(data={'x': [], 'y': []}) | |
fig.line('x', 'y', source=line_source, line_width=2, line_color='black') | |
synthetic_centers_js = {k: [v[0], v[1]] for k, v in centers_es.items()} | |
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()} | |
# Callback para actualizar el gr谩fico | |
callback = CustomJS(args=dict(source=source_table, line_source=line_source, | |
synthetic_centers=synthetic_centers_js, | |
real_centers=real_centers_js, | |
synthetic_renderers=synthetic_renderers, | |
real_renderers=real_renderers, | |
synthetic_colors=color_maps["es-digital-seq"], | |
real_colors=color_maps["real"], | |
real_select=real_select), | |
code=""" | |
var selected = source.selected.indices; | |
if (selected.length > 0) { | |
var row = selected[0]; | |
var data = source.data; | |
var synthetic_label = data['Synthetic'][row]; | |
var real_label = real_select.value; | |
var syn_coords = synthetic_centers[synthetic_label]; | |
var real_coords = real_centers[real_label]; | |
line_source.data = { 'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]] }; | |
line_source.change.emit(); | |
for (var key in synthetic_renderers) { | |
if (synthetic_renderers.hasOwnProperty(key)) { | |
var renderer = synthetic_renderers[key]; | |
if (key === synthetic_label) { | |
renderer.glyph.fill_color = synthetic_colors[key]; | |
renderer.glyph.line_color = synthetic_colors[key]; | |
} else { | |
renderer.glyph.fill_color = "lightgray"; | |
renderer.glyph.line_color = "lightgray"; | |
} | |
} | |
} | |
for (var key in real_renderers) { | |
if (real_renderers.hasOwnProperty(key)) { | |
var renderer = real_renderers[key]; | |
if (key === real_label) { | |
renderer.glyph.fill_color = real_colors[key]; | |
renderer.glyph.line_color = real_colors[key]; | |
} else { | |
renderer.glyph.fill_color = "lightgray"; | |
renderer.glyph.line_color = "lightgray"; | |
} | |
} | |
} | |
} else { | |
line_source.data = { 'x': [], 'y': [] }; | |
line_source.change.emit(); | |
for (var key in synthetic_renderers) { | |
if (synthetic_renderers.hasOwnProperty(key)) { | |
var renderer = synthetic_renderers[key]; | |
renderer.glyph.fill_color = synthetic_colors[key]; | |
renderer.glyph.line_color = synthetic_colors[key]; | |
} | |
} | |
for (var key in real_renderers) { | |
if (real_renderers.hasOwnProperty(key)) { | |
var renderer = real_renderers[key]; | |
renderer.glyph.fill_color = real_colors[key]; | |
renderer.glyph.line_color = real_colors[key]; | |
} | |
} | |
} | |
""") | |
source_table.selected.js_on_change('indices', callback) | |
real_select.js_on_change('value', callback) | |
reset_callback = CustomJS(args=dict(line_source=line_source, | |
synthetic_renderers=synthetic_renderers, | |
real_renderers=real_renderers, | |
synthetic_colors=color_maps["es-digital-seq"], | |
real_colors=color_maps["real"]), | |
code=""" | |
line_source.data = { 'x': [], 'y': [] }; | |
line_source.change.emit(); | |
for (var key in synthetic_renderers) { | |
if (synthetic_renderers.hasOwnProperty(key)) { | |
var renderer = synthetic_renderers[key]; | |
renderer.glyph.fill_color = synthetic_colors[key]; | |
renderer.glyph.line_color = synthetic_colors[key]; | |
} | |
} | |
for (var key in real_renderers) { | |
if (real_renderers.hasOwnProperty(key)) { | |
var renderer = real_renderers[key]; | |
renderer.glyph.fill_color = real_colors[key]; | |
renderer.glyph.line_color = real_colors[key]; | |
} | |
} | |
""") | |
reset_button.js_on_event("button_click", reset_callback) | |
buffer = io.BytesIO() | |
df_table.to_excel(buffer, index=False) | |
buffer.seek(0) | |
# Agregar un bot贸n de descarga en Streamlit | |
st.download_button( | |
label="Exportar tabla a Excel", | |
data=buffer, | |
file_name="tabla.xlsx", | |
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | |
) | |
layout = column(fig, column(real_select, reset_button, data_table)) | |
st.bokeh_chart(layout, use_container_width=True) | |
# Funci贸n principal con tabs para cambiar de modelo | |
def main(): | |
config_style() | |
tabs = st.tabs(["Donut", "Idefics2"]) | |
with tabs[0]: | |
st.markdown('<h2 class="sub-title">Modelo Donut 馃</h2>', unsafe_allow_html=True) | |
run_model("Donut") | |
with tabs[1]: | |
st.markdown('<h2 class="sub-title">Modelo Idefics2 馃</h2>', unsafe_allow_html=True) | |
run_model("Idefics2") | |
if __name__ == "__main__": | |
main() | |