de-Rodrigo commited on
Commit
71eb50f
·
1 Parent(s): b79fb2d

Add Dimensions

Browse files
Files changed (1) hide show
  1. app.py +86 -90
app.py CHANGED
@@ -12,6 +12,10 @@ import io
12
  import ot
13
  from sklearn.linear_model import LinearRegression
14
 
 
 
 
 
15
  TOOLTIPS = """
16
  <div>
17
  <div>
@@ -37,10 +41,6 @@ def config_style():
37
  """, unsafe_allow_html=True)
38
  st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True)
39
 
40
- # =============================================================================
41
- # Funciones de carga de datos y procesamiento (sin cambios en su mayoría)
42
- # =============================================================================
43
-
44
  def load_embeddings(model, version):
45
  if model == "Donut":
46
  df_real = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_secret_all_embeddings.csv")
@@ -95,8 +95,10 @@ def load_embeddings(model, version):
95
  return None
96
 
97
  def split_versions(df_combined, reduced):
98
- df_combined['x'] = reduced[:, 0]
99
- df_combined['y'] = reduced[:, 1]
 
 
100
  df_real = df_combined[df_combined["version"] == "real"].copy()
101
  df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
102
  unique_real = sorted(df_real['label'].unique().tolist())
@@ -107,10 +109,14 @@ def split_versions(df_combined, reduced):
107
  unique_subsets = {"real": unique_real, "synthetic": unique_synth}
108
  return df_dict, unique_subsets
109
 
110
- # =============================================================================
111
- # Funciones para calcular distancias entre clusters según la métrica seleccionada
112
- # (Wasserstein, Euclidean o KL)
113
- # =============================================================================
 
 
 
 
114
 
115
  def compute_cluster_distance(synthetic_points, real_points, metric="wasserstein", bins=20):
116
  if metric.lower() == "wasserstein":
