de-Rodrigo commited on
Commit
f541667
1 Parent(s): b5f38e3

Vanilla or Overfitted Model Selection

Browse files
Files changed (1) hide show
  1. app.py +24 -42
app.py CHANGED
@@ -40,15 +40,15 @@ def config_style():
40
  # Funciones de carga de datos, generaci贸n de gr谩ficos y c谩lculo de distancias (sin cambios)
41
  # =============================================================================
42
 
43
- def load_embeddings(model):
44
  if model == "Donut":
45
- df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
46
- df_par = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_embeddings.csv")
47
- df_line = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-line-degradation-seq_embeddings.csv")
48
- df_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
49
- df_rot = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-rotation-degradation-seq_embeddings.csv")
50
- df_zoom = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-zoom-degradation-seq_embeddings.csv")
51
- df_render = pd.read_csv("data/donut_de_Rodrigo_merit_es-render-seq_embeddings.csv")
52
  df_real["version"] = "real"
53
  df_par["version"] = "synthetic"
54
  df_line["version"] = "synthetic"
@@ -66,13 +66,13 @@ def load_embeddings(model):
66
  return {"real": df_real, "synthetic": pd.concat([df_seq, df_line, df_par, df_rot, df_zoom, df_render], ignore_index=True)}
67
 
68
  elif model == "Idefics2":
69
- df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
70
- df_par = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_embeddings.csv")
71
- df_line = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-line-degradation-seq_embeddings.csv")
72
- df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
73
- df_rot = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-rotation-degradation-seq_embeddings.csv")
74
- df_zoom = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-zoom-degradation-seq_embeddings.csv")
75
- df_render = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-render-seq_embeddings.csv")
76
  df_real["version"] = "real"
77
  df_par["version"] = "synthetic"
78
  df_line["version"] = "synthetic"
@@ -266,16 +266,12 @@ def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_l
266
  renderers[label + f" ({group_label})"] = r
267
  return renderers
268
 
269
-
270
-
271
  def get_color_maps(unique_subsets):
272
  color_map = {}
273
- # Para reales se asigna color para cada etiqueta
274
  num_real = len(unique_subsets["real"])
275
  red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real]
276
  color_map["real"] = {label: red_palette[i] for i, label in enumerate(sorted(unique_subsets["real"]))}
277
 
278
- # Para sint茅ticos se asigna color de forma granular: para cada source se mapea cada etiqueta
279
  color_map["synthetic"] = {}
280
  for source, labels in unique_subsets["synthetic"].items():
281
  if source == "es-digital-seq":
@@ -294,8 +290,7 @@ def get_color_maps(unique_subsets):
294
  palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)]
295
  color_map["synthetic"][source] = {label: palette[i] for i, label in enumerate(sorted(labels))}
296
  return color_map
297
-
298
-
299
  def calculate_cluster_centers(df, labels):
300
  centers = {}
301
  for label in labels:
@@ -304,14 +299,11 @@ def calculate_cluster_centers(df, labels):
304
  centers[label] = (subset['x'].mean(), subset['y'].mean())
305
  return centers
306
 
307
-
308
-
309
  # =============================================================================
310
  # Funci贸n centralizada para la pipeline: reducci贸n, distancias y regresi贸n global
311
  # =============================================================================
312
 
313
  def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE"):
314
- # Seleccionar el reductor seg煤n el m茅todo
315
  if reduction_method == "PCA":
316
  reducer = PCA(n_components=2)
317
  else:
@@ -319,25 +311,21 @@ def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, r
319
  perplexity=tsne_params["perplexity"],
320
  learning_rate=tsne_params["learning_rate"])
321
 
322
- # Aplicar reducci贸n dimensional
323
  reduced = reducer.fit_transform(df_combined[embedding_cols].values)
324
  dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
325
 
326
- # Calcular distancias Wasserstein
327
  df_distances = compute_wasserstein_distances_synthetic_individual(
328
  dfs_reduced["synthetic"],
329
  dfs_reduced["real"],
330
  unique_subsets["real"]
331
  )
332
 
333
- # Extraer valores globales para cada fuente (se esperan 10 por fuente)
334
  global_distances = {}
335
  for idx in df_distances.index:
336
  if idx.startswith("Global"):
337
  source = idx.split("(")[1].rstrip(")")
338
  global_distances[source] = df_distances.loc[idx].values
