de-Rodrigo commited on
Commit
09537d6
1 Parent(s): c85d8a8

Include Pretraining Datasets

Browse files
Files changed (1) hide show
  1. app.py +73 -25
app.py CHANGED
@@ -40,15 +40,18 @@ def config_style():
40
  """, unsafe_allow_html=True)
41
  st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
42
 
43
- def load_embeddings(model, version):
44
  if model == "Donut":
45
  df_real = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_secret_all_embeddings.csv")
46
- df_par = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_embeddings.csv")
47
- df_line = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-digital-line-degradation-seq_embeddings.csv")
48
- df_seq = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
49
- df_rot = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-digital-rotation-degradation-seq_embeddings.csv")
50
- df_zoom = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-digital-zoom-degradation-seq_embeddings.csv")
51
- df_render = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-render-seq_embeddings.csv")
 
 
 
52
  df_real["version"] = "real"
53
  df_par["version"] = "synthetic"
54
  df_line["version"] = "synthetic"
@@ -56,23 +59,32 @@ def load_embeddings(model, version):
56
  df_rot["version"] = "synthetic"
57
  df_zoom["version"] = "synthetic"
58
  df_render["version"] = "synthetic"
 
59
 
 
60
  df_par["source"] = "es-digital-paragraph-degradation-seq"
61
  df_line["source"] = "es-digital-line-degradation-seq"
62
  df_seq["source"] = "es-digital-seq"
63
  df_rot["source"] = "es-digital-rotation-degradation-seq"
64
  df_zoom["source"] = "es-digital-zoom-degradation-seq"
65
  df_render["source"] = "es-render-seq"
66
- return {"real": df_real, "synthetic": pd.concat([df_seq, df_line, df_par, df_rot, df_zoom, df_render], ignore_index=True)}
 
 
 
 
 
67
 
68
  elif model == "Idefics2":
69
  df_real = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_secret_britanico_embeddings.csv")
70
- df_par = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_embeddings.csv")
71
- df_line = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-line-degradation-seq_embeddings.csv")
72
- df_seq = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
73
- df_rot = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-rotation-degradation-seq_embeddings.csv")
74
- df_zoom = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-zoom-degradation-seq_embeddings.csv")
75
- df_render = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-render-seq_embeddings.csv")
 
 
76
  df_real["version"] = "real"
77
  df_par["version"] = "synthetic"
78
  df_line["version"] = "synthetic"
@@ -80,6 +92,7 @@ def load_embeddings(model, version):
80
  df_rot["version"] = "synthetic"
81
  df_zoom["version"] = "synthetic"
82
  df_render["version"] = "synthetic"
 
83
 
84
  df_par["source"] = "es-digital-paragraph-degradation-seq"
85
  df_line["source"] = "es-digital-line-degradation-seq"
@@ -87,27 +100,38 @@ def load_embeddings(model, version):
87
  df_rot["source"] = "es-digital-rotation-degradation-seq"
88
  df_zoom["source"] = "es-digital-zoom-degradation-seq"
89
  df_render["source"] = "es-render-seq"
90
- return {"real": df_real, "synthetic": pd.concat([df_seq, df_line, df_par, df_rot, df_zoom, df_render], ignore_index=True)}
 
 
 
 
91
 
92
  else:
93
  st.error("Modelo no reconocido")
94
  return None
95
 
 
 
96
  def split_versions(df_combined, reduced):
97
- # Si el embedding es 2D se asignan las columnas x e y para visualizaci贸n.
98
  if reduced.shape[1] == 2:
99
  df_combined['x'] = reduced[:, 0]
100
  df_combined['y'] = reduced[:, 1]
101
  df_real = df_combined[df_combined["version"] == "real"].copy()
102
  df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
 
 
103
  unique_real = sorted(df_real['label'].unique().tolist())
104
  unique_synth = {}
105
  for source in df_synth["source"].unique():
106
  unique_synth[source] = sorted(df_synth[df_synth["source"] == source]['label'].unique().tolist())
107
- df_dict = {"real": df_real, "synthetic": df_synth}
108
- unique_subsets = {"real": unique_real, "synthetic": unique_synth}
 
 
109
  return df_dict, unique_subsets
110
 
 
111
  def get_embedding_from_df(df):
112
  # Retorna el embedding completo (4 dimensiones en este caso) guardado en la columna 'embedding'
113
  if 'embedding' in df.columns:
@@ -212,11 +236,15 @@ def create_table(df_distances):
212
  return data_table, df_table, source_table
213
 
214
  def create_figure(dfs, unique_subsets, color_maps, model_name):
215
- # Se crea solo si el embedding es 2D (ya que se usan 'x' y 'y' para visualizar)
216
  fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="")
 
 
217
  real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
218
  marker="circle", color_mapping=color_maps["real"],
219
  group_label="Real")
 
 
220
  marker_mapping = {
221
  "es-digital-paragraph-degradation-seq": "x",
222
  "es-digital-line-degradation-seq": "cross",
@@ -236,11 +264,17 @@ def create_figure(dfs, unique_subsets, color_maps, model_name):
236
  group_label=source)
237
  synthetic_renderers.update(renderers)
238
 
 
 
 
 
 
239
  fig.legend.location = "top_right"
240
  fig.legend.click_policy = "hide"
241
  show_legend = st.checkbox("Show Legend", value=False, key=f"legend_{model_name}")
242
  fig.legend.visible = show_legend
243
- return fig, real_renderers, synthetic_renderers
 
244
 
245
  def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label):
246
  renderers = {}
@@ -343,8 +377,15 @@ def get_color_maps(unique_subsets):
343
  else:
344
  palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)]
345
  color_map["synthetic"][source] = {label: palette[i] for i, label in enumerate(sorted(labels))}
 
 
 
 
 
 
346
  return color_map
347
 
 
348
  def calculate_cluster_centers(df, labels):
349
  centers = {}
350
  for label in labels:
@@ -485,8 +526,12 @@ def optimize_tsne_params(df_combined, embedding_cols, df_f1, distance_metric):
485
 
486
  def run_model(model_name):
487
  version = st.selectbox("Select Model Version:", options=["vanilla", "finetuned_real"], key=f"version_{model_name}")
 
 
 
 
488
 
489
- embeddings = load_embeddings(model_name, version)
490
  if embeddings is None:
491
  return
492
  embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
@@ -562,10 +607,13 @@ def run_model(model_name):
562
  reset_button = Button(label="Reset Colors", button_type="primary")
563
  line_source = ColumnDataSource(data={'x': [], 'y': []})
564
 
565
- # Si el embedding es 2D se crea el scatter plot de embeddings;
566
- # dado que con PCA ahora usamos 4 dimensiones, este bloque se omite para PCA
567
  if (reduction_method == "t-SNE" and N_COMPONENTS == 2) or (reduction_method == "PCA" and N_COMPONENTS == 2):
568
- fig, real_renderers, synthetic_renderers = create_figure(result["dfs_reduced"], result["unique_subsets"], get_color_maps(result["unique_subsets"]), model_name)
 
 
 
 
 
569
  fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
570
  centers_real = calculate_cluster_centers(result["dfs_reduced"]["real"], result["unique_subsets"]["real"])
571
  real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
@@ -633,4 +681,4 @@ def main():
633
  run_model("Idefics2")
634
 
635
  if __name__ == "__main__":
636
- main()
 
40
  """, unsafe_allow_html=True)
