de-Rodrigo commited on
Commit
6b1f66d
·
1 Parent(s): b961047

Multiple Dataset Versions

Browse files
Files changed (1) hide show
  1. app.py +139 -143
app.py CHANGED
@@ -3,11 +3,12 @@ import pandas as pd
3
  import numpy as np
4
  from bokeh.plotting import figure
5
  from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button
6
- from bokeh.layouts import row, column
7
- from bokeh.palettes import Reds9, Blues9
8
  from sklearn.decomposition import PCA
9
  from sklearn.manifold import TSNE
10
  import io
 
11
 
12
  TOOLTIPS = """
13
  <div>
@@ -30,20 +31,31 @@ def config_style():
30
  """, unsafe_allow_html=True)
31
  st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True)
32
 
33
- # Modificamos load_embeddings para aceptar el modelo a cargar
34
  def load_embeddings(model):
35
  if model == "Donut":
36
  df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
37
- df_es_digital_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
 
 
 
 
 
 
 
 
38
  elif model == "Idefics2":
39
  df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
40
- df_es_digital_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
 
 
 
 
41
  else:
42
  st.error("Modelo no reconocido")
43
  return None
44
- return {"real": df_real, "es-digital-seq": df_es_digital_seq}
45
 
46
- # Funciones auxiliares (idénticas a las de tu código)
47
  def reducer_selector(df_combined, embedding_cols):
48
  reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"])
49
  all_embeddings = df_combined[embedding_cols].values
@@ -53,7 +65,8 @@ def reducer_selector(df_combined, embedding_cols):
53
  reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
54
  return reducer.fit_transform(all_embeddings)
55
 
56
- def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping):
 
57
  renderers = {}
58
  for label in selected_labels:
59
  subset = df[df['label'] == label]
@@ -63,112 +76,153 @@ def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping):
63
  x=subset['x'],
64
  y=subset['y'],
65
  label=subset['label'],
66
- img=subset['img']
67
  ))
68
  color = color_mapping[label]
 
 
69
  if marker == "circle":
70
  r = fig.circle('x', 'y', size=10, source=source,
71
  fill_color=color, line_color=color,
72
- legend_label=f"{label} (Real)")
73
  elif marker == "square":
74
- r = fig.square('x', 'y', size=6, source=source,
75
  fill_color=color, line_color=color,
76
- legend_label=f"{label} (Synthetic)")
77
- renderers[label] = r
 
 
 
 
78
  return renderers
79
 
80
- def get_color_maps(selected_subsets: dict):
81
- num_real = len(selected_subsets["real"])
 
 
 
82
  red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real]
83
- color_mapping_real = {label: red_palette[i] for i, label in enumerate(sorted(selected_subsets["real"]))}
84
 
