Spaces:
Sleeping
Sleeping
Commit
路
a448d0f
1
Parent(s):
fe51656
Input Floats to TSNE and Refactor
Browse files
app.py
CHANGED
|
@@ -36,7 +36,10 @@ def config_style():
|
|
| 36 |
""", unsafe_allow_html=True)
|
| 37 |
st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
|
| 38 |
|
| 39 |
-
#
|
|
|
|
|
|
|
|
|
|
| 40 |
def load_embeddings(model):
|
| 41 |
if model == "Donut":
|
| 42 |
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
|
|
@@ -54,7 +57,6 @@ def load_embeddings(model):
|
|
| 54 |
df_zoom["version"] = "synthetic"
|
| 55 |
df_render["version"] = "synthetic"
|
| 56 |
|
| 57 |
-
# Se asigna la fuente
|
| 58 |
df_par["source"] = "es-digital-paragraph-degradation-seq"
|
| 59 |
df_line["source"] = "es-digital-line-degradation-seq"
|
| 60 |
df_seq["source"] = "es-digital-seq"
|
|
@@ -65,29 +67,127 @@ def load_embeddings(model):
|
|
| 65 |
|
| 66 |
elif model == "Idefics2":
|
| 67 |
df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
|
|
|
|
|
|
|
| 68 |
df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
|
|
|
|
|
|
|
|
|
|
| 69 |
df_real["version"] = "real"
|
|
|
|
|
|
|
| 70 |
df_seq["version"] = "synthetic"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
df_seq["source"] = "es-digital-seq"
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
else:
|
| 75 |
st.error("Modelo no reconocido")
|
| 76 |
return None
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
-
# Funci贸n para agregar datos reales (por cada etiqueta)
|
| 91 |
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label):
|
| 92 |
renderers = {}
|
| 93 |
for label in selected_labels:
|
|
@@ -117,7 +217,6 @@ def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_la
|
|
| 117 |
renderers[label + f" ({group_label})"] = r
|
| 118 |
return renderers
|
| 119 |
|
| 120 |
-
# Nueva funci贸n para plotear sint茅ticos de forma granular pero con leyenda agrupada por source
|
| 121 |
def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_label):
|
| 122 |
renderers = {}
|
| 123 |
for label in labels:
|
|
@@ -130,11 +229,8 @@ def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_l
|
|
| 130 |
label=subset['label'],
|
| 131 |
img=subset.get('img', "")
|
| 132 |
))
|
| 133 |
-
# Se usa el color granular asignado a cada etiqueta
|
| 134 |
color = color_mapping[label]
|
| 135 |
-
# La leyenda se asigna al nombre del source para que se agrupe
|
| 136 |
legend_label = group_label
|
| 137 |
-
|
| 138 |
if marker == "square":
|
| 139 |
r = fig.square('x', 'y', size=10, source=source_obj,
|
| 140 |
fill_color=color, line_color=color,
|
|
@@ -171,6 +267,7 @@ def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_l
|
|
| 171 |
return renderers
|
| 172 |
|
| 173 |
|
|
|
|
| 174 |
def get_color_maps(unique_subsets):
|
| 175 |
color_map = {}
|
| 176 |
# Para reales se asigna color para cada etiqueta
|
|
@@ -197,59 +294,8 @@ def get_color_maps(unique_subsets):
|
|
| 197 |
palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)]
|
| 198 |
color_map["synthetic"][source] = {label: palette[i] for i, label in enumerate(sorted(labels))}
|
| 199 |
return color_map
|
| 200 |
-
|
| 201 |
-
def split_versions(df_combined, reduced):
|
| 202 |
-
df_combined['x'] = reduced[:, 0]
|
| 203 |
-
df_combined['y'] = reduced[:, 1]
|
| 204 |
-
df_real = df_combined[df_combined["version"] == "real"].copy()
|
| 205 |
-
df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
|
| 206 |
-
# Extraer etiquetas 煤nicas para reales
|
| 207 |
-
unique_real = sorted(df_real['label'].unique().tolist())
|
| 208 |
-
# Para sint茅ticos, se agrupan las etiquetas por source
|
| 209 |
-
unique_synth = {}
|
| 210 |
-
for source in df_synth["source"].unique():
|
| 211 |
-
unique_synth[source] = sorted(df_synth[df_synth["source"] == source]['label'].unique().tolist())
|
| 212 |
-
df_dict = {"real": df_real, "synthetic": df_synth}
|
| 213 |
-
# Para los reales se guarda la lista, y para sint茅ticos el diccionario
|
| 214 |
-
unique_subsets = {"real": unique_real, "synthetic": unique_synth}
|
| 215 |
-
return df_dict, unique_subsets
|
| 216 |
-
|
| 217 |
-
def create_figure(dfs, unique_subsets, color_maps, model_name):
|
| 218 |
-
fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="")
|
| 219 |
-
# Datos reales: se mantienen granulares en plot y en leyenda
|
| 220 |
-
real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
|
| 221 |
-
marker="circle", color_mapping=color_maps["real"],
|
| 222 |
-
group_label="Real")
|
| 223 |
-
# Diccionario de asignaci贸n de marcadores para sint茅ticos por source
|
| 224 |
-
marker_mapping = {
|
| 225 |
-
"es-digital-paragraph-degradation-seq": "x",
|
| 226 |
-
"es-digital-line-degradation-seq": "cross",
|
| 227 |
-
"es-digital-seq": "triangle",
|
| 228 |
-
"es-digital-rotation-degradation-seq": "diamond",
|
| 229 |
-
"es-digital-zoom-degradation-seq": "asterisk",
|
| 230 |
-
"es-render-seq": "inverted_triangle"
|
| 231 |
-
}
|
| 232 |
-
|
| 233 |
-
# Datos sint茅ticos: se plotean granularmente (por etiqueta) pero se agrupa la leyenda por source
|
| 234 |
-
synthetic_renderers = {}
|
| 235 |
-
synth_df = dfs["synthetic"]
|
| 236 |
-
for source in unique_subsets["synthetic"]:
|
| 237 |
-
df_source = synth_df[synth_df["source"] == source]
|
| 238 |
-
marker = marker_mapping.get(source, "square") # Por defecto "square" si no se encuentra
|
| 239 |
-
renderers = add_synthetic_dataset_to_fig(fig, df_source, unique_subsets["synthetic"][source],
|
| 240 |
-
marker=marker,
|
| 241 |
-
color_mapping=color_maps["synthetic"][source],
|
| 242 |
-
group_label=source)
|
| 243 |
-
synthetic_renderers.update(renderers)
|
| 244 |
|
| 245 |
-
|
| 246 |
-
fig.legend.click_policy = "hide"
|
| 247 |
-
show_legend = st.checkbox("Show Legend", value=False, key=f"legend_{model_name}")
|
| 248 |
-
fig.legend.visible = show_legend
|
| 249 |
-
return fig, real_renderers, synthetic_renderers
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
# Calcula los centros de cada cluster (por grupo)
|
| 253 |
def calculate_cluster_centers(df, labels):
|
| 254 |
centers = {}
|
| 255 |
for label in labels:
|
|
@@ -258,189 +304,60 @@ def calculate_cluster_centers(df, labels):
|
|
| 258 |
centers[label] = (subset['x'].mean(), subset['y'].mean())
|
| 259 |
return centers
|
| 260 |
|
| 261 |
-
# Calcula la distancia Wasserstein de cada subset sint茅tico respecto a cada cluster real (por cluster y global)
|
| 262 |
-
def compute_wasserstein_distances_synthetic_individual(synthetic_df: pd.DataFrame, df_real: pd.DataFrame, real_labels: list) -> pd.DataFrame:
|
| 263 |
-
distances = {}
|
| 264 |
-
groups = synthetic_df.groupby(['source', 'label'])
|
| 265 |
-
for (source, label), group in groups:
|
| 266 |
-
key = f"{label} ({source})"
|
| 267 |
-
data = group[['x', 'y']].values
|
| 268 |
-
n = data.shape[0]
|
| 269 |
-
weights = np.ones(n) / n
|
| 270 |
-
distances[key] = {}
|
| 271 |
-
for real_label in real_labels:
|
| 272 |
-
real_data = df_real[df_real['label'] == real_label][['x','y']].values
|
| 273 |
-
m = real_data.shape[0]
|
| 274 |
-
weights_real = np.ones(m) / m
|
| 275 |
-
M = ot.dist(data, real_data, metric='euclidean')
|
| 276 |
-
distances[key][real_label] = ot.emd2(weights, weights_real, M)
|
| 277 |
-
|
| 278 |
-
# Distancia global por fuente
|
| 279 |
-
for source, group in synthetic_df.groupby('source'):
|
| 280 |
-
key = f"Global ({source})"
|
| 281 |
-
data = group[['x','y']].values
|
| 282 |
-
n = data.shape[0]
|
| 283 |
-
weights = np.ones(n) / n
|
| 284 |
-
distances[key] = {}
|
| 285 |
-
for real_label in real_labels:
|
| 286 |
-
real_data = df_real[df_real['label'] == real_label][['x','y']].values
|
| 287 |
-
m = real_data.shape[0]
|
| 288 |
-
weights_real = np.ones(m) / m
|
| 289 |
-
M = ot.dist(data, real_data, metric='euclidean')
|
| 290 |
-
distances[key][real_label] = ot.emd2(weights, weights_real, M)
|
| 291 |
-
return pd.DataFrame(distances).T
|
| 292 |
-
|
| 293 |
-
def create_table(df_distances):
|
| 294 |
-
df_table = df_distances.copy()
|
| 295 |
-
df_table.reset_index(inplace=True)
|
| 296 |
-
df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
|
| 297 |
-
min_row = {"Synthetic": "Min."}
|
| 298 |
-
mean_row = {"Synthetic": "Mean"}
|
| 299 |
-
max_row = {"Synthetic": "Max."}
|
| 300 |
-
for col in df_table.columns:
|
| 301 |
-
if col != "Synthetic":
|
| 302 |
-
min_row[col] = df_table[col].min()
|
| 303 |
-
mean_row[col] = df_table[col].mean()
|
| 304 |
-
max_row[col] = df_table[col].max()
|
| 305 |
-
df_table = pd.concat([df_table, pd.DataFrame([min_row, mean_row, max_row])], ignore_index=True)
|
| 306 |
-
source_table = ColumnDataSource(df_table)
|
| 307 |
-
columns = [TableColumn(field='Synthetic', title='Synthetic')]
|
| 308 |
-
for col in df_table.columns:
|
| 309 |
-
if col != 'Synthetic':
|
| 310 |
-
columns.append(TableColumn(field=col, title=col))
|
| 311 |
-
total_height = 30 + len(df_table)*28
|
| 312 |
-
data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height)
|
| 313 |
-
return data_table, df_table, source_table
|
| 314 |
-
|
| 315 |
-
def optimize_tsne_params(df_combined, embedding_cols, df_f1):
|
| 316 |
-
# Rangos de b煤squeda (puedes ajustar estos l铆mites y pasos)
|
| 317 |
-
perplexity_range = np.linspace(30, 50, 10)
|
| 318 |
-
learning_rate_range = np.linspace(200, 1000, 20)
|
| 319 |
-
|
| 320 |
-
best_R2 = -np.inf
|
| 321 |
-
best_params = None
|
| 322 |
-
total_steps = len(perplexity_range) * len(learning_rate_range)
|
| 323 |
-
step = 0
|
| 324 |
-
|
| 325 |
-
# Usamos un placeholder de Streamlit para actualizar mensajes de progreso
|
| 326 |
-
progress_text = st.empty()
|
| 327 |
-
|
| 328 |
-
for p in perplexity_range:
|
| 329 |
-
for lr in learning_rate_range:
|
| 330 |
-
step += 1
|
| 331 |
-
# Actualizamos el mensaje de progreso
|
| 332 |
-
progress_text.text(f"Evaluating: Perplexity={p:.2f}, Learning Rate={lr:.2f} (Step: {step}/{total_steps})")
|
| 333 |
-
|
| 334 |
-
# Calcular la reducci贸n con TSNE
|
| 335 |
-
reducer_temp = TSNE(n_components=2, random_state=42, perplexity=p, learning_rate=lr)
|
| 336 |
-
reduced_temp = reducer_temp.fit_transform(df_combined[embedding_cols].values)
|
| 337 |
-
dfs_reduced_temp, unique_subsets_temp = split_versions(df_combined, reduced_temp)
|
| 338 |
-
|
| 339 |
-
# Calcular distancias Wasserstein
|
| 340 |
-
df_distances_temp = compute_wasserstein_distances_synthetic_individual(
|
| 341 |
-
dfs_reduced_temp["synthetic"],
|
| 342 |
-
dfs_reduced_temp["real"],
|
| 343 |
-
unique_subsets_temp["real"]
|
| 344 |
-
)
|
| 345 |
-
# Extraer los valores globales (suponemos 10 por fuente)
|
| 346 |
-
global_distances_temp = {}
|
| 347 |
-
for idx in df_distances_temp.index:
|
| 348 |
-
if idx.startswith("Global"):
|
| 349 |
-
source = idx.split("(")[1].rstrip(")")
|
| 350 |
-
global_distances_temp[source] = df_distances_temp.loc[idx].values
|
| 351 |
-
|
| 352 |
-
# Acumular datos para la regresi贸n global
|
| 353 |
-
all_x_temp = []
|
| 354 |
-
all_y_temp = []
|
| 355 |
-
for source in df_f1.columns:
|
| 356 |
-
if source in global_distances_temp:
|
| 357 |
-
x_vals_temp = global_distances_temp[source]
|
| 358 |
-
y_vals_temp = df_f1[source].values
|
| 359 |
-
all_x_temp.extend(x_vals_temp)
|
| 360 |
-
all_y_temp.extend(y_vals_temp)
|
| 361 |
-
if len(all_x_temp) == 0:
|
| 362 |
-
continue
|
| 363 |
-
all_x_temp_arr = np.array(all_x_temp).reshape(-1, 1)
|
| 364 |
-
all_y_temp_arr = np.array(all_y_temp)
|
| 365 |
-
|
| 366 |
-
model_temp = LinearRegression().fit(all_x_temp_arr, all_y_temp_arr)
|
| 367 |
-
r2_temp = model_temp.score(all_x_temp_arr, all_y_temp_arr)
|
| 368 |
-
|
| 369 |
-
# Mostrar en pantalla (o log) la tupla evaluada y el R虏 obtenido
|
| 370 |
-
st.write(f"Parameters: Perplexity={p:.2f}, Learning Rate={lr:.2f} -> R虏={r2_temp:.4f}")
|
| 371 |
-
|
| 372 |
-
if r2_temp > best_R2:
|
| 373 |
-
best_R2 = r2_temp
|
| 374 |
-
best_params = (p, lr)
|
| 375 |
-
|
| 376 |
-
progress_text.text("Optimization completed!")
|
| 377 |
-
return best_params, best_R2
|
| 378 |
|
| 379 |
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
-
def
|
| 382 |
-
|
| 383 |
-
if embeddings is None:
|
| 384 |
-
return
|
| 385 |
-
|
| 386 |
-
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
|
| 387 |
-
df_combined = pd.concat(list(embeddings.values()), ignore_index=True)
|
| 388 |
-
|
| 389 |
-
# Leer el CSV de f1-donut (usado para evaluar la regresi贸n)
|
| 390 |
-
try:
|
| 391 |
-
df_f1 = pd.read_csv("data/f1-donut.csv", sep=';', index_col=0)
|
| 392 |
-
except Exception as e:
|
| 393 |
-
st.error(f"Error loading f1-donut.csv: {e}")
|
| 394 |
-
return
|
| 395 |
-
|
| 396 |
-
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
|
| 397 |
-
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=f"reduction_{model_name}")
|
| 398 |
-
|
| 399 |
-
# Opci贸n para optimizar los par谩metros TSNE
|
| 400 |
-
if reduction_method == "t-SNE":
|
| 401 |
-
if st.button("Optimize TSNE parameters", key=f"optimize_tnse_{model_name}"):
|
| 402 |
-
st.info("Running optimization, this can take a while...")
|
| 403 |
-
best_params, best_R2 = optimize_tsne_params(df_combined, embedding_cols, df_f1)
|
| 404 |
-
st.success(f"Mejores par谩metros: Perplexity = {best_params[0]:.2f}, Learning Rate = {best_params[1]:.2f} con R虏 = {best_R2:.4f}")
|
| 405 |
-
|
| 406 |
-
# Permitir al usuario ingresar manualmente los valores (o podr铆as reemplazar estos por los optimizados)
|
| 407 |
if reduction_method == "PCA":
|
| 408 |
reducer = PCA(n_components=2)
|
| 409 |
else:
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
|
|
|
|
| 414 |
reduced = reducer.fit_transform(df_combined[embedding_cols].values)
|
| 415 |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
|
| 416 |
|
| 417 |
-
|
| 418 |
-
fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, unique_subsets, color_maps, model_name)
|
| 419 |
-
|
| 420 |
-
centers_real = calculate_cluster_centers(dfs_reduced["real"], unique_subsets["real"])
|
| 421 |
-
|
| 422 |
df_distances = compute_wasserstein_distances_synthetic_individual(
|
| 423 |
dfs_reduced["synthetic"],
|
| 424 |
dfs_reduced["real"],
|
| 425 |
unique_subsets["real"]
|
| 426 |
)
|
| 427 |
|
| 428 |
-
#
|
| 429 |
-
try:
|
| 430 |
-
df_f1 = pd.read_csv("data/f1-donut.csv", sep=';', index_col=0)
|
| 431 |
-
except Exception as e:
|
| 432 |
-
st.error(f"Error loading f1-donut.csv: {e}")
|
| 433 |
-
return
|
| 434 |
-
|
| 435 |
-
# Extraer los valores globales para cada fuente (sin promediar: 10 valores por fuente)
|
| 436 |
global_distances = {}
|
| 437 |
for idx in df_distances.index:
|
| 438 |
if idx.startswith("Global"):
|
| 439 |
-
# Ejemplo: "Global (es-digital-seq)"
|
| 440 |
source = idx.split("(")[1].rstrip(")")
|
| 441 |
global_distances[source] = df_distances.loc[idx].values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
source_colors = {
|
| 445 |
"es-digital-paragraph-degradation-seq": "blue",
|
| 446 |
"es-digital-line-degradation-seq": "green",
|
|
@@ -450,68 +367,146 @@ def run_model(model_name):
|
|
| 450 |
"es-digital-rotation-zoom-degradation-seq": "brown",
|
| 451 |
"es-render-seq": "cyan"
|
| 452 |
}
|
| 453 |
-
|
| 454 |
-
scatter_fig = figure(width=600, height=600, tools="pan,wheel_zoom,reset,save", title="Scatter Plot: Wasserstein vs F1")
|
| 455 |
-
# Variables para la regresi贸n global
|
| 456 |
-
all_x = []
|
| 457 |
-
all_y = []
|
| 458 |
-
|
| 459 |
-
# Se plotea cada fuente y se acumulan los datos para la regresi贸n global
|
| 460 |
for source in df_f1.columns:
|
| 461 |
if source in global_distances:
|
| 462 |
-
x_vals = global_distances[source]
|
| 463 |
-
y_vals = df_f1[source].values
|
| 464 |
-
data = {"x": x_vals, "y": y_vals, "Fuente": [source]
|
| 465 |
cds = ColumnDataSource(data=data)
|
| 466 |
scatter_fig.circle('x', 'y', size=8, alpha=0.7, source=cds,
|
| 467 |
fill_color=source_colors.get(source, "gray"),
|
| 468 |
line_color=source_colors.get(source, "gray"),
|
| 469 |
legend_label=source)
|
| 470 |
-
all_x.extend(x_vals)
|
| 471 |
-
all_y.extend(y_vals)
|
| 472 |
-
|
| 473 |
scatter_fig.xaxis.axis_label = "Wasserstein Distance (Global, por Colegio)"
|
| 474 |
scatter_fig.yaxis.axis_label = "F1 Score"
|
| 475 |
scatter_fig.legend.location = "top_right"
|
| 476 |
-
|
| 477 |
-
# Agregar HoverTool para mostrar x, y y la fuente al hacer hover
|
| 478 |
hover_tool = HoverTool(tooltips=[("Wass. Distance", "@x"), ("f1", "@y"), ("Subset", "@Fuente")])
|
| 479 |
scatter_fig.add_tools(hover_tool)
|
| 480 |
-
# --- Fin scatter plot ---
|
| 481 |
|
| 482 |
-
#
|
| 483 |
-
all_x_arr = np.array(all_x).reshape(-1, 1)
|
| 484 |
-
all_y_arr = np.array(all_y)
|
| 485 |
-
model_global = LinearRegression().fit(all_x_arr, all_y_arr)
|
| 486 |
-
slope = model_global.coef_[0]
|
| 487 |
-
intercept = model_global.intercept_
|
| 488 |
-
r2 = model_global.score(all_x_arr, all_y_arr)
|
| 489 |
-
|
| 490 |
-
# Agregar l铆nea de regresi贸n global al scatter plot
|
| 491 |
x_line = np.linspace(all_x_arr.min(), all_x_arr.max(), 100)
|
| 492 |
y_line = model_global.predict(x_line.reshape(-1, 1))
|
| 493 |
scatter_fig.line(x_line, y_line, line_width=2, line_color="black", legend_label="Global Regression")
|
| 494 |
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
|
| 500 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
|
| 502 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
real_subset_names = list(df_table.columns[1:])
|
| 505 |
real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
|
| 506 |
reset_button = Button(label="Reset Colors", button_type="primary")
|
| 507 |
line_source = ColumnDataSource(data={'x': [], 'y': []})
|
|
|
|
|
|
|
| 508 |
fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
|
| 509 |
-
|
| 510 |
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
|
| 511 |
synthetic_centers = {}
|
| 512 |
-
synth_labels = sorted(dfs_reduced["synthetic"]['label'].unique().tolist())
|
| 513 |
for label in synth_labels:
|
| 514 |
-
subset = dfs_reduced["synthetic"][dfs_reduced["synthetic"]['label'] == label]
|
| 515 |
synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()]
|
| 516 |
|
| 517 |
callback = CustomJS(args=dict(source=source_table, line_source=line_source,
|
|
@@ -548,7 +543,8 @@ def run_model(model_name):
|
|
| 548 |
df_table.to_excel(buffer, index=False)
|
| 549 |
buffer.seek(0)
|
| 550 |
|
| 551 |
-
|
|
|
|
| 552 |
st.bokeh_chart(layout, use_container_width=True)
|
| 553 |
|
| 554 |
st.download_button(
|
|
@@ -559,7 +555,6 @@ def run_model(model_name):
|
|
| 559 |
key=f"download_button_excel_{model_name}"
|
| 560 |
)
|
| 561 |
|
| 562 |
-
|
| 563 |
def main():
|
| 564 |
config_style()
|
| 565 |
tabs = st.tabs(["Donut", "Idefics2"])
|
|
|
|
| 36 |
""", unsafe_allow_html=True)
|
| 37 |
st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
|
| 38 |
|
| 39 |
+
# =============================================================================
|
| 40 |
+
# Funciones de carga de datos, generaci贸n de gr谩ficos y c谩lculo de distancias (sin cambios)
|
| 41 |
+
# =============================================================================
|
| 42 |
+
|
| 43 |
def load_embeddings(model):
|
| 44 |
if model == "Donut":
|
| 45 |
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
|
|
|
|
| 57 |
df_zoom["version"] = "synthetic"
|
| 58 |
df_render["version"] = "synthetic"
|
| 59 |
|
|
|
|
| 60 |
df_par["source"] = "es-digital-paragraph-degradation-seq"
|
| 61 |
df_line["source"] = "es-digital-line-degradation-seq"
|
| 62 |
df_seq["source"] = "es-digital-seq"
|
|
|
|
| 67 |
|
| 68 |
elif model == "Idefics2":
|
| 69 |
df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
|
| 70 |
+
df_par = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_embeddings.csv")
|
| 71 |
+
df_line = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-line-degradation-seq_embeddings.csv")
|
| 72 |
df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
|
| 73 |
+
df_rot = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-rotation-degradation-seq_embeddings.csv")
|
| 74 |
+
df_zoom = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-zoom-degradation-seq_embeddings.csv")
|
| 75 |
+
df_render = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-render-seq_embeddings.csv")
|
| 76 |
df_real["version"] = "real"
|
| 77 |
+
df_par["version"] = "synthetic"
|
| 78 |
+
df_line["version"] = "synthetic"
|
| 79 |
df_seq["version"] = "synthetic"
|
| 80 |
+
df_rot["version"] = "synthetic"
|
| 81 |
+
df_zoom["version"] = "synthetic"
|
| 82 |
+
df_render["version"] = "synthetic"
|
| 83 |
+
|
| 84 |
+
df_par["source"] = "es-digital-paragraph-degradation-seq"
|
| 85 |
+
df_line["source"] = "es-digital-line-degradation-seq"
|
| 86 |
df_seq["source"] = "es-digital-seq"
|
| 87 |
+
df_rot["source"] = "es-digital-rotation-degradation-seq"
|
| 88 |
+
df_zoom["source"] = "es-digital-zoom-degradation-seq"
|
| 89 |
+
df_render["source"] = "es-render-seq"
|
| 90 |
+
return {"real": df_real, "synthetic": pd.concat([df_seq, df_line, df_par, df_rot, df_zoom, df_render], ignore_index=True)}
|
| 91 |
|
| 92 |
else:
|
| 93 |
st.error("Modelo no reconocido")
|
| 94 |
return None
|
| 95 |
|
| 96 |
+
def split_versions(df_combined, reduced):
|
| 97 |
+
df_combined['x'] = reduced[:, 0]
|
| 98 |
+
df_combined['y'] = reduced[:, 1]
|
| 99 |
+
df_real = df_combined[df_combined["version"] == "real"].copy()
|
| 100 |
+
df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
|
| 101 |
+
unique_real = sorted(df_real['label'].unique().tolist())
|
| 102 |
+
unique_synth = {}
|
| 103 |
+
for source in df_synth["source"].unique():
|
| 104 |
+
unique_synth[source] = sorted(df_synth[df_synth["source"] == source]['label'].unique().tolist())
|
| 105 |
+
df_dict = {"real": df_real, "synthetic": df_synth}
|
| 106 |
+
unique_subsets = {"real": unique_real, "synthetic": unique_synth}
|
| 107 |
+
return df_dict, unique_subsets
|
| 108 |
+
|
| 109 |
+
def compute_wasserstein_distances_synthetic_individual(synthetic_df: pd.DataFrame, df_real: pd.DataFrame, real_labels: list) -> pd.DataFrame:
|
| 110 |
+
distances = {}
|
| 111 |
+
groups = synthetic_df.groupby(['source', 'label'])
|
| 112 |
+
for (source, label), group in groups:
|
| 113 |
+
key = f"{label} ({source})"
|
| 114 |
+
data = group[['x', 'y']].values
|
| 115 |
+
n = data.shape[0]
|
| 116 |
+
weights = np.ones(n) / n
|
| 117 |
+
distances[key] = {}
|
| 118 |
+
for real_label in real_labels:
|
| 119 |
+
real_data = df_real[df_real['label'] == real_label][['x','y']].values
|
| 120 |
+
m = real_data.shape[0]
|
| 121 |
+
weights_real = np.ones(m) / m
|
| 122 |
+
M = ot.dist(data, real_data, metric='euclidean')
|
| 123 |
+
distances[key][real_label] = ot.emd2(weights, weights_real, M)
|
| 124 |
+
|
| 125 |
+
for source, group in synthetic_df.groupby('source'):
|
| 126 |
+
key = f"Global ({source})"
|
| 127 |
+
data = group[['x','y']].values
|
| 128 |
+
n = data.shape[0]
|
| 129 |
+
weights = np.ones(n) / n
|
| 130 |
+
distances[key] = {}
|
| 131 |
+
for real_label in real_labels:
|
| 132 |
+
real_data = df_real[df_real['label'] == real_label][['x','y']].values
|
| 133 |
+
m = real_data.shape[0]
|
| 134 |
+
weights_real = np.ones(m) / m
|
| 135 |
+
M = ot.dist(data, real_data, metric='euclidean')
|
| 136 |
+
distances[key][real_label] = ot.emd2(weights, weights_real, M)
|
| 137 |
+
return pd.DataFrame(distances).T
|
| 138 |
+
|
| 139 |
+
def create_table(df_distances):
|
| 140 |
+
df_table = df_distances.copy()
|
| 141 |
+
df_table.reset_index(inplace=True)
|
| 142 |
+
df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
|
| 143 |
+
min_row = {"Synthetic": "Min."}
|
| 144 |
+
mean_row = {"Synthetic": "Mean"}
|
| 145 |
+
max_row = {"Synthetic": "Max."}
|
| 146 |
+
for col in df_table.columns:
|
| 147 |
+
if col != "Synthetic":
|
| 148 |
+
min_row[col] = df_table[col].min()
|
| 149 |
+
mean_row[col] = df_table[col].mean()
|
| 150 |
+
max_row[col] = df_table[col].max()
|
| 151 |
+
df_table = pd.concat([df_table, pd.DataFrame([min_row, mean_row, max_row])], ignore_index=True)
|
| 152 |
+
source_table = ColumnDataSource(df_table)
|
| 153 |
+
columns = [TableColumn(field='Synthetic', title='Synthetic')]
|
| 154 |
+
for col in df_table.columns:
|
| 155 |
+
if col != 'Synthetic':
|
| 156 |
+
columns.append(TableColumn(field=col, title=col))
|
| 157 |
+
total_height = 30 + len(df_table)*28
|
| 158 |
+
data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height)
|
| 159 |
+
return data_table, df_table, source_table
|
| 160 |
+
|
| 161 |
+
def create_figure(dfs, unique_subsets, color_maps, model_name):
|
| 162 |
+
fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="")
|
| 163 |
+
real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
|
| 164 |
+
marker="circle", color_mapping=color_maps["real"],
|
| 165 |
+
group_label="Real")
|
| 166 |
+
marker_mapping = {
|
| 167 |
+
"es-digital-paragraph-degradation-seq": "x",
|
| 168 |
+
"es-digital-line-degradation-seq": "cross",
|
| 169 |
+
"es-digital-seq": "triangle",
|
| 170 |
+
"es-digital-rotation-degradation-seq": "diamond",
|
| 171 |
+
"es-digital-zoom-degradation-seq": "asterisk",
|
| 172 |
+
"es-render-seq": "inverted_triangle"
|
| 173 |
+
}
|
| 174 |
+
synthetic_renderers = {}
|
| 175 |
+
synth_df = dfs["synthetic"]
|
| 176 |
+
for source in unique_subsets["synthetic"]:
|
| 177 |
+
df_source = synth_df[synth_df["source"] == source]
|
| 178 |
+
marker = marker_mapping.get(source, "square")
|
| 179 |
+
renderers = add_synthetic_dataset_to_fig(fig, df_source, unique_subsets["synthetic"][source],
|
| 180 |
+
marker=marker,
|
| 181 |
+
color_mapping=color_maps["synthetic"][source],
|
| 182 |
+
group_label=source)
|
| 183 |
+
synthetic_renderers.update(renderers)
|
| 184 |
+
|
| 185 |
+
fig.legend.location = "top_right"
|
| 186 |
+
fig.legend.click_policy = "hide"
|
| 187 |
+
show_legend = st.checkbox("Show Legend", value=False, key=f"legend_{model_name}")
|
| 188 |
+
fig.legend.visible = show_legend
|
| 189 |
+
return fig, real_renderers, synthetic_renderers
|
| 190 |
|
|
|
|
| 191 |
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label):
|
| 192 |
renderers = {}
|
| 193 |
for label in selected_labels:
|
|
|
|
| 217 |
renderers[label + f" ({group_label})"] = r
|
| 218 |
return renderers
|
| 219 |
|
|
|
|
| 220 |
def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_label):
|
| 221 |
renderers = {}
|
| 222 |
for label in labels:
|
|
|
|
| 229 |
label=subset['label'],
|
| 230 |
img=subset.get('img', "")
|
| 231 |
))
|
|
|
|
| 232 |
color = color_mapping[label]
|
|
|
|
| 233 |
legend_label = group_label
|
|
|
|
| 234 |
if marker == "square":
|
| 235 |
r = fig.square('x', 'y', size=10, source=source_obj,
|
| 236 |
fill_color=color, line_color=color,
|
|
|
|
| 267 |
return renderers
|
| 268 |
|
| 269 |
|
| 270 |
+
|
| 271 |
def get_color_maps(unique_subsets):
|
| 272 |
color_map = {}
|
| 273 |
# Para reales se asigna color para cada etiqueta
|
|
|
|
| 294 |
palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)]
|
| 295 |
color_map["synthetic"][source] = {label: palette[i] for i, label in enumerate(sorted(labels))}
|
| 296 |
return color_map
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
def calculate_cluster_centers(df, labels):
|
| 300 |
centers = {}
|
| 301 |
for label in labels:
|
|
|
|
| 304 |
centers[label] = (subset['x'].mean(), subset['y'].mean())
|
| 305 |
return centers
|
| 306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
|
| 308 |
|
| 309 |
+
# =============================================================================
|
| 310 |
+
# Funci贸n centralizada para la pipeline: reducci贸n, distancias y regresi贸n global
|
| 311 |
+
# =============================================================================
|
| 312 |
|
| 313 |
+
def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE"):
|
| 314 |
+
# Seleccionar el reductor seg煤n el m茅todo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
if reduction_method == "PCA":
|
| 316 |
reducer = PCA(n_components=2)
|
| 317 |
else:
|
| 318 |
+
reducer = TSNE(n_components=2, random_state=42,
|
| 319 |
+
perplexity=tsne_params["perplexity"],
|
| 320 |
+
learning_rate=tsne_params["learning_rate"])
|
| 321 |
|
| 322 |
+
# Aplicar reducci贸n dimensional
|
| 323 |
reduced = reducer.fit_transform(df_combined[embedding_cols].values)
|
| 324 |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
|
| 325 |
|
| 326 |
+
# Calcular distancias Wasserstein
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
df_distances = compute_wasserstein_distances_synthetic_individual(
|
| 328 |
dfs_reduced["synthetic"],
|
| 329 |
dfs_reduced["real"],
|
| 330 |
unique_subsets["real"]
|
| 331 |
)
|
| 332 |
|
| 333 |
+
# Extraer valores globales para cada fuente (se esperan 10 por fuente)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
global_distances = {}
|
| 335 |
for idx in df_distances.index:
|
| 336 |
if idx.startswith("Global"):
|
|
|
|
| 337 |
source = idx.split("(")[1].rstrip(")")
|
| 338 |
global_distances[source] = df_distances.loc[idx].values
|
| 339 |
+
|
| 340 |
+
# Acumular todos los puntos (globales) y sus correspondientes f1 de cada colegio
|
| 341 |
+
all_x = []
|
| 342 |
+
all_y = []
|
| 343 |
+
for source in df_f1.columns:
|
| 344 |
+
if source in global_distances:
|
| 345 |
+
x_vals = global_distances[source]
|
| 346 |
+
y_vals = df_f1[source].values
|
| 347 |
+
all_x.extend(x_vals)
|
| 348 |
+
all_y.extend(y_vals)
|
| 349 |
+
all_x_arr = np.array(all_x).reshape(-1, 1)
|
| 350 |
+
all_y_arr = np.array(all_y)
|
| 351 |
|
| 352 |
+
# Realizar regresi贸n lineal global
|
| 353 |
+
model_global = LinearRegression().fit(all_x_arr, all_y_arr)
|
| 354 |
+
r2 = model_global.score(all_x_arr, all_y_arr)
|
| 355 |
+
slope = model_global.coef_[0]
|
| 356 |
+
intercept = model_global.intercept_
|
| 357 |
+
|
| 358 |
+
# Crear scatter plot para visualizar la relaci贸n
|
| 359 |
+
scatter_fig = figure(width=600, height=600, tools="pan,wheel_zoom,reset,save",
|
| 360 |
+
title="Scatter Plot: Wasserstein vs F1")
|
| 361 |
source_colors = {
|
| 362 |
"es-digital-paragraph-degradation-seq": "blue",
|
| 363 |
"es-digital-line-degradation-seq": "green",
|
|
|
|
| 367 |
"es-digital-rotation-zoom-degradation-seq": "brown",
|
| 368 |
"es-render-seq": "cyan"
|
| 369 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
for source in df_f1.columns:
|
| 371 |
if source in global_distances:
|
| 372 |
+
x_vals = global_distances[source]
|
| 373 |
+
y_vals = df_f1[source].values
|
| 374 |
+
data = {"x": x_vals, "y": y_vals, "Fuente": [source]*len(x_vals)}
|
| 375 |
cds = ColumnDataSource(data=data)
|
| 376 |
scatter_fig.circle('x', 'y', size=8, alpha=0.7, source=cds,
|
| 377 |
fill_color=source_colors.get(source, "gray"),
|
| 378 |
line_color=source_colors.get(source, "gray"),
|
| 379 |
legend_label=source)
|
|
|
|
|
|
|
|
|
|
| 380 |
scatter_fig.xaxis.axis_label = "Wasserstein Distance (Global, por Colegio)"
|
| 381 |
scatter_fig.yaxis.axis_label = "F1 Score"
|
| 382 |
scatter_fig.legend.location = "top_right"
|
|
|
|
|
|
|
| 383 |
hover_tool = HoverTool(tooltips=[("Wass. Distance", "@x"), ("f1", "@y"), ("Subset", "@Fuente")])
|
| 384 |
scatter_fig.add_tools(hover_tool)
|
|
|
|
| 385 |
|
| 386 |
+
# L铆nea de regresi贸n global
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
x_line = np.linspace(all_x_arr.min(), all_x_arr.max(), 100)
|
| 388 |
y_line = model_global.predict(x_line.reshape(-1, 1))
|
| 389 |
scatter_fig.line(x_line, y_line, line_width=2, line_color="black", legend_label="Global Regression")
|
| 390 |
|
| 391 |
+
return {
|
| 392 |
+
"R2": r2,
|
| 393 |
+
"slope": slope,
|
| 394 |
+
"intercept": intercept,
|
| 395 |
+
"scatter_fig": scatter_fig,
|
| 396 |
+
"dfs_reduced": dfs_reduced,
|
| 397 |
+
"unique_subsets": unique_subsets,
|
| 398 |
+
"df_distances": df_distances
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
# =============================================================================
|
| 402 |
+
# Funci贸n de optimizaci贸n (grid search) para TSNE, ahora que se usa la misma pipeline
|
| 403 |
+
# =============================================================================
|
| 404 |
+
|
| 405 |
+
def optimize_tsne_params(df_combined, embedding_cols, df_f1):
|
| 406 |
+
# Rango de b煤squeda
|
| 407 |
+
perplexity_range = np.linspace(30, 50, 10)
|
| 408 |
+
learning_rate_range = np.linspace(200, 1000, 20)
|
| 409 |
+
|
| 410 |
+
best_R2 = -np.inf
|
| 411 |
+
best_params = None
|
| 412 |
+
total_steps = len(perplexity_range) * len(learning_rate_range)
|
| 413 |
+
step = 0
|
| 414 |
+
|
| 415 |
+
progress_text = st.empty()
|
| 416 |
|
| 417 |
+
for p in perplexity_range:
|
| 418 |
+
for lr in learning_rate_range:
|
| 419 |
+
step += 1
|
| 420 |
+
progress_text.text(f"Evaluating: Perplexity={p:.2f}, Learning Rate={lr:.2f} (Step {step}/{total_steps})")
|
| 421 |
+
|
| 422 |
+
tsne_params = {"perplexity": p, "learning_rate": lr}
|
| 423 |
+
result = compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE")
|
| 424 |
+
r2_temp = result["R2"]
|
| 425 |
+
st.write(f"Parameters: Perplexity={p:.2f}, Learning Rate={lr:.2f} -> R虏={r2_temp:.4f}")
|
| 426 |
+
|
| 427 |
+
if r2_temp > best_R2:
|
| 428 |
+
best_R2 = r2_temp
|
| 429 |
+
best_params = (p, lr)
|
| 430 |
|
| 431 |
+
progress_text.text("Optimization completed!")
|
| 432 |
+
return best_params, best_R2
|
| 433 |
+
|
| 434 |
+
# =============================================================================
|
| 435 |
+
# Funci贸n principal run_model que integra la optimizaci贸n y la ejecuci贸n manual
|
| 436 |
+
# =============================================================================
|
| 437 |
+
|
| 438 |
+
def run_model(model_name):
|
| 439 |
+
embeddings = load_embeddings(model_name)
|
| 440 |
+
if embeddings is None:
|
| 441 |
+
return
|
| 442 |
+
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
|
| 443 |
+
df_combined = pd.concat(list(embeddings.values()), ignore_index=True)
|
| 444 |
|
| 445 |
+
# Cargar CSV f1-donut
|
| 446 |
+
try:
|
| 447 |
+
df_f1 = pd.read_csv("data/f1-donut.csv", sep=';', index_col=0)
|
| 448 |
+
except Exception as e:
|
| 449 |
+
st.error(f"Error loading f1-donut.csv: {e}")
|
| 450 |
+
return
|
| 451 |
+
|
| 452 |
+
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
|
| 453 |
+
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=f"reduction_{model_name}")
|
| 454 |
+
|
| 455 |
+
tsne_params = {}
|
| 456 |
+
if reduction_method == "t-SNE":
|
| 457 |
+
if st.button("Optimize TSNE parameters", key=f"optimize_tsne_{model_name}"):
|
| 458 |
+
st.info("Running optimization, this can take a while...")
|
| 459 |
+
best_params, best_R2 = optimize_tsne_params(df_combined, embedding_cols, df_f1)
|
| 460 |
+
st.success(f"Mejores par谩metros: Perplexity = {best_params[0]:.2f}, Learning Rate = {best_params[1]:.2f} con R虏 = {best_R2:.4f}")
|
| 461 |
+
tsne_params = {"perplexity": best_params[0], "learning_rate": best_params[1]}
|
| 462 |
+
else:
|
| 463 |
+
perplexity_val = st.number_input(
|
| 464 |
+
"Perplexity",
|
| 465 |
+
min_value=5.0,
|
| 466 |
+
max_value=50.0,
|
| 467 |
+
value=30.0,
|
| 468 |
+
step=1.0,
|
| 469 |
+
format="%.2f",
|
| 470 |
+
key=f"perplexity_{model_name}"
|
| 471 |
+
)
|
| 472 |
+
learning_rate_val = st.number_input(
|
| 473 |
+
"Learning Rate",
|
| 474 |
+
min_value=10.0,
|
| 475 |
+
max_value=1000.0,
|
| 476 |
+
value=200.0,
|
| 477 |
+
step=10.0,
|
| 478 |
+
format="%.2f",
|
| 479 |
+
key=f"learning_rate_{model_name}"
|
| 480 |
+
)
|
| 481 |
+
tsne_params = {"perplexity": perplexity_val, "learning_rate": learning_rate_val}
|
| 482 |
+
# Si se selecciona PCA, tsne_params no se usa.
|
| 483 |
+
|
| 484 |
+
# Usar la funci贸n centralizada para obtener la regresi贸n global y el scatter plot
|
| 485 |
+
result = compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method=reduction_method)
|
| 486 |
+
|
| 487 |
+
reg_metrics = pd.DataFrame({
|
| 488 |
+
"Slope": [result["slope"]],
|
| 489 |
+
"Intercept": [result["intercept"]],
|
| 490 |
+
"R2": [result["R2"]]
|
| 491 |
+
})
|
| 492 |
+
st.table(reg_metrics)
|
| 493 |
+
|
| 494 |
+
# No llamamos a st.bokeh_chart(result["scatter_fig"], ...) aqu铆
|
| 495 |
+
# Sino que combinamos todo en un 煤nico layout:
|
| 496 |
+
data_table, df_table, source_table = create_table(result["df_distances"])
|
| 497 |
real_subset_names = list(df_table.columns[1:])
|
| 498 |
real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
|
| 499 |
reset_button = Button(label="Reset Colors", button_type="primary")
|
| 500 |
line_source = ColumnDataSource(data={'x': [], 'y': []})
|
| 501 |
+
# Suponiendo que tienes una figura base 'fig' para los clusters:
|
| 502 |
+
fig, real_renderers, synthetic_renderers = create_figure(result["dfs_reduced"], result["unique_subsets"], get_color_maps(result["unique_subsets"]), model_name)
|
| 503 |
fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
|
| 504 |
+
centers_real = calculate_cluster_centers(result["dfs_reduced"]["real"], result["unique_subsets"]["real"])
|
| 505 |
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
|
| 506 |
synthetic_centers = {}
|
| 507 |
+
synth_labels = sorted(result["dfs_reduced"]["synthetic"]['label'].unique().tolist())
|
| 508 |
for label in synth_labels:
|
| 509 |
+
subset = result["dfs_reduced"]["synthetic"][result["dfs_reduced"]["synthetic"]['label'] == label]
|
| 510 |
synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()]
|
| 511 |
|
| 512 |
callback = CustomJS(args=dict(source=source_table, line_source=line_source,
|
|
|
|
| 543 |
df_table.to_excel(buffer, index=False)
|
| 544 |
buffer.seek(0)
|
| 545 |
|
| 546 |
+
# Combinar todos los gr谩ficos en un 煤nico layout:
|
| 547 |
+
layout = column(fig, result["scatter_fig"], column(real_select, reset_button, data_table))
|
| 548 |
st.bokeh_chart(layout, use_container_width=True)
|
| 549 |
|
| 550 |
st.download_button(
|
|
|
|
| 555 |
key=f"download_button_excel_{model_name}"
|
| 556 |
)
|
| 557 |
|
|
|
|
| 558 |
def main():
|
| 559 |
config_style()
|
| 560 |
tabs = st.tabs(["Donut", "Idefics2"])
|