Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from bokeh.plotting import figure | |
from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button, HoverTool | |
from bokeh.layouts import column | |
from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9, Greys9, BuGn9, Greens9 | |
from sklearn.decomposition import PCA | |
from sklearn.manifold import TSNE | |
import io | |
import ot | |
from sklearn.linear_model import LinearRegression | |
TOOLTIPS = """ | |
<div> | |
<div> | |
<img src="@img{safe}" style="width:128px; height:auto; float: left; margin: 0px 15px 15px 0px;" alt="@img" border="2"></img> | |
</div> | |
<div> | |
<span style="font-size: 17px; font-weight: bold;">@label</span> | |
</div> | |
</div> | |
""" | |
def config_style(): | |
st.markdown(""" | |
<style> | |
.main-title { font-size: 50px; color: #4CAF50; text-align: center; } | |
.sub-title { font-size: 30px; color: #555; } | |
.custom-text { font-size: 18px; line-height: 1.5; } | |
.bk-legend { | |
max-height: 200px; | |
overflow-y: auto; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True) | |
# ============================================================================= | |
# Funciones de carga de datos, generación de gráficos y cálculo de distancias (sin cambios) | |
# ============================================================================= | |
def load_embeddings(model): | |
if model == "Donut": | |
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv") | |
df_par = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_embeddings.csv") | |
df_line = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-line-degradation-seq_embeddings.csv") | |
df_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv") | |
df_rot = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-rotation-degradation-seq_embeddings.csv") | |
df_zoom = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-zoom-degradation-seq_embeddings.csv") | |
df_render = pd.read_csv("data/donut_de_Rodrigo_merit_es-render-seq_embeddings.csv") | |
df_real["version"] = "real" | |
df_par["version"] = "synthetic" | |
df_line["version"] = "synthetic" | |
df_seq["version"] = "synthetic" | |
df_rot["version"] = "synthetic" | |
df_zoom["version"] = "synthetic" | |
df_render["version"] = "synthetic" | |
df_par["source"] = "es-digital-paragraph-degradation-seq" | |
df_line["source"] = "es-digital-line-degradation-seq" | |
df_seq["source"] = "es-digital-seq" | |
df_rot["source"] = "es-digital-rotation-degradation-seq" | |
df_zoom["source"] = "es-digital-zoom-degradation-seq" | |
df_render["source"] = "es-render-seq" | |
return {"real": df_real, "synthetic": pd.concat([df_seq, df_line, df_par, df_rot, df_zoom, df_render], ignore_index=True)} | |
elif model == "Idefics2": | |
df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv") | |
df_par = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_embeddings.csv") | |
df_line = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-line-degradation-seq_embeddings.csv") | |
df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv") | |
df_rot = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-rotation-degradation-seq_embeddings.csv") | |
df_zoom = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-zoom-degradation-seq_embeddings.csv") | |
df_render = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-render-seq_embeddings.csv") | |
df_real["version"] = "real" | |
df_par["version"] = "synthetic" | |
df_line["version"] = "synthetic" | |
df_seq["version"] = "synthetic" | |
df_rot["version"] = "synthetic" | |
df_zoom["version"] = "synthetic" | |
df_render["version"] = "synthetic" | |
df_par["source"] = "es-digital-paragraph-degradation-seq" | |
df_line["source"] = "es-digital-line-degradation-seq" | |
df_seq["source"] = "es-digital-seq" | |
df_rot["source"] = "es-digital-rotation-degradation-seq" | |
df_zoom["source"] = "es-digital-zoom-degradation-seq" | |
df_render["source"] = "es-render-seq" | |
return {"real": df_real, "synthetic": pd.concat([df_seq, df_line, df_par, df_rot, df_zoom, df_render], ignore_index=True)} | |
else: | |
st.error("Modelo no reconocido") | |
return None | |
def split_versions(df_combined, reduced): | |
df_combined['x'] = reduced[:, 0] | |
df_combined['y'] = reduced[:, 1] | |
df_real = df_combined[df_combined["version"] == "real"].copy() | |
df_synth = df_combined[df_combined["version"] == "synthetic"].copy() | |
unique_real = sorted(df_real['label'].unique().tolist()) | |
unique_synth = {} | |
for source in df_synth["source"].unique(): | |
unique_synth[source] = sorted(df_synth[df_synth["source"] == source]['label'].unique().tolist()) | |
df_dict = {"real": df_real, "synthetic": df_synth} | |
unique_subsets = {"real": unique_real, "synthetic": unique_synth} | |
return df_dict, unique_subsets | |
def compute_wasserstein_distances_synthetic_individual(synthetic_df: pd.DataFrame, df_real: pd.DataFrame, real_labels: list) -> pd.DataFrame: | |
distances = {} | |
groups = synthetic_df.groupby(['source', 'label']) | |
for (source, label), group in groups: | |
key = f"{label} ({source})" | |
data = group[['x', 'y']].values | |
n = data.shape[0] | |
weights = np.ones(n) / n | |
distances[key] = {} | |
for real_label in real_labels: | |
real_data = df_real[df_real['label'] == real_label][['x','y']].values | |
m = real_data.shape[0] | |
weights_real = np.ones(m) / m | |
M = ot.dist(data, real_data, metric='euclidean') | |
distances[key][real_label] = ot.emd2(weights, weights_real, M) | |
for source, group in synthetic_df.groupby('source'): | |
key = f"Global ({source})" | |
data = group[['x','y']].values | |
n = data.shape[0] | |
weights = np.ones(n) / n | |
distances[key] = {} | |
for real_label in real_labels: | |
real_data = df_real[df_real['label'] == real_label][['x','y']].values | |
m = real_data.shape[0] | |
weights_real = np.ones(m) / m | |
M = ot.dist(data, real_data, metric='euclidean') | |
distances[key][real_label] = ot.emd2(weights, weights_real, M) | |
return pd.DataFrame(distances).T | |
def create_table(df_distances): | |
df_table = df_distances.copy() | |
df_table.reset_index(inplace=True) | |
df_table.rename(columns={'index': 'Synthetic'}, inplace=True) | |
min_row = {"Synthetic": "Min."} | |
mean_row = {"Synthetic": "Mean"} | |
max_row = {"Synthetic": "Max."} | |
for col in df_table.columns: | |
if col != "Synthetic": | |
min_row[col] = df_table[col].min() | |
mean_row[col] = df_table[col].mean() | |
max_row[col] = df_table[col].max() | |
df_table = pd.concat([df_table, pd.DataFrame([min_row, mean_row, max_row])], ignore_index=True) | |
source_table = ColumnDataSource(df_table) | |
columns = [TableColumn(field='Synthetic', title='Synthetic')] | |
for col in df_table.columns: | |
if col != 'Synthetic': | |
columns.append(TableColumn(field=col, title=col)) | |
total_height = 30 + len(df_table)*28 | |
data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height) | |
return data_table, df_table, source_table | |
def create_figure(dfs, unique_subsets, color_maps, model_name): | |
fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="") | |
real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"], | |
marker="circle", color_mapping=color_maps["real"], | |
group_label="Real") | |
marker_mapping = { | |
"es-digital-paragraph-degradation-seq": "x", | |
"es-digital-line-degradation-seq": "cross", | |
"es-digital-seq": "triangle", | |
"es-digital-rotation-degradation-seq": "diamond", | |
"es-digital-zoom-degradation-seq": "asterisk", | |
"es-render-seq": "inverted_triangle" | |
} | |
synthetic_renderers = {} | |
synth_df = dfs["synthetic"] | |
for source in unique_subsets["synthetic"]: | |
df_source = synth_df[synth_df["source"] == source] | |
marker = marker_mapping.get(source, "square") | |
renderers = add_synthetic_dataset_to_fig(fig, df_source, unique_subsets["synthetic"][source], | |
marker=marker, | |
color_mapping=color_maps["synthetic"][source], | |
group_label=source) | |
synthetic_renderers.update(renderers) | |
fig.legend.location = "top_right" | |
fig.legend.click_policy = "hide" | |
show_legend = st.checkbox("Show Legend", value=False, key=f"legend_{model_name}") | |
fig.legend.visible = show_legend | |
return fig, real_renderers, synthetic_renderers | |
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label): | |
renderers = {} | |
for label in selected_labels: | |
subset = df[df['label'] == label] | |
if subset.empty: | |
continue | |
source = ColumnDataSource(data=dict( | |
x=subset['x'], | |
y=subset['y'], | |
label=subset['label'], | |
img=subset.get('img', "") | |
)) | |
color = color_mapping[label] | |
legend_label = f"{label} ({group_label})" | |
if marker == "circle": | |
r = fig.circle('x', 'y', size=10, source=source, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "square": | |
r = fig.square('x', 'y', size=10, source=source, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "triangle": | |
r = fig.triangle('x', 'y', size=12, source=source, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
renderers[label + f" ({group_label})"] = r | |
return renderers | |
def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_label): | |
renderers = {} | |
for label in labels: | |
subset = df[df['label'] == label] | |
if subset.empty: | |
continue | |
source_obj = ColumnDataSource(data=dict( | |
x=subset['x'], | |
y=subset['y'], | |
label=subset['label'], | |
img=subset.get('img', "") | |
)) | |
color = color_mapping[label] | |
legend_label = group_label | |
if marker == "square": | |
r = fig.square('x', 'y', size=10, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "triangle": | |
r = fig.triangle('x', 'y', size=12, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "inverted_triangle": | |
r = fig.inverted_triangle('x', 'y', size=12, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "diamond": | |
r = fig.diamond('x', 'y', size=10, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "cross": | |
r = fig.cross('x', 'y', size=12, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "x": | |
r = fig.x('x', 'y', size=12, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "asterisk": | |
r = fig.asterisk('x', 'y', size=12, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
else: | |
r = fig.circle('x', 'y', size=10, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
renderers[label + f" ({group_label})"] = r | |
return renderers | |
def get_color_maps(unique_subsets): | |
color_map = {} | |
# Para reales se asigna color para cada etiqueta | |
num_real = len(unique_subsets["real"]) | |
red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real] | |
color_map["real"] = {label: red_palette[i] for i, label in enumerate(sorted(unique_subsets["real"]))} | |
# Para sintéticos se asigna color de forma granular: para cada source se mapea cada etiqueta | |
color_map["synthetic"] = {} | |
for source, labels in unique_subsets["synthetic"].items(): | |
if source == "es-digital-seq": | |
palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)] | |
elif source == "es-digital-line-degradation-seq": | |
palette = Purples9[:len(labels)] if len(labels) <= 9 else (Purples9 * ((len(labels)//9)+1))[:len(labels)] | |
elif source == "es-digital-paragraph-degradation-seq": | |
palette = BuGn9[:len(labels)] if len(labels) <= 9 else (BuGn9 * ((len(labels)//9)+1))[:len(labels)] | |
elif source == "es-digital-rotation-degradation-seq": | |
palette = Greys9[:len(labels)] if len(labels) <= 9 else (Greys9 * ((len(labels)//9)+1))[:len(labels)] | |
elif source == "es-digital-zoom-degradation-seq": | |
palette = Oranges9[:len(labels)] if len(labels) <= 9 else (Oranges9 * ((len(labels)//9)+1))[:len(labels)] | |
elif source == "es-render-seq": | |
palette = Greens9[:len(labels)] if len(labels) <= 9 else (Greens9 * ((len(labels)//9)+1))[:len(labels)] | |
else: | |
palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)] | |
color_map["synthetic"][source] = {label: palette[i] for i, label in enumerate(sorted(labels))} | |
return color_map | |
def calculate_cluster_centers(df, labels): | |
centers = {} | |
for label in labels: | |
subset = df[df['label'] == label] | |
if not subset.empty: | |
centers[label] = (subset['x'].mean(), subset['y'].mean()) | |
return centers | |
# ============================================================================= | |
# Función centralizada para la pipeline: reducción, distancias y regresión global | |
# ============================================================================= | |
def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE"): | |
# Seleccionar el reductor según el método | |
if reduction_method == "PCA": | |
reducer = PCA(n_components=2) | |
else: | |
reducer = TSNE(n_components=2, random_state=42, | |
perplexity=tsne_params["perplexity"], | |
learning_rate=tsne_params["learning_rate"]) | |
# Aplicar reducción dimensional | |
reduced = reducer.fit_transform(df_combined[embedding_cols].values) | |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced) | |
# Calcular distancias Wasserstein | |
df_distances = compute_wasserstein_distances_synthetic_individual( | |
dfs_reduced["synthetic"], | |
dfs_reduced["real"], | |
unique_subsets["real"] | |
) | |
# Extraer valores globales para cada fuente (se esperan 10 por fuente) | |
global_distances = {} | |
for idx in df_distances.index: | |
if idx.startswith("Global"): | |
source = idx.split("(")[1].rstrip(")") | |
global_distances[source] = df_distances.loc[idx].values | |
# Acumular todos los puntos (globales) y sus correspondientes f1 de cada colegio | |
all_x = [] | |
all_y = [] | |
for source in df_f1.columns: | |
if source in global_distances: | |
x_vals = global_distances[source] | |
y_vals = df_f1[source].values | |
all_x.extend(x_vals) | |
all_y.extend(y_vals) | |
all_x_arr = np.array(all_x).reshape(-1, 1) | |
all_y_arr = np.array(all_y) | |
# Realizar regresión lineal global | |
model_global = LinearRegression().fit(all_x_arr, all_y_arr) | |
r2 = model_global.score(all_x_arr, all_y_arr) | |
slope = model_global.coef_[0] | |
intercept = model_global.intercept_ | |
# Crear scatter plot para visualizar la relación | |
scatter_fig = figure(width=600, height=600, tools="pan,wheel_zoom,reset,save", | |
title="Scatter Plot: Wasserstein vs F1") | |
source_colors = { | |
"es-digital-paragraph-degradation-seq": "blue", | |
"es-digital-line-degradation-seq": "green", | |
"es-digital-seq": "red", | |
"es-digital-zoom-degradation-seq": "orange", | |
"es-digital-rotation-degradation-seq": "purple", | |
"es-digital-rotation-zoom-degradation-seq": "brown", | |
"es-render-seq": "cyan" | |
} | |
for source in df_f1.columns: | |
if source in global_distances: | |
x_vals = global_distances[source] | |
y_vals = df_f1[source].values | |
data = {"x": x_vals, "y": y_vals, "Fuente": [source]*len(x_vals)} | |
cds = ColumnDataSource(data=data) | |
scatter_fig.circle('x', 'y', size=8, alpha=0.7, source=cds, | |
fill_color=source_colors.get(source, "gray"), | |
line_color=source_colors.get(source, "gray"), | |
legend_label=source) | |
scatter_fig.xaxis.axis_label = "Wasserstein Distance (Global, por Colegio)" | |
scatter_fig.yaxis.axis_label = "F1 Score" | |
scatter_fig.legend.location = "top_right" | |
hover_tool = HoverTool(tooltips=[("Wass. Distance", "@x"), ("f1", "@y"), ("Subset", "@Fuente")]) | |
scatter_fig.add_tools(hover_tool) | |
# Línea de regresión global | |
x_line = np.linspace(all_x_arr.min(), all_x_arr.max(), 100) | |
y_line = model_global.predict(x_line.reshape(-1, 1)) | |
scatter_fig.line(x_line, y_line, line_width=2, line_color="black", legend_label="Global Regression") | |
return { | |
"R2": r2, | |
"slope": slope, | |
"intercept": intercept, | |
"scatter_fig": scatter_fig, | |
"dfs_reduced": dfs_reduced, | |
"unique_subsets": unique_subsets, | |
"df_distances": df_distances | |
} | |
# ============================================================================= | |
# Función de optimización (grid search) para TSNE, ahora que se usa la misma pipeline | |
# ============================================================================= | |
def optimize_tsne_params(df_combined, embedding_cols, df_f1): | |
# Rango de búsqueda | |
perplexity_range = np.linspace(30, 50, 10) | |
learning_rate_range = np.linspace(200, 1000, 20) | |
best_R2 = -np.inf | |
best_params = None | |
total_steps = len(perplexity_range) * len(learning_rate_range) | |
step = 0 | |
progress_text = st.empty() | |
for p in perplexity_range: | |
for lr in learning_rate_range: | |
step += 1 | |
progress_text.text(f"Evaluating: Perplexity={p:.2f}, Learning Rate={lr:.2f} (Step {step}/{total_steps})") | |
tsne_params = {"perplexity": p, "learning_rate": lr} | |
result = compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE") | |
r2_temp = result["R2"] | |
st.write(f"Parameters: Perplexity={p:.2f}, Learning Rate={lr:.2f} -> R²={r2_temp:.4f}") | |
if r2_temp > best_R2: | |
best_R2 = r2_temp | |
best_params = (p, lr) | |
progress_text.text("Optimization completed!") | |
return best_params, best_R2 | |
# ============================================================================= | |
# Función principal run_model que integra la optimización y la ejecución manual | |
# ============================================================================= | |
def run_model(model_name): | |
embeddings = load_embeddings(model_name) | |
if embeddings is None: | |
return | |
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")] | |
df_combined = pd.concat(list(embeddings.values()), ignore_index=True) | |
# Cargar CSV f1-donut | |
try: | |
df_f1 = pd.read_csv("data/f1-donut.csv", sep=';', index_col=0) | |
except Exception as e: | |
st.error(f"Error loading f1-donut.csv: {e}") | |
return | |
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True) | |
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=f"reduction_{model_name}") | |
tsne_params = {} | |
if reduction_method == "t-SNE": | |
if st.button("Optimize TSNE parameters", key=f"optimize_tsne_{model_name}"): | |
st.info("Running optimization, this can take a while...") | |
best_params, best_R2 = optimize_tsne_params(df_combined, embedding_cols, df_f1) | |
st.success(f"Mejores parámetros: Perplexity = {best_params[0]:.2f}, Learning Rate = {best_params[1]:.2f} con R² = {best_R2:.4f}") | |
tsne_params = {"perplexity": best_params[0], "learning_rate": best_params[1]} | |
else: | |
perplexity_val = st.number_input( | |
"Perplexity", | |
min_value=5.0, | |
max_value=50.0, | |
value=30.0, | |
step=1.0, | |
format="%.2f", | |
key=f"perplexity_{model_name}" | |
) | |
learning_rate_val = st.number_input( | |
"Learning Rate", | |
min_value=10.0, | |
max_value=1000.0, | |
value=200.0, | |
step=10.0, | |
format="%.2f", | |
key=f"learning_rate_{model_name}" | |
) | |
tsne_params = {"perplexity": perplexity_val, "learning_rate": learning_rate_val} | |
# Si se selecciona PCA, tsne_params no se usa. | |
# Usar la función centralizada para obtener la regresión global y el scatter plot | |
result = compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method=reduction_method) | |
reg_metrics = pd.DataFrame({ | |
"Slope": [result["slope"]], | |
"Intercept": [result["intercept"]], | |
"R2": [result["R2"]] | |
}) | |
st.table(reg_metrics) | |
# No llamamos a st.bokeh_chart(result["scatter_fig"], ...) aquí | |
# Sino que combinamos todo en un único layout: | |
data_table, df_table, source_table = create_table(result["df_distances"]) | |
real_subset_names = list(df_table.columns[1:]) | |
real_select = Select(title="", value=real_subset_names[0], options=real_subset_names) | |
reset_button = Button(label="Reset Colors", button_type="primary") | |
line_source = ColumnDataSource(data={'x': [], 'y': []}) | |
# Suponiendo que tienes una figura base 'fig' para los clusters: | |
fig, real_renderers, synthetic_renderers = create_figure(result["dfs_reduced"], result["unique_subsets"], get_color_maps(result["unique_subsets"]), model_name) | |
fig.line('x', 'y', source=line_source, line_width=2, line_color='black') | |
centers_real = calculate_cluster_centers(result["dfs_reduced"]["real"], result["unique_subsets"]["real"]) | |
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()} | |
synthetic_centers = {} | |
synth_labels = sorted(result["dfs_reduced"]["synthetic"]['label'].unique().tolist()) | |
for label in synth_labels: | |
subset = result["dfs_reduced"]["synthetic"][result["dfs_reduced"]["synthetic"]['label'] == label] | |
synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()] | |
callback = CustomJS(args=dict(source=source_table, line_source=line_source, | |
synthetic_centers=synthetic_centers, | |
real_centers=real_centers_js, | |
real_select=real_select), | |
code=""" | |
var selected = source.selected.indices; | |
if (selected.length > 0) { | |
var idx = selected[0]; | |
var data = source.data; | |
var synth_label = data['Synthetic'][idx]; | |
var real_label = real_select.value; | |
var syn_coords = synthetic_centers[synth_label]; | |
var real_coords = real_centers[real_label]; | |
line_source.data = {'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]]}; | |
line_source.change.emit(); | |
} else { | |
line_source.data = {'x': [], 'y': []}; | |
line_source.change.emit(); | |
} | |
""") | |
source_table.selected.js_on_change('indices', callback) | |
real_select.js_on_change('value', callback) | |
reset_callback = CustomJS(args=dict(line_source=line_source), | |
code=""" | |
line_source.data = {'x': [], 'y': []}; | |
line_source.change.emit(); | |
""") | |
reset_button.js_on_event("button_click", reset_callback) | |
buffer = io.BytesIO() | |
df_table.to_excel(buffer, index=False) | |
buffer.seek(0) | |
# Combinar todos los gráficos en un único layout: | |
layout = column(fig, result["scatter_fig"], column(real_select, reset_button, data_table)) | |
st.bokeh_chart(layout, use_container_width=True) | |
st.download_button( | |
label="Export Table", | |
data=buffer, | |
file_name=f"cluster_distances_{model_name}.xlsx", | |
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", | |
key=f"download_button_excel_{model_name}" | |
) | |
def main(): | |
config_style() | |
tabs = st.tabs(["Donut", "Idefics2"]) | |
with tabs[0]: | |
st.markdown('<h2 class="sub-title">Donut 🤗</h2>', unsafe_allow_html=True) | |
run_model("Donut") | |
with tabs[1]: | |
st.markdown('<h2 class="sub-title">Idefics2 🤗</h2>', unsafe_allow_html=True) | |
run_model("Idefics2") | |
if __name__ == "__main__": | |
main() | |