de-Rodrigo commited on
Commit
757102e
·
1 Parent(s): 3465900

Donut Ready

Browse files
Files changed (1) hide show
  1. app.py +159 -72
app.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  from bokeh.plotting import figure
5
  from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button
6
  from bokeh.layouts import column
7
- from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9
8
  from sklearn.decomposition import PCA
9
  from sklearn.manifold import TSNE
10
  import io
@@ -27,6 +27,10 @@ def config_style():
27
  .main-title { font-size: 50px; color: #4CAF50; text-align: center; }
28
  .sub-title { font-size: 30px; color: #555; }
29
  .custom-text { font-size: 18px; line-height: 1.5; }
 
 
 
 
30
  </style>
31
  """, unsafe_allow_html=True)
32
  st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True)
@@ -35,15 +39,29 @@ def config_style():
35
  def load_embeddings(model):
36
  if model == "Donut":
37
  df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
 
 
38
  df_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
39
- df_line = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
 
 
40
  df_real["version"] = "real"
41
- df_seq["version"] = "synthetic"
42
  df_line["version"] = "synthetic"
43
- # Usamos un identificador en la columna 'source' para diferenciarlos
44
- df_seq["source"] = "es-digital-seq"
 
 
 
 
 
45
  df_line["source"] = "es-digital-line-degradation-seq"
46
- return {"real": df_real, "synthetic": pd.concat([df_seq, df_line], ignore_index=True)}
 
 
 
 
 
47
  elif model == "Idefics2":
48
  df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
49
  df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
@@ -51,6 +69,7 @@ def load_embeddings(model):
51
  df_seq["version"] = "synthetic"
52
  df_seq["source"] = "es-digital-seq"
53
  return {"real": df_real, "synthetic": df_seq}
 
54
  else:
55
  st.error("Modelo no reconocido")
56
  return None
@@ -65,7 +84,7 @@ def reducer_selector(df_combined, embedding_cols):
65
  reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
66
  return reducer.fit_transform(all_embeddings)
67
 
68
- # Función genérica para agregar datos al gráfico
69
  def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label):
70
  renderers = {}
71
  for label in selected_labels:
@@ -79,7 +98,6 @@ def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_la
79
  img=subset.get('img', "")
80
  ))
81
  color = color_mapping[label]
82
- # Se añade el identificador de la fuente en la leyenda
83
  legend_label = f"{label} ({group_label})"
84
  if marker == "circle":
85
  r = fig.circle('x', 'y', size=10, source=source,
@@ -96,64 +114,138 @@ def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_la
96
  renderers[label + f" ({group_label})"] = r
97
  return renderers
98
 
99
- # Asigna paletas de colores de forma genérica para cada grupo (real y para cada fuente sintética)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def get_color_maps(unique_subsets):
101
  color_map = {}
102
- # Real
103
  num_real = len(unique_subsets["real"])
104
  red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real]
105
  color_map["real"] = {label: red_palette[i] for i, label in enumerate(sorted(unique_subsets["real"]))}
106
 
107
- # Synthetic: vamos a separar por fuente (source) basándonos en la lista completa de etiquetas
108
- # Suponemos que en la columna "source" se encuentran los identificadores
109
- synthetic_labels = sorted(unique_subsets["synthetic"])
110
- # Aquí usamos una sola paleta para todos, pero se podría distinguir según la fuente si se quiere
111
- blue_palette = Blues9[:len(synthetic_labels)] if len(synthetic_labels) <= 9 else (Blues9 * ((len(synthetic_labels) // 9) + 1))[:len(synthetic_labels)]
112
- color_map["synthetic"] = {label: blue_palette[i] for i, label in enumerate(synthetic_labels)}
 
 
 
 
 
 
 
 
 
 
 
 
113
  return color_map
114
 
115
- # Separa los datos reducidos en "real" y "synthetic" y extrae los subsets (clusters)
116
  def split_versions(df_combined, reduced):
117
  df_combined['x'] = reduced[:, 0]
118
  df_combined['y'] = reduced[:, 1]
119
  df_real = df_combined[df_combined["version"] == "real"].copy()
120
  df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
121
- # Extraemos los clusters (subset) usando la columna 'label'
122
  unique_real = sorted(df_real['label'].unique().tolist())
123
- unique_synth = sorted(df_synth['label'].unique().tolist())
 
 
 
124
  df_dict = {"real": df_real, "synthetic": df_synth}
 
125
  unique_subsets = {"real": unique_real, "synthetic": unique_synth}
126
  return df_dict, unique_subsets
127
 
128
- # Crea el gráfico; se tratan de forma uniforme ambos conjuntos sintéticos
129
- def create_figure(dfs, unique_subsets, color_maps):
130
- fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="")
131
  real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
132
  marker="circle", color_mapping=color_maps["real"],
133
  group_label="Real")
134
- # Aquí separamos los puntos sintéticos según su fuente para asignar diferentes marcadores
 
 
 
 
 
 
 
 
 
 
 
135
  synth_df = dfs["synthetic"]
136
- # Dividimos por 'source'
137
- df_seq = synth_df[synth_df["source"] == "es-digital-seq"]
138
- df_line = synth_df[synth_df["source"] == "es-digital-line-degradation-seq"]
139
-
140
- # Extraemos los clusters para cada fuente (si existen)
141
- unique_seq = sorted(df_seq['label'].unique().tolist())
142
- unique_line = sorted(df_line['label'].unique().tolist())
143
-
144
- seq_renderers = add_dataset_to_fig(fig, df_seq, unique_seq,
145
- marker="square", color_mapping=color_maps["synthetic"],
146
- group_label="es-digital-seq")
147
- line_renderers = add_dataset_to_fig(fig, df_line, unique_line,
148
- marker="triangle", color_mapping=color_maps["synthetic"],
149
- group_label="es-digital-line-degradation-seq")
150
- # Combina ambos renderers sintéticos
151
- synthetic_renderers = {**seq_renderers, **line_renderers}
152
 
153
  fig.legend.location = "top_right"
154
  fig.legend.click_policy = "hide"
 
 
155
  return fig, real_renderers, synthetic_renderers
156
 
 
157
  # Calcula los centros de cada cluster (por grupo)
158
  def calculate_cluster_centers(df, labels):
159
  centers = {}
@@ -164,34 +256,35 @@ def calculate_cluster_centers(df, labels):
164
  return centers
165
 
166
  # Calcula la distancia Wasserstein de cada subset sintético respecto a cada cluster real (por cluster y global)
167
- def compute_wasserstein_distances_all_synthetics(df_synth, df_real, labels_real):
168
  distances = {}
169
- # Para cada cluster en el conjunto sintético (la tabla mostrará todas las etiquetas)
170
- synth_labels = sorted(df_synth['label'].unique().tolist())
171
- for label in synth_labels:
172
- key = f"{label}"
 
 
173
  distances[key] = {}
174
- cluster = df_synth[df_synth['label'] == label][['x','y']].values
175
- n = cluster.shape[0]
 
 
 
 
 
 
 
 
 
 
176
  weights = np.ones(n) / n
177
- for real_label in labels_real:
178
- cluster_real = df_real[df_real['label'] == real_label][['x','y']].values
179
- m = cluster_real.shape[0]
 
180
  weights_real = np.ones(m) / m
181
- M = ot.dist(cluster, cluster_real, metric='euclidean')
182
  distances[key][real_label] = ot.emd2(weights, weights_real, M)
183
- # Distancia global del conjunto sintético a cada cluster real
184
- key = "Global synthetic"
185
- distances[key] = {}
186
- global_synth = df_synth[['x','y']].values
187
- n_global = global_synth.shape[0]
188
- weights_global = np.ones(n_global) / n_global
189
- for real_label in labels_real:
190
- cluster_real = df_real[df_real['label'] == real_label][['x','y']].values
191
- m = cluster_real.shape[0]
192
- weights_real = np.ones(m) / m
193
- M = ot.dist(global_synth, cluster_real, metric='euclidean')
194
- distances[key][real_label] = ot.emd2(weights_global, weights_real, M)
195
  return pd.DataFrame(distances).T
196
 
197
  def create_table(df_distances):
@@ -220,11 +313,11 @@ def run_model(model_name):
220
  embeddings = load_embeddings(model_name)
221
  if embeddings is None:
222
  return
 
223
  embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
224
- # Combina todos los DataFrames
225
  df_combined = pd.concat(list(embeddings.values()), ignore_index=True)
226
  st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
227
- reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=model_name)
228
  if reduction_method == "PCA":
229
  reducer = PCA(n_components=2)
230
  else:
@@ -232,15 +325,12 @@ def run_model(model_name):
232
  reduced = reducer.fit_transform(df_combined[embedding_cols].values)
233
  dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
234
 
235
- # Se espera que unique_subsets tenga claves "real" y "synthetic"
236
  color_maps = get_color_maps(unique_subsets)
237
- fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, unique_subsets, color_maps)
238
 
239
  centers_real = calculate_cluster_centers(dfs_reduced["real"], unique_subsets["real"])
240
 
241
- df_distances = compute_wasserstein_distances_all_synthetics(dfs_reduced["synthetic"],
242
- dfs_reduced["real"],
243
- unique_subsets["real"])
244
  data_table, df_table, source_table = create_table(df_distances)
245
 
246
  real_subset_names = list(df_table.columns[1:])
@@ -249,10 +339,7 @@ def run_model(model_name):
249
  line_source = ColumnDataSource(data={'x': [], 'y': []})
250
  fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
251
 
252
- # Preparar centros para callback (para trazar líneas entre centros)
253
  real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
254
-
255
- # Se podría preparar también los centros sintéticos si se requiere
256
  synthetic_centers = {}
257
  synth_labels = sorted(dfs_reduced["synthetic"]['label'].unique().tolist())
258
  for label in synth_labels:
 
4
  from bokeh.plotting import figure
5
  from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button
6
  from bokeh.layouts import column
7
+ from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9, Greys9, BuGn9, Greens9
8
  from sklearn.decomposition import PCA
9
  from sklearn.manifold import TSNE
10
  import io
 
27
  .main-title { font-size: 50px; color: #4CAF50; text-align: center; }
28
  .sub-title { font-size: 30px; color: #555; }
29
  .custom-text { font-size: 18px; line-height: 1.5; }
30
+ .bk-legend {
31
+ max-height: 200px;
32
+ overflow-y: auto;
33
+ }
34
  </style>
35
  """, unsafe_allow_html=True)
36
  st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True)
 