41
  st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
42
 
43
+ def load_embeddings(model, version, embedding_prefix):
44
  if model == "Donut":
45
  df_real = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_secret_all_embeddings.csv")
46
+ df_par = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_{embedding_prefix}embeddings.csv")
47
+ df_line = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-digital-line-degradation-seq_{embedding_prefix}embeddings.csv")
48
+ df_seq = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-digital-seq_{embedding_prefix}embeddings.csv")
49
+ df_rot = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-digital-rotation-degradation-seq_{embedding_prefix}embeddings.csv")
50
+ df_zoom = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-digital-zoom-degradation-seq_{embedding_prefix}embeddings.csv")
51
+ df_render = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_es-render-seq_{embedding_prefix}embeddings.csv")
52
+ df_pretratrained = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_aux_IIT-CDIP_{embedding_prefix}embeddings.csv")
53
+
54
+ # Asignar etiquetas de versi贸n
55
  df_real["version"] = "real"
56
  df_par["version"] = "synthetic"
57
  df_line["version"] = "synthetic"
 
59
  df_rot["version"] = "synthetic"
60
  df_zoom["version"] = "synthetic"
61
  df_render["version"] = "synthetic"
62
+ df_pretratrained["version"] = "pretrained"
63
 
64
+ # Asignar fuente (source)
65
  df_par["source"] = "es-digital-paragraph-degradation-seq"