@@ -125,13 +131,14 @@ def compute_cluster_distance(synthetic_points, real_points, metric="wasserstein"
125
  center_real = np.mean(real_points, axis=0)
126
  return np.linalg.norm(center_syn - center_real)
127
  elif metric.lower() == "kl":
 
128
  all_points = np.vstack([synthetic_points, real_points])
129
- x_min, y_min = np.min(all_points, axis=0)
130
- x_max, y_max = np.max(all_points, axis=0)
131
- x_bins = np.linspace(x_min, x_max, bins+1)
132
- y_bins = np.linspace(y_min, y_max, bins+1)
133
- H_syn, _, _ = np.histogram2d(synthetic_points[:,0], synthetic_points[:,1], bins=[x_bins, y_bins])
134
- H_real, _, _ = np.histogram2d(real_points[:,0], real_points[:,1], bins=[x_bins, y_bins])
135
  eps = 1e-10
136
  P = H_syn + eps
137
  Q = H_real + eps
@@ -147,26 +154,22 @@ def compute_cluster_distances_synthetic_individual(synthetic_df: pd.DataFrame, d
147
  groups = synthetic_df.groupby(['source', 'label'])
148
  for (source, label), group in groups:
149
  key = f"{label} ({source})"
150
- data = group[['x', 'y']].values
151
  distances[key] = {}
152
  for real_label in real_labels:
153
- real_data = df_real[df_real['label'] == real_label][['x','y']].values
154
  d = compute_cluster_distance(data, real_data, metric=metric, bins=bins)
155
  distances[key][real_label] = d
156
  for source, group in synthetic_df.groupby('source'):
157
  key = f"Global ({source})"
158
- data = group[['x','y']].values
159
  distances[key] = {}
160
  for real_label in real_labels:
161
- real_data = df_real[df_real['label'] == real_label][['x','y']].values
162
  d = compute_cluster_distance(data, real_data, metric=metric, bins=bins)
163
  distances[key][real_label] = d
164
  return pd.DataFrame(distances).T
165
 
166
- # =============================================================================
167
- # Función para calcular continuidad (mide la preservación de la vecindad original en el embedding)
168
- # =============================================================================
169
-
170
  def compute_continuity(X, X_embedded, n_neighbors=5):
171
  n = X.shape[0]
172
  D_high = pairwise_distances(X, metric='euclidean')
@@ -187,10 +190,6 @@ def compute_continuity(X, X_embedded, n_neighbors=5):
187
  continuity_value = 1 - norm * total
188
  return continuity_value
189
 
190
- # =============================================================================
191
- # Funciones de visualización (sin cambios)
192
- # =============================================================================
193
-
194
  def create_table(df_distances):
195
  df_table = df_distances.copy()
196
  df_table.reset_index(inplace=True)
@@ -214,6 +213,7 @@ def create_table(df_distances):
214
  return data_table, df_table, source_table
215
 
216
  def create_figure(dfs, unique_subsets, color_maps, model_name):
 
217
  fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="")
218
  real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
219
  marker="circle", color_mapping=color_maps["real"],
@@ -350,38 +350,36 @@ def calculate_cluster_centers(df, labels):
350
  centers = {}
351
  for label in labels:
352
  subset = df[df['label'] == label]
353
- if not subset.empty:
354
  centers[label] = (subset['x'].mean(), subset['y'].mean())
355
  return centers
356
 
357
- # =============================================================================
358
- # Pipeline central: reducción, cálculo de distancias y regresión global.
359
- # Se agrega el parámetro distance_metric.
360
- # Además, si se utiliza t-SNE, se calculan trustworthiness y continuity.
361
- # =============================================================================
362
-
363
  def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE", distance_metric="wasserstein"):
364
  if reduction_method == "PCA":
365
- reducer = PCA(n_components=2)
366
  else:
367
- reducer = TSNE(n_components=2, random_state=42,
368
  perplexity=tsne_params["perplexity"],
369
  learning_rate=tsne_params["learning_rate"])
370
 
371
  reduced = reducer.fit_transform(df_combined[embedding_cols].values)
 
 
 
 
 
 
372
 
373
- # Para PCA se captura la explained variance ratio
374
  explained_variance = None
375
  if reduction_method == "PCA":
376
  explained_variance = reducer.explained_variance_ratio_
377
 
378
- # Si se usa t-SNE, calculamos trustworthiness y continuity
379
  trust = None
380
  cont = None
381
  if reduction_method == "t-SNE":
382
  X = df_combined[embedding_cols].values
383
- trust = trustworthiness(X, reduced, n_neighbors=5)
384
- cont = compute_continuity(X, reduced, n_neighbors=5)
385
 
386
  dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
387
 
@@ -453,15 +451,11 @@ def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, r
453
  "dfs_reduced": dfs_reduced,
454
  "unique_subsets": unique_subsets,
455
  "df_distances": df_distances,
456
- "explained_variance": explained_variance, # Solo para PCA
457
- "trustworthiness": trust, # Solo para t-SNE
458
- "continuity": cont # Solo para t-SNE
459
  }
460
 
461
- # =============================================================================
462
- # Optimización de parámetros para TSNE (se propaga también la métrica de distancia)
463
- # =============================================================================
464
-
465
  def optimize_tsne_params(df_combined, embedding_cols, df_f1, distance_metric):
466
  perplexity_range = np.linspace(30, 50, 10)
467
  learning_rate_range = np.linspace(200, 1000, 20)
@@ -490,11 +484,6 @@ def optimize_tsne_params(df_combined, embedding_cols, df_f1, distance_metric):
490
  progress_text.text("Optimization completed!")
491
  return best_params, best_R2
492
 
493
- # =============================================================================
494
- # Función principal run_model: incluye selector de versión, método de reducción, métrica de distancia,
495
- # y, si se usa t-SNE, muestra trustworthiness y continuity.
496
- # =============================================================================
497
-
498
  def run_model(model_name):
499
  version = st.selectbox("Select Model Version:", options=["vanilla", "finetuned_real"], key=f"version_{model_name}")
500
 
@@ -556,8 +545,9 @@ def run_model(model_name):
556
 
557
  if reduction_method == "PCA" and result["explained_variance"] is not None:
558
  st.subheader("Explained Variance Ratio")
 
559
  variance_df = pd.DataFrame({
560
- "Component": ["PC1", "PC2"],
561
  "Explained Variance": result["explained_variance"]
562
  })
563
  st.table(variance_df)
@@ -565,6 +555,7 @@ def run_model(model_name):
565
  st.subheader("t-SNE Quality Metrics")
566
  st.write(f"Trustworthiness: {result['trustworthiness']:.4f}")
567
  st.write(f"Continuity: {result['continuity']:.4f}")
 
568
 
569
  data_table, df_table, source_table = create_table(result["df_distances"])