39
  def load_embeddings(model):
40
  if model == "Donut":
41
  df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
42
+ df_par = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_embeddings.csv")
43
+ df_line = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-line-degradation-seq_embeddings.csv")
44
  df_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
45
+ df_rot = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-rotation-degradation-seq_embeddings.csv")
46
+ df_zoom = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-zoom-degradation-seq_embeddings.csv")
47
+ df_render = pd.read_csv("data/donut_de_Rodrigo_merit_es-render-seq_embeddings.csv")
48
  df_real["version"] = "real"
49
+ df_par["version"] = "synthetic"
50
  df_line["version"] = "synthetic"
51
+ df_seq["version"] = "synthetic"
52
+ df_rot["version"] = "synthetic"
53
+ df_zoom["version"] = "synthetic"
54
+ df_render["version"] = "synthetic"
55
+
56
+ # Se asigna la fuente
57
+ df_par["source"] = "es-digital-paragraph-degradation-seq"
58
  df_line["source"] = "es-digital-line-degradation-seq"
59
+ df_seq["source"] = "es-digital-seq"
60
+ df_rot["source"] = "es-digital-rotation-degradation-seq"
61
+ df_zoom["source"] = "es-digital-zoom-degradation-seq"
62
+ df_render["source"] = "es-render-seq"
63
+ return {"real": df_real, "synthetic": pd.concat([df_seq, df_line, df_par, df_rot, df_zoom, df_render], ignore_index=True)}
64
+
65
  elif model == "Idefics2":
