Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from bokeh.plotting import figure | |
from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button, HoverTool | |
from bokeh.layouts import column | |
from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9, Greys9, BuGn9, Greens9 | |
from sklearn.decomposition import PCA | |
from sklearn.manifold import TSNE, trustworthiness | |
from sklearn.metrics import pairwise_distances | |
import io | |
import ot | |
from sklearn.linear_model import LinearRegression | |
N_COMPONENTS = 2 | |
TSNE_NEIGHBOURS = 150 | |
WEIGHT_FACTOR = 0.25 | |
TOOLTIPS = """ | |
<div> | |
<div> | |
<img src="@img{safe}" style="width:128px; height:auto; float: left; margin: 0px 15px 15px 0px;" alt="@img" border="2"></img> | |
</div> | |
<div> | |
<span style="font-size: 17px; font-weight: bold;">@label</span> | |
</div> | |
</div> | |
""" | |
def config_style(): | |
st.markdown(""" | |
<style> | |
.main-title { font-size: 50px; color: #4CAF50; text-align: center; } | |
.sub-title { font-size: 30px; color: #555; } | |
.custom-text { font-size: 18px; line-height: 1.5; } | |
.bk-legend { | |
max-height: 200px; | |
overflow-y: auto; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True) | |
def load_embeddings(model, version, embedding_prefix, weight_factor): | |
if model == "Donut": | |
df_real = pd.read_csv(f"data/donut/{version}/{embedding_prefix}/de_Rodrigo_merit_secret_all_{weight_factor}embeddings.csv") | |
df_par = pd.read_csv(f"data/donut/{version}/{embedding_prefix}/de_Rodrigo_merit_es-digital-paragraph-degradation-seq_{weight_factor}embeddings.csv") | |
df_line = pd.read_csv(f"data/donut/{version}/{embedding_prefix}/de_Rodrigo_merit_es-digital-line-degradation-seq_{weight_factor}embeddings.csv") | |
df_seq = pd.read_csv(f"data/donut/{version}/{embedding_prefix}/de_Rodrigo_merit_es-digital-seq_{weight_factor}embeddings.csv") | |
df_rot = pd.read_csv(f"data/donut/{version}/{embedding_prefix}/de_Rodrigo_merit_es-digital-rotation-degradation-seq_{weight_factor}embeddings.csv") | |
df_zoom = pd.read_csv(f"data/donut/{version}/{embedding_prefix}/de_Rodrigo_merit_es-digital-zoom-degradation-seq_{weight_factor}embeddings.csv") | |
df_render = pd.read_csv(f"data/donut/{version}/{embedding_prefix}/de_Rodrigo_merit_es-render-seq_{weight_factor}embeddings.csv") | |
df_pretratrained = pd.read_csv(f"data/donut/{version}/{embedding_prefix}/de_Rodrigo_merit_aux_IIT-CDIP_{weight_factor}embeddings.csv") | |
# Asignar etiquetas de versi贸n | |
df_real["version"] = "real" | |
df_par["version"] = "synthetic" | |
df_line["version"] = "synthetic" | |
df_seq["version"] = "synthetic" | |
df_rot["version"] = "synthetic" | |
df_zoom["version"] = "synthetic" | |
df_render["version"] = "synthetic" | |
df_pretratrained["version"] = "pretrained" | |
# Asignar fuente (source) | |
df_par["source"] = "es-digital-paragraph-degradation-seq" | |
df_line["source"] = "es-digital-line-degradation-seq" | |
df_seq["source"] = "es-digital-seq" | |
df_rot["source"] = "es-digital-rotation-degradation-seq" | |
df_zoom["source"] = "es-digital-zoom-degradation-seq" | |
df_render["source"] = "es-render-seq" | |
df_pretratrained["source"] = "pretrained" | |
return {"real": df_real, | |
"synthetic": pd.concat([df_seq, df_line, df_par, df_rot, df_zoom, df_render], ignore_index=True), | |
"pretrained": df_pretratrained} | |
elif model == "Idefics2": | |
df_real = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_secret_britanico_{embedding_prefix}embeddings.csv") | |
df_par = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_{embedding_prefix}embeddings.csv") | |
df_line = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-line-degradation-seq_{embedding_prefix}embeddings.csv") | |
df_seq = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-seq_{embedding_prefix}embeddings.csv") | |
df_rot = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-rotation-degradation-seq_{embedding_prefix}embeddings.csv") | |
df_zoom = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-zoom-degradation-seq_{embedding_prefix}embeddings.csv") | |
df_render = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-render-seq_{embedding_prefix}embeddings.csv") | |
# Cargar ambos subconjuntos pretrained y combinarlos | |
df_pretratrained_PDFA = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_aux_PDFA_{embedding_prefix}embeddings.csv") | |
df_pretratrained_IDL = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_aux_IDL_{embedding_prefix}embeddings.csv") | |
df_pretratrained = pd.concat([df_pretratrained_PDFA, df_pretratrained_IDL], ignore_index=True) | |
# Asignar etiquetas de versi贸n | |
df_real["version"] = "real" | |
df_par["version"] = "synthetic" | |
df_line["version"] = "synthetic" | |
df_seq["version"] = "synthetic" | |
df_rot["version"] = "synthetic" | |
df_zoom["version"] = "synthetic" | |
df_render["version"] = "synthetic" | |
df_pretratrained["version"] = "pretrained" | |
# Asignar fuente (source) | |
df_par["source"] = "es-digital-paragraph-degradation-seq" | |
df_line["source"] = "es-digital-line-degradation-seq" | |
df_seq["source"] = "es-digital-seq" | |
df_rot["source"] = "es-digital-rotation-degradation-seq" | |
df_zoom["source"] = "es-digital-zoom-degradation-seq" | |
df_render["source"] = "es-render-seq" | |
df_pretratrained["source"] = "pretrained" | |
return {"real": df_real, | |
"synthetic": pd.concat([df_seq, df_line, df_par, df_rot, df_zoom, df_render], ignore_index=True), | |
"pretrained": df_pretratrained} | |
else: | |
st.error("Modelo no reconocido") | |
return None | |
def split_versions(df_combined, reduced): | |
# Asignar las coordenadas si la reducci贸n es 2D | |
if reduced.shape[1] == 2: | |
df_combined['x'] = reduced[:, 0] | |
df_combined['y'] = reduced[:, 1] | |
df_real = df_combined[df_combined["version"] == "real"].copy() | |
df_synth = df_combined[df_combined["version"] == "synthetic"].copy() | |
df_pretrained = df_combined[df_combined["version"] == "pretrained"].copy() | |
unique_real = sorted(df_real['label'].unique().tolist()) | |
unique_synth = {} | |
for source in df_synth["source"].unique(): | |
unique_synth[source] = sorted(df_synth[df_synth["source"] == source]['label'].unique().tolist()) | |
unique_pretrained = sorted(df_pretrained['label'].unique().tolist()) | |
df_dict = {"real": df_real, "synthetic": df_synth, "pretrained": df_pretrained} | |
unique_subsets = {"real": unique_real, "synthetic": unique_synth, "pretrained": unique_pretrained} | |
return df_dict, unique_subsets | |
def get_embedding_from_df(df): | |
# Retorna el embedding completo (4 dimensiones en este caso) guardado en la columna 'embedding' | |
if 'embedding' in df.columns: | |
return np.stack(df['embedding'].to_numpy()) | |
elif 'x' in df.columns and 'y' in df.columns: | |
return df[['x', 'y']].values | |
else: | |
raise ValueError("No se encontr贸 embedding o coordenadas x,y en el DataFrame.") | |
def compute_cluster_distance(synthetic_points, real_points, metric="wasserstein", bins=20): | |
if metric.lower() == "wasserstein": | |
n = synthetic_points.shape[0] | |
m = real_points.shape[0] | |
weights = np.ones(n) / n | |
weights_real = np.ones(m) / m | |
M = ot.dist(synthetic_points, real_points, metric='euclidean') | |
return ot.emd2(weights, weights_real, M) | |
elif metric.lower() == "euclidean": | |
center_syn = np.mean(synthetic_points, axis=0) | |
center_real = np.mean(real_points, axis=0) | |
return np.linalg.norm(center_syn - center_real) | |
elif metric.lower() == "kl": | |
# Para KL usamos histogramas multidimensionales con l铆mites globales en cada dimensi贸n | |
all_points = np.vstack([synthetic_points, real_points]) | |
edges = [ | |
np.linspace(np.min(all_points[:, i]), np.max(all_points[:, i]), bins+1) | |
for i in range(all_points.shape[1]) | |
] | |
H_syn, _ = np.histogramdd(synthetic_points, bins=edges) | |
H_real, _ = np.histogramdd(real_points, bins=edges) | |
eps = 1e-10 | |
P = H_syn + eps | |
Q = H_real + eps | |
P = P / P.sum() | |
Q = Q / Q.sum() | |
kl = np.sum(P * np.log(P / Q)) | |
return kl | |
else: | |
raise ValueError("M茅trica desconocida. Usa 'wasserstein', 'euclidean' o 'kl'.") | |
def compute_cluster_distances_synthetic_individual(synthetic_df: pd.DataFrame, df_real: pd.DataFrame, real_labels: list, metric="wasserstein", bins=20) -> pd.DataFrame: | |
distances = {} | |
groups = synthetic_df.groupby(['source', 'label']) | |
for (source, label), group in groups: | |
key = f"{label} ({source})" | |
data = get_embedding_from_df(group) | |
distances[key] = {} | |
for real_label in real_labels: | |
real_data = get_embedding_from_df(df_real[df_real['label'] == real_label]) | |
d = compute_cluster_distance(data, real_data, metric=metric, bins=bins) | |
distances[key][real_label] = d | |
for source, group in synthetic_df.groupby('source'): | |
key = f"Global ({source})" | |
data = get_embedding_from_df(group) | |
distances[key] = {} | |
for real_label in real_labels: | |
real_data = get_embedding_from_df(df_real[df_real['label'] == real_label]) | |
d = compute_cluster_distance(data, real_data, metric=metric, bins=bins) | |
distances[key][real_label] = d | |
return pd.DataFrame(distances).T | |
def compute_continuity(X, X_embedded, n_neighbors=5): | |
n = X.shape[0] | |
D_high = pairwise_distances(X, metric='euclidean') | |
D_low = pairwise_distances(X_embedded, metric='euclidean') | |
indices_high = np.argsort(D_high, axis=1) | |
indices_low = np.argsort(D_low, axis=1) | |
k_high = indices_high[:, 1:n_neighbors+1] | |
k_low = indices_low[:, 1:n_neighbors+1] | |
total = 0.0 | |
for i in range(n): | |
set_high = set(k_high[i]) | |
set_low = set(k_low[i]) | |
missing = set_high - set_low | |
for j in missing: | |
rank = np.where(indices_low[i] == j)[0][0] | |
total += (rank - n_neighbors) | |
norm = 2.0 / (n * n_neighbors * (2*n - 3*n_neighbors - 1)) | |
continuity_value = 1 - norm * total | |
return continuity_value | |
def create_table(df_distances): | |
df_table = df_distances.copy() | |
df_table.reset_index(inplace=True) | |
df_table.rename(columns={'index': 'Synthetic'}, inplace=True) | |
min_row = {"Synthetic": "Min."} | |
mean_row = {"Synthetic": "Mean"} | |
max_row = {"Synthetic": "Max."} | |
for col in df_table.columns: | |
if col != "Synthetic": | |
min_row[col] = df_table[col].min() | |
mean_row[col] = df_table[col].mean() | |
max_row[col] = df_table[col].max() | |
df_table = pd.concat([df_table, pd.DataFrame([min_row, mean_row, max_row])], ignore_index=True) | |
source_table = ColumnDataSource(df_table) | |
columns = [TableColumn(field='Synthetic', title='Synthetic')] | |
for col in df_table.columns: | |
if col != 'Synthetic': | |
columns.append(TableColumn(field=col, title=col)) | |
total_height = 30 + len(df_table)*28 | |
data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height) | |
return data_table, df_table, source_table | |
def create_figure(dfs, unique_subsets, color_maps, model_name): | |
# Se crea el plot para el embedding reducido (asumiendo que es 2D) | |
fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="") | |
# Renderizar datos reales | |
real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"], | |
marker="circle", color_mapping=color_maps["real"], | |
group_label="Real") | |
# Renderizar datos sint茅ticos (por fuente) | |
marker_mapping = { | |
"es-digital-paragraph-degradation-seq": "x", | |
"es-digital-line-degradation-seq": "cross", | |
"es-digital-seq": "triangle", | |
"es-digital-rotation-degradation-seq": "diamond", | |
"es-digital-zoom-degradation-seq": "asterisk", | |
"es-render-seq": "inverted_triangle" | |
} | |
synthetic_renderers = {} | |
synth_df = dfs["synthetic"] | |
for source in unique_subsets["synthetic"]: | |
df_source = synth_df[synth_df["source"] == source] | |
marker = marker_mapping.get(source, "square") | |
renderers = add_synthetic_dataset_to_fig(fig, df_source, unique_subsets["synthetic"][source], | |
marker=marker, | |
color_mapping=color_maps["synthetic"][source], | |
group_label=source) | |
synthetic_renderers.update(renderers) | |
# Agregar el subset pretrained (se puede usar un marcador distinto, por ejemplo, "triangle") | |
pretrained_renderers = add_dataset_to_fig(fig, dfs["pretrained"], unique_subsets["pretrained"], | |
marker="triangle", color_mapping=color_maps["pretrained"], | |
group_label="Pretrained") | |
fig.legend.location = "top_right" | |
fig.legend.click_policy = "hide" | |
show_legend = st.checkbox("Show Legend", value=False, key=f"legend_{model_name}") | |
fig.legend.visible = show_legend | |
return fig, real_renderers, synthetic_renderers, pretrained_renderers | |
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label): | |
renderers = {} | |
for label in selected_labels: | |
subset = df[df['label'] == label] | |
if subset.empty: | |
continue | |
source = ColumnDataSource(data=dict( | |
x=subset['x'], | |
y=subset['y'], | |
label=subset['label'], | |
img=subset.get('img', "") | |
)) | |
color = color_mapping[label] | |
legend_label = f"{label} ({group_label})" | |
if marker == "circle": | |
r = fig.circle('x', 'y', size=10, source=source, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "square": | |
r = fig.square('x', 'y', size=10, source=source, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "triangle": | |
r = fig.triangle('x', 'y', size=12, source=source, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
renderers[label + f" ({group_label})"] = r | |
return renderers | |
def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_label): | |
renderers = {} | |
for label in labels: | |
subset = df[df['label'] == label] | |
if subset.empty: | |
continue | |
source_obj = ColumnDataSource(data=dict( | |
x=subset['x'], | |
y=subset['y'], | |
label=subset['label'], | |
img=subset.get('img', "") | |
)) | |
color = color_mapping[label] | |
legend_label = group_label | |
if marker == "square": | |
r = fig.square('x', 'y', size=10, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "triangle": | |
r = fig.triangle('x', 'y', size=12, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "inverted_triangle": | |
r = fig.inverted_triangle('x', 'y', size=12, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "diamond": | |
r = fig.diamond('x', 'y', size=10, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "cross": | |
r = fig.cross('x', 'y', size=12, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "x": | |
r = fig.x('x', 'y', size=12, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
elif marker == "asterisk": | |
r = fig.asterisk('x', 'y', size=12, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
else: | |
r = fig.circle('x', 'y', size=10, source=source_obj, | |
fill_color=color, line_color=color, | |
legend_label=legend_label) | |
renderers[label + f" ({group_label})"] = r | |
return renderers | |
def get_color_maps(unique_subsets): | |
color_map = {} | |
num_real = len(unique_subsets["real"]) | |
red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real] | |
color_map["real"] = {label: red_palette[i] for i, label in enumerate(sorted(unique_subsets["real"]))} | |
color_map["synthetic"] = {} | |
for source, labels in unique_subsets["synthetic"].items(): | |
if source == "es-digital-seq": | |
palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)] | |
elif source == "es-digital-line-degradation-seq": | |
palette = Purples9[:len(labels)] if len(labels) <= 9 else (Purples9 * ((len(labels)//9)+1))[:len(labels)] | |
elif source == "es-digital-paragraph-degradation-seq": | |
palette = BuGn9[:len(labels)] if len(labels) <= 9 else (BuGn9 * ((len(labels)//9)+1))[:len(labels)] | |
elif source == "es-digital-rotation-degradation-seq": | |
palette = Greys9[:len(labels)] if len(labels) <= 9 else (Greys9 * ((len(labels)//9)+1))[:len(labels)] | |
elif source == "es-digital-zoom-degradation-seq": | |
palette = Oranges9[:len(labels)] if len(labels) <= 9 else (Oranges9 * ((len(labels)//9)+1))[:len(labels)] | |
elif source == "es-render-seq": | |
palette = Greens9[:len(labels)] if len(labels) <= 9 else (Greens9 * ((len(labels)//9)+1))[:len(labels)] | |
else: | |
palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)] | |
color_map["synthetic"][source] = {label: palette[i] for i, label in enumerate(sorted(labels))} | |
# Asignar colores al subset pretrained usando, por ejemplo, la paleta Purples9 | |
num_pretrained = len(unique_subsets["pretrained"]) | |
purple_palette = Purples9[:num_pretrained] if num_pretrained <= 9 else (Purples9 * ((num_pretrained // 9) + 1))[:num_pretrained] | |
color_map["pretrained"] = {label: purple_palette[i] for i, label in enumerate(sorted(unique_subsets["pretrained"]))} | |
return color_map | |
def calculate_cluster_centers(df, labels): | |
centers = {} | |
for label in labels: | |
subset = df[df['label'] == label] | |
if not subset.empty and 'x' in subset.columns and 'y' in subset.columns: | |
centers[label] = (subset['x'].mean(), subset['y'].mean()) | |
return centers | |
def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE", distance_metric="wasserstein"): | |
if reduction_method == "PCA": | |
reducer = PCA(n_components=N_COMPONENTS) | |
else: | |
reducer = TSNE(n_components=2, random_state=42, | |
perplexity=tsne_params["perplexity"], | |
learning_rate=tsne_params["learning_rate"]) | |
reduced = reducer.fit_transform(df_combined[embedding_cols].values) | |
# Guardamos el embedding completo (por ejemplo, 4 dimensiones en PCA) | |
df_combined['embedding'] = list(reduced) | |
# Si el embedding es 2D, asignamos x e y para visualizaci贸n | |
if reduced.shape[1] == 2: | |
df_combined['x'] = reduced[:, 0] | |
df_combined['y'] = reduced[:, 1] | |
explained_variance = None | |
if reduction_method == "PCA": | |
explained_variance = reducer.explained_variance_ratio_ | |
trust = None | |
cont = None | |
if reduction_method == "t-SNE": | |
X = df_combined[embedding_cols].values | |
trust = trustworthiness(X, reduced, n_neighbors=TSNE_NEIGHBOURS) | |
cont = compute_continuity(X, reduced, n_neighbors=TSNE_NEIGHBOURS) | |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced) | |
df_distances = compute_cluster_distances_synthetic_individual( | |
dfs_reduced["synthetic"], | |
dfs_reduced["real"], | |
unique_subsets["real"], | |
metric=distance_metric | |
) | |
global_distances = {} | |
for idx in df_distances.index: | |
if idx.startswith("Global"): | |
source = idx.split("(")[1].rstrip(")") | |
global_distances[source] = df_distances.loc[idx].values | |
all_x = [] | |
all_y = [] | |
for source in df_f1.columns: | |
if source in global_distances: | |
x_vals = global_distances[source] | |
y_vals = df_f1[source].values | |
all_x.extend(x_vals) | |
all_y.extend(y_vals) | |
all_x_arr = np.array(all_x).reshape(-1, 1) | |
all_y_arr = np.array(all_y) | |
model_global = LinearRegression().fit(all_x_arr, all_y_arr) | |
r2 = model_global.score(all_x_arr, all_y_arr) | |
slope = model_global.coef_[0] | |
intercept = model_global.intercept_ | |
scatter_fig = figure(width=600, height=600, tools="pan,wheel_zoom,reset,save", | |
title="Scatter Plot: Distance vs F1") | |
source_colors = { | |
"es-digital-paragraph-degradation-seq": "blue", | |
"es-digital-line-degradation-seq": "green", | |
"es-digital-seq": "red", | |
"es-digital-zoom-degradation-seq": "orange", | |
"es-digital-rotation-degradation-seq": "purple", | |
"es-digital-rotation-zoom-degradation-seq": "brown", | |
"es-render-seq": "cyan" | |
} | |
for source in df_f1.columns: | |
if source in global_distances: | |
x_vals = global_distances[source] | |
y_vals = df_f1[source].values | |
data = {"x": x_vals, "y": y_vals, "Fuente": [source]*len(x_vals)} | |
cds = ColumnDataSource(data=data) | |
scatter_fig.circle('x', 'y', size=8, alpha=0.7, source=cds, | |
fill_color=source_colors.get(source, "gray"), | |
line_color=source_colors.get(source, "gray"), | |
legend_label=source) | |
scatter_fig.xaxis.axis_label = "Distance (Global, por Colegio)" | |
scatter_fig.yaxis.axis_label = "F1 Score" | |
scatter_fig.legend.location = "top_right" | |
hover_tool = HoverTool(tooltips=[("Distance", "@x"), ("F1", "@y"), ("Subset", "@Fuente")]) | |
scatter_fig.add_tools(hover_tool) | |
x_line = np.linspace(all_x_arr.min(), all_x_arr.max(), 100) | |
y_line = model_global.predict(x_line.reshape(-1, 1)) | |
scatter_fig.line(x_line, y_line, line_width=2, line_color="black", legend_label="Global Regression") | |
results = { | |
"R2": r2, | |
"slope": slope, | |
"intercept": intercept, | |
"scatter_fig": scatter_fig, | |
"dfs_reduced": dfs_reduced, | |
"unique_subsets": unique_subsets, | |
"df_distances": df_distances, | |
"explained_variance": explained_variance, | |
"trustworthiness": trust, | |
"continuity": cont | |
} | |
if reduction_method == "PCA": | |
results["pca_model"] = reducer # Agregamos el objeto PCA para usarlo luego en los plots | |
return results | |
def optimize_tsne_params(df_combined, embedding_cols, df_f1, distance_metric): | |
perplexity_range = np.linspace(30, 50, 10) | |
learning_rate_range = np.linspace(200, 1000, 20) | |
best_R2 = -np.inf | |
best_params = None | |
total_steps = len(perplexity_range) * len(learning_rate_range) | |
step = 0 | |
progress_text = st.empty() | |
for p in perplexity_range: | |
for lr in learning_rate_range: | |
step += 1 | |
progress_text.text(f"Evaluating: Perplexity={p:.2f}, Learning Rate={lr:.2f} (Step {step}/{total_steps})") | |
tsne_params = {"perplexity": p, "learning_rate": lr} | |
result = compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE", distance_metric=distance_metric) | |
r2_temp = result["R2"] | |
st.write(f"Parameters: Perplexity={p:.2f}, Learning Rate={lr:.2f} -> R虏={r2_temp:.4f}") | |
if r2_temp > best_R2: | |
best_R2 = r2_temp | |
best_params = (p, lr) | |
progress_text.text("Optimization completed!") | |
return best_params, best_R2 | |
def run_model(model_name): | |
version = st.selectbox("Select Model Version:", options=["vanilla", "finetuned_real"], key=f"version_{model_name}") | |
# Selector para el m茅todo de c贸mputo del embedding | |
embedding_computation = st.selectbox("驴C贸mo se computa el embedding?", options=["weighted", "averaged"], key=f"embedding_method_{model_name}") | |
# Se asigna el prefijo correspondiente | |
# prefijo_embedding = "weighted_" if embedding_computation == "weighted" else "averaged_" | |
if embedding_computation == "weighted": | |
# prefijo_embedding = "weighted_" | |
weight_factor = f"{WEIGHT_FACTOR}_" | |
else: | |
# prefijo_embedding = "averaged_" | |
weight_factor = "" | |
embeddings = load_embeddings(model_name, version, embedding_computation, weight_factor) | |
if embeddings is None: | |
return | |
# Nuevo selector para incluir o excluir el dataset pretrained | |
include_pretrained = st.checkbox("Incluir dataset pretrained", value=True) | |
if not include_pretrained: | |
# Removemos la entrada pretrained del diccionario, si existe. | |
embeddings.pop("pretrained", None) | |
# Extraer columnas de embedding de los datos "real" | |
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")] | |
# Concatenamos los datasets disponibles (ahora, sin pretrained si se deseleccion贸) | |
df_combined = pd.concat(list(embeddings.values()), ignore_index=True) | |
try: | |
df_f1 = pd.read_csv("data/f1-donut.csv", sep=';', index_col=0) | |
except Exception as e: | |
st.error(f"Error loading f1-donut.csv: {e}") | |
return | |
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True) | |
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=f"reduction_{model_name}") | |
distance_metric = st.selectbox("Select Distance Metric:", | |
options=["Wasserstein", "Euclidean", "KL"], | |
key=f"distance_metric_{model_name}") | |
tsne_params = {} | |
if reduction_method == "t-SNE": | |
if st.button("Optimize TSNE parameters", key=f"optimize_tsne_{model_name}"): | |
st.info("Running optimization, this can take a while...") | |
best_params, best_R2 = optimize_tsne_params(df_combined, embedding_cols, df_f1, distance_metric.lower()) | |
st.success(f"Best parameters: Perplexity = {best_params[0]:.2f}, Learning Rate = {best_params[1]:.2f} with R虏 = {best_R2:.4f}") | |
tsne_params = {"perplexity": best_params[0], "learning_rate": best_params[1]} | |
else: | |
perplexity_val = st.number_input( | |
"Perplexity", | |
min_value=5.0, | |
max_value=50.0, | |
value=30.0, | |
step=1.0, | |
format="%.2f", | |
key=f"perplexity_{model_name}" | |
) | |
learning_rate_val = st.number_input( | |
"Learning Rate", | |
min_value=10.0, | |
max_value=1000.0, | |
value=200.0, | |
step=10.0, | |
format="%.2f", | |
key=f"learning_rate_{model_name}" | |
) | |
tsne_params = {"perplexity": perplexity_val, "learning_rate": learning_rate_val} | |
result = compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method=reduction_method, distance_metric=distance_metric.lower()) | |
reg_metrics = pd.DataFrame({ | |
"Slope": [result["slope"]], | |
"Intercept": [result["intercept"]], | |
"R2": [result["R2"]] | |
}) | |
st.table(reg_metrics) | |
if reduction_method == "PCA" and result["explained_variance"] is not None: | |
st.subheader("Explained Variance Ratio") | |
component_names = [f"PC{i+1}" for i in range(len(result["explained_variance"]))] | |
variance_df = pd.DataFrame({ | |
"Component": component_names, | |
"Explained Variance": result["explained_variance"] | |
}) | |
st.table(variance_df) | |
elif reduction_method == "t-SNE": | |
st.subheader("t-SNE Quality Metrics") | |
st.write(f"Trustworthiness: {result['trustworthiness']:.4f}") | |
st.write(f"Continuity: {result['continuity']:.4f}") | |
# Si se us贸 PCA, se muestran los plots de loadings con Bokeh (con hover para ver la etiqueta) | |
if reduction_method == "PCA" and result.get("pca_model") is not None: | |
pca_model = result["pca_model"] | |
components = pca_model.components_ # Shape: (n_components, n_features) | |
st.subheader("Pesos de las Componentes Principales (Loadings)") | |
# Se crea un plot de barras por cada componente | |
for i, comp in enumerate(components): | |
source = ColumnDataSource(data=dict( | |
dimensions=embedding_cols, | |
weight=comp | |
)) | |
p = figure(x_range=embedding_cols, title=f"Componente Principal {i+1}", | |
plot_height=400, plot_width=600, | |
toolbar_location=None, tools="") | |
p.vbar(x='dimensions', top='weight', width=0.8, source=source) | |
# Ocultar etiquetas del eje x para un aspecto m谩s limpio | |
p.xaxis.major_label_text_font_size = '0pt' | |
# Agregar HoverTool para mostrar la dimensi贸n y su peso | |
hover = HoverTool(tooltips=[("Dimensi贸n", "@dimensions"), ("Peso", "@weight")]) | |
p.add_tools(hover) | |
p.xaxis.axis_label = "Dimensiones originales" | |
p.yaxis.axis_label = "Peso" | |
st.bokeh_chart(p) | |
data_table, df_table, source_table = create_table(result["df_distances"]) | |
real_subset_names = list(df_table.columns[1:]) | |
real_select = Select(title="", value=real_subset_names[0], options=real_subset_names) | |
reset_button = Button(label="Reset Colors", button_type="primary") | |
line_source = ColumnDataSource(data={'x': [], 'y': []}) | |
if (reduction_method == "t-SNE" and N_COMPONENTS == 2) or (reduction_method == "PCA" and N_COMPONENTS == 2): | |
fig, real_renderers, synthetic_renderers, pretrained_renderers = create_figure( | |
result["dfs_reduced"], | |
result["unique_subsets"], | |
get_color_maps(result["unique_subsets"]), | |
model_name | |
) | |
fig.line('x', 'y', source=line_source, line_width=2, line_color='black') | |
centers_real = calculate_cluster_centers(result["dfs_reduced"]["real"], result["unique_subsets"]["real"]) | |
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()} | |
synthetic_centers = {} | |
synth_labels = sorted(result["dfs_reduced"]["synthetic"]['label'].unique().tolist()) | |
for label in synth_labels: | |
subset = result["dfs_reduced"]["synthetic"][result["dfs_reduced"]["synthetic"]['label'] == label] | |
if 'x' in subset.columns and 'y' in subset.columns: | |
synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()] | |
callback = CustomJS(args=dict(source=source_table, line_source=line_source, | |
synthetic_centers=synthetic_centers, | |
real_centers=real_centers_js, | |
real_select=real_select), | |
code=""" | |
var selected = source.selected.indices; | |
if (selected.length > 0) { | |
var idx = selected[0]; | |
var data = source.data; | |
var synth_label = data['Synthetic'][idx]; | |
var real_label = real_select.value; | |
var syn_coords = synthetic_centers[synth_label]; | |
var real_coords = real_centers[real_label]; | |
line_source.data = {'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]]}; | |
line_source.change.emit(); | |
} else { | |
line_source.data = {'x': [], 'y': []}; | |
line_source.change.emit(); | |
} | |
""") | |
source_table.selected.js_on_change('indices', callback) | |
real_select.js_on_change('value', callback) | |
reset_callback = CustomJS(args=dict(line_source=line_source), | |
code=""" | |
line_source.data = {'x': [], 'y': []}; | |
line_source.change.emit(); | |
""") | |
reset_button.js_on_event("button_click", reset_callback) | |
layout = column(fig, result["scatter_fig"], column(real_select, reset_button, data_table)) | |
else: | |
layout = column(result["scatter_fig"], column(real_select, reset_button, data_table)) | |
st.bokeh_chart(layout, use_container_width=True) | |
buffer = io.BytesIO() | |
df_table.to_excel(buffer, index=False) | |
buffer.seek(0) | |
st.download_button( | |
label="Export Table", | |
data=buffer, | |
file_name=f"cluster_distances_{model_name}.xlsx", | |
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", | |
key=f"download_button_excel_{model_name}" | |
) | |
def main(): | |
config_style() | |
tabs = st.tabs(["Donut", "Idefics2"]) | |
with tabs[0]: | |
st.markdown('<h2 class="sub-title">Donut 馃</h2>', unsafe_allow_html=True) | |
run_model("Donut") | |
with tabs[1]: | |
st.markdown('<h2 class="sub-title">Idefics2 馃</h2>', unsafe_allow_html=True) | |
run_model("Idefics2") | |
if __name__ == "__main__": | |
main() |