de-Rodrigo commited on
Commit
a448d0f
1 Parent(s): fe51656

Input Floats to TSNE and Refactor

Browse files
Files changed (1) hide show
  1. app.py +263 -268
app.py CHANGED
@@ -36,7 +36,10 @@ def config_style():
36
  """, unsafe_allow_html=True)
37
  st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
38
 
39
- # Carga los datos y asigna versiones de forma uniforme
 
 
 
40
  def load_embeddings(model):
41
  if model == "Donut":
42
  df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
@@ -54,7 +57,6 @@ def load_embeddings(model):
54
  df_zoom["version"] = "synthetic"
55
  df_render["version"] = "synthetic"
56
 
57
- # Se asigna la fuente
58
  df_par["source"] = "es-digital-paragraph-degradation-seq"
59
  df_line["source"] = "es-digital-line-degradation-seq"
60
  df_seq["source"] = "es-digital-seq"
@@ -65,29 +67,127 @@ def load_embeddings(model):
65
 
66
  elif model == "Idefics2":
67
  df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
 
 
68
  df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
 
 
 
69
  df_real["version"] = "real"
 
 
70
  df_seq["version"] = "synthetic"
 
 
 
 
 
 
71
  df_seq["source"] = "es-digital-seq"
72
- return {"real": df_real, "synthetic": df_seq}
 
 
 
73
 
74
  else:
75
  st.error("Modelo no reconocido")
76
  return None
77
 
78
- # Selecci贸n de reducci贸n dimensional
79
- def reducer_selector(df_combined, embedding_cols):
80
- reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"])
81
- all_embeddings = df_combined[embedding_cols].values
82
- if reduction_method == "PCA":
83
- reducer = PCA(n_components=2)
84
- else:
85
- perplexity_val = st.number_input("Perplexity", min_value=5, max_value=50, value=30, step=1)
86
- learning_rate_val = st.number_input("Learning Rate", min_value=10, max_value=1000, value=200, step=10)
87
- reducer = TSNE(n_components=2, random_state=42, perplexity=perplexity_val, learning_rate=learning_rate_val)
88
- return reducer.fit_transform(all_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # Funci贸n para agregar datos reales (por cada etiqueta)
91
  def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label):
92
  renderers = {}
93
  for label in selected_labels:
@@ -117,7 +217,6 @@ def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_la
117
  renderers[label + f" ({group_label})"] = r
118
  return renderers
119
 
120
- # Nueva funci贸n para plotear sint茅ticos de forma granular pero con leyenda agrupada por source
121
  def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_label):
122
  renderers = {}
123
  for label in labels:
@@ -130,11 +229,8 @@ def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_l
130
  label=subset['label'],
131
  img=subset.get('img', "")
132
  ))
133
- # Se usa el color granular asignado a cada etiqueta
134
  color = color_mapping[label]
135
- # La leyenda se asigna al nombre del source para que se agrupe
136
  legend_label = group_label
137
-
138
  if marker == "square":
139
  r = fig.square('x', 'y', size=10, source=source_obj,
140
  fill_color=color, line_color=color,
@@ -171,6 +267,7 @@ def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_l
171
  return renderers
172
 
173
 
 
174
  def get_color_maps(unique_subsets):
175
  color_map = {}
176
  # Para reales se asigna color para cada etiqueta
@@ -197,59 +294,8 @@ def get_color_maps(unique_subsets):
197
  palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)]
198
  color_map["synthetic"][source] = {label: palette[i] for i, label in enumerate(sorted(labels))}
199
  return color_map
200
-
201
- def split_versions(df_combined, reduced):
202
- df_combined['x'] = reduced[:, 0]
203
- df_combined['y'] = reduced[:, 1]
204
- df_real = df_combined[df_combined["version"] == "real"].copy()
205
- df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
206
- # Extraer etiquetas 煤nicas para reales
207
- unique_real = sorted(df_real['label'].unique().tolist())
208
- # Para sint茅ticos, se agrupan las etiquetas por source
209
- unique_synth = {}
210
- for source in df_synth["source"].unique():
211
- unique_synth[source] = sorted(df_synth[df_synth["source"] == source]['label'].unique().tolist())
212
- df_dict = {"real": df_real, "synthetic": df_synth}
213
- # Para los reales se guarda la lista, y para sint茅ticos el diccionario
214
- unique_subsets = {"real": unique_real, "synthetic": unique_synth}
215
- return df_dict, unique_subsets
216
-
217
- def create_figure(dfs, unique_subsets, color_maps, model_name):
218
- fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="")
219
- # Datos reales: se mantienen granulares en plot y en leyenda
220
- real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
221
- marker="circle", color_mapping=color_maps["real"],
222
- group_label="Real")
223
- # Diccionario de asignaci贸n de marcadores para sint茅ticos por source
224
- marker_mapping = {
225
- "es-digital-paragraph-degradation-seq": "x",
226
- "es-digital-line-degradation-seq": "cross",
227
- "es-digital-seq": "triangle",
228
- "es-digital-rotation-degradation-seq": "diamond",
229
- "es-digital-zoom-degradation-seq": "asterisk",
230
- "es-render-seq": "inverted_triangle"
231
- }
232
-
233
- # Datos sint茅ticos: se plotean granularmente (por etiqueta) pero se agrupa la leyenda por source
234
- synthetic_renderers = {}
235
- synth_df = dfs["synthetic"]
236
- for source in unique_subsets["synthetic"]:
237
- df_source = synth_df[synth_df["source"] == source]
238
- marker = marker_mapping.get(source, "square") # Por defecto "square" si no se encuentra
239
- renderers = add_synthetic_dataset_to_fig(fig, df_source, unique_subsets["synthetic"][source],
240
- marker=marker,
241
- color_mapping=color_maps["synthetic"][source],
242
- group_label=source)
243
- synthetic_renderers.update(renderers)
244
 
245
- fig.legend.location = "top_right"
246
- fig.legend.click_policy = "hide"
247
- show_legend = st.checkbox("Show Legend", value=False, key=f"legend_{model_name}")
248
- fig.legend.visible = show_legend
249
- return fig, real_renderers, synthetic_renderers
250
-
251
-
252
- # Calcula los centros de cada cluster (por grupo)
253
  def calculate_cluster_centers(df, labels):
254
  centers = {}
255
  for label in labels:
@@ -258,189 +304,60 @@ def calculate_cluster_centers(df, labels):
258
  centers[label] = (subset['x'].mean(), subset['y'].mean())
259
  return centers
260
 
261
- # Calcula la distancia Wasserstein de cada subset sint茅tico respecto a cada cluster real (por cluster y global)
262
- def compute_wasserstein_distances_synthetic_individual(synthetic_df: pd.DataFrame, df_real: pd.DataFrame, real_labels: list) -> pd.DataFrame:
263
- distances = {}
264
- groups = synthetic_df.groupby(['source', 'label'])
265
- for (source, label), group in groups:
266
- key = f"{label} ({source})"
267
- data = group[['x', 'y']].values
268
- n = data.shape[0]
269
- weights = np.ones(n) / n
270
- distances[key] = {}
271
- for real_label in real_labels:
272
- real_data = df_real[df_real['label'] == real_label][['x','y']].values
273
- m = real_data.shape[0]
274
- weights_real = np.ones(m) / m
275
- M = ot.dist(data, real_data, metric='euclidean')
276
- distances[key][real_label] = ot.emd2(weights, weights_real, M)
277
-
278
- # Distancia global por fuente
279
- for source, group in synthetic_df.groupby('source'):
280
- key = f"Global ({source})"
281
- data = group[['x','y']].values
282
- n = data.shape[0]
283
- weights = np.ones(n) / n
284
- distances[key] = {}
285
- for real_label in real_labels:
286
- real_data = df_real[df_real['label'] == real_label][['x','y']].values
287
- m = real_data.shape[0]
288
- weights_real = np.ones(m) / m
289
- M = ot.dist(data, real_data, metric='euclidean')
290
- distances[key][real_label] = ot.emd2(weights, weights_real, M)
291
- return pd.DataFrame(distances).T
292
-
293
- def create_table(df_distances):
294
- df_table = df_distances.copy()
295
- df_table.reset_index(inplace=True)
296
- df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
297
- min_row = {"Synthetic": "Min."}
298
- mean_row = {"Synthetic": "Mean"}
299
- max_row = {"Synthetic": "Max."}
300
- for col in df_table.columns:
301
- if col != "Synthetic":
302
- min_row[col] = df_table[col].min()
303
- mean_row[col] = df_table[col].mean()
304
- max_row[col] = df_table[col].max()
305
- df_table = pd.concat([df_table, pd.DataFrame([min_row, mean_row, max_row])], ignore_index=True)
306
- source_table = ColumnDataSource(df_table)
307
- columns = [TableColumn(field='Synthetic', title='Synthetic')]
308
- for col in df_table.columns:
309
- if col != 'Synthetic':
310
- columns.append(TableColumn(field=col, title=col))
311
- total_height = 30 + len(df_table)*28
312
- data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height)
313
- return data_table, df_table, source_table
314
-
315
- def optimize_tsne_params(df_combined, embedding_cols, df_f1):
316
- # Rangos de b煤squeda (puedes ajustar estos l铆mites y pasos)
317
- perplexity_range = np.linspace(30, 50, 10)
318
- learning_rate_range = np.linspace(200, 1000, 20)
319
-
320
- best_R2 = -np.inf
321
- best_params = None
322
- total_steps = len(perplexity_range) * len(learning_rate_range)
323
- step = 0
324
-
325
- # Usamos un placeholder de Streamlit para actualizar mensajes de progreso
326
- progress_text = st.empty()
327
-
328
- for p in perplexity_range:
329
- for lr in learning_rate_range:
330
- step += 1
331
- # Actualizamos el mensaje de progreso
332
- progress_text.text(f"Evaluating: Perplexity={p:.2f}, Learning Rate={lr:.2f} (Step: {step}/{total_steps})")
333
-
334
- # Calcular la reducci贸n con TSNE
335
- reducer_temp = TSNE(n_components=2, random_state=42, perplexity=p, learning_rate=lr)
336
- reduced_temp = reducer_temp.fit_transform(df_combined[embedding_cols].values)
337
- dfs_reduced_temp, unique_subsets_temp = split_versions(df_combined, reduced_temp)
338
-
339
- # Calcular distancias Wasserstein
340
- df_distances_temp = compute_wasserstein_distances_synthetic_individual(
341
- dfs_reduced_temp["synthetic"],
342
- dfs_reduced_temp["real"],
343
- unique_subsets_temp["real"]
344
- )
345
- # Extraer los valores globales (suponemos 10 por fuente)
346
- global_distances_temp = {}
347
- for idx in df_distances_temp.index:
348
- if idx.startswith("Global"):
349
- source = idx.split("(")[1].rstrip(")")
350
- global_distances_temp[source] = df_distances_temp.loc[idx].values
351
-
352
- # Acumular datos para la regresi贸n global
353
- all_x_temp = []
354
- all_y_temp = []
355
- for source in df_f1.columns:
356
- if source in global_distances_temp:
357
- x_vals_temp = global_distances_temp[source]
358
- y_vals_temp = df_f1[source].values
359
- all_x_temp.extend(x_vals_temp)
360
- all_y_temp.extend(y_vals_temp)
361
- if len(all_x_temp) == 0:
362
- continue
363
- all_x_temp_arr = np.array(all_x_temp).reshape(-1, 1)
364
- all_y_temp_arr = np.array(all_y_temp)
365
-
366
- model_temp = LinearRegression().fit(all_x_temp_arr, all_y_temp_arr)
367
- r2_temp = model_temp.score(all_x_temp_arr, all_y_temp_arr)
368
-
369
- # Mostrar en pantalla (o log) la tupla evaluada y el R虏 obtenido
370
- st.write(f"Parameters: Perplexity={p:.2f}, Learning Rate={lr:.2f} -> R虏={r2_temp:.4f}")
371
-
372
- if r2_temp > best_R2:
373
- best_R2 = r2_temp
374
- best_params = (p, lr)
375
-
376
- progress_text.text("Optimization completed!")
377
- return best_params, best_R2
378
 
379
 
 
 
 
380
 
381
- def run_model(model_name):
382
- embeddings = load_embeddings(model_name)
383
- if embeddings is None:
384
- return
385
-
386
- embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
387
- df_combined = pd.concat(list(embeddings.values()), ignore_index=True)
388
-
389
- # Leer el CSV de f1-donut (usado para evaluar la regresi贸n)
390
- try:
391
- df_f1 = pd.read_csv("data/f1-donut.csv", sep=';', index_col=0)
392
- except Exception as e:
393
- st.error(f"Error loading f1-donut.csv: {e}")
394
- return
395
-
396
- st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
397
- reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=f"reduction_{model_name}")
398
-
399
- # Opci贸n para optimizar los par谩metros TSNE
400
- if reduction_method == "t-SNE":
401
- if st.button("Optimize TSNE parameters", key=f"optimize_tnse_{model_name}"):
402
- st.info("Running optimization, this can take a while...")
403
- best_params, best_R2 = optimize_tsne_params(df_combined, embedding_cols, df_f1)
404
- st.success(f"Mejores par谩metros: Perplexity = {best_params[0]:.2f}, Learning Rate = {best_params[1]:.2f} con R虏 = {best_R2:.4f}")
405
-
406
- # Permitir al usuario ingresar manualmente los valores (o podr铆as reemplazar estos por los optimizados)
407
  if reduction_method == "PCA":
408
  reducer = PCA(n_components=2)
409
  else:
410
- perplexity_val = st.number_input("Perplexity", min_value=5, max_value=50, value=30, step=1, key=f"perplexity_{model_name}")
411
- learning_rate_val = st.number_input("Learning Rate", min_value=10, max_value=1000, value=200, step=10, key=f"learning_rate_{model_name}")
412
- reducer = TSNE(n_components=2, random_state=42, perplexity=perplexity_val, learning_rate=learning_rate_val)
413
 
 
414
  reduced = reducer.fit_transform(df_combined[embedding_cols].values)
415
  dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
416
 
417
- color_maps = get_color_maps(unique_subsets)
418
- fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, unique_subsets, color_maps, model_name)
419
-
420
- centers_real = calculate_cluster_centers(dfs_reduced["real"], unique_subsets["real"])
421
-
422
  df_distances = compute_wasserstein_distances_synthetic_individual(
423
  dfs_reduced["synthetic"],
424
  dfs_reduced["real"],
425
  unique_subsets["real"]
426
  )
427
 
428
- # --- Scatter plot usando f1-donut.csv ---
429
- try:
430
- df_f1 = pd.read_csv("data/f1-donut.csv", sep=';', index_col=0)
431
- except Exception as e:
432
- st.error(f"Error loading f1-donut.csv: {e}")
433
- return
434
-
435
- # Extraer los valores globales para cada fuente (sin promediar: 10 valores por fuente)
436
  global_distances = {}
437
  for idx in df_distances.index:
438
  if idx.startswith("Global"):
439
- # Ejemplo: "Global (es-digital-seq)"
440
  source = idx.split("(")[1].rstrip(")")
441
  global_distances[source] = df_distances.loc[idx].values
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
- # Reutilizaci贸n de los c贸digos de colores
 
 
 
 
 
 
 
 
444
  source_colors = {
445
  "es-digital-paragraph-degradation-seq": "blue",
446
  "es-digital-line-degradation-seq": "green",
@@ -450,68 +367,146 @@ def run_model(model_name):
450
  "es-digital-rotation-zoom-degradation-seq": "brown",
451
  "es-render-seq": "cyan"
452
  }
453
-
454
- scatter_fig = figure(width=600, height=600, tools="pan,wheel_zoom,reset,save", title="Scatter Plot: Wasserstein vs F1")
455
- # Variables para la regresi贸n global
456
- all_x = []
457
- all_y = []
458
-
459
- # Se plotea cada fuente y se acumulan los datos para la regresi贸n global
460
  for source in df_f1.columns:
461
  if source in global_distances:
462
- x_vals = global_distances[source] # 10 valores (uno por colegio)
463
- y_vals = df_f1[source].values # 10 valores de f1, en el mismo orden
464
- data = {"x": x_vals, "y": y_vals, "Fuente": [source] * len(x_vals)}
465
  cds = ColumnDataSource(data=data)
466
  scatter_fig.circle('x', 'y', size=8, alpha=0.7, source=cds,
467
  fill_color=source_colors.get(source, "gray"),
468
  line_color=source_colors.get(source, "gray"),
469
  legend_label=source)
470
- all_x.extend(x_vals)
471
- all_y.extend(y_vals)
472
-
473
  scatter_fig.xaxis.axis_label = "Wasserstein Distance (Global, por Colegio)"
474
  scatter_fig.yaxis.axis_label = "F1 Score"
475
  scatter_fig.legend.location = "top_right"
476
-
477
- # Agregar HoverTool para mostrar x, y y la fuente al hacer hover
478
  hover_tool = HoverTool(tooltips=[("Wass. Distance", "@x"), ("f1", "@y"), ("Subset", "@Fuente")])
479
  scatter_fig.add_tools(hover_tool)
480
- # --- Fin scatter plot ---
481
 
482
- # --- Regresi贸n global ---
483
- all_x_arr = np.array(all_x).reshape(-1, 1)
484
- all_y_arr = np.array(all_y)
485
- model_global = LinearRegression().fit(all_x_arr, all_y_arr)
486
- slope = model_global.coef_[0]
487
- intercept = model_global.intercept_
488
- r2 = model_global.score(all_x_arr, all_y_arr)
489
-
490
- # Agregar l铆nea de regresi贸n global al scatter plot
491
  x_line = np.linspace(all_x_arr.min(), all_x_arr.max(), 100)
492
  y_line = model_global.predict(x_line.reshape(-1, 1))
493
  scatter_fig.line(x_line, y_line, line_width=2, line_color="black", legend_label="Global Regression")
494
 
495
- # Mostrar m茅tricas de regresi贸n despu茅s del scatter plot
496
- regression_metrics = {"Slope": [slope], "Intercept": [intercept], "R2": [r2]}
497
- reg_df = pd.DataFrame(regression_metrics)
498
- st.table(reg_df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
 
500
- # --- Fin regresi贸n global ---
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
- data_table, df_table, source_table = create_table(df_distances)
 
 
 
 
 
 
 
 
 
 
 
 
503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  real_subset_names = list(df_table.columns[1:])
505
  real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
506
  reset_button = Button(label="Reset Colors", button_type="primary")
507
  line_source = ColumnDataSource(data={'x': [], 'y': []})
 
 
508
  fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
509
-
510
  real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
511
  synthetic_centers = {}
512
- synth_labels = sorted(dfs_reduced["synthetic"]['label'].unique().tolist())
513
  for label in synth_labels:
514
- subset = dfs_reduced["synthetic"][dfs_reduced["synthetic"]['label'] == label]
515
  synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()]
516
 
517
  callback = CustomJS(args=dict(source=source_table, line_source=line_source,
@@ -548,7 +543,8 @@ def run_model(model_name):
548
  df_table.to_excel(buffer, index=False)
549
  buffer.seek(0)
550
 
551
- layout = column(fig, scatter_fig, column(real_select, reset_button, data_table))
 
552
  st.bokeh_chart(layout, use_container_width=True)
553
 
554
  st.download_button(
@@ -559,7 +555,6 @@ def run_model(model_name):
559
  key=f"download_button_excel_{model_name}"
560
  )
561
 
562
-
563
  def main():
564
  config_style()
565
  tabs = st.tabs(["Donut", "Idefics2"])
 
36
  """, unsafe_allow_html=True)