570
  real_subset_names = list(df_table.columns[1:])
@@ -572,53 +563,58 @@ def run_model(model_name):
572
  reset_button = Button(label="Reset Colors", button_type="primary")
573
  line_source = ColumnDataSource(data={'x': [], 'y': []})
574
 
575
- fig, real_renderers, synthetic_renderers = create_figure(result["dfs_reduced"], result["unique_subsets"], get_color_maps(result["unique_subsets"]), model_name)
576
- fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
577
- centers_real = calculate_cluster_centers(result["dfs_reduced"]["real"], result["unique_subsets"]["real"])
578
- real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
579
- synthetic_centers = {}
580
- synth_labels = sorted(result["dfs_reduced"]["synthetic"]['label'].unique().tolist())
581
- for label in synth_labels:
582
- subset = result["dfs_reduced"]["synthetic"][result["dfs_reduced"]["synthetic"]['label'] == label]
583
- synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()]
584
-
585
- callback = CustomJS(args=dict(source=source_table, line_source=line_source,
 
 
 
586
  synthetic_centers=synthetic_centers,
587
  real_centers=real_centers_js,
588
  real_select=real_select),
589
- code="""
590
- var selected = source.selected.indices;
591
- if (selected.length > 0) {
592
- var idx = selected[0];
593
- var data = source.data;
594
- var synth_label = data['Synthetic'][idx];
595
- var real_label = real_select.value;
596
- var syn_coords = synthetic_centers[synth_label];
597
- var real_coords = real_centers[real_label];
598
- line_source.data = {'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]]};
599
- line_source.change.emit();
600
- } else {
 
 
 
 
 
 
 
 
 
601
  line_source.data = {'x': [], 'y': []};
602
  line_source.change.emit();
603
- }
604
- """)
605
- source_table.selected.js_on_change('indices', callback)
606
- real_select.js_on_change('value', callback)
 
607
 
608
- reset_callback = CustomJS(args=dict(line_source=line_source),
609
- code="""
610
- line_source.data = {'x': [], 'y': []};
611
- line_source.change.emit();
612
- """)
613
- reset_button.js_on_event("button_click", reset_callback)
614
 
615
  buffer = io.BytesIO()
616
  df_table.to_excel(buffer, index=False)
617
  buffer.seek(0)
618
 
619
- layout = column(fig, result["scatter_fig"], column(real_select, reset_button, data_table))
620
- st.bokeh_chart(layout, use_container_width=True)
621
-
622
  st.download_button(
623
  label="Export Table",
624
  data=buffer,
 
12
  import ot
13
  from sklearn.linear_model import LinearRegression
14
 
15
+ # Usaremos 4 componentes para el embedding
16
+ N_COMPONENTS = 100
17
+ TSNE_NEIGHBOURS = 150
18
+
19
  TOOLTIPS = """
20
  <div>
21
  <div>
 
41
  """, unsafe_allow_html=True)
42
  st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True)
43
 
 
 
 
 
44
  def load_embeddings(model, version):
45
  if model == "Donut":
46
  df_real = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_secret_all_embeddings.csv")
 
95
  return None
96
 
97
  def split_versions(df_combined, reduced):
98
+ # Si el embedding es 2D se asignan las columnas x e y para visualización.
99
+ if reduced.shape[1] == 2:
100
+ df_combined['x'] = reduced[:, 0]
101
+ df_combined['y'] = reduced[:, 1]
102
  df_real = df_combined[df_combined["version"] == "real"].copy()
103
  df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
104
  unique_real = sorted(df_real['label'].unique().tolist())
 
109
  unique_subsets = {"real": unique_real, "synthetic": unique_synth}
110
  return df_dict, unique_subsets
111
 
112
+ def get_embedding_from_df(df):
113
+ # Retorna el embedding completo (4 dimensiones en este caso) guardado en la columna 'embedding'
114
+ if 'embedding' in df.columns:
115
+ return np.stack(df['embedding'].to_numpy())
116
+ elif 'x' in df.columns and 'y' in df.columns:
117
+ return df[['x', 'y']].values
118
+ else:
119
+ raise ValueError("No se encontró embedding o coordenadas x,y en el DataFrame.")
120
 
121
  def compute_cluster_distance(synthetic_points, real_points, metric="wasserstein", bins=20):
122
  if metric.lower() == "wasserstein":
 
131
  center_real = np.mean(real_points, axis=0)
