Spaces:
Running
Running
Commit
路
a448d0f
1
Parent(s):
fe51656
Input Floats to TSNE and Refactor
Browse files
app.py
CHANGED
@@ -36,7 +36,10 @@ def config_style():
|
|
36 |
""", unsafe_allow_html=True)
|
37 |
st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
|
38 |
|
39 |
-
#
|
|
|
|
|
|
|
40 |
def load_embeddings(model):
|
41 |
if model == "Donut":
|
42 |
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
|
@@ -54,7 +57,6 @@ def load_embeddings(model):
|
|
54 |
df_zoom["version"] = "synthetic"
|
55 |
df_render["version"] = "synthetic"
|
56 |
|
57 |
-
# Se asigna la fuente
|
58 |
df_par["source"] = "es-digital-paragraph-degradation-seq"
|
59 |
df_line["source"] = "es-digital-line-degradation-seq"
|
60 |
df_seq["source"] = "es-digital-seq"
|
@@ -65,29 +67,127 @@ def load_embeddings(model):
|
|
65 |
|
66 |
elif model == "Idefics2":
|
67 |
df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
|
|
|
|
|
68 |
df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
|
|
|
|
|
|
|
69 |
df_real["version"] = "real"
|
|
|
|
|
70 |
df_seq["version"] = "synthetic"
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
df_seq["source"] = "es-digital-seq"
|
72 |
-
|
|
|
|
|
|
|
73 |
|
74 |
else:
|
75 |
st.error("Modelo no reconocido")
|
76 |
return None
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
-
# Funci贸n para agregar datos reales (por cada etiqueta)
|
91 |
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label):
|
92 |
renderers = {}
|
93 |
for label in selected_labels:
|
@@ -117,7 +217,6 @@ def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_la
|
|
117 |
renderers[label + f" ({group_label})"] = r
|
118 |
return renderers
|
119 |
|
120 |
-
# Nueva funci贸n para plotear sint茅ticos de forma granular pero con leyenda agrupada por source
|
121 |
def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_label):
|
122 |
renderers = {}
|
123 |
for label in labels:
|
@@ -130,11 +229,8 @@ def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_l
|
|
130 |
label=subset['label'],
|
131 |
img=subset.get('img', "")
|
132 |
))
|
133 |
-
# Se usa el color granular asignado a cada etiqueta
|
134 |
color = color_mapping[label]
|
135 |
-
# La leyenda se asigna al nombre del source para que se agrupe
|
136 |
legend_label = group_label
|
137 |
-
|
138 |
if marker == "square":
|
139 |
r = fig.square('x', 'y', size=10, source=source_obj,
|
140 |
fill_color=color, line_color=color,
|
@@ -171,6 +267,7 @@ def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_l
|
|
171 |
return renderers
|
172 |
|
173 |
|
|
|
174 |
def get_color_maps(unique_subsets):
|
175 |
color_map = {}
|
176 |
# Para reales se asigna color para cada etiqueta
|
@@ -197,59 +294,8 @@ def get_color_maps(unique_subsets):
|
|
197 |
palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)]
|
198 |
color_map["synthetic"][source] = {label: palette[i] for i, label in enumerate(sorted(labels))}
|
199 |
return color_map
|
200 |
-
|
201 |
-
def split_versions(df_combined, reduced):
|
202 |
-
df_combined['x'] = reduced[:, 0]
|
203 |
-
df_combined['y'] = reduced[:, 1]
|
204 |
-
df_real = df_combined[df_combined["version"] == "real"].copy()
|
205 |
-
df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
|
206 |
-
# Extraer etiquetas 煤nicas para reales
|
207 |
-
unique_real = sorted(df_real['label'].unique().tolist())
|
208 |
-
# Para sint茅ticos, se agrupan las etiquetas por source
|
209 |
-
unique_synth = {}
|
210 |
-
for source in df_synth["source"].unique():
|
211 |
-
unique_synth[source] = sorted(df_synth[df_synth["source"] == source]['label'].unique().tolist())
|
212 |
-
df_dict = {"real": df_real, "synthetic": df_synth}
|
213 |
-
# Para los reales se guarda la lista, y para sint茅ticos el diccionario
|
214 |
-
unique_subsets = {"real": unique_real, "synthetic": unique_synth}
|
215 |
-
return df_dict, unique_subsets
|
216 |
-
|
217 |
-
def create_figure(dfs, unique_subsets, color_maps, model_name):
|
218 |
-
fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="")
|
219 |
-
# Datos reales: se mantienen granulares en plot y en leyenda
|
220 |
-
real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
|
221 |
-
marker="circle", color_mapping=color_maps["real"],
|
222 |
-
group_label="Real")
|
223 |
-
# Diccionario de asignaci贸n de marcadores para sint茅ticos por source
|
224 |
-
marker_mapping = {
|
225 |
-
"es-digital-paragraph-degradation-seq": "x",
|
226 |
-
"es-digital-line-degradation-seq": "cross",
|
227 |
-
"es-digital-seq": "triangle",
|
228 |
-
"es-digital-rotation-degradation-seq": "diamond",
|
229 |
-
"es-digital-zoom-degradation-seq": "asterisk",
|
230 |
-
"es-render-seq": "inverted_triangle"
|
231 |
-
}
|
232 |
-
|
233 |
-
# Datos sint茅ticos: se plotean granularmente (por etiqueta) pero se agrupa la leyenda por source
|
234 |
-
synthetic_renderers = {}
|
235 |
-
synth_df = dfs["synthetic"]
|
236 |
-
for source in unique_subsets["synthetic"]:
|
237 |
-
df_source = synth_df[synth_df["source"] == source]
|
238 |
-
marker = marker_mapping.get(source, "square") # Por defecto "square" si no se encuentra
|
239 |
-
renderers = add_synthetic_dataset_to_fig(fig, df_source, unique_subsets["synthetic"][source],
|
240 |
-
marker=marker,
|
241 |
-
color_mapping=color_maps["synthetic"][source],
|
242 |
-
group_label=source)
|
243 |
-
synthetic_renderers.update(renderers)
|
244 |
|
245 |
-
|
246 |
-
fig.legend.click_policy = "hide"
|
247 |
-
show_legend = st.checkbox("Show Legend", value=False, key=f"legend_{model_name}")
|
248 |
-
fig.legend.visible = show_legend
|
249 |
-
return fig, real_renderers, synthetic_renderers
|
250 |
-
|
251 |
-
|
252 |
-
# Calcula los centros de cada cluster (por grupo)
|
253 |
def calculate_cluster_centers(df, labels):
|
254 |
centers = {}
|
255 |
for label in labels:
|
@@ -258,189 +304,60 @@ def calculate_cluster_centers(df, labels):
|
|
258 |
centers[label] = (subset['x'].mean(), subset['y'].mean())
|
259 |
return centers
|
260 |
|
261 |
-
# Calcula la distancia Wasserstein de cada subset sint茅tico respecto a cada cluster real (por cluster y global)
|
262 |
-
def compute_wasserstein_distances_synthetic_individual(synthetic_df: pd.DataFrame, df_real: pd.DataFrame, real_labels: list) -> pd.DataFrame:
|
263 |
-
distances = {}
|
264 |
-
groups = synthetic_df.groupby(['source', 'label'])
|
265 |
-
for (source, label), group in groups:
|
266 |
-
key = f"{label} ({source})"
|
267 |
-
data = group[['x', 'y']].values
|
268 |
-
n = data.shape[0]
|
269 |
-
weights = np.ones(n) / n
|
270 |
-
distances[key] = {}
|
271 |
-
for real_label in real_labels:
|
272 |
-
real_data = df_real[df_real['label'] == real_label][['x','y']].values
|
273 |
-
m = real_data.shape[0]
|
274 |
-
weights_real = np.ones(m) / m
|
275 |
-
M = ot.dist(data, real_data, metric='euclidean')
|
276 |
-
distances[key][real_label] = ot.emd2(weights, weights_real, M)
|
277 |
-
|
278 |
-
# Distancia global por fuente
|
279 |
-
for source, group in synthetic_df.groupby('source'):
|
280 |
-
key = f"Global ({source})"
|
281 |
-
data = group[['x','y']].values
|
282 |
-
n = data.shape[0]
|
283 |
-
weights = np.ones(n) / n
|
284 |
-
distances[key] = {}
|
285 |
-
for real_label in real_labels:
|
286 |
-
real_data = df_real[df_real['label'] == real_label][['x','y']].values
|
287 |
-
m = real_data.shape[0]
|
288 |
-
weights_real = np.ones(m) / m
|
289 |
-
M = ot.dist(data, real_data, metric='euclidean')
|
290 |
-
distances[key][real_label] = ot.emd2(weights, weights_real, M)
|
291 |
-
return pd.DataFrame(distances).T
|
292 |
-
|
293 |
-
def create_table(df_distances):
|
294 |
-
df_table = df_distances.copy()
|
295 |
-
df_table.reset_index(inplace=True)
|
296 |
-
df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
|
297 |
-
min_row = {"Synthetic": "Min."}
|
298 |
-
mean_row = {"Synthetic": "Mean"}
|
299 |
-
max_row = {"Synthetic": "Max."}
|
300 |
-
for col in df_table.columns:
|
301 |
-
if col != "Synthetic":
|
302 |
-
min_row[col] = df_table[col].min()
|
303 |
-
mean_row[col] = df_table[col].mean()
|
304 |
-
max_row[col] = df_table[col].max()
|
305 |
-
df_table = pd.concat([df_table, pd.DataFrame([min_row, mean_row, max_row])], ignore_index=True)
|
306 |
-
source_table = ColumnDataSource(df_table)
|
307 |
-
columns = [TableColumn(field='Synthetic', title='Synthetic')]
|
308 |
-
for col in df_table.columns:
|
309 |
-
if col != 'Synthetic':
|
310 |
-
columns.append(TableColumn(field=col, title=col))
|
311 |
-
total_height = 30 + len(df_table)*28
|
312 |
-
data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height)
|
313 |
-
return data_table, df_table, source_table
|
314 |
-
|
315 |
-
def optimize_tsne_params(df_combined, embedding_cols, df_f1):
|
316 |
-
# Rangos de b煤squeda (puedes ajustar estos l铆mites y pasos)
|
317 |
-
perplexity_range = np.linspace(30, 50, 10)
|
318 |
-
learning_rate_range = np.linspace(200, 1000, 20)
|
319 |
-
|
320 |
-
best_R2 = -np.inf
|
321 |
-
best_params = None
|
322 |
-
total_steps = len(perplexity_range) * len(learning_rate_range)
|
323 |
-
step = 0
|
324 |
-
|
325 |
-
# Usamos un placeholder de Streamlit para actualizar mensajes de progreso
|
326 |
-
progress_text = st.empty()
|
327 |
-
|
328 |
-
for p in perplexity_range:
|
329 |
-
for lr in learning_rate_range:
|
330 |
-
step += 1
|
331 |
-
# Actualizamos el mensaje de progreso
|
332 |
-
progress_text.text(f"Evaluating: Perplexity={p:.2f}, Learning Rate={lr:.2f} (Step: {step}/{total_steps})")
|
333 |
-
|
334 |
-
# Calcular la reducci贸n con TSNE
|
335 |
-
reducer_temp = TSNE(n_components=2, random_state=42, perplexity=p, learning_rate=lr)
|
336 |
-
reduced_temp = reducer_temp.fit_transform(df_combined[embedding_cols].values)
|
337 |
-
dfs_reduced_temp, unique_subsets_temp = split_versions(df_combined, reduced_temp)
|
338 |
-
|
339 |
-
# Calcular distancias Wasserstein
|
340 |
-
df_distances_temp = compute_wasserstein_distances_synthetic_individual(
|
341 |
-
dfs_reduced_temp["synthetic"],
|
342 |
-
dfs_reduced_temp["real"],
|
343 |
-
unique_subsets_temp["real"]
|
344 |
-
)
|
345 |
-
# Extraer los valores globales (suponemos 10 por fuente)
|
346 |
-
global_distances_temp = {}
|
347 |
-
for idx in df_distances_temp.index:
|
348 |
-
if idx.startswith("Global"):
|
349 |
-
source = idx.split("(")[1].rstrip(")")
|
350 |
-
global_distances_temp[source] = df_distances_temp.loc[idx].values
|
351 |
-
|
352 |
-
# Acumular datos para la regresi贸n global
|
353 |
-
all_x_temp = []
|
354 |
-
all_y_temp = []
|
355 |
-
for source in df_f1.columns:
|
356 |
-
if source in global_distances_temp:
|
357 |
-
x_vals_temp = global_distances_temp[source]
|
358 |
-
y_vals_temp = df_f1[source].values
|
359 |
-
all_x_temp.extend(x_vals_temp)
|
360 |
-
all_y_temp.extend(y_vals_temp)
|
361 |
-
if len(all_x_temp) == 0:
|
362 |
-
continue
|
363 |
-
all_x_temp_arr = np.array(all_x_temp).reshape(-1, 1)
|
364 |
-
all_y_temp_arr = np.array(all_y_temp)
|
365 |
-
|
366 |
-
model_temp = LinearRegression().fit(all_x_temp_arr, all_y_temp_arr)
|
367 |
-
r2_temp = model_temp.score(all_x_temp_arr, all_y_temp_arr)
|
368 |
-
|
369 |
-
# Mostrar en pantalla (o log) la tupla evaluada y el R虏 obtenido
|
370 |
-
st.write(f"Parameters: Perplexity={p:.2f}, Learning Rate={lr:.2f} -> R虏={r2_temp:.4f}")
|
371 |
-
|
372 |
-
if r2_temp > best_R2:
|
373 |
-
best_R2 = r2_temp
|
374 |
-
best_params = (p, lr)
|
375 |
-
|
376 |
-
progress_text.text("Optimization completed!")
|
377 |
-
return best_params, best_R2
|
378 |
|
379 |
|
|
|
|
|
|
|
380 |
|
381 |
-
def
|
382 |
-
|
383 |
-
if embeddings is None:
|
384 |
-
return
|
385 |
-
|
386 |
-
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
|
387 |
-
df_combined = pd.concat(list(embeddings.values()), ignore_index=True)
|
388 |
-
|
389 |
-
# Leer el CSV de f1-donut (usado para evaluar la regresi贸n)
|
390 |
-
try:
|
391 |
-
df_f1 = pd.read_csv("data/f1-donut.csv", sep=';', index_col=0)
|
392 |
-
except Exception as e:
|
393 |
-
st.error(f"Error loading f1-donut.csv: {e}")
|
394 |
-
return
|
395 |
-
|
396 |
-
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
|
397 |
-
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=f"reduction_{model_name}")
|
398 |
-
|
399 |
-
# Opci贸n para optimizar los par谩metros TSNE
|
400 |
-
if reduction_method == "t-SNE":
|
401 |
-
if st.button("Optimize TSNE parameters", key=f"optimize_tnse_{model_name}"):
|
402 |
-
st.info("Running optimization, this can take a while...")
|
403 |
-
best_params, best_R2 = optimize_tsne_params(df_combined, embedding_cols, df_f1)
|
404 |
-
st.success(f"Mejores par谩metros: Perplexity = {best_params[0]:.2f}, Learning Rate = {best_params[1]:.2f} con R虏 = {best_R2:.4f}")
|
405 |
-
|
406 |
-
# Permitir al usuario ingresar manualmente los valores (o podr铆as reemplazar estos por los optimizados)
|
407 |
if reduction_method == "PCA":
|
408 |
reducer = PCA(n_components=2)
|
409 |
else:
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
|
|
|
414 |
reduced = reducer.fit_transform(df_combined[embedding_cols].values)
|
415 |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
|
416 |
|
417 |
-
|
418 |
-
fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, unique_subsets, color_maps, model_name)
|
419 |
-
|
420 |
-
centers_real = calculate_cluster_centers(dfs_reduced["real"], unique_subsets["real"])
|
421 |
-
|
422 |
df_distances = compute_wasserstein_distances_synthetic_individual(
|
423 |
dfs_reduced["synthetic"],
|
424 |
dfs_reduced["real"],
|
425 |
unique_subsets["real"]
|
426 |
)
|
427 |
|
428 |
-
#
|
429 |
-
try:
|
430 |
-
df_f1 = pd.read_csv("data/f1-donut.csv", sep=';', index_col=0)
|
431 |
-
except Exception as e:
|
432 |
-
st.error(f"Error loading f1-donut.csv: {e}")
|
433 |
-
return
|
434 |
-
|
435 |
-
# Extraer los valores globales para cada fuente (sin promediar: 10 valores por fuente)
|
436 |
global_distances = {}
|
437 |
for idx in df_distances.index:
|
438 |
if idx.startswith("Global"):
|
439 |
-
# Ejemplo: "Global (es-digital-seq)"
|
440 |
source = idx.split("(")[1].rstrip(")")
|
441 |
global_distances[source] = df_distances.loc[idx].values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
442 |
|
443 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
source_colors = {
|
445 |
"es-digital-paragraph-degradation-seq": "blue",
|
446 |
"es-digital-line-degradation-seq": "green",
|
@@ -450,68 +367,146 @@ def run_model(model_name):
|
|
450 |
"es-digital-rotation-zoom-degradation-seq": "brown",
|
451 |
"es-render-seq": "cyan"
|
452 |
}
|
453 |
-
|
454 |
-
scatter_fig = figure(width=600, height=600, tools="pan,wheel_zoom,reset,save", title="Scatter Plot: Wasserstein vs F1")
|
455 |
-
# Variables para la regresi贸n global
|
456 |
-
all_x = []
|
457 |
-
all_y = []
|
458 |
-
|
459 |
-
# Se plotea cada fuente y se acumulan los datos para la regresi贸n global
|
460 |
for source in df_f1.columns:
|
461 |
if source in global_distances:
|
462 |
-
x_vals = global_distances[source]
|
463 |
-
y_vals = df_f1[source].values
|
464 |
-
data = {"x": x_vals, "y": y_vals, "Fuente": [source]
|
465 |
cds = ColumnDataSource(data=data)
|
466 |
scatter_fig.circle('x', 'y', size=8, alpha=0.7, source=cds,
|
467 |
fill_color=source_colors.get(source, "gray"),
|
468 |
line_color=source_colors.get(source, "gray"),
|
469 |
legend_label=source)
|
470 |
-
all_x.extend(x_vals)
|
471 |
-
all_y.extend(y_vals)
|
472 |
-
|
473 |
scatter_fig.xaxis.axis_label = "Wasserstein Distance (Global, por Colegio)"
|
474 |
scatter_fig.yaxis.axis_label = "F1 Score"
|
475 |
scatter_fig.legend.location = "top_right"
|
476 |
-
|
477 |
-
# Agregar HoverTool para mostrar x, y y la fuente al hacer hover
|
478 |
hover_tool = HoverTool(tooltips=[("Wass. Distance", "@x"), ("f1", "@y"), ("Subset", "@Fuente")])
|
479 |
scatter_fig.add_tools(hover_tool)
|
480 |
-
# --- Fin scatter plot ---
|
481 |
|
482 |
-
#
|
483 |
-
all_x_arr = np.array(all_x).reshape(-1, 1)
|
484 |
-
all_y_arr = np.array(all_y)
|
485 |
-
model_global = LinearRegression().fit(all_x_arr, all_y_arr)
|
486 |
-
slope = model_global.coef_[0]
|
487 |
-
intercept = model_global.intercept_
|
488 |
-
r2 = model_global.score(all_x_arr, all_y_arr)
|
489 |
-
|
490 |
-
# Agregar l铆nea de regresi贸n global al scatter plot
|
491 |
x_line = np.linspace(all_x_arr.min(), all_x_arr.max(), 100)
|
492 |
y_line = model_global.predict(x_line.reshape(-1, 1))
|
493 |
scatter_fig.line(x_line, y_line, line_width=2, line_color="black", legend_label="Global Regression")
|
494 |
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
|
500 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
|
502 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
504 |
real_subset_names = list(df_table.columns[1:])
|
505 |
real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
|
506 |
reset_button = Button(label="Reset Colors", button_type="primary")
|
507 |
line_source = ColumnDataSource(data={'x': [], 'y': []})
|
|
|
|
|
508 |
fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
|
509 |
-
|
510 |
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
|
511 |
synthetic_centers = {}
|
512 |
-
synth_labels = sorted(dfs_reduced["synthetic"]['label'].unique().tolist())
|
513 |
for label in synth_labels:
|
514 |
-
subset = dfs_reduced["synthetic"][dfs_reduced["synthetic"]['label'] == label]
|
515 |
synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()]
|
516 |
|
517 |
callback = CustomJS(args=dict(source=source_table, line_source=line_source,
|
@@ -548,7 +543,8 @@ def run_model(model_name):
|
|
548 |
df_table.to_excel(buffer, index=False)
|
549 |
buffer.seek(0)
|
550 |
|
551 |
-
|
|
|
552 |
st.bokeh_chart(layout, use_container_width=True)
|
553 |
|
554 |
st.download_button(
|
@@ -559,7 +555,6 @@ def run_model(model_name):
|
|
559 |
key=f"download_button_excel_{model_name}"
|
560 |
)
|
561 |
|
562 |
-
|
563 |
def main():
|
564 |
config_style()
|
565 |
tabs = st.tabs(["Donut", "Idefics2"])
|
|
|
36 |
""", unsafe_allow_html=True)
|
37 |
st.markdown('<h1 class="main-title">Merit Embeddings 馃帓馃搩馃弳</h1>', unsafe_allow_html=True)
|
38 |
|
39 |
+
# =============================================================================
|
40 |
+
# Funciones de carga de datos, generaci贸n de gr谩ficos y c谩lculo de distancias (sin cambios)
|
41 |
+
# =============================================================================
|
42 |
+
|
43 |
def load_embeddings(model):
|
44 |
if model == "Donut":
|
45 |
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
|
|
|
57 |
df_zoom["version"] = "synthetic"
|
58 |
df_render["version"] = "synthetic"
|
59 |
|
|
|
60 |
df_par["source"] = "es-digital-paragraph-degradation-seq"
|
61 |
df_line["source"] = "es-digital-line-degradation-seq"
|
62 |
df_seq["source"] = "es-digital-seq"
|
|
|
67 |
|
68 |
elif model == "Idefics2":
|
69 |
df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
|
70 |
+
df_par = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_embeddings.csv")
|
71 |
+
df_line = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-line-degradation-seq_embeddings.csv")
|
72 |
df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
|
73 |
+
df_rot = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-rotation-degradation-seq_embeddings.csv")
|
74 |
+
df_zoom = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-zoom-degradation-seq_embeddings.csv")
|
75 |
+
df_render = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-render-seq_embeddings.csv")
|
76 |
df_real["version"] = "real"
|
77 |
+
df_par["version"] = "synthetic"
|
78 |
+
df_line["version"] = "synthetic"
|
79 |
df_seq["version"] = "synthetic"
|
80 |
+
df_rot["version"] = "synthetic"
|
81 |
+
df_zoom["version"] = "synthetic"
|
82 |
+
df_render["version"] = "synthetic"
|
83 |
+
|
84 |
+
df_par["source"] = "es-digital-paragraph-degradation-seq"
|
85 |
+
df_line["source"] = "es-digital-line-degradation-seq"
|
86 |
df_seq["source"] = "es-digital-seq"
|
87 |
+
df_rot["source"] = "es-digital-rotation-degradation-seq"
|
88 |
+
df_zoom["source"] = "es-digital-zoom-degradation-seq"
|
89 |
+
df_render["source"] = "es-render-seq"
|
90 |
+
return {"real": df_real, "synthetic": pd.concat([df_seq, df_line, df_par, df_rot, df_zoom, df_render], ignore_index=True)}
|
91 |
|
92 |
else:
|
93 |
st.error("Modelo no reconocido")
|
94 |
return None
|
95 |
|
96 |
+
def split_versions(df_combined, reduced):
|
97 |
+
df_combined['x'] = reduced[:, 0]
|
98 |
+
df_combined['y'] = reduced[:, 1]
|
99 |
+
df_real = df_combined[df_combined["version"] == "real"].copy()
|
100 |
+
df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
|
101 |
+
unique_real = sorted(df_real['label'].unique().tolist())
|
102 |
+
unique_synth = {}
|
103 |
+
for source in df_synth["source"].unique():
|
104 |
+
unique_synth[source] = sorted(df_synth[df_synth["source"] == source]['label'].unique().tolist())
|
105 |
+
df_dict = {"real": df_real, "synthetic": df_synth}
|
106 |
+
unique_subsets = {"real": unique_real, "synthetic": unique_synth}
|
107 |
+
return df_dict, unique_subsets
|
108 |
+
|
109 |
+
def compute_wasserstein_distances_synthetic_individual(synthetic_df: pd.DataFrame, df_real: pd.DataFrame, real_labels: list) -> pd.DataFrame:
|
110 |
+
distances = {}
|
111 |
+
groups = synthetic_df.groupby(['source', 'label'])
|
112 |
+
for (source, label), group in groups:
|
113 |
+
key = f"{label} ({source})"
|
114 |
+
data = group[['x', 'y']].values
|
115 |
+
n = data.shape[0]
|
116 |
+
weights = np.ones(n) / n
|
117 |
+
distances[key] = {}
|
118 |
+
for real_label in real_labels:
|
119 |
+
real_data = df_real[df_real['label'] == real_label][['x','y']].values
|
120 |
+
m = real_data.shape[0]
|
121 |
+
weights_real = np.ones(m) / m
|
122 |
+
M = ot.dist(data, real_data, metric='euclidean')
|
123 |
+
distances[key][real_label] = ot.emd2(weights, weights_real, M)
|
124 |
+
|
125 |
+
for source, group in synthetic_df.groupby('source'):
|
126 |
+
key = f"Global ({source})"
|
127 |
+
data = group[['x','y']].values
|
128 |
+
n = data.shape[0]
|
129 |
+
weights = np.ones(n) / n
|
130 |
+
distances[key] = {}
|
131 |
+
for real_label in real_labels:
|
132 |
+
real_data = df_real[df_real['label'] == real_label][['x','y']].values
|
133 |
+
m = real_data.shape[0]
|
134 |
+
weights_real = np.ones(m) / m
|
135 |
+
M = ot.dist(data, real_data, metric='euclidean')
|
136 |
+
distances[key][real_label] = ot.emd2(weights, weights_real, M)
|
137 |
+
return pd.DataFrame(distances).T
|
138 |
+
|
139 |
+
def create_table(df_distances):
|
140 |
+
df_table = df_distances.copy()
|
141 |
+
df_table.reset_index(inplace=True)
|
142 |
+
df_table.rename(columns={'index': 'Synthetic'}, inplace=True)
|
143 |
+
min_row = {"Synthetic": "Min."}
|
144 |
+
mean_row = {"Synthetic": "Mean"}
|
145 |
+
max_row = {"Synthetic": "Max."}
|
146 |
+
for col in df_table.columns:
|
147 |
+
if col != "Synthetic":
|
148 |
+
min_row[col] = df_table[col].min()
|
149 |
+
mean_row[col] = df_table[col].mean()
|
150 |
+
max_row[col] = df_table[col].max()
|
151 |
+
df_table = pd.concat([df_table, pd.DataFrame([min_row, mean_row, max_row])], ignore_index=True)
|
152 |
+
source_table = ColumnDataSource(df_table)
|
153 |
+
columns = [TableColumn(field='Synthetic', title='Synthetic')]
|
154 |
+
for col in df_table.columns:
|
155 |
+
if col != 'Synthetic':
|
156 |
+
columns.append(TableColumn(field=col, title=col))
|
157 |
+
total_height = 30 + len(df_table)*28
|
158 |
+
data_table = DataTable(source=source_table, columns=columns, sizing_mode='stretch_width', height=total_height)
|
159 |
+
return data_table, df_table, source_table
|
160 |
+
|
161 |
+
def create_figure(dfs, unique_subsets, color_maps, model_name):
|
162 |
+
fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="")
|
163 |
+
real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
|
164 |
+
marker="circle", color_mapping=color_maps["real"],
|
165 |
+
group_label="Real")
|
166 |
+
marker_mapping = {
|
167 |
+
"es-digital-paragraph-degradation-seq": "x",
|
168 |
+
"es-digital-line-degradation-seq": "cross",
|
169 |
+
"es-digital-seq": "triangle",
|
170 |
+
"es-digital-rotation-degradation-seq": "diamond",
|
171 |
+
"es-digital-zoom-degradation-seq": "asterisk",
|
172 |
+
"es-render-seq": "inverted_triangle"
|
173 |
+
}
|
174 |
+
synthetic_renderers = {}
|
175 |
+
synth_df = dfs["synthetic"]
|
176 |
+
for source in unique_subsets["synthetic"]:
|
177 |
+
df_source = synth_df[synth_df["source"] == source]
|
178 |
+
marker = marker_mapping.get(source, "square")
|
179 |
+
renderers = add_synthetic_dataset_to_fig(fig, df_source, unique_subsets["synthetic"][source],
|
180 |
+
marker=marker,
|
181 |
+
color_mapping=color_maps["synthetic"][source],
|
182 |
+
group_label=source)
|
183 |
+
synthetic_renderers.update(renderers)
|
184 |
+
|
185 |
+
fig.legend.location = "top_right"
|
186 |
+
fig.legend.click_policy = "hide"
|
187 |
+
show_legend = st.checkbox("Show Legend", value=False, key=f"legend_{model_name}")
|
188 |
+
fig.legend.visible = show_legend
|
189 |
+
return fig, real_renderers, synthetic_renderers
|
190 |
|
|
|
191 |
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label):
|
192 |
renderers = {}
|
193 |
for label in selected_labels:
|
|
|
217 |
renderers[label + f" ({group_label})"] = r
|
218 |
return renderers
|
219 |
|
|
|
220 |
def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_label):
|
221 |
renderers = {}
|
222 |
for label in labels:
|
|
|
229 |
label=subset['label'],
|
230 |
img=subset.get('img', "")
|
231 |
))
|
|
|
232 |
color = color_mapping[label]
|
|
|
233 |
legend_label = group_label
|
|
|
234 |
if marker == "square":
|
235 |
r = fig.square('x', 'y', size=10, source=source_obj,
|
236 |
fill_color=color, line_color=color,
|
|
|
267 |
return renderers
|
268 |
|
269 |
|
270 |
+
|
271 |
def get_color_maps(unique_subsets):
|
272 |
color_map = {}
|
273 |
# Para reales se asigna color para cada etiqueta
|
|
|
294 |
palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)]
|
295 |
color_map["synthetic"][source] = {label: palette[i] for i, label in enumerate(sorted(labels))}
|
296 |
return color_map
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
def calculate_cluster_centers(df, labels):
|
300 |
centers = {}
|
301 |
for label in labels:
|
|
|
304 |
centers[label] = (subset['x'].mean(), subset['y'].mean())
|
305 |
return centers
|
306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
|
308 |
|
309 |
+
# =============================================================================
|
310 |
+
# Funci贸n centralizada para la pipeline: reducci贸n, distancias y regresi贸n global
|
311 |
+
# =============================================================================
|
312 |
|
313 |
+
def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE"):
|
314 |
+
# Seleccionar el reductor seg煤n el m茅todo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
if reduction_method == "PCA":
|
316 |
reducer = PCA(n_components=2)
|
317 |
else:
|
318 |
+
reducer = TSNE(n_components=2, random_state=42,
|
319 |
+
perplexity=tsne_params["perplexity"],
|
320 |
+
learning_rate=tsne_params["learning_rate"])
|
321 |
|
322 |
+
# Aplicar reducci贸n dimensional
|
323 |
reduced = reducer.fit_transform(df_combined[embedding_cols].values)
|
324 |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
|
325 |
|
326 |
+
# Calcular distancias Wasserstein
|
|
|
|
|
|
|
|
|
327 |
df_distances = compute_wasserstein_distances_synthetic_individual(
|
328 |
dfs_reduced["synthetic"],
|
329 |
dfs_reduced["real"],
|
330 |
unique_subsets["real"]
|
331 |
)
|
332 |
|
333 |
+
# Extraer valores globales para cada fuente (se esperan 10 por fuente)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
global_distances = {}
|
335 |
for idx in df_distances.index:
|
336 |
if idx.startswith("Global"):
|
|
|
337 |
source = idx.split("(")[1].rstrip(")")
|
338 |
global_distances[source] = df_distances.loc[idx].values
|
339 |
+
|
340 |
+
# Acumular todos los puntos (globales) y sus correspondientes f1 de cada colegio
|
341 |
+
all_x = []
|
342 |
+
all_y = []
|
343 |
+
for source in df_f1.columns:
|
344 |
+
if source in global_distances:
|
345 |
+
x_vals = global_distances[source]
|
346 |
+
y_vals = df_f1[source].values
|
347 |
+
all_x.extend(x_vals)
|
348 |
+
all_y.extend(y_vals)
|
349 |
+
all_x_arr = np.array(all_x).reshape(-1, 1)
|
350 |
+
all_y_arr = np.array(all_y)
|
351 |
|
352 |
+
# Realizar regresi贸n lineal global
|
353 |
+
model_global = LinearRegression().fit(all_x_arr, all_y_arr)
|
354 |
+
r2 = model_global.score(all_x_arr, all_y_arr)
|
355 |
+
slope = model_global.coef_[0]
|
356 |
+
intercept = model_global.intercept_
|
357 |
+
|
358 |
+
# Crear scatter plot para visualizar la relaci贸n
|
359 |
+
scatter_fig = figure(width=600, height=600, tools="pan,wheel_zoom,reset,save",
|
360 |
+
title="Scatter Plot: Wasserstein vs F1")
|
361 |
source_colors = {
|
362 |
"es-digital-paragraph-degradation-seq": "blue",
|
363 |
"es-digital-line-degradation-seq": "green",
|
|
|
367 |
"es-digital-rotation-zoom-degradation-seq": "brown",
|
368 |
"es-render-seq": "cyan"
|
369 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
for source in df_f1.columns:
|
371 |
if source in global_distances:
|
372 |
+
x_vals = global_distances[source]
|
373 |
+
y_vals = df_f1[source].values
|
374 |
+
data = {"x": x_vals, "y": y_vals, "Fuente": [source]*len(x_vals)}
|
375 |
cds = ColumnDataSource(data=data)
|
376 |
scatter_fig.circle('x', 'y', size=8, alpha=0.7, source=cds,
|
377 |
fill_color=source_colors.get(source, "gray"),
|
378 |
line_color=source_colors.get(source, "gray"),
|
379 |
legend_label=source)
|
|
|
|
|
|
|
380 |
scatter_fig.xaxis.axis_label = "Wasserstein Distance (Global, por Colegio)"
|
381 |
scatter_fig.yaxis.axis_label = "F1 Score"
|
382 |
scatter_fig.legend.location = "top_right"
|
|
|
|
|
383 |
hover_tool = HoverTool(tooltips=[("Wass. Distance", "@x"), ("f1", "@y"), ("Subset", "@Fuente")])
|
384 |
scatter_fig.add_tools(hover_tool)
|
|
|
385 |
|
386 |
+
# L铆nea de regresi贸n global
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
x_line = np.linspace(all_x_arr.min(), all_x_arr.max(), 100)
|
388 |
y_line = model_global.predict(x_line.reshape(-1, 1))
|
389 |
scatter_fig.line(x_line, y_line, line_width=2, line_color="black", legend_label="Global Regression")
|
390 |
|
391 |
+
return {
|
392 |
+
"R2": r2,
|
393 |
+
"slope": slope,
|
394 |
+
"intercept": intercept,
|
395 |
+
"scatter_fig": scatter_fig,
|
396 |
+
"dfs_reduced": dfs_reduced,
|
397 |
+
"unique_subsets": unique_subsets,
|
398 |
+
"df_distances": df_distances
|
399 |
+
}
|
400 |
+
|
401 |
+
# =============================================================================
|
402 |
+
# Funci贸n de optimizaci贸n (grid search) para TSNE, ahora que se usa la misma pipeline
|
403 |
+
# =============================================================================
|
404 |
+
|
405 |
+
def optimize_tsne_params(df_combined, embedding_cols, df_f1):
|
406 |
+
# Rango de b煤squeda
|
407 |
+
perplexity_range = np.linspace(30, 50, 10)
|
408 |
+
learning_rate_range = np.linspace(200, 1000, 20)
|
409 |
+
|
410 |
+
best_R2 = -np.inf
|
411 |
+
best_params = None
|
412 |
+
total_steps = len(perplexity_range) * len(learning_rate_range)
|
413 |
+
step = 0
|
414 |
+
|
415 |
+
progress_text = st.empty()
|
416 |
|
417 |
+
for p in perplexity_range:
|
418 |
+
for lr in learning_rate_range:
|
419 |
+
step += 1
|
420 |
+
progress_text.text(f"Evaluating: Perplexity={p:.2f}, Learning Rate={lr:.2f} (Step {step}/{total_steps})")
|
421 |
+
|
422 |
+
tsne_params = {"perplexity": p, "learning_rate": lr}
|
423 |
+
result = compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE")
|
424 |
+
r2_temp = result["R2"]
|
425 |
+
st.write(f"Parameters: Perplexity={p:.2f}, Learning Rate={lr:.2f} -> R虏={r2_temp:.4f}")
|
426 |
+
|
427 |
+
if r2_temp > best_R2:
|
428 |
+
best_R2 = r2_temp
|
429 |
+
best_params = (p, lr)
|
430 |
|
431 |
+
progress_text.text("Optimization completed!")
|
432 |
+
return best_params, best_R2
|
433 |
+
|
434 |
+
# =============================================================================
|
435 |
+
# Funci贸n principal run_model que integra la optimizaci贸n y la ejecuci贸n manual
|
436 |
+
# =============================================================================
|
437 |
+
|
438 |
+
def run_model(model_name):
|
439 |
+
embeddings = load_embeddings(model_name)
|
440 |
+
if embeddings is None:
|
441 |
+
return
|
442 |
+
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
|
443 |
+
df_combined = pd.concat(list(embeddings.values()), ignore_index=True)
|
444 |
|
445 |
+
# Cargar CSV f1-donut
|
446 |
+
try:
|
447 |
+
df_f1 = pd.read_csv("data/f1-donut.csv", sep=';', index_col=0)
|
448 |
+
except Exception as e:
|
449 |
+
st.error(f"Error loading f1-donut.csv: {e}")
|
450 |
+
return
|
451 |
+
|
452 |
+
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
|
453 |
+
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=f"reduction_{model_name}")
|
454 |
+
|
455 |
+
tsne_params = {}
|
456 |
+
if reduction_method == "t-SNE":
|
457 |
+
if st.button("Optimize TSNE parameters", key=f"optimize_tsne_{model_name}"):
|
458 |
+
st.info("Running optimization, this can take a while...")
|
459 |
+
best_params, best_R2 = optimize_tsne_params(df_combined, embedding_cols, df_f1)
|
460 |
+
st.success(f"Mejores par谩metros: Perplexity = {best_params[0]:.2f}, Learning Rate = {best_params[1]:.2f} con R虏 = {best_R2:.4f}")
|
461 |
+
tsne_params = {"perplexity": best_params[0], "learning_rate": best_params[1]}
|
462 |
+
else:
|
463 |
+
perplexity_val = st.number_input(
|
464 |
+
"Perplexity",
|
465 |
+
min_value=5.0,
|
466 |
+
max_value=50.0,
|
467 |
+
value=30.0,
|
468 |
+
step=1.0,
|
469 |
+
format="%.2f",
|
470 |
+
key=f"perplexity_{model_name}"
|
471 |
+
)
|
472 |
+
learning_rate_val = st.number_input(
|
473 |
+
"Learning Rate",
|
474 |
+
min_value=10.0,
|
475 |
+
max_value=1000.0,
|
476 |
+
value=200.0,
|
477 |
+
step=10.0,
|
478 |
+
format="%.2f",
|
479 |
+
key=f"learning_rate_{model_name}"
|
480 |
+
)
|
481 |
+
tsne_params = {"perplexity": perplexity_val, "learning_rate": learning_rate_val}
|
482 |
+
# Si se selecciona PCA, tsne_params no se usa.
|
483 |
+
|
484 |
+
# Usar la funci贸n centralizada para obtener la regresi贸n global y el scatter plot
|
485 |
+
result = compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method=reduction_method)
|
486 |
+
|
487 |
+
reg_metrics = pd.DataFrame({
|
488 |
+
"Slope": [result["slope"]],
|
489 |
+
"Intercept": [result["intercept"]],
|
490 |
+
"R2": [result["R2"]]
|
491 |
+
})
|
492 |
+
st.table(reg_metrics)
|
493 |
+
|
494 |
+
# No llamamos a st.bokeh_chart(result["scatter_fig"], ...) aqu铆
|
495 |
+
# Sino que combinamos todo en un 煤nico layout:
|
496 |
+
data_table, df_table, source_table = create_table(result["df_distances"])
|
497 |
real_subset_names = list(df_table.columns[1:])
|
498 |
real_select = Select(title="", value=real_subset_names[0], options=real_subset_names)
|
499 |
reset_button = Button(label="Reset Colors", button_type="primary")
|
500 |
line_source = ColumnDataSource(data={'x': [], 'y': []})
|
501 |
+
# Suponiendo que tienes una figura base 'fig' para los clusters:
|
502 |
+
fig, real_renderers, synthetic_renderers = create_figure(result["dfs_reduced"], result["unique_subsets"], get_color_maps(result["unique_subsets"]), model_name)
|
503 |
fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
|
504 |
+
centers_real = calculate_cluster_centers(result["dfs_reduced"]["real"], result["unique_subsets"]["real"])
|
505 |
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
|
506 |
synthetic_centers = {}
|
507 |
+
synth_labels = sorted(result["dfs_reduced"]["synthetic"]['label'].unique().tolist())
|
508 |
for label in synth_labels:
|
509 |
+
subset = result["dfs_reduced"]["synthetic"][result["dfs_reduced"]["synthetic"]['label'] == label]
|
510 |
synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()]
|
511 |
|
512 |
callback = CustomJS(args=dict(source=source_table, line_source=line_source,
|
|
|
543 |
df_table.to_excel(buffer, index=False)
|
544 |
buffer.seek(0)
|
545 |
|
546 |
+
# Combinar todos los gr谩ficos en un 煤nico layout:
|
547 |
+
layout = column(fig, result["scatter_fig"], column(real_select, reset_button, data_table))
|
548 |
st.bokeh_chart(layout, use_container_width=True)
|
549 |
|
550 |
st.download_button(
|
|
|
555 |
key=f"download_button_excel_{model_name}"
|
556 |
)
|
557 |
|
|
|
558 |
def main():
|
559 |
config_style()
|
560 |
tabs = st.tabs(["Donut", "Idefics2"])
|