37
  st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
38
 
39
+ # =============================================================================
40
+ # Funciones de carga de datos, generaci贸n de gr谩ficos y c谩lculo de distancias (sin cambios)
41
+ # =============================================================================
42
+
43
  def load_embeddings(model):
44
  if model == "Donut":
45
  df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
 
57
  df_zoom["version"] = "synthetic"
58
  df_render["version"] = "synthetic"
59
 
 
60
  df_par["source"] = "es-digital-paragraph-degradation-seq"
61
  df_line["source"] = "es-digital-line-degradation-seq"
62
  df_seq["source"] = "es-digital-seq"
 
67
 
68
  elif model == "Idefics2":
69
  df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
70
+ df_par = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_embeddings.csv")
71
+ df_line = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-line-degradation-seq_embeddings.csv")
72
  df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
73
+ df_rot = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-rotation-degradation-seq_embeddings.csv")
74
+ df_zoom = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-zoom-degradation-seq_embeddings.csv")
75
+ df_render = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-render-seq_embeddings.csv")
76
  df_real["version"] = "real"
77
+ df_par["version"] = "synthetic"
78
+ df_line["version"] = "synthetic"
79
  df_seq["version"] = "synthetic"
80
+ df_rot["version"] = "synthetic"
81
+ df_zoom["version"] = "synthetic"
82
+ df_render["version"] = "synthetic"
83
+
84
+ df_par["source"] = "es-digital-paragraph-degradation-seq"
85
+ df_line["source"] = "es-digital-line-degradation-seq"
86
  df_seq["source"] = "es-digital-seq"
