de-Rodrigo commited on
Commit
89ffe36
1 Parent(s): d966a8e

Cleaner Layout and Tabs for Different Models

Browse files
Files changed (1) hide show
  1. app.py +63 -41
app.py CHANGED
@@ -28,13 +28,21 @@ def config_style():
28
  </style>
29
  """, unsafe_allow_html=True)
30
  st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
31
- st.markdown('<h2 class="sub-title">Donut 馃</h2>', unsafe_allow_html=True)
32
 
33
- def load_embeddings():
34
- df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
35
- df_es_digital_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
 
 
 
 
 
 
 
 
36
  return {"real": df_real, "es-digital-seq": df_es_digital_seq}
37
 
 
38
  def reducer_selector(df_combined, embedding_cols):
39
  reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"])
40
  all_embeddings = df_combined[embedding_cols].values
@@ -88,11 +96,6 @@ def split_versions(df_combined, reduced):
88
  unique_es = sorted(df_es['label'].unique().tolist())
89
  return {"real": df_real, "es-digital-seq": df_es}, {"real": unique_real, "es-digital-seq": unique_es}
90
 
91
- def subset_selectors(unique_subsets: dict):
92
- selected_real = st.multiselect("Select Real Subsets:", options=unique_subsets["real"], default=unique_subsets["real"])
93
- selected_es = st.multiselect("Select Synthetic Subsets:", options=unique_subsets["es-digital-seq"], default=unique_subsets["es-digital-seq"])
94
- return {"real": selected_real, "es-digital-seq": selected_es}
95
-
96
  def create_figure(dfs_reduced, selected_subsets: dict, color_maps: dict):
97
  fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="")
98
  real_renderers = add_dataset_to_fig(fig, dfs_reduced["real"], selected_subsets["real"],
@@ -119,52 +122,61 @@ def compute_distances(centers_es: dict, centers_real: dict) -> pd.DataFrame:
119
  distances[es_label][real_label] = np.sqrt((x_es - x_real)**2 + (y_es - y_real)**2)
120
  return pd.DataFrame(distances).T
121
 
122
- def main():
123
- config_style()
124
- embeddings = load_embeddings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  embeddings["real"]["version"] = "real"
126
  embeddings["es-digital-seq"]["version"] = "es_digital_seq"
127
  embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
128
-
129
  df_combined = pd.concat([embeddings["real"], embeddings["es-digital-seq"]], ignore_index=True)
130
- reduced = reducer_selector(df_combined, embedding_cols)
 
 
 
 
 
 
 
131
 
132
  dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
133
- selected_subsets = subset_selectors(unique_subsets)
134
  color_maps = get_color_maps(selected_subsets)
135
- fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, selected_subsets, color_maps)
136
 
 
137
  centers_real = calculate_cluster_centers(dfs_reduced["real"], selected_subsets["real"])
138
  centers_es = calculate_cluster_centers(dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"])
139
  df_distances = compute_distances(centers_es, centers_real)
140
-
141
- # Tabla de distancias: se muestran todas las combinaciones
142
- df_table = df_distances.copy()
143
- df_table.reset_index(inplace=True)
144
- df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
145
- source_table = ColumnDataSource(df_table)
146
- columns = [TableColumn(field='Synthetic', title='Synthetic')]
147
- for col in df_table.columns:
148
- if col != 'Synthetic':
149
- columns.append(TableColumn(field=col, title=col))
150
- data_table = DataTable(source=source_table, columns=columns, width=400, height=300)
151
-
152
- # Widget Select para elegir el subset real (columnas de la tabla)
153
- real_subset_names = list(df_table.columns[1:]) # todas las columnas excepto 'Synthetic'
154
- real_select = Select(title="Select Real Subset:", value=real_subset_names[0], options=real_subset_names)
155
-
156
- # Bot贸n para resetear la visualizaci贸n a colores originales
157
  reset_button = Button(label="Reset Colors", button_type="primary")
158
-
159
- # Fuente para la l铆nea que conecta los centros
160
  line_source = ColumnDataSource(data={'x': [], 'y': []})
161
  fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
162
 
163
- # Preparar centros para el callback
164
  synthetic_centers_js = {k: [v[0], v[1]] for k, v in centers_es.items()}
165
  real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
166
 
167
- # Callback para actualizar la visualizaci贸n seg煤n la selecci贸n de la tabla y el dropdown
168
  callback = CustomJS(args=dict(source=source_table, line_source=line_source,
169
  synthetic_centers=synthetic_centers_js,
170
  real_centers=real_centers_js,
@@ -228,11 +240,9 @@ def main():
228
  }
229
  }
230
  """)
231
-
232
  source_table.selected.js_on_change('indices', callback)
233
  real_select.js_on_change('value', callback)
234
 