66
  df_line["source"] = "es-digital-line-degradation-seq"
67
  df_seq["source"] = "es-digital-seq"
68
  df_rot["source"] = "es-digital-rotation-degradation-seq"
69
  df_zoom["source"] = "es-digital-zoom-degradation-seq"
70
  df_render["source"] = "es-render-seq"
71
+ # Si lo requieres, puedes asignar tambi茅n una fuente para pretrained
72
+ df_pretratrained["source"] = "pretrained"
73
+
74
+ return {"real": df_real,
75
+ "synthetic": pd.concat([df_seq, df_line, df_par, df_rot, df_zoom, df_render], ignore_index=True),
76
+ "pretrained": df_pretratrained}
77
 
78
  elif model == "Idefics2":
79
  df_real = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_secret_britanico_embeddings.csv")
80
+ df_par = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_{embedding_prefix}embeddings.csv")
81
+ df_line = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-line-degradation-seq_{embedding_prefix}embeddings.csv")
82
+ df_seq = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-seq_{embedding_prefix}embeddings.csv")
83
+ df_rot = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-rotation-degradation-seq_{embedding_prefix}embeddings.csv")
84
+ df_zoom = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-digital-zoom-degradation-seq_{embedding_prefix}embeddings.csv")
85
+ df_render = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_es-render-seq_{embedding_prefix}embeddings.csv")
86
+ df_pretratrained = pd.read_csv(f"data/idefics2_{version}_de_Rodrigo_merit_pretrained_{embedding_prefix}embeddings.csv")
87
+
88
  df_real["version"] = "real"
89
  df_par["version"] = "synthetic"
90
  df_line["version"] = "synthetic"
 
92
  df_rot["version"] = "synthetic"
93
  df_zoom["version"] = "synthetic"
94
  df_render["version"] = "synthetic"
95
+ df_pretratrained["version"] = "pretrained"
96
 
97
  df_par["source"] = "es-digital-paragraph-degradation-seq"
98
  df_line["source"] = "es-digital-line-degradation-seq"
 
100
  df_rot["source"] = "es-digital-rotation-degradation-seq"
101
  df_zoom["source"] = "es-digital-zoom-degradation-seq"
102
  df_render["source"] = "es-render-seq"
103
+ df_pretratrained["source"] = "pretrained"
104
+
105
+ return {"real": df_real,
106
+ "synthetic": pd.concat([df_seq, df_line, df_par, df_rot, df_zoom, df_render], ignore_index=True),
107
+ "pretrained": df_pretratrained}
108
 
109
  else:
110
  st.error("Modelo no reconocido")
111
  return None
112
 
113
+
114
+
115
  def split_versions(df_combined, reduced):
116
+ # Asignar las coordenadas si la reducci贸n es 2D
117
  if reduced.shape[1] == 2:
118
  df_combined['x'] = reduced[:, 0]
119
  df_combined['y'] = reduced[:, 1]
120
  df_real = df_combined[df_combined["version"] == "real"].copy()
121
  df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
122
+ df_pretrained = df_combined[df_combined["version"] == "pretrained"].copy()
123
+
124
  unique_real = sorted(df_real['label'].unique().tolist())
125
  unique_synth = {}
126
  for source in df_synth["source"].unique():
127
  unique_synth[source] = sorted(df_synth[df_synth["source"] == source]['label'].unique().tolist())
128
+ unique_pretrained = sorted(df_pretrained['label'].unique().tolist())
129
+
130
+ df_dict = {"real": df_real, "synthetic": df_synth, "pretrained": df_pretrained}
131
+ unique_subsets = {"real": unique_real, "synthetic": unique_synth, "pretrained": unique_pretrained}
132
  return df_dict, unique_subsets