87
+ df_rot["source"] = "es-digital-rotation-degradation-seq"
88
+ df_zoom["source"] = "es-digital-zoom-degradation-seq"
89
+ df_render["source"] = "es-render-seq"
90
+ return {"real": df_real, "synthetic": pd.concat([df_seq, df_line, df_par, df_rot, df_zoom, df_render], ignore_index=True)}
91
 
92
  else:
93
  st.error("Modelo no reconocido")
94
  return None
95
 
96
+ def split_versions(df_combined, reduced):
97
+ df_combined['x'] = reduced[:, 0]
98
+ df_combined['y'] = reduced[:, 1]
99
+ df_real = df_combined[df_combined["version"] == "real"].copy()
100
+ df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
101
+ unique_real = sorted(df_real['label'].unique().tolist())
102
+ unique_synth = {}
103
+ for source in df_synth["source"].unique():
104
+ unique_synth[source] = sorted(df_synth[df_synth["source"] == source]['label'].unique().tolist())
105
+ df_dict = {"real": df_real, "synthetic": df_synth}
106
+ unique_subsets = {"real": unique_real, "synthetic": unique_synth}
107
+ return df_dict, unique_subsets
108
+
109
+ def compute_wasserstein_distances_synthetic_individual(synthetic_df: pd.DataFrame, df_real: pd.DataFrame, real_labels: list) -> pd.DataFrame:
110
+ distances = {}
111
+ groups = synthetic_df.groupby(['source', 'label'])
112
+ for (source, label), group in groups:
113
+ key = f"{label} ({source})"
114
+ data = group[['x', 'y']].values
115
+ n = data.shape[0]
116
+ weights = np.ones(n) / n
117
+ distances[key] = {}
118
+ for real_label in real_labels:
119
+ real_data = df_real[df_real['label'] == real_label][['x','y']].values
120
+ m = real_data.shape[0]
121
+ weights_real = np.ones(m) / m
122
+ M = ot.dist(data, real_data, metric='euclidean')
123
+ distances[key][real_label] = ot.emd2(weights, weights_real, M)
124
+
125
+ for source, group in synthetic_df.groupby('source'):
126
+ key = f"Global ({source})"
127
+ data = group[['x','y']].values
128
+ n = data.shape[0]
129
+ weights = np.ones(n) / n
130
+ distances[key] = {}
131
+ for real_label in real_labels:
132
+ real_data = df_real[df_real['label'] == real_label][['x','y']].values
133
+ m = real_data.shape[0]
134
+ weights_real = np.ones(m) / m
135
+ M = ot.dist(data, real_data, metric='euclidean')
136
+ distances[key][real_label] = ot.emd2(weights, weights_real, M)
137
+ return pd.DataFrame(distances).T
138
+
139
+ def create_table(df_distances):
140
+ df_table = df_distances.copy()
141
+ df_table.reset_index(inplace=True)
142
+ df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
143
+ min_row = {"Synthetic": "Min."}
144
+ mean_row = {"Synthetic": "Mean"}
145
+ max_row = {"Synthetic": "Max."}
146
+ for col in df_table.columns:
147
+ if col != "Synthetic":
148
+ min_row[col] = df_table[col].min()
149
+ mean_row[col] = df_table[col].mean()
150
+ max_row[col] = df_table[col].max()
151
+ df_table = pd.concat([df_table, pd.DataFrame([min_row, mean_row, max_row])], ignore_index=True)
152
+ source_table = ColumnDataSource(df_table)
153
+ columns = [TableColumn(field='Synthetic', title='Synthetic')]
154
+ for col in df_table.columns:
155
+ if col != 'Synthetic':
156
+ columns.append(TableColumn(field=col, title=col))
157
+ total_height = 30 + len(df_table)*28
158
+ data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height)
159
+ return data_table, df_table, source_table
160
+
161
+ def create_figure(dfs, unique_subsets, color_maps, model_name):
162
+ fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="")
163
+ real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
164
+ marker="circle", color_mapping=color_maps["real"],
165
+ group_label="Real")
166
+ marker_mapping = {
167
+ "es-digital-paragraph-degradation-seq": "x",
168
+ "es-digital-line-degradation-seq": "cross",
169
+ "es-digital-seq": "triangle",
170
+ "es-digital-rotation-degradation-seq": "diamond",
171
+ "es-digital-zoom-degradation-seq": "asterisk",
172
+ "es-render-seq": "inverted_triangle"
173
+ }
174
+ synthetic_renderers = {}
175
+ synth_df = dfs["synthetic"]
176
+ for source in unique_subsets["synthetic"]:
177
+ df_source = synth_df[synth_df["source"] == source]
178
+ marker = marker_mapping.get(source, "square")
179
+ renderers = add_synthetic_dataset_to_fig(fig, df_source, unique_subsets["synthetic"][source],
180
+ marker=marker,
181
+ color_mapping=color_maps["synthetic"][source],
182
+ group_label=source)
183
+ synthetic_renderers.update(renderers)
184
+
185
+ fig.legend.location = "top_right"
186
+ fig.legend.click_policy = "hide"
187
+ show_legend = st.checkbox("Show Legend", value=False, key=f"legend_{model_name}")
188
+ fig.legend.visible = show_legend
189
+ return fig, real_renderers, synthetic_renderers
190
 
 
191
  def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label):