339
 
340
- # Acumular todos los puntos (globales) y sus correspondientes f1 de cada colegio
341
  all_x = []
342
  all_y = []
343
  for source in df_f1.columns:
@@ -349,13 +337,11 @@ def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, r
349
  all_x_arr = np.array(all_x).reshape(-1, 1)
350
  all_y_arr = np.array(all_y)
351
 
352
- # Realizar regresi贸n lineal global
353
  model_global = LinearRegression().fit(all_x_arr, all_y_arr)
354
  r2 = model_global.score(all_x_arr, all_y_arr)
355
  slope = model_global.coef_[0]
356
  intercept = model_global.intercept_
357
 
358
- # Crear scatter plot para visualizar la relaci贸n
359
  scatter_fig = figure(width=600, height=600, tools="pan,wheel_zoom,reset,save",
360
  title="Scatter Plot: Wasserstein vs F1")
361
  source_colors = {
@@ -383,7 +369,6 @@ def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, r
383
  hover_tool = HoverTool(tooltips=[("Wass. Distance", "@x"), ("f1", "@y"), ("Subset", "@Fuente")])
384
  scatter_fig.add_tools(hover_tool)
385
 
386
- # L铆nea de regresi贸n global
387
  x_line = np.linspace(all_x_arr.min(), all_x_arr.max(), 100)
388
  y_line = model_global.predict(x_line.reshape(-1, 1))
389
  scatter_fig.line(x_line, y_line, line_width=2, line_color="black", legend_label="Global Regression")
@@ -399,11 +384,10 @@ def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, r
399
  }
400
 
401
  # =============================================================================
402
- # Funci贸n de optimizaci贸n (grid search) para TSNE, ahora que se usa la misma pipeline
403
  # =============================================================================
404
 
405
  def optimize_tsne_params(df_combined, embedding_cols, df_f1):
406
- # Rango de b煤squeda
407
  perplexity_range = np.linspace(30, 50, 10)
408
  learning_rate_range = np.linspace(200, 1000, 20)
409
 
@@ -432,17 +416,19 @@ def optimize_tsne_params(df_combined, embedding_cols, df_f1):
432
  return best_params, best_R2
433
 
434
  # =============================================================================
435
- # Funci贸n principal run_model que integra la optimizaci贸n y la ejecuci贸n manual
436
  # =============================================================================
437
 
438
  def run_model(model_name):
439
- embeddings = load_embeddings(model_name)
 
 
 
440
  if embeddings is None:
441
  return
442
  embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
443
  df_combined = pd.concat(list(embeddings.values()), ignore_index=True)
444
 
445
- # Cargar CSV f1-donut
446
  try:
447
  df_f1 = pd.read_csv("data/f1-donut.csv", sep=';', index_col=0)
448
  except Exception as e:
@@ -457,7 +443,7 @@ def run_model(model_name):
457
  if st.button("Optimize TSNE parameters", key=f"optimize_tsne_{model_name}"):
458
  st.info("Running optimization, this can take a while...")
459
  best_params, best_R2 = optimize_tsne_params(df_combined, embedding_cols, df_f1)
460
- st.success(f"Mejores par谩metros: Perplexity = {best_params[0]:.2f}, Learning Rate = {best_params[1]:.2f} con R虏 = {best_R2:.4f}")
461
  tsne_params = {"perplexity": best_params[0], "learning_rate": best_params[1]}
462
  else:
463
  perplexity_val = st.number_input(
@@ -481,7 +467,6 @@ def run_model(model_name):
481
  tsne_params = {"perplexity": perplexity_val, "learning_rate": learning_rate_val}
482
  # Si se selecciona PCA, tsne_params no se usa.
483
 
484
- # Usar la funci贸n centralizada para obtener la regresi贸n global y el scatter plot
485
  result = compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method=reduction_method)
486
 
487
  reg_metrics = pd.DataFrame({
@@ -491,14 +476,12 @@ def run_model(model_name):
491
  })
492
  st.table(reg_metrics)
493
 
494
- # No llamamos a st.bokeh_chart(result["scatter_fig"], ...) aqu铆
495
- # Sino que combinamos todo en un 煤nico layout:
496
  data_table, df_table, source_table = create_table(result["df_distances"])
497
  real_subset_names = list(df_table.columns[1:])
498
  real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
