de-Rodrigo commited on
Commit
ed8f744
1 Parent(s): 14bdc44

Draw Lines between Cluster Centers

Browse files
Files changed (1) hide show
  1. app.py +162 -120
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import streamlit as st
2
  import pandas as pd
 
3
  from bokeh.plotting import figure
4
- from bokeh.models import ColumnDataSource
 
5
  from bokeh.palettes import Reds9, Blues9
6
  from sklearn.decomposition import PCA
7
  from sklearn.manifold import TSNE
@@ -17,7 +19,6 @@ TOOLTIPS = """
17
  </div>
18
  """
19
 
20
-
21
  def config_style():
22
  st.markdown("""
23
  <style>
@@ -28,171 +29,212 @@ def config_style():
28
  """, unsafe_allow_html=True)
29
  st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
30
  st.markdown('<h2 class="sub-title">Donut 馃</h2>', unsafe_allow_html=True)
31
- st.markdown(
32
- """
33
- <p class="custom-text">
34
- Se cargan ambas versiones de los embeddings y se aplica una reducci贸n dimensional sobre el conjunto combinado.
35
- Los puntos de la versi贸n real se muestran como <strong>c铆rculos</strong> (tonos de rojo)
36
- y los de la es_digital_seq como <strong>cuadrados</strong> (tonos de azul).
37
- </p>
38
- """, unsafe_allow_html=True)
39
-
40
 
41
  def load_embeddings():
42
  df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
43
  df_es_digital_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
44
-
45
- embeddings = {
46
- "real": df_real,
47
- "es-digital-seq": df_es_digital_seq
48
- }
49
-
50
- return embeddings
51
-
52
 
53
  def reducer_selector(df_combined, embedding_cols):
54
-
55
- reduction_method = st.selectbox("Seleccione m茅todo de reducci贸n:", options=["PCA", "t-SNE"])
56
  all_embeddings = df_combined[embedding_cols].values
57
  if reduction_method == "PCA":
58
  reducer = PCA(n_components=2)
59
  else:
60
  reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
61
- reduced = reducer.fit_transform(all_embeddings)
62
-
63
- return reduced
64
-
65
 
66
  def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping):
 
67
  for label in selected_labels:
68
  subset = df[df['label'] == label]
69
  if subset.empty:
70
  continue
71
  source = ColumnDataSource(data=dict(
72
- x = subset['x'],
73
- y = subset['y'],
74
- label = subset['label'],
75
- img = subset['img']
76
  ))
77
  color = color_mapping[label]
78
  if marker == "circle":
79
- fig.circle('x', 'y', size=10, source=source,
80
- fill_color=color, line_color=color,
81
- legend_label=f"{label} (Real)")
82
  elif marker == "square":
83
- fig.square('x', 'y', size=4, source=source, fill_color=color, line_color=color,
84
- legend_label=f"{label} (Sint茅tico)")
85
-
 
 
86
 
87
  def get_color_maps(selected_subsets: dict):
88
-
89
- # real
90
  num_real = len(selected_subsets["real"])