192
  renderers = {}
193
  for label in selected_labels:
 
217
  renderers[label + f" ({group_label})"] = r
218
  return renderers
219
 
 
220
  def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_label):
221
  renderers = {}
222
  for label in labels:
 
229
  label=subset['label'],
230
  img=subset.get('img', "")
231
  ))
 
232
  color = color_mapping[label]
 
233
  legend_label = group_label
 
234
  if marker == "square":
235
  r = fig.square('x', 'y', size=10, source=source_obj,
236
  fill_color=color, line_color=color,
 
267
  return renderers
268
 
269
 
270
+
271
  def get_color_maps(unique_subsets):
272
  color_map = {}
273
  # Para reales se asigna color para cada etiqueta
 
294
  palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)]
295
  color_map["synthetic"][source] = {label: palette[i] for i, label in enumerate(sorted(labels))}
296
  return color_map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
+
 
 
 
 
 
 
 
299
  def calculate_cluster_centers(df, labels):
300
  centers = {}
301
  for label in labels:
 
304
  centers[label] = (subset['x'].mean(), subset['y'].mean())
305
  return centers
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
 
309
+ # =============================================================================
310
+ # Funci贸n centralizada para la pipeline: reducci贸n, distancias y regresi贸n global
311
+ # =============================================================================
312
 