499
  reset_button = Button(label="Reset Colors", button_type="primary")
500
  line_source = ColumnDataSource(data={'x': [], 'y': []})
501
- # Suponiendo que tienes una figura base 'fig' para los clusters:
502
  fig, real_renderers, synthetic_renderers = create_figure(result["dfs_reduced"], result["unique_subsets"], get_color_maps(result["unique_subsets"]), model_name)
503
  fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
504
  centers_real = calculate_cluster_centers(result["dfs_reduced"]["real"], result["unique_subsets"]["real"])
@@ -543,7 +526,6 @@ def run_model(model_name):
543
  df_table.to_excel(buffer, index=False)
544
  buffer.seek(0)
545
 
546
- # Combinar todos los gr谩ficos en un 煤nico layout:
547
  layout = column(fig, result["scatter_fig"], column(real_select, reset_button, data_table))
548
  st.bokeh_chart(layout, use_container_width=True)
549
 
 
40
  # Funciones de carga de datos, generaci贸n de gr谩ficos y c谩lculo de distancias (sin cambios)
41
  # =============================================================================
42
 
43
+ def load_embeddings(model, version):
44
  if model == "Donut":
45
+ df_real = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_secret_all_embeddings.csv")
46
+ df_par = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_embeddings.csv")
47
+ df_line = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-digital-line-degradation-seq_embeddings.csv")
48
+ df_seq = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
49
+ df_rot = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-digital-rotation-degradation-seq_embeddings.csv")
50
+ df_zoom = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-digital-zoom-degradation-seq_embeddings.csv")
51
+ df_render = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-render-seq_embeddings.csv")
52
  df_real["version"] = "real"
53
  df_par["version"] = "synthetic"
54
  df_line["version"] = "synthetic"
 
66
  return {"real": df_real, "synthetic": pd.concat([df_seq, df_line, df_par, df_rot, df_zoom, df_render], ignore_index=True)}
67
 
68
  elif model == "Idefics2":
69
+ df_real = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_secret_britanico_embeddings.csv")
70
+ df_par = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_embeddings.csv")
71
+ df_line = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-line-degradation-seq_embeddings.csv")
72
+ df_seq = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
73
+ df_rot = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-rotation-degradation-seq_embeddings.csv")
74
+ df_zoom = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-zoom-degradation-seq_embeddings.csv")
75
+ df_render = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-render-seq_embeddings.csv")
76
  df_real["version"] = "real"
77
  df_par["version"] = "synthetic"
78
  df_line["version"] = "synthetic"
 
266
  renderers[label + f" ({group_label})"] = r
267
  return renderers
268
 
 
 
269
  def get_color_maps(unique_subsets):
270
  color_map = {}
 
271
  num_real = len(unique_subsets["real"])
272
  red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real]
273
  color_map["real"] = {label: red_palette[i] for i, label in enumerate(sorted(unique_subsets["real"]))}
274
 
 
275
  color_map["synthetic"] = {}
276
  for source, labels in unique_subsets["synthetic"].items():
277
  if source == "es-digital-seq":
 
290
  palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)]
291
  color_map["synthetic"][source] = {label: palette[i] for i, label in enumerate(sorted(labels))}
292
  return color_map
293
+
 
294
  def calculate_cluster_centers(df, labels):
295
  centers = {}
296
  for label in labels:
 
299
  centers[label] = (subset['x'].mean(), subset['y'].mean())
300
  return centers
301
 
 
 
302
  # =============================================================================
303
  # Funci贸n centralizada para la pipeline: reducci贸n, distancias y regresi贸n global
304
  # =============================================================================
305
 
306
  def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE"):
 
307
  if reduction_method == "PCA":
308
  reducer = PCA(n_components=2)
309
  else:
 
311
  perplexity=tsne_params["perplexity"],
312
  learning_rate=tsne_params["learning_rate"])
313
 
 
314
  reduced = reducer.fit_transform(df_combined[embedding_cols].values)
315
  dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
316
 
 
317
  df_distances = compute_wasserstein_distances_synthetic_individual(
318
  dfs_reduced["synthetic"],
319
  dfs_reduced["real"],
320
  unique_subsets["real"]
321
  )
322
 
 
323
  global_distances = {}
324
  for idx in df_distances.index:
325
  if idx.startswith("Global"):
326
  source = idx.split("(")[1].rstrip(")")