133
 
134
+
135
  def get_embedding_from_df(df):
136
  # Retorna el embedding completo (4 dimensiones en este caso) guardado en la columna 'embedding'
137
  if 'embedding' in df.columns:
 
236
  return data_table, df_table, source_table
237
 
238
  def create_figure(dfs, unique_subsets, color_maps, model_name):
239
+ # Se crea el plot para el embedding reducido (asumiendo que es 2D)
240
  fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="")
241
+
242
+ # Renderizar datos reales
243
  real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
244
  marker="circle", color_mapping=color_maps["real"],
245
  group_label="Real")
246
+
247
+ # Renderizar datos sint茅ticos (por fuente)
248
  marker_mapping = {
249
  "es-digital-paragraph-degradation-seq": "x",
250
  "es-digital-line-degradation-seq": "cross",
 
264
  group_label=source)
265
  synthetic_renderers.update(renderers)
266
 
267
+ # Agregar el subset pretrained (se puede usar un marcador distinto, por ejemplo, "triangle")
268
+ pretrained_renderers = add_dataset_to_fig(fig, dfs["pretrained"], unique_subsets["pretrained"],
269
+ marker="triangle", color_mapping=color_maps["pretrained"],
270
+ group_label="Pretrained")
271
+
272
  fig.legend.location = "top_right"
273
  fig.legend.click_policy = "hide"
274
  show_legend = st.checkbox("Show Legend", value=False, key=f"legend_{model_name}")
275
  fig.legend.visible = show_legend
276
+ return fig, real_renderers, synthetic_renderers, pretrained_renderers
277
+
278
 
279
  def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label):
280
  renderers = {}
 
377
  else:
378
  palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)]
379
  color_map["synthetic"][source] = {label: palette[i] for i, label in enumerate(sorted(labels))}
380
+
381
+ # Asignar colores al subset pretrained usando, por ejemplo, la paleta Purples9
382
+ num_pretrained = len(unique_subsets["pretrained"])
383
+ purple_palette = Purples9[:num_pretrained] if num_pretrained <= 9 else (Purples9 * ((num_pretrained // 9) + 1))[:num_pretrained]
384
+ color_map["pretrained"] = {label: purple_palette[i] for i, label in enumerate(sorted(unique_subsets["pretrained"]))}
385
+
386
  return color_map
387
 
388
+
389
  def calculate_cluster_centers(df, labels):
390
  centers = {}
391
  for label in labels:
 
526
 
527
  def run_model(model_name):
528
  version = st.selectbox("Select Model Version:", options=["vanilla", "finetuned_real"], key=f"version_{model_name}")
529
+ # Nuevo selector para el c贸mputo del embedding
530
+ embedding_computation = st.selectbox("驴C贸mo se computa el embedding?", options=["weighted", "averaged"], key=f"embedding_method_{model_name}")
531
+ # Se asigna el prefijo correspondiente
532
+ prefijo_embedding = "weighted_" if embedding_computation == "weighted" else ""
533
 
534
+ embeddings = load_embeddings(model_name, version, prefijo_embedding)
535
  if embeddings is None:
536
  return
537
  embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
 
607
  reset_button = Button(label="Reset Colors", button_type="primary")
608
  line_source = ColumnDataSource(data={'x': [], 'y': []})
609
 
 
 
610
  if (reduction_method == "t-SNE" and N_COMPONENTS == 2) or (reduction_method == "PCA" and N_COMPONENTS == 2):
611
+ fig, real_renderers, synthetic_renderers, pretrained_renderers = create_figure(
612
+ result["dfs_reduced"],
613
+ result["unique_subsets"],
614
+ get_color_maps(result["unique_subsets"]),
615
+ model_name
616
+ )
617
  fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
618
  centers_real = calculate_cluster_centers(result["dfs_reduced"]["real"], result["unique_subsets"]["real"])
619
  real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
 
681
  run_model("Idefics2")
682
 
683
  if __name__ == "__main__":
684
+ main()