85
- num_es = len(selected_subsets["es-digital-seq"])
86
- blue_palette = Blues9[:num_es] if num_es <= 9 else (Blues9 * ((num_es // 9) + 1))[:num_es]
87
- color_mapping_es = {label: blue_palette[i] for i, label in enumerate(sorted(selected_subsets["es-digital-seq"]))}
88
-
89
- return {"real": color_mapping_real, "es-digital-seq": color_mapping_es}
 
 
90
 
 
91
  def split_versions(df_combined, reduced):
92
  df_combined['x'] = reduced[:, 0]
93
  df_combined['y'] = reduced[:, 1]
94
  df_real = df_combined[df_combined["version"] == "real"].copy()
95
- df_es = df_combined[df_combined["version"] == "es_digital_seq"].copy()
 
96
  unique_real = sorted(df_real['label'].unique().tolist())
97
- unique_es = sorted(df_es['label'].unique().tolist())
98
- return {"real": df_real, "es-digital-seq": df_es}, {"real": unique_real, "es-digital-seq": unique_es}
 
 
99
 
100
- def create_figure(dfs_reduced, selected_subsets: dict, color_maps: dict):
 
101
  fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="")
102
- real_renderers = add_dataset_to_fig(fig, dfs_reduced["real"], selected_subsets["real"],
103
- marker="circle", color_mapping=color_maps["real"])
104
- synthetic_renderers = add_dataset_to_fig(fig, dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"],
105
- marker="square", color_mapping=color_maps["es-digital-seq"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  fig.legend.location = "top_right"
107
  fig.legend.click_policy = "hide"
108
  return fig, real_renderers, synthetic_renderers
109
 
110
- def calculate_cluster_centers(df: pd.DataFrame, selected_labels: list) -> dict:
 
111
  centers = {}
112
- for label in selected_labels:
113
  subset = df[df['label'] == label]
114
  if not subset.empty:
115
  centers[label] = (subset['x'].mean(), subset['y'].mean())
116
  return centers
117
 
118
- def compute_distances(centers_es: dict, centers_real: dict) -> pd.DataFrame:
 
119
  distances = {}
120
- for es_label, (x_es, y_es) in centers_es.items():
121
- distances[es_label] = {}
122
- for real_label, (x_real, y_real) in centers_real.items():
123
- distances[es_label][real_label] = np.sqrt((x_es - x_real)**2 + (y_es - y_real)**2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  return pd.DataFrame(distances).T
125
 
126
  def create_table(df_distances):
127
  df_table = df_distances.copy()
128
  df_table.reset_index(inplace=True)
129
  df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
130
-
131
- # Calcular las filas de medias, máximos y mínimos para cada columna numérica
132
  min_row = {"Synthetic": "Min."}
133
  mean_row = {"Synthetic": "Mean"}
134
  max_row = {"Synthetic": "Max."}
135
-
136
  for col in df_table.columns:
137
  if col != "Synthetic":
138
  min_row[col] = df_table[col].min()
139
  mean_row[col] = df_table[col].mean()
140
  max_row[col] = df_table[col].max()
141
-
142
- # Agregar las filas de medias, máximos y mínimos al final del DataFrame
143
  df_table = pd.concat([df_table, pd.DataFrame([min_row, mean_row, max_row])], ignore_index=True)
144
-
145
  source_table = ColumnDataSource(df_table)
146
  columns = [TableColumn(field='Synthetic', title='Synthetic')]
147
  for col in df_table.columns:
148
  if col != 'Synthetic':
149
  columns.append(TableColumn(field=col, title=col))
150
-
151
- row_height = 28
152
- header_height = 30
153
- total_height = header_height + len(df_table) * row_height
154
-
155
  data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height)
156
  return data_table, df_table, source_table
157
 
158
-
159
-
160
- # Función que ejecuta todo el proceso para un modelo determinado
161
  def run_model(model_name):
162
  embeddings = load_embeddings(model_name)
163
  if embeddings is None:
164
  return
165
-
166
- # Asignamos la versión para distinguir en el split
167
- embeddings["real"]["version"] = "real"
168
- embeddings["es-digital-seq"]["version"] = "es_digital_seq"
169
  embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
170
- df_combined = pd.concat([embeddings["real"], embeddings["es-digital-seq"]], ignore_index=True)
171
-
172
  st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
173
  reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=model_name)
174
  if reduction_method == "PCA":
@@ -176,125 +230,72 @@ def run_model(model_name):
176
  else:
177
  reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
178
  reduced = reducer.fit_transform(df_combined[embedding_cols].values)
179
-
180
  dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
181
- selected_subsets = {"real": unique_subsets["real"], "es-digital-seq": unique_subsets["es-digital-seq"]}
182
- color_maps = get_color_maps(selected_subsets)
183
 
184
- fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, selected_subsets, color_maps)
185
- centers_real = calculate_cluster_centers(dfs_reduced["real"], selected_subsets["real"])
186
- centers_es = calculate_cluster_centers(dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"])
187
- df_distances = compute_distances(centers_es, centers_real)
 
 
 
 
 
188
  data_table, df_table, source_table = create_table(df_distances)
 
189
  real_subset_names = list(df_table.columns[1:])
190
  real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
191
  reset_button = Button(label="Reset Colors", button_type="primary")
192
  line_source = ColumnDataSource(data={'x': [], 'y': []})
193
  fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
194
 
195
- synthetic_centers_js = {k: [v[0], v[1]] for k, v in centers_es.items()}
196
  real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
197
 
198
- # Callback para actualizar el gráfico
 
 
 
 
 
 
199
  callback = CustomJS(args=dict(source=source_table, line_source=line_source,
200
- synthetic_centers=synthetic_centers_js,
201
  real_centers=real_centers_js,
202
- synthetic_renderers=synthetic_renderers,
203
- real_renderers=real_renderers,
204
- synthetic_colors=color_maps["es-digital-seq"],
205
- real_colors=color_maps["real"],
206
  real_select=real_select),
207
  code="""
208
  var selected = source.selected.indices;
209
  if (selected.length > 0) {
210
- var row = selected[0];
211
  var data = source.data;
212
- var synthetic_label = data['Synthetic'][row];
213
  var real_label = real_select.value;
214
- var syn_coords = synthetic_centers[synthetic_label];
215
  var real_coords = real_centers[real_label];
216
- line_source.data = { 'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]] };
217
  line_source.change.emit();
218
-
219
- for (var key in synthetic_renderers) {
220
- if (synthetic_renderers.hasOwnProperty(key)) {
221
- var renderer = synthetic_renderers[key];
222
- if (key === synthetic_label) {
223
- renderer.glyph.fill_color = synthetic_colors[key];
224
- renderer.glyph.line_color = synthetic_colors[key];
225
- } else {
226
- renderer.glyph.fill_color = "lightgray";
227
- renderer.glyph.line_color = "lightgray";
228
- }
229
- }
230
- }
231
- for (var key in real_renderers) {
232
- if (real_renderers.hasOwnProperty(key)) {
233
- var renderer = real_renderers[key];
234
- if (key === real_label) {
235
- renderer.glyph.fill_color = real_colors[key];
236
- renderer.glyph.line_color = real_colors[key];
237
- } else {
238
- renderer.glyph.fill_color = "lightgray";
239
- renderer.glyph.line_color = "lightgray";
240
- }
241
- }
242
- }
243
  } else {
244
- line_source.data = { 'x': [], 'y': [] };
245
  line_source.change.emit();
246
- for (var key in synthetic_renderers) {
247
- if (synthetic_renderers.hasOwnProperty(key)) {
248
- var renderer = synthetic_renderers[key];
249
- renderer.glyph.fill_color = synthetic_colors[key];
250
- renderer.glyph.line_color = synthetic_colors[key];
251
- }
252
- }
253
- for (var key in real_renderers) {
254
- if (real_renderers.hasOwnProperty(key)) {
255
- var renderer = real_renderers[key];
256
- renderer.glyph.fill_color = real_colors[key];
257
- renderer.glyph.line_color = real_colors[key];
258
- }
259
- }
260
  }
261
  """)
262
  source_table.selected.js_on_change('indices', callback)
263
  real_select.js_on_change('value', callback)
264
 
265
- reset_callback = CustomJS(args=dict(line_source=line_source,
266
- synthetic_renderers=synthetic_renderers,
267
- real_renderers=real_renderers,
268
- synthetic_colors=color_maps["es-digital-seq"],
269
- real_colors=color_maps["real"]),
270
  code="""
271
- line_source.data = { 'x': [], 'y': [] };
272
  line_source.change.emit();
273
- for (var key in synthetic_renderers) {
274
- if (synthetic_renderers.hasOwnProperty(key)) {
275
- var renderer = synthetic_renderers[key];
276
- renderer.glyph.fill_color = synthetic_colors[key];
277
- renderer.glyph.line_color = synthetic_colors[key];
278
- }
279
- }
280
- for (var key in real_renderers) {
281
- if (real_renderers.hasOwnProperty(key)) {
282
- var renderer = real_renderers[key];
283
- renderer.glyph.fill_color = real_colors[key];
284
- renderer.glyph.line_color = real_colors[key];
285
- }
286
- }
287
  """)
288
  reset_button.js_on_event("button_click", reset_callback)
289
-
290
  buffer = io.BytesIO()
291
  df_table.to_excel(buffer, index=False)
292
  buffer.seek(0)
293
-
294
  layout = column(fig, column(real_select, reset_button, data_table))
295
  st.bokeh_chart(layout, use_container_width=True)
296
-
297
- # Agregar un botón de descarga en Streamlit
298
  st.download_button(
299
  label="Export Table",
300
  data=buffer,
@@ -302,18 +303,13 @@ def run_model(model_name):
302
  mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
303
  key=f"download_button_excel_{model_name}"
304
  )
305
-
306
-
307
 
308
- # Función principal con tabs para cambiar de modelo
309
  def main():
310
  config_style()
311
  tabs = st.tabs(["Donut", "Idefics2"])
312
-
313
  with tabs[0]:
314
  st.markdown('<h2 class="sub-title">Donut 🤗</h2>', unsafe_allow_html=True)
315
  run_model("Donut")
316
-
317
  with tabs[1]:
318
  st.markdown('<h2 class="sub-title">Idefics2 🤗</h2>', unsafe_allow_html=True)
319
  run_model("Idefics2")
 
3
  import numpy as np
4
  from bokeh.plotting import figure
5
  from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button
6
+ from bokeh.layouts import column
7
+ from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9
8
  from sklearn.decomposition import PCA
9
  from sklearn.manifold import TSNE
10
  import io
11
+ import ot
12
 
13
  TOOLTIPS = """
14
  <div>
 
31
  """, unsafe_allow_html=True)
32
  st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True)
33
 
34
+ # Carga los datos y asigna versiones de forma uniforme
35
  def load_embeddings(model):
36
  if model == "Donut":
37
  df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
38
+ df_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
39
+ df_line = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
40
+ df_real["version"] = "real"
41
+ df_seq["version"] = "synthetic"
42
+ df_line["version"] = "synthetic"
43
+ # Usamos un identificador en la columna 'source' para diferenciarlos
44
+ df_seq["source"] = "es-digital-seq"
45
+ df_line["source"] = "es-digital-line-degradation-seq"
46
+ return {"real": df_real, "synthetic": pd.concat([df_seq, df_line], ignore_index=True)}
47
  elif model == "Idefics2":
48
  df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
49
+ df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
50
+ df_real["version"] = "real"
51
+ df_seq["version"] = "synthetic"
52
+ df_seq["source"] = "es-digital-seq"
53
+ return {"real": df_real, "synthetic": df_seq}
54
  else:
55
  st.error("Modelo no reconocido")
56
  return None
 
57
 
58
+ # Selección de reducción dimensional
59
  def reducer_selector(df_combined, embedding_cols):
60
  reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"])
61
  all_embeddings = df_combined[embedding_cols].values
 
65
  reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
66
  return reducer.fit_transform(all_embeddings)
67
 
68
+ # Función genérica para agregar datos al gráfico
69
+ def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label):
70
  renderers = {}