66
  df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
67
  df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
 
69
  df_seq["version"] = "synthetic"
70
  df_seq["source"] = "es-digital-seq"
71
  return {"real": df_real, "synthetic": df_seq}
72
+
73
  else:
74
  st.error("Modelo no reconocido")
75
  return None
 
84
  reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
85
  return reducer.fit_transform(all_embeddings)
86
 
87
+ # Función para agregar datos reales (por cada etiqueta)
88
  def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label):
89
  renderers = {}
90
  for label in selected_labels:
 
98
  img=subset.get('img', "")
99
  ))
100
  color = color_mapping[label]
 
101
  legend_label = f"{label} ({group_label})"
102
  if marker == "circle":
103
  r = fig.circle('x', 'y', size=10, source=source,
 
114
  renderers[label + f" ({group_label})"] = r
115
  return renderers
116
 
117
+ # Nueva función para plotear sintéticos de forma granular pero con leyenda agrupada por source
118
+ def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_label):
119
+ renderers = {}
120
+ for label in labels:
121
+ subset = df[df['label'] == label]
122
+ if subset.empty:
123
+ continue
124
+ source_obj = ColumnDataSource(data=dict(
125
+ x=subset['x'],
126
+ y=subset['y'],
127
+ label=subset['label'],
128
+ img=subset.get('img', "")
129
+ ))
130
+ # Se usa el color granular asignado a cada etiqueta
131
+ color = color_mapping[label]
132
+ # La leyenda se asigna al nombre del source para que se agrupe
133
+ legend_label = group_label
134
+
135
+ if marker == "square":
136
+ r = fig.square('x', 'y', size=10, source=source_obj,
137
+ fill_color=color, line_color=color,
138
+ legend_label=legend_label)
139
+ elif marker == "triangle":
140
+ r = fig.triangle('x', 'y', size=12, source=source_obj,
141
+ fill_color=color, line_color=color,
142
+ legend_label=legend_label)
143
+ elif marker == "inverted_triangle":
144
+ r = fig.inverted_triangle('x', 'y', size=12, source=source_obj,
145
+ fill_color=color, line_color=color,
146
+ legend_label=legend_label)
147
+ elif marker == "diamond":
148
+ r = fig.diamond('x', 'y', size=10, source=source_obj,
149
+ fill_color=color, line_color=color,
150
+ legend_label=legend_label)
151
+ elif marker == "cross":
152
+ r = fig.cross('x', 'y', size=12, source=source_obj,
153
+ fill_color=color, line_color=color,
154
+ legend_label=legend_label)
155
+ elif marker == "x":
156
+ r = fig.x('x', 'y', size=12, source=source_obj,
157
+ fill_color=color, line_color=color,
158
+ legend_label=legend_label)
159
+ elif marker == "asterisk":
160
+ r = fig.asterisk('x', 'y', size=12, source=source_obj,
161
+ fill_color=color, line_color=color,
162
+ legend_label=legend_label)
163
+ else:
164
+ r = fig.circle('x', 'y', size=10, source=source_obj,
165
+ fill_color=color, line_color=color,
166
+ legend_label=legend_label)
167
+ renderers[label + f" ({group_label})"] = r
168
+ return renderers
169
+
170
+
171
  def get_color_maps(unique_subsets):