313
+ def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE"):
314
+ # Seleccionar el reductor seg煤n el m茅todo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  if reduction_method == "PCA":
316
  reducer = PCA(n_components=2)
317
  else:
318
+ reducer = TSNE(n_components=2, random_state=42,
319
+ perplexity=tsne_params["perplexity"],
320
+ learning_rate=tsne_params["learning_rate"])
321
 
322
+ # Aplicar reducci贸n dimensional
323
  reduced = reducer.fit_transform(df_combined[embedding_cols].values)
324
  dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
325
 
326
+ # Calcular distancias Wasserstein
 
 
 
 
327
  df_distances = compute_wasserstein_distances_synthetic_individual(
328
  dfs_reduced["synthetic"],
329
  dfs_reduced["real"],
330
  unique_subsets["real"]
331
  )
332
 
333
+ # Extraer valores globales para cada fuente (se esperan 10 por fuente)
 
 
 
 
 
 
 
334
  global_distances = {}
335
  for idx in df_distances.index:
336
  if idx.startswith("Global"):
 
337
  source = idx.split("(")[1].rstrip(")")
338
  global_distances[source] = df_distances.loc[idx].values
339
+
340
+ # Acumular todos los puntos (globales) y sus correspondientes f1 de cada colegio
341
+ all_x = []
342
+ all_y = []
343
+ for source in df_f1.columns:
344
+ if source in global_distances:
345
+ x_vals = global_distances[source]
346
+ y_vals = df_f1[source].values
347
+ all_x.extend(x_vals)
348
+ all_y.extend(y_vals)
349
+ all_x_arr = np.array(all_x).reshape(-1, 1)
350
+ all_y_arr = np.array(all_y)
351
 