71
  for label in selected_labels:
72
  subset = df[df['label'] == label]
 
76
  x=subset['x'],
77
  y=subset['y'],
78
  label=subset['label'],
79
+ img=subset.get('img', "")
80
  ))
81
  color = color_mapping[label]
82
+ # Se añade el identificador de la fuente en la leyenda
83
+ legend_label = f"{label} ({group_label})"
84
  if marker == "circle":
85
  r = fig.circle('x', 'y', size=10, source=source,
86
  fill_color=color, line_color=color,
87
+ legend_label=legend_label)
88
  elif marker == "square":
89
+ r = fig.square('x', 'y', size=10, source=source,
90
  fill_color=color, line_color=color,
91
+ legend_label=legend_label)
92
+ elif marker == "triangle":
93
+ r = fig.triangle('x', 'y', size=12, source=source,
94
+ fill_color=color, line_color=color,
95
+ legend_label=legend_label)
96
+ renderers[label + f" ({group_label})"] = r
97
  return renderers
98
 
99
+ # Asigna paletas de colores de forma genérica para cada grupo (real y para cada fuente sintética)
100
+ def get_color_maps(unique_subsets):
101
+ color_map = {}
102
+ # Real
103
+ num_real = len(unique_subsets["real"])
104
  red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real]
105
+ color_map["real"] = {label: red_palette[i] for i, label in enumerate(sorted(unique_subsets["real"]))}
106
 