235
- # Callback para el bot贸n de resetear: se reinician la l铆nea y los colores a su estado original.
236
  reset_callback = CustomJS(args=dict(line_source=line_source,
237
  synthetic_renderers=synthetic_renderers,
238
  real_renderers=real_renderers,
@@ -258,9 +268,21 @@ def main():
258
  """)
259
  reset_button.js_on_event("button_click", reset_callback)
260
 
261
- # Organizar el layout: gr谩fico, dropdown, bot贸n de reset y tabla
262
  layout = column(fig, column(real_select, reset_button, data_table))
263
- st.bokeh_chart(layout)
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  if __name__ == "__main__":
266
  main()
 
28
  </style>
29
  """, unsafe_allow_html=True)
30
  st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
 
31
 
32
+ # Modificamos load_embeddings para aceptar el modelo a cargar
33
+ def load_embeddings(model):
34
+ if model == "Donut":
35
+ df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
36
+ df_es_digital_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
37
+ elif model == "Idefics2":
38
+ df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
39
+ df_es_digital_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
40
+ else:
41
+ st.error("Modelo no reconocido")
42
+ return None
43
  return {"real": df_real, "es-digital-seq": df_es_digital_seq}
44
 
45
+ # Funciones auxiliares (id茅nticas a las de tu c贸digo)
46
  def reducer_selector(df_combined, embedding_cols):
47
  reduction_method = st.selectbox("Select Dimensionality Reduction Method:", options=["PCA", "t-SNE"])
48
  all_embeddings = df_combined[embedding_cols].values
 
96
  unique_es = sorted(df_es['label'].unique().tolist())
97
  return {"real": df_real, "es-digital-seq": df_es}, {"real": unique_real, "es-digital-seq": unique_es}
98
 
 
 
 
 
 
99
  def create_figure(dfs_reduced, selected_subsets: dict, color_maps: dict):
100
  fig = figure(width=400, height=400, tooltips=TOOLTIPS, title="")
101
  real_renderers = add_dataset_to_fig(fig, dfs_reduced["real"], selected_subsets["real"],
 
122
  distances[es_label][real_label] = np.sqrt((x_es - x_real)**2 + (y_es - y_real)**2)
123
  return pd.DataFrame(distances).T
124
 
125
+ def create_table(df_distances):
126
+ df_table = df_distances.copy()
127
+ df_table.reset_index(inplace=True)
128
+ df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
129
+ source_table = ColumnDataSource(df_table)
130
+ columns = [TableColumn(field='Synthetic', title='Synthetic')]
131
+ for col in df_table.columns:
132
+ if col != 'Synthetic':
133
+ columns.append(TableColumn(field=col, title=col))
134
+ row_height = 28
135
+ header_height = 30
136
+ total_height = header_height + len(df_table) * row_height
137
+
138
+ data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height)
139
+ return data_table, df_table, source_table
140
+
141
+ # Funci贸n que ejecuta todo el proceso para un modelo determinado
142
+ def run_model(model_name):
143
+ embeddings = load_embeddings(model_name)
144
+ if embeddings is None:
145
+ return
146
+
147
+ # Asignamos la versi贸n para distinguir en el split
148
  embeddings["real"]["version"] = "real"
149
  embeddings["es-digital-seq"]["version"] = "es_digital_seq"
150
  embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
 
151
  df_combined = pd.concat([embeddings["real"], embeddings["es-digital-seq"]], ignore_index=True)
152
+
153
+ st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
154
+ reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=model_name)
155
+ if reduction_method == "PCA":
156
+ reducer = PCA(n_components=2)
157
+ else:
158
+ reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
159
+ reduced = reducer.fit_transform(df_combined[embedding_cols].values)
160
 
161
  dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
162
+ selected_subsets = {"real": unique_subsets["real"], "es-digital-seq": unique_subsets["es-digital-seq"]}
163
  color_maps = get_color_maps(selected_subsets)
 
164
 
165
+ fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, selected_subsets, color_maps)
166
  centers_real = calculate_cluster_centers(dfs_reduced["real"], selected_subsets["real"])
167
  centers_es = calculate_cluster_centers(dfs_reduced["es-digital-seq"], selected_subsets["es-digital-seq"])
168
  df_distances = compute_distances(centers_es, centers_real)
169
+ data_table, df_table, source_table = create_table(df_distances)
170
+ real_subset_names = list(df_table.columns[1:])
171
+ real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  reset_button = Button(label="Reset Colors", button_type="primary")
 
 
173
  line_source = ColumnDataSource(data={'x': [], 'y': []})
174
  fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
175
 
 
176
  synthetic_centers_js = {k: [v[0], v[1]] for k, v in centers_es.items()}
177
  real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
178
 
179
+ # Callback para actualizar el gr谩fico
180
  callback = CustomJS(args=dict(source=source_table, line_source=line_source,
181
  synthetic_centers=synthetic_centers_js,
182
  real_centers=real_centers_js,
 
240
  }
241
  }
242
  """)
 
243
  source_table.selected.js_on_change('indices', callback)
244
  real_select.js_on_change('value', callback)
245
 
 
246
  reset_callback = CustomJS(args=dict(line_source=line_source,
247
  synthetic_renderers=synthetic_renderers,
248
  real_renderers=real_renderers,
 
268
  """)
269
  reset_button.js_on_event("button_click", reset_callback)
270
 
 
271
  layout = column(fig, column(real_select, reset_button, data_table))
272
+ st.bokeh_chart(layout, use_container_width=True)
273
+
274
+ # Funci贸n principal con tabs para cambiar de modelo
275
+ def main():
276
+ config_style()
277
+ tabs = st.tabs(["Donut", "Idefics2"])
278
+
279
+ with tabs[0]:
280
+ st.markdown('<h2 class="sub-title">Modelo Donut 馃</h2>', unsafe_allow_html=True)
281
+ run_model("Donut")
282
+
283
+ with tabs[1]:
284
+ st.markdown('<h2 class="sub-title">Modelo Idefics2 馃</h2>', unsafe_allow_html=True)
285
+ run_model("Idefics2")
286
 
287
  if __name__ == "__main__":
288
  main()