352
+ # Realizar regresi贸n lineal global
353
+ model_global = LinearRegression().fit(all_x_arr, all_y_arr)
354
+ r2 = model_global.score(all_x_arr, all_y_arr)
355
+ slope = model_global.coef_[0]
356
+ intercept = model_global.intercept_
357
+
358
+ # Crear scatter plot para visualizar la relaci贸n
359
+ scatter_fig = figure(width=600, height=600, tools="pan,wheel_zoom,reset,save",
360
+ title="Scatter Plot: Wasserstein vs F1")
361
  source_colors = {
362
  "es-digital-paragraph-degradation-seq": "blue",
363
  "es-digital-line-degradation-seq": "green",
 
367
  "es-digital-rotation-zoom-degradation-seq": "brown",
368
  "es-render-seq": "cyan"
369
  }
 
 
 
 
 
 
 
370
  for source in df_f1.columns:
371
  if source in global_distances:
372
+ x_vals = global_distances[source]
373
+ y_vals = df_f1[source].values
374
+ data = {"x": x_vals, "y": y_vals, "Fuente": [source]*len(x_vals)}
375
  cds = ColumnDataSource(data=data)
376
  scatter_fig.circle('x', 'y', size=8, alpha=0.7, source=cds,
377
  fill_color=source_colors.get(source, "gray"),
378
  line_color=source_colors.get(source, "gray"),
379
  legend_label=source)
 
 
 