107
+ # Synthetic: vamos a separar por fuente (source) basándonos en la lista completa de etiquetas
108
+ # Suponemos que en la columna "source" se encuentran los identificadores
109
+ synthetic_labels = sorted(unique_subsets["synthetic"])
110
+ # Aquí usamos una sola paleta para todos, pero se podría distinguir según la fuente si se quiere
111
+ blue_palette = Blues9[:len(synthetic_labels)] if len(synthetic_labels) <= 9 else (Blues9 * ((len(synthetic_labels) // 9) + 1))[:len(synthetic_labels)]
112
+ color_map["synthetic"] = {label: blue_palette[i] for i, label in enumerate(synthetic_labels)}
113
+ return color_map
114
 
115
+ # Separa los datos reducidos en "real" y "synthetic" y extrae los subsets (clusters)
116
  def split_versions(df_combined, reduced):
117
  df_combined['x'] = reduced[:, 0]
118
  df_combined['y'] = reduced[:, 1]
119
  df_real = df_combined[df_combined["version"] == "real"].copy()
120
+ df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
121
+ # Extraemos los clusters (subset) usando la columna 'label'
122
  unique_real = sorted(df_real['label'].unique().tolist())
123
+ unique_synth = sorted(df_synth['label'].unique().tolist())
124
+ df_dict = {"real": df_real, "synthetic": df_synth}
125
+ unique_subsets = {"real": unique_real, "synthetic": unique_synth}
126
+ return df_dict, unique_subsets
127
 
128
+ # Crea el gráfico; se tratan de forma uniforme ambos conjuntos sintéticos
129
+ def create_figure(dfs, unique_subsets, color_maps):
130
  fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="")
131
+ real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
132
+ marker="circle", color_mapping=color_maps["real"],
133
+ group_label="Real")
134
+ # Aquí separamos los puntos sintéticos según su fuente para asignar diferentes marcadores
135
+ synth_df = dfs["synthetic"]
136
+ # Dividimos por 'source'
137
+ df_seq = synth_df[synth_df["source"] == "es-digital-seq"]
138
+ df_line = synth_df[synth_df["source"] == "es-digital-line-degradation-seq"]
139
+
140
+ # Extraemos los clusters para cada fuente (si existen)
141
+ unique_seq = sorted(df_seq['label'].unique().tolist())
142
+ unique_line = sorted(df_line['label'].unique().tolist())
143
+
144
+ seq_renderers = add_dataset_to_fig(fig, df_seq, unique_seq,
145
+ marker="square", color_mapping=color_maps["synthetic"],
146
+ group_label="es-digital-seq")
147
+ line_renderers = add_dataset_to_fig(fig, df_line, unique_line,
148
+ marker="triangle", color_mapping=color_maps["synthetic"],
149
+ group_label="es-digital-line-degradation-seq")
150
+ # Combina ambos renderers sintéticos
151
+ synthetic_renderers = {**seq_renderers, **line_renderers}
152
+
153
  fig.legend.location = "top_right"
154
  fig.legend.click_policy = "hide"
155
  return fig, real_renderers, synthetic_renderers
156
 
157
+ # Calcula los centros de cada cluster (por grupo)
158
+ def calculate_cluster_centers(df, labels):
159
  centers = {}
160
+ for label in labels:
161
  subset = df[df['label'] == label]
162
  if not subset.empty:
163
  centers[label] = (subset['x'].mean(), subset['y'].mean())
164
  return centers
165
 
166
+ # Calcula la distancia Wasserstein de cada subset sintético respecto a cada cluster real (por cluster y global)
167
+ def compute_wasserstein_distances_all_synthetics(df_synth, df_real, labels_real):
168
  distances = {}
169
+ # Para cada cluster en el conjunto sintético (la tabla mostrará todas las etiquetas)
170
+ synth_labels = sorted(df_synth['label'].unique().tolist())
171
+ for label in synth_labels:
172
+ key = f"{label}"
173
+ distances[key] = {}
174
+ cluster = df_synth[df_synth['label'] == label][['x','y']].values
175
+ n = cluster.shape[0]
176
+ weights = np.ones(n) / n
177
+ for real_label in labels_real:
178
+ cluster_real = df_real[df_real['label'] == real_label][['x','y']].values
179
+ m = cluster_real.shape[0]
180
+ weights_real = np.ones(m) / m
181
+ M = ot.dist(cluster, cluster_real, metric='euclidean')
182
+ distances[key][real_label] = ot.emd2(weights, weights_real, M)
183
+ # Distancia global del conjunto sintético a cada cluster real
184
+ key = "Global synthetic"
185
+ distances[key] = {}
186
+ global_synth = df_synth[['x','y']].values
187
+ n_global = global_synth.shape[0]
188
+ weights_global = np.ones(n_global) / n_global
189
+ for real_label in labels_real:
190
+ cluster_real = df_real[df_real['label'] == real_label][['x','y']].values
191
+ m = cluster_real.shape[0]
192
+ weights_real = np.ones(m) / m
193
+ M = ot.dist(global_synth, cluster_real, metric='euclidean')
194
+ distances[key][real_label] = ot.emd2(weights_global, weights_real, M)
195
  return pd.DataFrame(distances).T
196
 
197
  def create_table(df_distances):
198
  df_table = df_distances.copy()
199
  df_table.reset_index(inplace=True)
200
  df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
 
 
201
  min_row = {"Synthetic": "Min."}
202
  mean_row = {"Synthetic": "Mean"}
203
  max_row = {"Synthetic": "Max."}
 
204
  for col in df_table.columns:
205
  if col != "Synthetic":
206
  min_row[col] = df_table[col].min()
207
  mean_row[col] = df_table[col].mean()
208
  max_row[col] = df_table[col].max()
 
 
209
  df_table = pd.concat([df_table, pd.DataFrame([min_row, mean_row, max_row])], ignore_index=True)
 
210
  source_table = ColumnDataSource(df_table)
211
  columns = [TableColumn(field='Synthetic', title='Synthetic')]
212
  for col in df_table.columns:
213
  if col != 'Synthetic':
214
  columns.append(TableColumn(field=col, title=col))
215
+ total_height = 30 + len(df_table)*28
 
 
 
 
216
  data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height)