327
  global_distances[source] = df_distances.loc[idx].values
328
 
 
329
  all_x = []
330
  all_y = []
331
  for source in df_f1.columns:
 
337
  all_x_arr = np.array(all_x).reshape(-1, 1)
338
  all_y_arr = np.array(all_y)
339
 
 
340
  model_global = LinearRegression().fit(all_x_arr, all_y_arr)
341
  r2 = model_global.score(all_x_arr, all_y_arr)
342
  slope = model_global.coef_[0]
343
  intercept = model_global.intercept_
344
 
 
345
  scatter_fig = figure(width=600, height=600, tools="pan,wheel_zoom,reset,save",
346
  title="Scatter Plot: Wasserstein vs F1")
347
  source_colors = {
 
369
  hover_tool = HoverTool(tooltips=[("Wass. Distance", "@x"), ("f1", "@y"), ("Subset", "@Fuente")])
370
  scatter_fig.add_tools(hover_tool)
371
 
 
372
  x_line = np.linspace(all_x_arr.min(), all_x_arr.max(), 100)
373
  y_line = model_global.predict(x_line.reshape(-1, 1))
374
  scatter_fig.line(x_line, y_line, line_width=2, line_color="black", legend_label="Global Regression")
 
384
  }
385
 
386
  # =============================================================================
387
+ # Funci贸n de optimizaci贸n (grid search) para TSNE, usando la misma pipeline
388
  # =============================================================================
389
 
390
  def optimize_tsne_params(df_combined, embedding_cols, df_f1):
 
391
  perplexity_range = np.linspace(30, 50, 10)
392
  learning_rate_range = np.linspace(200, 1000, 20)
393
 
 
416
  return best_params, best_R2
417
 
418
  # =============================================================================
419
+ # Funci贸n principal run_model que integra optimizaci贸n, selector de versi贸n y ejecuci贸n manual
420
  # =============================================================================
421
 
422
  def run_model(model_name):
423
+ # Seleccionar la versi贸n del modelo
424
+ version = st.selectbox("Select Model Version:", options=["vanilla", "finetuned_real"], key=f"version_{model_name}")
425
+
426
+ embeddings = load_embeddings(model_name, version)
427
  if embeddings is None:
428
  return
429
  embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
430
  df_combined = pd.concat(list(embeddings.values()), ignore_index=True)
431
 
 
432
  try:
433
  df_f1 = pd.read_csv("data/f1-donut.csv", sep=';', index_col=0)
434
  except Exception as e:
 
443
  if st.button("Optimize TSNE parameters", key=f"optimize_tsne_{model_name}"):
444
  st.info("Running optimization, this can take a while...")
445
  best_params, best_R2 = optimize_tsne_params(df_combined, embedding_cols, df_f1)
446
+ st.success(f"Best parameters: Perplexity = {best_params[0]:.2f}, Learning Rate = {best_params[1]:.2f} with R虏 = {best_R2:.4f}")
447
  tsne_params = {"perplexity": best_params[0], "learning_rate": best_params[1]}
448
  else:
449
  perplexity_val = st.number_input(
 
467
  tsne_params = {"perplexity": perplexity_val, "learning_rate": learning_rate_val}
468
  # Si se selecciona PCA, tsne_params no se usa.
469
 
 
470
  result = compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method=reduction_method)
471
 
472
  reg_metrics = pd.DataFrame({
 
476
  })
477
  st.table(reg_metrics)
478
 
 
 
479
  data_table, df_table, source_table = create_table(result["df_distances"])
480
  real_subset_names = list(df_table.columns[1:])
481
  real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
482
  reset_button = Button(label="Reset Colors", button_type="primary")
483
  line_source = ColumnDataSource(data={'x': [], 'y': []})
484
+
485
  fig, real_renderers, synthetic_renderers = create_figure(result["dfs_reduced"], result["unique_subsets"], get_color_maps(result["unique_subsets"]), model_name)
486
  fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
487
  centers_real = calculate_cluster_centers(result["dfs_reduced"]["real"], result["unique_subsets"]["real"])
 
526
  df_table.to_excel(buffer, index=False)
527
  buffer.seek(0)
528
 
 
529
  layout = column(fig, result["scatter_fig"], column(real_select, reset_button, data_table))
530
  st.bokeh_chart(layout, use_container_width=True)
531