91
- if num_real <= 9:
92
- red_palette = Reds9[:num_real]
93
- else:
94
- red_palette = (Reds9 * ((num_real // 9) + 1))[:num_real]
95
  color_mapping_real = {label: red_palette[i] for i, label in enumerate(sorted(selected_subsets["real"]))}
96
-
97
- # es-digital-seq
98
- num_es_digital_seq = len(selected_subsets["es-digital-seq"])
99
- if num_es_digital_seq <= 9:
100
- blue_palette = Blues9[:num_es_digital_seq]
101
- else:
102
- blue_palette = (Blues9 * ((num_es_digital_seq // 9) + 1))[:num_es_digital_seq]
103
- color_mapping_es_digital_seq = {label: blue_palette[i] for i, label in enumerate(sorted(selected_subsets["es-digital-seq"]))}
104
-
105
- # Gather color maps
106
- color_maps = {
107
- "real": color_mapping_real,
108
- "es-digital-seq": color_mapping_es_digital_seq
109
- }
110
-
111
- return color_maps
112
-
113
 
114
  def split_versions(df_combined, reduced):
115
-
116
  df_combined['x'] = reduced[:, 0]
117
  df_combined['y'] = reduced[:, 1]
118
-
119
- df_real_reduced = df_combined[df_combined["version"] == "real"].copy()
120
- df_es_digital_seq_reduced = df_combined[df_combined["version"] == "es_digital_seq"].copy()
121
-
122
- # Obtener los subsets 煤nicos de cada versi贸n
123
- unique_subsets_real = sorted(df_real_reduced['label'].unique().tolist())
124
- unique_subsets_es_digital_seq = sorted(df_es_digital_seq_reduced['label'].unique().tolist())
125
-
126
- unique_subsets = {
127
- "real": unique_subsets_real,
128
- "es-digital-seq": unique_subsets_es_digital_seq,
129
- }
130
-
131
- dfs_reduced = {
132
- "real": df_real_reduced,
133
- "es-digital-seq": df_es_digital_seq_reduced,
134
- }
135
-
136
- return dfs_reduced, unique_subsets
137
-
138
 
139
  def subset_selectors(unique_subsets: dict):
140
-
141
- selected_subsets_real = st.multiselect("Seleccione subsets para visualizar (Real):",
142
- options=unique_subsets["real"],
143
- default=unique_subsets["real"])
144
- selected_subsets_es_digital_seq = st.multiselect("Seleccione subsets para visualizar (Sint茅tico):",
145
- options=unique_subsets["es-digital-seq"],
146
- default=unique_subsets["es-digital-seq"])
147
-
148
- selected_subsets = {
149
- "real": selected_subsets_real,
150
- "es-digital-seq": selected_subsets_es_digital_seq
151
- }
152
-
153
- return selected_subsets
154
-
155
 
156
  def create_figure(dfs_reduced, selected_subsets: dict, color_maps: dict):
157
-
158
- fig = figure(width=600, height=600, tooltips=TOOLTIPS,
159
- title="")
160
-
161
- add_dataset_to_fig(fig, dfs_reduced["real"], selected_subsets["real"],
162
- marker="circle", color_mapping=color_maps["real"])
163
- add_dataset_to_fig(fig, dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"],
164
- marker="square", color_mapping=color_maps["es-digital-seq"])
165
-
166
  fig.legend.location = "top_right"
167
  fig.legend.click_policy = "hide"
 
168
 
169
- return fig
170
-
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  def main():
173
-
174
  config_style()
 
 
 
 
175
 
176
- embeddings_dfs = load_embeddings()
177
-
178
- embeddings_dfs["real"]["version"] = "real"
179
- embeddings_dfs["es-digital-seq"]["version"] = "es_digital_seq"
180
-
181
- embedding_cols = [col for col in embeddings_dfs["real"].columns if col.startswith("dim_")]
182
-
183
- # Combine dataframes to apply method reduction
184
- df_combined = pd.concat([embeddings_dfs["real"], embeddings_dfs["es-digital-seq"]], ignore_index=True)
185
-
186
  reduced = reducer_selector(df_combined, embedding_cols)
187
 
188
- # Split back the different versions
189
  dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
190
-
191
  selected_subsets = subset_selectors(unique_subsets)
192
  color_maps = get_color_maps(selected_subsets)
193
- figure = create_figure(dfs_reduced, selected_subsets, color_maps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- st.bokeh_chart(figure)
 
 
196
 
197
  if __name__ == "__main__":
198
  main()
 
1
  import streamlit as st
2
  import pandas as pd
3
+ import numpy as np
4
  from bokeh.plotting import figure
5
+ from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select
6
+ from bokeh.layouts import row, column
7
  from bokeh.palettes import Reds9, Blues9
8
  from sklearn.decomposition import PCA
9
  from sklearn.manifold import TSNE
 
19
  </div>
20
  """
21
 
 
22
  def config_style():
23
  st.markdown("""
24
  <style>
 
29
  """, unsafe_allow_html=True)
30
  st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
31
  st.markdown('<h2 class="sub-title">Donut 馃</h2>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
32
 
33
  def load_embeddings():
34
  df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
35
  df_es_digital_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
36
+ return {"real": df_real, "es-digital-seq": df_es_digital_seq}
 
 
 
 
 
 
 
37
 
38
  def reducer_selector(df_combined, embedding_cols):
39
+ reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"])
 
40
  all_embeddings = df_combined[embedding_cols].values
41
  if reduction_method == "PCA":
42
  reducer = PCA(n_components=2)
43
  else:
44
  reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
45
+ return reducer.fit_transform(all_embeddings)
 
 
 
46
 
47
  def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping):
48
+ renderers = {}
49
  for label in selected_labels:
50
  subset = df[df['label'] == label]
51
  if subset.empty:
52
  continue
53
  source = ColumnDataSource(data=dict(
54
+ x=subset['x'],
55
+ y=subset['y'],
56
+ label=subset['label'],
57
+ img=subset['img']
58
  ))
59
  color = color_mapping[label]
60
  if marker == "circle":
61
+ r = fig.circle('x', 'y', size=10, source=source,
62
+ fill_color=color, line_color=color,
63
+ legend_label=f"{label} (Real)")
64
  elif marker == "square":
65
+ r = fig.square('x', 'y', size=10, source=source,
66
+ fill_color=color, line_color=color,
67
+ legend_label=f"{label} (Synthetic)")
68
+ renderers[label] = r
69
+ return renderers
70
 
71
  def get_color_maps(selected_subsets: dict):
72
+ # Para real
 
73
  num_real = len(selected_subsets["real"])
74
+ red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real]
 
 
 
75
  color_mapping_real = {label: red_palette[i] for i, label in enumerate(sorted(selected_subsets["real"]))}
76
+ # Para es-digital-seq (sint茅ticos)
77
+ num_es = len(selected_subsets["es-digital-seq"])
78
+ blue_palette = Blues9[:num_es] if num_es <= 9 else (Blues9 * ((num_es // 9) + 1))[:num_es]
79
+ color_mapping_es = {label: blue_palette[i] for i, label in enumerate(sorted(selected_subsets["es-digital-seq"]))}
80
+ return {"real": color_mapping_real, "es-digital-seq": color_mapping_es}
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def split_versions(df_combined, reduced):
 
83
  df_combined['x'] = reduced[:, 0]
84
  df_combined['y'] = reduced[:, 1]
85
+ df_real = df_combined[df_combined["version"] == "real"].copy()
86
+ df_es = df_combined[df_combined["version"] == "es_digital_seq"].copy()
87
+ unique_real = sorted(df_real['label'].unique().tolist())
88
+ unique_es = sorted(df_es['label'].unique().tolist())
89
+ return {"real": df_real, "es-digital-seq": df_es}, {"real": unique_real, "es-digital-seq": unique_es}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  def subset_selectors(unique_subsets: dict):
92
+ selected_real = st.multiselect("Select Real Subsets:", options=unique_subsets["real"], default=unique_subsets["real"])
93
+ selected_es = st.multiselect("Select Synthetic Subsets:", options=unique_subsets["es-digital-seq"], default=unique_subsets["es-digital-seq"])
94
+ return {"real": selected_real, "es-digital-seq": selected_es}
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  def create_figure(dfs_reduced, selected_subsets: dict, color_maps: dict):
97
+ fig = figure(width=600, height=600, tooltips=TOOLTIPS, title="")
98
+ real_renderers = add_dataset_to_fig(fig, dfs_reduced["real"], selected_subsets["real"],
99
+ marker="circle", color_mapping=color_maps["real"])
100
+ synthetic_renderers = add_dataset_to_fig(fig, dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"],
101
+ marker="square", color_mapping=color_maps["es-digital-seq"])
 
 
 
 
102
  fig.legend.location = "top_right"
103
  fig.legend.click_policy = "hide"
104
+ return fig, real_renderers, synthetic_renderers
105
 
106
+ def calculate_cluster_centers(df: pd.DataFrame, selected_labels: list) -> dict:
107
+ centers = {}
108
+ for label in selected_labels:
109
+ subset = df[df['label'] == label]
110
+ if not subset.empty:
111
+ centers[label] = (subset['x'].mean(), subset['y'].mean())
112
+ return centers
113
+
114
+ def compute_distances(centers_es: dict, centers_real: dict) -> pd.DataFrame:
115
+ distances = {}
116
+ for es_label, (x_es, y_es) in centers_es.items():
117
+ distances[es_label] = {}
118
+ for real_label, (x_real, y_real) in centers_real.items():
119
+ distances[es_label][real_label] = np.sqrt((x_es - x_real)**2 + (y_es - y_real)**2)
120
+ return pd.DataFrame(distances).T
121
 
122
  def main():
 
123
  config_style()
124
+ embeddings = load_embeddings()
125
+ embeddings["real"]["version"] = "real"
126
+ embeddings["es-digital-seq"]["version"] = "es_digital_seq"
127
+ embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
128
 
129
+ df_combined = pd.concat([embeddings["real"], embeddings["es-digital-seq"]], ignore_index=True)
 
 
 
 
 
 
 
 
 
130
  reduced = reducer_selector(df_combined, embedding_cols)
131
 
 
132
  dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
 
133
  selected_subsets = subset_selectors(unique_subsets)
134
  color_maps = get_color_maps(selected_subsets)
135
+ fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, selected_subsets, color_maps)
136
+
137
+ centers_real = calculate_cluster_centers(dfs_reduced["real"], selected_subsets["real"])
138
+ centers_es = calculate_cluster_centers(dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"])
139
+ df_distances = compute_distances(centers_es, centers_real)
140
+
141
+ # Creamos la tabla de distancias (se muestran todas las combinaciones)
142
+ df_table = df_distances.copy()
143
+ df_table.reset_index(inplace=True)
144
+ df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
145
+ source_table = ColumnDataSource(df_table)
146
+ columns = [TableColumn(field='Synthetic', title='Synthetic')]
147
+ for col in df_table.columns:
148
+ if col != 'Synthetic':
149
+ columns.append(TableColumn(field=col, title=col))
150
+ data_table = DataTable(source=source_table, columns=columns, width=400, height=300) # Selecci贸n por fila
151
+
152
+ # Creamos un widget Select para elegir el subset real (columnas de la tabla)
153
+ real_subset_names = list(df_table.columns[1:]) # todas las columnas excepto 'Synthetic'
154
+ real_select = Select(title="Select Real Subset:", value=real_subset_names[0], options=real_subset_names)
155
+
156
+ # Fuente para la l铆nea que conecta los centros
157
+ line_source = ColumnDataSource(data={'x': [], 'y': []})
158
+ fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
159
+
160
+ # Preparar centros para el callback
161
+ synthetic_centers_js = {k: [v[0], v[1]] for k, v in centers_es.items()}
162
+ real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
163
+
164
+ # Callback para actualizar la l铆nea y colores en funci贸n de la fila seleccionada y el valor del dropdown
165
+ callback = CustomJS(args=dict(source=source_table, line_source=line_source,
166
+ synthetic_centers=synthetic_centers_js,
167
+ real_centers=real_centers_js,
168
+ synthetic_renderers=synthetic_renderers,
169
+ real_renderers=real_renderers,
170
+ synthetic_colors=color_maps["es-digital-seq"],
171
+ real_colors=color_maps["real"],
172
+ real_select=real_select),
173
+ code="""
174
+ var selected = source.selected.indices;
175
+ if (selected.length > 0) {
176
+ var row = selected[0];
177
+ var data = source.data;
178
+ var synthetic_label = data['Synthetic'][row];
179
+ var real_label = real_select.value;
180
+ var syn_coords = synthetic_centers[synthetic_label];
181
+ var real_coords = real_centers[real_label];
182
+ line_source.data = { 'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]] };
183
+ line_source.change.emit();
184
+
185
+ // Actualizar colores: resaltar 煤nicamente los puntos implicados
186
+ for (var key in synthetic_renderers) {
187
+ if (synthetic_renderers.hasOwnProperty(key)) {
188
+ var renderer = synthetic_renderers[key];
189
+ if (key === synthetic_label) {
190
+ renderer.glyph.fill_color = synthetic_colors[key];
191
+ renderer.glyph.line_color = synthetic_colors[key];
192
+ } else {
193
+ renderer.glyph.fill_color = "lightgray";
194
+ renderer.glyph.line_color = "lightgray";
195
+ }
196
+ }
197
+ }
198
+ for (var key in real_renderers) {
199
+ if (real_renderers.hasOwnProperty(key)) {
200
+ var renderer = real_renderers[key];
201
+ if (key === real_label) {
202
+ renderer.glyph.fill_color = real_colors[key];
203
+ renderer.glyph.line_color = real_colors[key];
204
+ } else {
205
+ renderer.glyph.fill_color = "lightgray";
206
+ renderer.glyph.line_color = "lightgray";
207
+ }
208
+ }
209
+ }
210
+ } else {
211
+ // Sin selecci贸n: reiniciar l铆nea y colores
212
+ line_source.data = { 'x': [], 'y': [] };
213
+ line_source.change.emit();
214
+ for (var key in synthetic_renderers) {
215
+ if (synthetic_renderers.hasOwnProperty(key)) {
216
+ var renderer = synthetic_renderers[key];
217
+ renderer.glyph.fill_color = synthetic_colors[key];
218
+ renderer.glyph.line_color = synthetic_colors[key];
219
+ }
220
+ }
221
+ for (var key in real_renderers) {
222
+ if (real_renderers.hasOwnProperty(key)) {
223
+ var renderer = real_renderers[key];
224
+ renderer.glyph.fill_color = real_colors[key];
225
+ renderer.glyph.line_color = real_colors[key];
226
+ }
227
+ }
228
+ }
229
+ """)
230
+
231
+ # Asociamos el callback a los cambios en la selecci贸n de filas y en el dropdown
232
+ source_table.selected.js_on_change('indices', callback)
233
+ real_select.js_on_change('value', callback)
234
 
235
+ # Organizar layout: colocamos el gr谩fico, la tabla y el dropdown
236
+ layout = row(fig, column(real_select, data_table))
237
+ st.bokeh_chart(layout)
238
 
239
  if __name__ == "__main__":
240
  main()