217
  return data_table, df_table, source_table
218
 
 
 
 
219
  def run_model(model_name):
220
  embeddings = load_embeddings(model_name)
221
  if embeddings is None:
222
  return
 
 
 
 
223
  embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
224
+ # Combina todos los DataFrames
225
+ df_combined = pd.concat(list(embeddings.values()), ignore_index=True)
226
  st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
227
  reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=model_name)
228
  if reduction_method == "PCA":
 
230
  else:
231
  reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
232
  reduced = reducer.fit_transform(df_combined[embedding_cols].values)
 
233
  dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
 
 
234
 
235
+ # Se espera que unique_subsets tenga claves "real" y "synthetic"
236
+ color_maps = get_color_maps(unique_subsets)
237
+ fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, unique_subsets, color_maps)
238
+
239
+ centers_real = calculate_cluster_centers(dfs_reduced["real"], unique_subsets["real"])
240
+
241
+ df_distances = compute_wasserstein_distances_all_synthetics(dfs_reduced["synthetic"],
242
+ dfs_reduced["real"],
243
+ unique_subsets["real"])
244
  data_table, df_table, source_table = create_table(df_distances)
245
+
246
  real_subset_names = list(df_table.columns[1:])
247
  real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
248
  reset_button = Button(label="Reset Colors", button_type="primary")
249
  line_source = ColumnDataSource(data={'x': [], 'y': []})
250
  fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
251
 
252
+ # Preparar centros para callback (para trazar líneas entre centros)
253
  real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
254
 
255
+ # Se podría preparar también los centros sintéticos si se requiere
256
+ synthetic_centers = {}
257
+ synth_labels = sorted(dfs_reduced["synthetic"]['label'].unique().tolist())
258
+ for label in synth_labels:
259
+ subset = dfs_reduced["synthetic"][dfs_reduced["synthetic"]['label'] == label]
260
+ synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()]
261
+
262
  callback = CustomJS(args=dict(source=source_table, line_source=line_source,
263
+ synthetic_centers=synthetic_centers,
264
  real_centers=real_centers_js,
 
 
 
 
265
  real_select=real_select),
266
  code="""
267
  var selected = source.selected.indices;
268
  if (selected.length > 0) {
269
+ var idx = selected[0];
270
  var data = source.data;
271
+ var synth_label = data['Synthetic'][idx];
272
  var real_label = real_select.value;
273
+ var syn_coords = synthetic_centers[synth_label];
274
  var real_coords = real_centers[real_label];
275
+ line_source.data = {'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]]};
276
  line_source.change.emit();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  } else {
278
+ line_source.data = {'x': [], 'y': []};
279
  line_source.change.emit();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  }
281
  """)
282
  source_table.selected.js_on_change('indices', callback)
283
  real_select.js_on_change('value', callback)
284
 
285
+ reset_callback = CustomJS(args=dict(line_source=line_source),
 
 
 
 
286
  code="""
287
+ line_source.data = {'x': [], 'y': []};
288
  line_source.change.emit();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  """)
290
  reset_button.js_on_event("button_click", reset_callback)
291
+
292
  buffer = io.BytesIO()
293
  df_table.to_excel(buffer, index=False)
294
  buffer.seek(0)
295
+
296
  layout = column(fig, column(real_select, reset_button, data_table))
297
  st.bokeh_chart(layout, use_container_width=True)
298
+
 
299
  st.download_button(
300
  label="Export Table",
301
  data=buffer,
 
303
  mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
304
  key=f"download_button_excel_{model_name}"
305
  )
 
 
306
 
 
307
  def main():
308
  config_style()
309
  tabs = st.tabs(["Donut", "Idefics2"])
 
310
  with tabs[0]:
311
  st.markdown('<h2 class="sub-title">Donut 🤗</h2>', unsafe_allow_html=True)
312
  run_model("Donut")
 
313
  with tabs[1]:
314
  st.markdown('<h2 class="sub-title">Idefics2 🤗</h2>', unsafe_allow_html=True)
315
  run_model("Idefics2")