132
  return np.linalg.norm(center_syn - center_real)
133
  elif metric.lower() == "kl":
134
+ # Para KL usamos histogramas multidimensionales con límites globales en cada dimensión
135
  all_points = np.vstack([synthetic_points, real_points])
136
+ edges = [
137
+ np.linspace(np.min(all_points[:, i]), np.max(all_points[:, i]), bins+1)
138
+ for i in range(all_points.shape[1])
139
+ ]
140
+ H_syn, _ = np.histogramdd(synthetic_points, bins=edges)
141
+ H_real, _ = np.histogramdd(real_points, bins=edges)
142
  eps = 1e-10
143
  P = H_syn + eps
144
  Q = H_real + eps
 
154
  groups = synthetic_df.groupby(['source', 'label'])
155
  for (source, label), group in groups:
156
  key = f"{label} ({source})"
157
+ data = get_embedding_from_df(group)
158
  distances[key] = {}
159
  for real_label in real_labels:
160
+ real_data = get_embedding_from_df(df_real[df_real['label'] == real_label])
161
  d = compute_cluster_distance(data, real_data, metric=metric, bins=bins)
162
  distances[key][real_label] = d
163
  for source, group in synthetic_df.groupby('source'):
164
  key = f"Global ({source})"
165
+ data = get_embedding_from_df(group)
166
  distances[key] = {}
167
  for real_label in real_labels:
168
+ real_data = get_embedding_from_df(df_real[df_real['label'] == real_label])
169
  d = compute_cluster_distance(data, real_data, metric=metric, bins=bins)
170
  distances[key][real_label] = d
171
  return pd.DataFrame(distances).T
172
 
 
 
 
 
173
  def compute_continuity(X, X_embedded, n_neighbors=5):
174
  n = X.shape[0]
175
  D_high = pairwise_distances(X, metric='euclidean')
 
190
  continuity_value = 1 - norm * total
191
  return continuity_value
192
 
 
 
 
 
193
  def create_table(df_distances):
194
  df_table = df_distances.copy()
195
  df_table.reset_index(inplace=True)
 
213
  return data_table, df_table, source_table
214
 
215
  def create_figure(dfs, unique_subsets, color_maps, model_name):
216
+ # Se crea solo si el embedding es 2D (ya que se usan 'x' y 'y' para visualizar)
217
  fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="")
218
  real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
219
  marker="circle", color_mapping=color_maps["real"],
 
350
  centers = {}
351
  for label in labels:
352
  subset = df[df['label'] == label]
353
+ if not subset.empty and 'x' in subset.columns and 'y' in subset.columns:
354
  centers[label] = (subset['x'].mean(), subset['y'].mean())
355
  return centers
356
 
 
 
 
 
 
 
357
  def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE", distance_metric="wasserstein"):
358
  if reduction_method == "PCA":
359
+ reducer = PCA(n_components=N_COMPONENTS)
360
  else:
361
+ reducer = TSNE(n_components=3, random_state=42,
362
  perplexity=tsne_params["perplexity"],
363
  learning_rate=tsne_params["learning_rate"])
364
 
365
  reduced = reducer.fit_transform(df_combined[embedding_cols].values)
366
+ # Guardamos el embedding completo (4 dimensiones para PCA)
367
+ df_combined['embedding'] = list(reduced)
368
+ # Si el embedding es 2D (por t-SNE o PCA con 2 componentes) asignamos x e y para visualización
369
+ if reduced.shape[1] == 2:
370
+ df_combined['x'] = reduced[:, 0]
371
+ df_combined['y'] = reduced[:, 1]
372
 
 
373
  explained_variance = None
374
  if reduction_method == "PCA":
375
  explained_variance = reducer.explained_variance_ratio_
376
 
 
377
  trust = None
378
  cont = None
379
  if reduction_method == "t-SNE":
380
  X = df_combined[embedding_cols].values
381
+ trust = trustworthiness(X, reduced, n_neighbors=TSNE_NEIGHBOURS)
382
+ cont = compute_continuity(X, reduced, n_neighbors=TSNE_NEIGHBOURS)
383
 
384
  dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
385
 
 
451
  "dfs_reduced": dfs_reduced,
452
  "unique_subsets": unique_subsets,
453
  "df_distances": df_distances,
454
+ "explained_variance": explained_variance,
455
+ "trustworthiness": trust,
456
+ "continuity": cont
457
  }