380
  scatter_fig.xaxis.axis_label = "Wasserstein Distance (Global, por Colegio)"
381
  scatter_fig.yaxis.axis_label = "F1 Score"
382
  scatter_fig.legend.location = "top_right"
 
 
383
  hover_tool = HoverTool(tooltips=[("Wass. Distance", "@x"), ("f1", "@y"), ("Subset", "@Fuente")])
384
  scatter_fig.add_tools(hover_tool)
 
385
 
386
+ # L铆nea de regresi贸n global
 
 
 
 
 
 
 
 
387
  x_line = np.linspace(all_x_arr.min(), all_x_arr.max(), 100)
388
  y_line = model_global.predict(x_line.reshape(-1, 1))
389
  scatter_fig.line(x_line, y_line, line_width=2, line_color="black", legend_label="Global Regression")
390
 
391
+ return {
392
+ "R2": r2,
393
+ "slope": slope,
394
+ "intercept": intercept,
395
+ "scatter_fig": scatter_fig,
396
+ "dfs_reduced": dfs_reduced,
397
+ "unique_subsets": unique_subsets,
398
+ "df_distances": df_distances
399
+ }
400
+
401
+ # =============================================================================
402
+ # Funci贸n de optimizaci贸n (grid search) para TSNE, ahora que se usa la misma pipeline
403
+ # =============================================================================
404
+
405
+ def optimize_tsne_params(df_combined, embedding_cols, df_f1):
406
+ # Rango de b煤squeda
407
+ perplexity_range = np.linspace(30, 50, 10)
408
+ learning_rate_range = np.linspace(200, 1000, 20)
409
+
410
+ best_R2 = -np.inf
411
+ best_params = None
412
+ total_steps = len(perplexity_range) * len(learning_rate_range)
413
+ step = 0
414
+
415
+ progress_text = st.empty()
416
 
417
+ for p in perplexity_range:
418
+ for lr in learning_rate_range:
419
+ step += 1
420
+ progress_text.text(f"Evaluating: Perplexity={p:.2f}, Learning Rate={lr:.2f} (Step {step}/{total_steps})")
421
+
422
+ tsne_params = {"perplexity": p, "learning_rate": lr}
423
+ result = compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE")
424
+ r2_temp = result["R2"]
425
+ st.write(f"Parameters: Perplexity={p:.2f}, Learning Rate={lr:.2f} -> R虏={r2_temp:.4f}")
426
+
427
+ if r2_temp > best_R2:
428
+ best_R2 = r2_temp
429
+ best_params = (p, lr)
430
 
431
+ progress_text.text("Optimization completed!")
432
+ return best_params, best_R2
433
+
434
+ # =============================================================================
435
+ # Funci贸n principal run_model que integra la optimizaci贸n y la ejecuci贸n manual
436
+ # =============================================================================
437
+
438
+ def run_model(model_name):
439
+ embeddings = load_embeddings(model_name)
440
+ if embeddings is None:
441
+ return
442
+ embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
443
+ df_combined = pd.concat(list(embeddings.values()), ignore_index=True)
444
 
445
+ # Cargar CSV f1-donut
446
+ try:
447
+ df_f1 = pd.read_csv("data/f1-donut.csv", sep=';', index_col=0)
448
+ except Exception as e:
449
+ st.error(f"Error loading f1-donut.csv: {e}")
450
+ return
451
+
452
+ st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
453
+ reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=f"reduction_{model_name}")
454
+
455
+ tsne_params = {}
456
+ if reduction_method == "t-SNE":
457
+ if st.button("Optimize TSNE parameters", key=f"optimize_tsne_{model_name}"):
458
+ st.info("Running optimization, this can take a while...")
459
+ best_params, best_R2 = optimize_tsne_params(df_combined, embedding_cols, df_f1)
460
+ st.success(f"Mejores par谩metros: Perplexity = {best_params[0]:.2f}, Learning Rate = {best_params[1]:.2f} con R虏 = {best_R2:.4f}")
461
+ tsne_params = {"perplexity": best_params[0], "learning_rate": best_params[1]}
462
+ else:
463
+ perplexity_val = st.number_input(
464
+ "Perplexity",
465
+ min_value=5.0,
466
+ max_value=50.0,
467
+ value=30.0,
468
+ step=1.0,
469
+ format="%.2f",
470
+ key=f"perplexity_{model_name}"
471
+ )
472
+ learning_rate_val = st.number_input(
473
+ "Learning Rate",
474
+ min_value=10.0,
475
+ max_value=1000.0,
476
+ value=200.0,
477
+ step=10.0,
478
+ format="%.2f",
479
+ key=f"learning_rate_{model_name}"
480
+ )
481
+ tsne_params = {"perplexity": perplexity_val, "learning_rate": learning_rate_val}
482
+ # Si se selecciona PCA, tsne_params no se usa.
483
+
484
+ # Usar la funci贸n centralizada para obtener la regresi贸n global y el scatter plot
485
+ result = compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method=reduction_method)
486
+
487
+ reg_metrics = pd.DataFrame({
488
+ "Slope": [result["slope"]],
489
+ "Intercept": [result["intercept"]],
490
+ "R2": [result["R2"]]
491
+ })
492
+ st.table(reg_metrics)
493
+
494
+ # No llamamos a st.bokeh_chart(result["scatter_fig"], ...) aqu铆
495
+ # Sino que combinamos todo en un 煤nico layout:
496
+ data_table, df_table, source_table = create_table(result["df_distances"])
497
  real_subset_names = list(df_table.columns[1:])
498
  real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
499
  reset_button = Button(label="Reset Colors", button_type="primary")
500
  line_source = ColumnDataSource(data={'x': [], 'y': []})
501
+ # Suponiendo que tienes una figura base 'fig' para los clusters:
502
+ fig, real_renderers, synthetic_renderers = create_figure(result["dfs_reduced"], result["unique_subsets"], get_color_maps(result["unique_subsets"]), model_name)
503
  fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
504
+ centers_real = calculate_cluster_centers(result["dfs_reduced"]["real"], result["unique_subsets"]["real"])
505
  real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
506
  synthetic_centers = {}
507
+ synth_labels = sorted(result["dfs_reduced"]["synthetic"]['label'].unique().tolist())
508
  for label in synth_labels:
509
+ subset = result["dfs_reduced"]["synthetic"][result["dfs_reduced"]["synthetic"]['label'] == label]
510
  synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()]
511
 
512
  callback = CustomJS(args=dict(source=source_table, line_source=line_source,
 
543
  df_table.to_excel(buffer, index=False)
544
  buffer.seek(0)
545
 
546
+ # Combinar todos los gr谩ficos en un 煤nico layout:
547
+ layout = column(fig, result["scatter_fig"], column(real_select, reset_button, data_table))
548
  st.bokeh_chart(layout, use_container_width=True)
549
 
550
  st.download_button(
 
555
  key=f"download_button_excel_{model_name}"
556
  )
557
 
 
558
  def main():
559
  config_style()
560
  tabs = st.tabs(["Donut", "Idefics2"])