172
  color_map = {}
173
+ # Para reales se asigna color para cada etiqueta
174
  num_real = len(unique_subsets["real"])
175
  red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real]
176
  color_map["real"] = {label: red_palette[i] for i, label in enumerate(sorted(unique_subsets["real"]))}
177
 
178
+ # Para sintéticos se asigna color de forma granular: para cada source se mapea cada etiqueta
179
+ color_map["synthetic"] = {}
180
+ for source, labels in unique_subsets["synthetic"].items():
181
+ if source == "es-digital-seq":
182
+ palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)]
183
+ elif source == "es-digital-line-degradation-seq":
184
+ palette = Purples9[:len(labels)] if len(labels) <= 9 else (Purples9 * ((len(labels)//9)+1))[:len(labels)]
185
+ elif source == "es-digital-paragraph-degradation-seq":
186
+ palette = BuGn9[:len(labels)] if len(labels) <= 9 else (BuGn9 * ((len(labels)//9)+1))[:len(labels)]
187
+ elif source == "es-digital-rotation-degradation-seq":
188
+ palette = Greys9[:len(labels)] if len(labels) <= 9 else (Greys9 * ((len(labels)//9)+1))[:len(labels)]
189
+ elif source == "es-digital-zoom-degradation-seq":
190
+ palette = Oranges9[:len(labels)] if len(labels) <= 9 else (Oranges9 * ((len(labels)//9)+1))[:len(labels)]
191
+ elif source == "es-render-seq":
192
+ palette = Greens9[:len(labels)] if len(labels) <= 9 else (Greens9 * ((len(labels)//9)+1))[:len(labels)]
193
+ else:
194
+ palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)]
195
+ color_map["synthetic"][source] = {label: palette[i] for i, label in enumerate(sorted(labels))}
196
  return color_map
197
 
 
198
  def split_versions(df_combined, reduced):
199
  df_combined['x'] = reduced[:, 0]
200
  df_combined['y'] = reduced[:, 1]
201
  df_real = df_combined[df_combined["version"] == "real"].copy()
202
  df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
203
+ # Extraer etiquetas únicas para reales
204
  unique_real = sorted(df_real['label'].unique().tolist())
205
+ # Para sintéticos, se agrupan las etiquetas por source
206
+ unique_synth = {}
207
+ for source in df_synth["source"].unique():
208
+ unique_synth[source] = sorted(df_synth[df_synth["source"] == source]['label'].unique().tolist())
209
  df_dict = {"real": df_real, "synthetic": df_synth}
210
+ # Para los reales se guarda la lista, y para sintéticos el diccionario
211
  unique_subsets = {"real": unique_real, "synthetic": unique_synth}
212
  return df_dict, unique_subsets
213
 
214
+ def create_figure(dfs, unique_subsets, color_maps, model_name):
215
+ fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="")
216
+ # Datos reales: se mantienen granulares en plot y en leyenda
217
  real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
218
  marker="circle", color_mapping=color_maps["real"],
219
  group_label="Real")
220
+ # Diccionario de asignación de marcadores para sintéticos por source
221
+ marker_mapping = {
222
+ "es-digital-paragraph-degradation-seq": "x",
223
+ "es-digital-line-degradation-seq": "cross",
224
+ "es-digital-seq": "triangle",
225
+ "es-digital-rotation-degradation-seq": "diamond",
226
+ "es-digital-zoom-degradation-seq": "asterisk",
227
+ "es-render-seq": "inverted_triangle"
228
+ }
229
+
230
+ # Datos sintéticos: se plotean granularmente (por etiqueta) pero se agrupa la leyenda por source
231
+ synthetic_renderers = {}
232
  synth_df = dfs["synthetic"]
233
+ for source in unique_subsets["synthetic"]:
234
+ df_source = synth_df[synth_df["source"] == source]
235
+ marker = marker_mapping.get(source, "square") # Por defecto "square" si no se encuentra
236
+ renderers = add_synthetic_dataset_to_fig(fig, df_source, unique_subsets["synthetic"][source],
237
+ marker=marker,
238
+ color_mapping=color_maps["synthetic"][source],
239
+ group_label=source)
240
+ synthetic_renderers.update(renderers)
 
 
 
 
 
 
 
 
241
 
242
  fig.legend.location = "top_right"
243
  fig.legend.click_policy = "hide"
244
+ show_legend = st.checkbox("Show Legend", value=False, key=f"legend_{model_name}")
245
+ fig.legend.visible = show_legend
246
  return fig, real_renderers, synthetic_renderers
247
 
248
+
249
  # Calcula los centros de cada cluster (por grupo)
250
  def calculate_cluster_centers(df, labels):
251
  centers = {}
 
256
  return centers
257
 
258
  # Calcula la distancia Wasserstein de cada subset sintético respecto a cada cluster real (por cluster y global)
259
+ def compute_wasserstein_distances_synthetic_individual(synthetic_df: pd.DataFrame, df_real: pd.DataFrame, real_labels: list) -> pd.DataFrame:
260
  distances = {}
261
+ groups = synthetic_df.groupby(['source', 'label'])
262
+ for (source, label), group in groups:
263
+ key = f"{label} ({source})"
264
+ data = group[['x', 'y']].values
265
+ n = data.shape[0]
266
+ weights = np.ones(n) / n
267
  distances[key] = {}
268
+ for real_label in real_labels:
269
+ real_data = df_real[df_real['label'] == real_label][['x','y']].values
270
+ m = real_data.shape[0]
271
+ weights_real = np.ones(m) / m
272
+ M = ot.dist(data, real_data, metric='euclidean')
273
+ distances[key][real_label] = ot.emd2(weights, weights_real, M)
274
+
275
+ # Distancia global por fuente
276
+ for source, group in synthetic_df.groupby('source'):
277
+ key = f"Global ({source})"
278
+ data = group[['x','y']].values
279
+ n = data.shape[0]
280
  weights = np.ones(n) / n
281
+ distances[key] = {}
282
+ for real_label in real_labels:
283
+ real_data = df_real[df_real['label'] == real_label][['x','y']].values
284
+ m = real_data.shape[0]
285
  weights_real = np.ones(m) / m
286
+ M = ot.dist(data, real_data, metric='euclidean')
287
  distances[key][real_label] = ot.emd2(weights, weights_real, M)
 
 
 
 
 
 
 
 
 
 
 
 
288
  return pd.DataFrame(distances).T
289
 
290
  def create_table(df_distances):
 
313
  embeddings = load_embeddings(model_name)
314
  if embeddings is None:
315
  return
316
+
317
  embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
 
318
  df_combined = pd.concat(list(embeddings.values()), ignore_index=True)
319
  st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
320
+ reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=f"reduction_{model_name}")
321
  if reduction_method == "PCA":
322
  reducer = PCA(n_components=2)
323
  else:
 
325
  reduced = reducer.fit_transform(df_combined[embedding_cols].values)
326
  dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
327
 
 
328
  color_maps = get_color_maps(unique_subsets)
329
+ fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, unique_subsets, color_maps, model_name)
330
 
331
  centers_real = calculate_cluster_centers(dfs_reduced["real"], unique_subsets["real"])
332
 
333
+ df_distances = compute_wasserstein_distances_synthetic_individual(dfs_reduced["synthetic"], dfs_reduced["real"], unique_subsets["real"])
 
 
334
  data_table, df_table, source_table = create_table(df_distances)
335
 
336
  real_subset_names = list(df_table.columns[1:])
 
339
  line_source = ColumnDataSource(data={'x': [], 'y': []})
340
  fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
341
 
 
342
  real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
 
 
343
  synthetic_centers = {}
344
  synth_labels = sorted(dfs_reduced["synthetic"]['label'].unique().tolist())
345
  for label in synth_labels: