Embeddings / app.py
de-Rodrigo's picture
Cleaner Layout and Tabs for Different Models
89ffe36
raw
history blame
13.4 kB
import streamlit as st
import pandas as pd
import numpy as np
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button
from bokeh.layouts import row, column
from bokeh.palettes import Reds9, Blues9
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
TOOLTIPS = """
<div>
<div>
<img src="@img{safe}" style="width:128px; height:auto; float: left; margin: 0px 15px 15px 0px;" alt="@img" border="2"></img>
</div>
<div>
<span style="font-size: 17px; font-weight: bold;">@label</span>
</div>
</div>
"""
def config_style():
st.markdown("""
<style>
.main-title { font-size: 50px; color: #4CAF50; text-align: center; }
.sub-title { font-size: 30px; color: #555; }
.custom-text { font-size: 18px; line-height: 1.5; }
</style>
""", unsafe_allow_html=True)
st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True)
# Modificamos load_embeddings para aceptar el modelo a cargar
def load_embeddings(model):
if model == "Donut":
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
df_es_digital_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
elif model == "Idefics2":
df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
df_es_digital_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
else:
st.error("Modelo no reconocido")
return None
return {"real": df_real, "es-digital-seq": df_es_digital_seq}
# Funciones auxiliares (idénticas a las de tu código)
def reducer_selector(df_combined, embedding_cols):
reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"])
all_embeddings = df_combined[embedding_cols].values
if reduction_method == "PCA":
reducer = PCA(n_components=2)
else:
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
return reducer.fit_transform(all_embeddings)
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping):
renderers = {}
for label in selected_labels:
subset = df[df['label'] == label]
if subset.empty:
continue
source = ColumnDataSource(data=dict(
x=subset['x'],
y=subset['y'],
label=subset['label'],
img=subset['img']
))
color = color_mapping[label]
if marker == "circle":
r = fig.circle('x', 'y', size=10, source=source,
fill_color=color, line_color=color,
legend_label=f"{label} (Real)")
elif marker == "square":
r = fig.square('x', 'y', size=6, source=source,
fill_color=color, line_color=color,
legend_label=f"{label} (Synthetic)")
renderers[label] = r
return renderers
def get_color_maps(selected_subsets: dict):
num_real = len(selected_subsets["real"])
red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real]
color_mapping_real = {label: red_palette[i] for i, label in enumerate(sorted(selected_subsets["real"]))}
num_es = len(selected_subsets["es-digital-seq"])
blue_palette = Blues9[:num_es] if num_es <= 9 else (Blues9 * ((num_es // 9) + 1))[:num_es]
color_mapping_es = {label: blue_palette[i] for i, label in enumerate(sorted(selected_subsets["es-digital-seq"]))}
return {"real": color_mapping_real, "es-digital-seq": color_mapping_es}
def split_versions(df_combined, reduced):
df_combined['x'] = reduced[:, 0]
df_combined['y'] = reduced[:, 1]
df_real = df_combined[df_combined["version"] == "real"].copy()
df_es = df_combined[df_combined["version"] == "es_digital_seq"].copy()
unique_real = sorted(df_real['label'].unique().tolist())
unique_es = sorted(df_es['label'].unique().tolist())
return {"real": df_real, "es-digital-seq": df_es}, {"real": unique_real, "es-digital-seq": unique_es}
def create_figure(dfs_reduced, selected_subsets: dict, color_maps: dict):
fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="")
real_renderers = add_dataset_to_fig(fig, dfs_reduced["real"], selected_subsets["real"],
marker="circle", color_mapping=color_maps["real"])
synthetic_renderers = add_dataset_to_fig(fig, dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"],
marker="square", color_mapping=color_maps["es-digital-seq"])
fig.legend.location = "top_right"
fig.legend.click_policy = "hide"
return fig, real_renderers, synthetic_renderers
def calculate_cluster_centers(df: pd.DataFrame, selected_labels: list) -> dict:
centers = {}
for label in selected_labels:
subset = df[df['label'] == label]
if not subset.empty:
centers[label] = (subset['x'].mean(), subset['y'].mean())
return centers
def compute_distances(centers_es: dict, centers_real: dict) -> pd.DataFrame:
distances = {}
for es_label, (x_es, y_es) in centers_es.items():
distances[es_label] = {}
for real_label, (x_real, y_real) in centers_real.items():
distances[es_label][real_label] = np.sqrt((x_es - x_real)**2 + (y_es - y_real)**2)
return pd.DataFrame(distances).T
def create_table(df_distances):
df_table = df_distances.copy()
df_table.reset_index(inplace=True)
df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
source_table = ColumnDataSource(df_table)
columns = [TableColumn(field='Synthetic', title='Synthetic')]
for col in df_table.columns:
if col != 'Synthetic':
columns.append(TableColumn(field=col, title=col))
row_height = 28
header_height = 30
total_height = header_height + len(df_table) * row_height
data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height)
return data_table, df_table, source_table
# Función que ejecuta todo el proceso para un modelo determinado
def run_model(model_name):
embeddings = load_embeddings(model_name)
if embeddings is None:
return
# Asignamos la versión para distinguir en el split
embeddings["real"]["version"] = "real"
embeddings["es-digital-seq"]["version"] = "es_digital_seq"
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
df_combined = pd.concat([embeddings["real"], embeddings["es-digital-seq"]], ignore_index=True)
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=model_name)
if reduction_method == "PCA":
reducer = PCA(n_components=2)
else:
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
reduced = reducer.fit_transform(df_combined[embedding_cols].values)
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
selected_subsets = {"real": unique_subsets["real"], "es-digital-seq": unique_subsets["es-digital-seq"]}
color_maps = get_color_maps(selected_subsets)
fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, selected_subsets, color_maps)
centers_real = calculate_cluster_centers(dfs_reduced["real"], selected_subsets["real"])
centers_es = calculate_cluster_centers(dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"])
df_distances = compute_distances(centers_es, centers_real)
data_table, df_table, source_table = create_table(df_distances)
real_subset_names = list(df_table.columns[1:])
real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
reset_button = Button(label="Reset Colors", button_type="primary")
line_source = ColumnDataSource(data={'x': [], 'y': []})
fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
synthetic_centers_js = {k: [v[0], v[1]] for k, v in centers_es.items()}
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
# Callback para actualizar el gráfico
callback = CustomJS(args=dict(source=source_table, line_source=line_source,
synthetic_centers=synthetic_centers_js,
real_centers=real_centers_js,
synthetic_renderers=synthetic_renderers,
real_renderers=real_renderers,
synthetic_colors=color_maps["es-digital-seq"],
real_colors=color_maps["real"],
real_select=real_select),
code="""
var selected = source.selected.indices;
if (selected.length > 0) {
var row = selected[0];
var data = source.data;
var synthetic_label = data['Synthetic'][row];
var real_label = real_select.value;
var syn_coords = synthetic_centers[synthetic_label];
var real_coords = real_centers[real_label];
line_source.data = { 'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]] };
line_source.change.emit();
for (var key in synthetic_renderers) {
if (synthetic_renderers.hasOwnProperty(key)) {
var renderer = synthetic_renderers[key];
if (key === synthetic_label) {
renderer.glyph.fill_color = synthetic_colors[key];
renderer.glyph.line_color = synthetic_colors[key];
} else {
renderer.glyph.fill_color = "lightgray";
renderer.glyph.line_color = "lightgray";
}
}
}
for (var key in real_renderers) {
if (real_renderers.hasOwnProperty(key)) {
var renderer = real_renderers[key];
if (key === real_label) {
renderer.glyph.fill_color = real_colors[key];
renderer.glyph.line_color = real_colors[key];
} else {
renderer.glyph.fill_color = "lightgray";
renderer.glyph.line_color = "lightgray";
}
}
}
} else {
line_source.data = { 'x': [], 'y': [] };
line_source.change.emit();
for (var key in synthetic_renderers) {
if (synthetic_renderers.hasOwnProperty(key)) {
var renderer = synthetic_renderers[key];
renderer.glyph.fill_color = synthetic_colors[key];
renderer.glyph.line_color = synthetic_colors[key];
}
}
for (var key in real_renderers) {
if (real_renderers.hasOwnProperty(key)) {
var renderer = real_renderers[key];
renderer.glyph.fill_color = real_colors[key];
renderer.glyph.line_color = real_colors[key];
}
}
}
""")
source_table.selected.js_on_change('indices', callback)
real_select.js_on_change('value', callback)
reset_callback = CustomJS(args=dict(line_source=line_source,
synthetic_renderers=synthetic_renderers,
real_renderers=real_renderers,
synthetic_colors=color_maps["es-digital-seq"],
real_colors=color_maps["real"]),
code="""
line_source.data = { 'x': [], 'y': [] };
line_source.change.emit();
for (var key in synthetic_renderers) {
if (synthetic_renderers.hasOwnProperty(key)) {
var renderer = synthetic_renderers[key];
renderer.glyph.fill_color = synthetic_colors[key];
renderer.glyph.line_color = synthetic_colors[key];
}
}
for (var key in real_renderers) {
if (real_renderers.hasOwnProperty(key)) {
var renderer = real_renderers[key];
renderer.glyph.fill_color = real_colors[key];
renderer.glyph.line_color = real_colors[key];
}
}
""")
reset_button.js_on_event("button_click", reset_callback)
layout = column(fig, column(real_select, reset_button, data_table))
st.bokeh_chart(layout, use_container_width=True)
# Función principal con tabs para cambiar de modelo
def main():
config_style()
tabs = st.tabs(["Donut", "Idefics2"])
with tabs[0]:
st.markdown('<h2 class="sub-title">Modelo Donut 🤗</h2>', unsafe_allow_html=True)
run_model("Donut")
with tabs[1]:
st.markdown('<h2 class="sub-title">Modelo Idefics2 🤗</h2>', unsafe_allow_html=True)
run_model("Idefics2")
if __name__ == "__main__":
main()