Embeddings / app.py
de-Rodrigo's picture
Donwloadable Table
af68571
raw
history blame
13.7 kB
import streamlit as st
import pandas as pd
import numpy as np
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button
from bokeh.layouts import row, column
from bokeh.palettes import Reds9, Blues9
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import io
TOOLTIPS = """
<div>
<div>
<img src="@img{safe}" style="width:128px; height:auto; float: left; margin: 0px 15px 15px 0px;" alt="@img" border="2"></img>
</div>
<div>
<span style="font-size: 17px; font-weight: bold;">@label</span>
</div>
</div>
"""
def config_style():
st.markdown("""
<style>
.main-title { font-size: 50px; color: #4CAF50; text-align: center; }
.sub-title { font-size: 30px; color: #555; }
.custom-text { font-size: 18px; line-height: 1.5; }
</style>
""", unsafe_allow_html=True)
st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
# Modificamos load_embeddings para aceptar el modelo a cargar
def load_embeddings(model):
if model == "Donut":
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
df_es_digital_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
elif model == "Idefics2":
df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
df_es_digital_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
else:
st.error("Modelo no reconocido")
return None
return {"real": df_real, "es-digital-seq": df_es_digital_seq}
# Funciones auxiliares (id茅nticas a las de tu c贸digo)
def reducer_selector(df_combined, embedding_cols):
reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"])
all_embeddings = df_combined[embedding_cols].values
if reduction_method == "PCA":
reducer = PCA(n_components=2)
else:
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
return reducer.fit_transform(all_embeddings)
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping):
renderers = {}
for label in selected_labels:
subset = df[df['label'] == label]
if subset.empty:
continue
source = ColumnDataSource(data=dict(
x=subset['x'],
y=subset['y'],
label=subset['label'],
img=subset['img']
))
color = color_mapping[label]
if marker == "circle":
r = fig.circle('x', 'y', size=10, source=source,
fill_color=color, line_color=color,
legend_label=f"{label} (Real)")
elif marker == "square":
r = fig.square('x', 'y', size=6, source=source,
fill_color=color, line_color=color,
legend_label=f"{label} (Synthetic)")
renderers[label] = r
return renderers
def get_color_maps(selected_subsets: dict):
num_real = len(selected_subsets["real"])
red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real]
color_mapping_real = {label: red_palette[i] for i, label in enumerate(sorted(selected_subsets["real"]))}
num_es = len(selected_subsets["es-digital-seq"])
blue_palette = Blues9[:num_es] if num_es <= 9 else (Blues9 * ((num_es // 9) + 1))[:num_es]
color_mapping_es = {label: blue_palette[i] for i, label in enumerate(sorted(selected_subsets["es-digital-seq"]))}
return {"real": color_mapping_real, "es-digital-seq": color_mapping_es}
def split_versions(df_combined, reduced):
df_combined['x'] = reduced[:, 0]
df_combined['y'] = reduced[:, 1]
df_real = df_combined[df_combined["version"] == "real"].copy()
df_es = df_combined[df_combined["version"] == "es_digital_seq"].copy()
unique_real = sorted(df_real['label'].unique().tolist())
unique_es = sorted(df_es['label'].unique().tolist())
return {"real": df_real, "es-digital-seq": df_es}, {"real": unique_real, "es-digital-seq": unique_es}
def create_figure(dfs_reduced, selected_subsets: dict, color_maps: dict):
fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="")
real_renderers = add_dataset_to_fig(fig, dfs_reduced["real"], selected_subsets["real"],
marker="circle", color_mapping=color_maps["real"])
synthetic_renderers = add_dataset_to_fig(fig, dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"],
marker="square", color_mapping=color_maps["es-digital-seq"])
fig.legend.location = "top_right"
fig.legend.click_policy = "hide"
return fig, real_renderers, synthetic_renderers
def calculate_cluster_centers(df: pd.DataFrame, selected_labels: list) -> dict:
centers = {}
for label in selected_labels:
subset = df[df['label'] == label]
if not subset.empty:
centers[label] = (subset['x'].mean(), subset['y'].mean())
return centers
def compute_distances(centers_es: dict, centers_real: dict) -> pd.DataFrame:
distances = {}
for es_label, (x_es, y_es) in centers_es.items():
distances[es_label] = {}
for real_label, (x_real, y_real) in centers_real.items():
distances[es_label][real_label] = np.sqrt((x_es - x_real)**2 + (y_es - y_real)**2)
return pd.DataFrame(distances).T
def create_table(df_distances):
df_table = df_distances.copy()
df_table.reset_index(inplace=True)
df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
source_table = ColumnDataSource(df_table)
columns = [TableColumn(field='Synthetic', title='Synthetic')]
for col in df_table.columns:
if col != 'Synthetic':
columns.append(TableColumn(field=col, title=col))
row_height = 28
header_height = 30
total_height = header_height + len(df_table) * row_height
data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height)
return data_table, df_table, source_table
# Funci贸n que ejecuta todo el proceso para un modelo determinado
def run_model(model_name):
embeddings = load_embeddings(model_name)
if embeddings is None:
return
# Asignamos la versi贸n para distinguir en el split
embeddings["real"]["version"] = "real"
embeddings["es-digital-seq"]["version"] = "es_digital_seq"
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
df_combined = pd.concat([embeddings["real"], embeddings["es-digital-seq"]], ignore_index=True)
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=model_name)
if reduction_method == "PCA":
reducer = PCA(n_components=2)
else:
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
reduced = reducer.fit_transform(df_combined[embedding_cols].values)
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
selected_subsets = {"real": unique_subsets["real"], "es-digital-seq": unique_subsets["es-digital-seq"]}
color_maps = get_color_maps(selected_subsets)
fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, selected_subsets, color_maps)
centers_real = calculate_cluster_centers(dfs_reduced["real"], selected_subsets["real"])
centers_es = calculate_cluster_centers(dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"])
df_distances = compute_distances(centers_es, centers_real)
data_table, df_table, source_table = create_table(df_distances)
real_subset_names = list(df_table.columns[1:])
real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
reset_button = Button(label="Reset Colors", button_type="primary")
line_source = ColumnDataSource(data={'x': [], 'y': []})
fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
synthetic_centers_js = {k: [v[0], v[1]] for k, v in centers_es.items()}
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
# Callback para actualizar el gr谩fico
callback = CustomJS(args=dict(source=source_table, line_source=line_source,
synthetic_centers=synthetic_centers_js,
real_centers=real_centers_js,
synthetic_renderers=synthetic_renderers,
real_renderers=real_renderers,
synthetic_colors=color_maps["es-digital-seq"],
real_colors=color_maps["real"],
real_select=real_select),
code="""
var selected = source.selected.indices;
if (selected.length > 0) {
var row = selected[0];
var data = source.data;
var synthetic_label = data['Synthetic'][row];
var real_label = real_select.value;
var syn_coords = synthetic_centers[synthetic_label];
var real_coords = real_centers[real_label];
line_source.data = { 'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]] };
line_source.change.emit();
for (var key in synthetic_renderers) {
if (synthetic_renderers.hasOwnProperty(key)) {
var renderer = synthetic_renderers[key];
if (key === synthetic_label) {
renderer.glyph.fill_color = synthetic_colors[key];
renderer.glyph.line_color = synthetic_colors[key];
} else {
renderer.glyph.fill_color = "lightgray";
renderer.glyph.line_color = "lightgray";
}
}
}
for (var key in real_renderers) {
if (real_renderers.hasOwnProperty(key)) {
var renderer = real_renderers[key];
if (key === real_label) {
renderer.glyph.fill_color = real_colors[key];
renderer.glyph.line_color = real_colors[key];
} else {
renderer.glyph.fill_color = "lightgray";
renderer.glyph.line_color = "lightgray";
}
}
}
} else {
line_source.data = { 'x': [], 'y': [] };
line_source.change.emit();
for (var key in synthetic_renderers) {
if (synthetic_renderers.hasOwnProperty(key)) {
var renderer = synthetic_renderers[key];
renderer.glyph.fill_color = synthetic_colors[key];
renderer.glyph.line_color = synthetic_colors[key];
}
}
for (var key in real_renderers) {
if (real_renderers.hasOwnProperty(key)) {
var renderer = real_renderers[key];
renderer.glyph.fill_color = real_colors[key];
renderer.glyph.line_color = real_colors[key];
}
}
}
""")
source_table.selected.js_on_change('indices', callback)
real_select.js_on_change('value', callback)
reset_callback = CustomJS(args=dict(line_source=line_source,
synthetic_renderers=synthetic_renderers,
real_renderers=real_renderers,
synthetic_colors=color_maps["es-digital-seq"],
real_colors=color_maps["real"]),
code="""
line_source.data = { 'x': [], 'y': [] };
line_source.change.emit();
for (var key in synthetic_renderers) {
if (synthetic_renderers.hasOwnProperty(key)) {
var renderer = synthetic_renderers[key];
renderer.glyph.fill_color = synthetic_colors[key];
renderer.glyph.line_color = synthetic_colors[key];
}
}
for (var key in real_renderers) {
if (real_renderers.hasOwnProperty(key)) {
var renderer = real_renderers[key];
renderer.glyph.fill_color = real_colors[key];
renderer.glyph.line_color = real_colors[key];
}
}
""")
reset_button.js_on_event("button_click", reset_callback)
buffer = io.BytesIO()
df_table.to_excel(buffer, index=False)
buffer.seek(0)
# Agregar un bot贸n de descarga en Streamlit
st.download_button(
label="Exportar tabla a Excel",
data=buffer,
file_name="tabla.xlsx",
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
layout = column(fig, column(real_select, reset_button, data_table))
st.bokeh_chart(layout, use_container_width=True)
# Funci贸n principal con tabs para cambiar de modelo
def main():
config_style()
tabs = st.tabs(["Donut", "Idefics2"])
with tabs[0]:
st.markdown('<h2 class="sub-title">Modelo Donut 馃</h2>', unsafe_allow_html=True)
run_model("Donut")
with tabs[1]:
st.markdown('<h2 class="sub-title">Modelo Idefics2 馃</h2>', unsafe_allow_html=True)
run_model("Idefics2")
if __name__ == "__main__":
main()