458
 
 
 
 
 
459
  def optimize_tsne_params(df_combined, embedding_cols, df_f1, distance_metric):
460
  perplexity_range = np.linspace(30, 50, 10)
461
  learning_rate_range = np.linspace(200, 1000, 20)
 
484
  progress_text.text("Optimization completed!")
485
  return best_params, best_R2
486
 
 
 
 
 
 
487
  def run_model(model_name):
488
  version = st.selectbox("Select Model Version:", options=["vanilla", "finetuned_real"], key=f"version_{model_name}")
489
 
 
545
 
546
  if reduction_method == "PCA" and result["explained_variance"] is not None:
547
  st.subheader("Explained Variance Ratio")
548
+ component_names = [f"PC{i+1}" for i in range(len(result["explained_variance"]))]
549
  variance_df = pd.DataFrame({
550
+ "Component": component_names,
551
  "Explained Variance": result["explained_variance"]
552
  })
553
  st.table(variance_df)
 
555
  st.subheader("t-SNE Quality Metrics")
556
  st.write(f"Trustworthiness: {result['trustworthiness']:.4f}")
557
  st.write(f"Continuity: {result['continuity']:.4f}")
558
+
559
 
560
  data_table, df_table, source_table = create_table(result["df_distances"])
561
  real_subset_names = list(df_table.columns[1:])
 
563
  reset_button = Button(label="Reset Colors", button_type="primary")
564
  line_source = ColumnDataSource(data={'x': [], 'y': []})
565
 
566
+ # Si el embedding es 2D se crea el scatter plot de embeddings;
567
+ # dado que con PCA ahora usamos 4 dimensiones, este bloque se omite para PCA
568
+ if (reduction_method == "t-SNE" and N_COMPONENTS == 2) or (reduction_method == "PCA" and N_COMPONENTS == 2):
569
+ fig, real_renderers, synthetic_renderers = create_figure(result["dfs_reduced"], result["unique_subsets"], get_color_maps(result["unique_subsets"]), model_name)
570
+ fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
571
+ centers_real = calculate_cluster_centers(result["dfs_reduced"]["real"], result["unique_subsets"]["real"])
572
+ real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
573
+ synthetic_centers = {}
574
+ synth_labels = sorted(result["dfs_reduced"]["synthetic"]['label'].unique().tolist())
575
+ for label in synth_labels:
576
+ subset = result["dfs_reduced"]["synthetic"][result["dfs_reduced"]["synthetic"]['label'] == label]
577
+ if 'x' in subset.columns and 'y' in subset.columns:
578
+ synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()]
579
+ callback = CustomJS(args=dict(source=source_table, line_source=line_source,
580
  synthetic_centers=synthetic_centers,
581
  real_centers=real_centers_js,
582
  real_select=real_select),
583
+ code="""
584
+ var selected = source.selected.indices;
585
+ if (selected.length > 0) {
586
+ var idx = selected[0];
587
+ var data = source.data;
588
+ var synth_label = data['Synthetic'][idx];
589
+ var real_label = real_select.value;
590
+ var syn_coords = synthetic_centers[synth_label];
591
+ var real_coords = real_centers[real_label];
592
+ line_source.data = {'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]]};
593
+ line_source.change.emit();
594
+ } else {
595
+ line_source.data = {'x': [], 'y': []};
596
+ line_source.change.emit();
597
+ }
598
+ """)
599
+ source_table.selected.js_on_change('indices', callback)
600
+ real_select.js_on_change('value', callback)
601
+
602
+ reset_callback = CustomJS(args=dict(line_source=line_source),
603
+ code="""
604
  line_source.data = {'x': [], 'y': []};
605
  line_source.change.emit();
606
+ """)
607
+ reset_button.js_on_event("button_click", reset_callback)
608
+ layout = column(fig, result["scatter_fig"], column(real_select, reset_button, data_table))
609
+ else:
610
+ layout = column(result["scatter_fig"], column(real_select, reset_button, data_table))
611
 
612
+ st.bokeh_chart(layout, use_container_width=True)
 
 
 
 
 
613
 
614
  buffer = io.BytesIO()
615
  df_table.to_excel(buffer, index=False)
616
  buffer.seek(0)
617
 
 
 
 
618
  st.download_button(
619
  label="Export Table",
620
  data=buffer,