Spaces:
Sleeping
Sleeping
Commit
·
757102e
1
Parent(s):
3465900
Donut Ready
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import numpy as np
|
|
4 |
from bokeh.plotting import figure
|
5 |
from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button
|
6 |
from bokeh.layouts import column
|
7 |
-
from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9
|
8 |
from sklearn.decomposition import PCA
|
9 |
from sklearn.manifold import TSNE
|
10 |
import io
|
@@ -27,6 +27,10 @@ def config_style():
|
|
27 |
.main-title { font-size: 50px; color: #4CAF50; text-align: center; }
|
28 |
.sub-title { font-size: 30px; color: #555; }
|
29 |
.custom-text { font-size: 18px; line-height: 1.5; }
|
|
|
|
|
|
|
|
|
30 |
</style>
|
31 |
""", unsafe_allow_html=True)
|
32 |
st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True)
|
@@ -35,15 +39,29 @@ def config_style():
|
|
35 |
def load_embeddings(model):
|
36 |
if model == "Donut":
|
37 |
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
|
|
|
|
|
38 |
df_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
|
39 |
-
|
|
|
|
|
40 |
df_real["version"] = "real"
|
41 |
-
|
42 |
df_line["version"] = "synthetic"
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
45 |
df_line["source"] = "es-digital-line-degradation-seq"
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
47 |
elif model == "Idefics2":
|
48 |
df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
|
49 |
df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
|
@@ -51,6 +69,7 @@ def load_embeddings(model):
|
|
51 |
df_seq["version"] = "synthetic"
|
52 |
df_seq["source"] = "es-digital-seq"
|
53 |
return {"real": df_real, "synthetic": df_seq}
|
|
|
54 |
else:
|
55 |
st.error("Modelo no reconocido")
|
56 |
return None
|
@@ -65,7 +84,7 @@ def reducer_selector(df_combined, embedding_cols):
|
|
65 |
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
|
66 |
return reducer.fit_transform(all_embeddings)
|
67 |
|
68 |
-
# Función
|
69 |
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label):
|
70 |
renderers = {}
|
71 |
for label in selected_labels:
|
@@ -79,7 +98,6 @@ def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_la
|
|
79 |
img=subset.get('img', "")
|
80 |
))
|
81 |
color = color_mapping[label]
|
82 |
-
# Se añade el identificador de la fuente en la leyenda
|
83 |
legend_label = f"{label} ({group_label})"
|
84 |
if marker == "circle":
|
85 |
r = fig.circle('x', 'y', size=10, source=source,
|
@@ -96,64 +114,138 @@ def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_la
|
|
96 |
renderers[label + f" ({group_label})"] = r
|
97 |
return renderers
|
98 |
|
99 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
def get_color_maps(unique_subsets):
|
101 |
color_map = {}
|
102 |
-
#
|
103 |
num_real = len(unique_subsets["real"])
|
104 |
red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real]
|
105 |
color_map["real"] = {label: red_palette[i] for i, label in enumerate(sorted(unique_subsets["real"]))}
|
106 |
|
107 |
-
#
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
return color_map
|
114 |
|
115 |
-
# Separa los datos reducidos en "real" y "synthetic" y extrae los subsets (clusters)
|
116 |
def split_versions(df_combined, reduced):
|
117 |
df_combined['x'] = reduced[:, 0]
|
118 |
df_combined['y'] = reduced[:, 1]
|
119 |
df_real = df_combined[df_combined["version"] == "real"].copy()
|
120 |
df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
|
121 |
-
#
|
122 |
unique_real = sorted(df_real['label'].unique().tolist())
|
123 |
-
|
|
|
|
|
|
|
124 |
df_dict = {"real": df_real, "synthetic": df_synth}
|
|
|
125 |
unique_subsets = {"real": unique_real, "synthetic": unique_synth}
|
126 |
return df_dict, unique_subsets
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
|
132 |
marker="circle", color_mapping=color_maps["real"],
|
133 |
group_label="Real")
|
134 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
synth_df = dfs["synthetic"]
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
seq_renderers = add_dataset_to_fig(fig, df_seq, unique_seq,
|
145 |
-
marker="square", color_mapping=color_maps["synthetic"],
|
146 |
-
group_label="es-digital-seq")
|
147 |
-
line_renderers = add_dataset_to_fig(fig, df_line, unique_line,
|
148 |
-
marker="triangle", color_mapping=color_maps["synthetic"],
|
149 |
-
group_label="es-digital-line-degradation-seq")
|
150 |
-
# Combina ambos renderers sintéticos
|
151 |
-
synthetic_renderers = {**seq_renderers, **line_renderers}
|
152 |
|
153 |
fig.legend.location = "top_right"
|
154 |
fig.legend.click_policy = "hide"
|
|
|
|
|
155 |
return fig, real_renderers, synthetic_renderers
|
156 |
|
|
|
157 |
# Calcula los centros de cada cluster (por grupo)
|
158 |
def calculate_cluster_centers(df, labels):
|
159 |
centers = {}
|
@@ -164,34 +256,35 @@ def calculate_cluster_centers(df, labels):
|
|
164 |
return centers
|
165 |
|
166 |
# Calcula la distancia Wasserstein de cada subset sintético respecto a cada cluster real (por cluster y global)
|
167 |
-
def
|
168 |
distances = {}
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
173 |
distances[key] = {}
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
weights = np.ones(n) / n
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
180 |
weights_real = np.ones(m) / m
|
181 |
-
M = ot.dist(
|
182 |
distances[key][real_label] = ot.emd2(weights, weights_real, M)
|
183 |
-
# Distancia global del conjunto sintético a cada cluster real
|
184 |
-
key = "Global synthetic"
|
185 |
-
distances[key] = {}
|
186 |
-
global_synth = df_synth[['x','y']].values
|
187 |
-
n_global = global_synth.shape[0]
|
188 |
-
weights_global = np.ones(n_global) / n_global
|
189 |
-
for real_label in labels_real:
|
190 |
-
cluster_real = df_real[df_real['label'] == real_label][['x','y']].values
|
191 |
-
m = cluster_real.shape[0]
|
192 |
-
weights_real = np.ones(m) / m
|
193 |
-
M = ot.dist(global_synth, cluster_real, metric='euclidean')
|
194 |
-
distances[key][real_label] = ot.emd2(weights_global, weights_real, M)
|
195 |
return pd.DataFrame(distances).T
|
196 |
|
197 |
def create_table(df_distances):
|
@@ -220,11 +313,11 @@ def run_model(model_name):
|
|
220 |
embeddings = load_embeddings(model_name)
|
221 |
if embeddings is None:
|
222 |
return
|
|
|
223 |
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
|
224 |
-
# Combina todos los DataFrames
|
225 |
df_combined = pd.concat(list(embeddings.values()), ignore_index=True)
|
226 |
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
|
227 |
-
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=model_name)
|
228 |
if reduction_method == "PCA":
|
229 |
reducer = PCA(n_components=2)
|
230 |
else:
|
@@ -232,15 +325,12 @@ def run_model(model_name):
|
|
232 |
reduced = reducer.fit_transform(df_combined[embedding_cols].values)
|
233 |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
|
234 |
|
235 |
-
# Se espera que unique_subsets tenga claves "real" y "synthetic"
|
236 |
color_maps = get_color_maps(unique_subsets)
|
237 |
-
fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, unique_subsets, color_maps)
|
238 |
|
239 |
centers_real = calculate_cluster_centers(dfs_reduced["real"], unique_subsets["real"])
|
240 |
|
241 |
-
df_distances =
|
242 |
-
dfs_reduced["real"],
|
243 |
-
unique_subsets["real"])
|
244 |
data_table, df_table, source_table = create_table(df_distances)
|
245 |
|
246 |
real_subset_names = list(df_table.columns[1:])
|
@@ -249,10 +339,7 @@ def run_model(model_name):
|
|
249 |
line_source = ColumnDataSource(data={'x': [], 'y': []})
|
250 |
fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
|
251 |
|
252 |
-
# Preparar centros para callback (para trazar líneas entre centros)
|
253 |
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
|
254 |
-
|
255 |
-
# Se podría preparar también los centros sintéticos si se requiere
|
256 |
synthetic_centers = {}
|
257 |
synth_labels = sorted(dfs_reduced["synthetic"]['label'].unique().tolist())
|
258 |
for label in synth_labels:
|
|
|
4 |
from bokeh.plotting import figure
|
5 |
from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button
|
6 |
from bokeh.layouts import column
|
7 |
+
from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9, Greys9, BuGn9, Greens9
|
8 |
from sklearn.decomposition import PCA
|
9 |
from sklearn.manifold import TSNE
|
10 |
import io
|
|
|
27 |
.main-title { font-size: 50px; color: #4CAF50; text-align: center; }
|
28 |
.sub-title { font-size: 30px; color: #555; }
|
29 |
.custom-text { font-size: 18px; line-height: 1.5; }
|
30 |
+
.bk-legend {
|
31 |
+
max-height: 200px;
|
32 |
+
overflow-y: auto;
|
33 |
+
}
|
34 |
</style>
|
35 |
""", unsafe_allow_html=True)
|
36 |
st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True)
|
|
|
39 |
def load_embeddings(model):
|
40 |
if model == "Donut":
|
41 |
df_real = pd.read_csv("data/donut_de_Rodrigo_merit_secret_all_embeddings.csv")
|
42 |
+
df_par = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-paragraph-degradation-seq_embeddings.csv")
|
43 |
+
df_line = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-line-degradation-seq_embeddings.csv")
|
44 |
df_seq = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
|
45 |
+
df_rot = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-rotation-degradation-seq_embeddings.csv")
|
46 |
+
df_zoom = pd.read_csv("data/donut_de_Rodrigo_merit_es-digital-zoom-degradation-seq_embeddings.csv")
|
47 |
+
df_render = pd.read_csv("data/donut_de_Rodrigo_merit_es-render-seq_embeddings.csv")
|
48 |
df_real["version"] = "real"
|
49 |
+
df_par["version"] = "synthetic"
|
50 |
df_line["version"] = "synthetic"
|
51 |
+
df_seq["version"] = "synthetic"
|
52 |
+
df_rot["version"] = "synthetic"
|
53 |
+
df_zoom["version"] = "synthetic"
|
54 |
+
df_render["version"] = "synthetic"
|
55 |
+
|
56 |
+
# Se asigna la fuente
|
57 |
+
df_par["source"] = "es-digital-paragraph-degradation-seq"
|
58 |
df_line["source"] = "es-digital-line-degradation-seq"
|
59 |
+
df_seq["source"] = "es-digital-seq"
|
60 |
+
df_rot["source"] = "es-digital-rotation-degradation-seq"
|
61 |
+
df_zoom["source"] = "es-digital-zoom-degradation-seq"
|
62 |
+
df_render["source"] = "es-render-seq"
|
63 |
+
return {"real": df_real, "synthetic": pd.concat([df_seq, df_line, df_par, df_rot, df_zoom, df_render], ignore_index=True)}
|
64 |
+
|
65 |
elif model == "Idefics2":
|
66 |
df_real = pd.read_csv("data/idefics2_de_Rodrigo_merit_secret_britanico_embeddings.csv")
|
67 |
df_seq = pd.read_csv("data/idefics2_de_Rodrigo_merit_es-digital-seq_embeddings.csv")
|
|
|
69 |
df_seq["version"] = "synthetic"
|
70 |
df_seq["source"] = "es-digital-seq"
|
71 |
return {"real": df_real, "synthetic": df_seq}
|
72 |
+
|
73 |
else:
|
74 |
st.error("Modelo no reconocido")
|
75 |
return None
|
|
|
84 |
reducer = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
|
85 |
return reducer.fit_transform(all_embeddings)
|
86 |
|
87 |
+
# Función para agregar datos reales (por cada etiqueta)
|
88 |
def add_dataset_to_fig(fig, df, selected_labels, marker, color_mapping, group_label):
|
89 |
renderers = {}
|
90 |
for label in selected_labels:
|
|
|
98 |
img=subset.get('img', "")
|
99 |
))
|
100 |
color = color_mapping[label]
|
|
|
101 |
legend_label = f"{label} ({group_label})"
|
102 |
if marker == "circle":
|
103 |
r = fig.circle('x', 'y', size=10, source=source,
|
|
|
114 |
renderers[label + f" ({group_label})"] = r
|
115 |
return renderers
|
116 |
|
117 |
+
# Nueva función para plotear sintéticos de forma granular pero con leyenda agrupada por source
|
118 |
+
def add_synthetic_dataset_to_fig(fig, df, labels, marker, color_mapping, group_label):
|
119 |
+
renderers = {}
|
120 |
+
for label in labels:
|
121 |
+
subset = df[df['label'] == label]
|
122 |
+
if subset.empty:
|
123 |
+
continue
|
124 |
+
source_obj = ColumnDataSource(data=dict(
|
125 |
+
x=subset['x'],
|
126 |
+
y=subset['y'],
|
127 |
+
label=subset['label'],
|
128 |
+
img=subset.get('img', "")
|
129 |
+
))
|
130 |
+
# Se usa el color granular asignado a cada etiqueta
|
131 |
+
color = color_mapping[label]
|
132 |
+
# La leyenda se asigna al nombre del source para que se agrupe
|
133 |
+
legend_label = group_label
|
134 |
+
|
135 |
+
if marker == "square":
|
136 |
+
r = fig.square('x', 'y', size=10, source=source_obj,
|
137 |
+
fill_color=color, line_color=color,
|
138 |
+
legend_label=legend_label)
|
139 |
+
elif marker == "triangle":
|
140 |
+
r = fig.triangle('x', 'y', size=12, source=source_obj,
|
141 |
+
fill_color=color, line_color=color,
|
142 |
+
legend_label=legend_label)
|
143 |
+
elif marker == "inverted_triangle":
|
144 |
+
r = fig.inverted_triangle('x', 'y', size=12, source=source_obj,
|
145 |
+
fill_color=color, line_color=color,
|
146 |
+
legend_label=legend_label)
|
147 |
+
elif marker == "diamond":
|
148 |
+
r = fig.diamond('x', 'y', size=10, source=source_obj,
|
149 |
+
fill_color=color, line_color=color,
|
150 |
+
legend_label=legend_label)
|
151 |
+
elif marker == "cross":
|
152 |
+
r = fig.cross('x', 'y', size=12, source=source_obj,
|
153 |
+
fill_color=color, line_color=color,
|
154 |
+
legend_label=legend_label)
|
155 |
+
elif marker == "x":
|
156 |
+
r = fig.x('x', 'y', size=12, source=source_obj,
|
157 |
+
fill_color=color, line_color=color,
|
158 |
+
legend_label=legend_label)
|
159 |
+
elif marker == "asterisk":
|
160 |
+
r = fig.asterisk('x', 'y', size=12, source=source_obj,
|
161 |
+
fill_color=color, line_color=color,
|
162 |
+
legend_label=legend_label)
|
163 |
+
else:
|
164 |
+
r = fig.circle('x', 'y', size=10, source=source_obj,
|
165 |
+
fill_color=color, line_color=color,
|
166 |
+
legend_label=legend_label)
|
167 |
+
renderers[label + f" ({group_label})"] = r
|
168 |
+
return renderers
|
169 |
+
|
170 |
+
|
171 |
def get_color_maps(unique_subsets):
|
172 |
color_map = {}
|
173 |
+
# Para reales se asigna color para cada etiqueta
|
174 |
num_real = len(unique_subsets["real"])
|
175 |
red_palette = Reds9[:num_real] if num_real <= 9 else (Reds9 * ((num_real // 9) + 1))[:num_real]
|
176 |
color_map["real"] = {label: red_palette[i] for i, label in enumerate(sorted(unique_subsets["real"]))}
|
177 |
|
178 |
+
# Para sintéticos se asigna color de forma granular: para cada source se mapea cada etiqueta
|
179 |
+
color_map["synthetic"] = {}
|
180 |
+
for source, labels in unique_subsets["synthetic"].items():
|
181 |
+
if source == "es-digital-seq":
|
182 |
+
palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)]
|
183 |
+
elif source == "es-digital-line-degradation-seq":
|
184 |
+
palette = Purples9[:len(labels)] if len(labels) <= 9 else (Purples9 * ((len(labels)//9)+1))[:len(labels)]
|
185 |
+
elif source == "es-digital-paragraph-degradation-seq":
|
186 |
+
palette = BuGn9[:len(labels)] if len(labels) <= 9 else (BuGn9 * ((len(labels)//9)+1))[:len(labels)]
|
187 |
+
elif source == "es-digital-rotation-degradation-seq":
|
188 |
+
palette = Greys9[:len(labels)] if len(labels) <= 9 else (Greys9 * ((len(labels)//9)+1))[:len(labels)]
|
189 |
+
elif source == "es-digital-zoom-degradation-seq":
|
190 |
+
palette = Oranges9[:len(labels)] if len(labels) <= 9 else (Oranges9 * ((len(labels)//9)+1))[:len(labels)]
|
191 |
+
elif source == "es-render-seq":
|
192 |
+
palette = Greens9[:len(labels)] if len(labels) <= 9 else (Greens9 * ((len(labels)//9)+1))[:len(labels)]
|
193 |
+
else:
|
194 |
+
palette = Blues9[:len(labels)] if len(labels) <= 9 else (Blues9 * ((len(labels)//9)+1))[:len(labels)]
|
195 |
+
color_map["synthetic"][source] = {label: palette[i] for i, label in enumerate(sorted(labels))}
|
196 |
return color_map
|
197 |
|
|
|
198 |
def split_versions(df_combined, reduced):
|
199 |
df_combined['x'] = reduced[:, 0]
|
200 |
df_combined['y'] = reduced[:, 1]
|
201 |
df_real = df_combined[df_combined["version"] == "real"].copy()
|
202 |
df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
|
203 |
+
# Extraer etiquetas únicas para reales
|
204 |
unique_real = sorted(df_real['label'].unique().tolist())
|
205 |
+
# Para sintéticos, se agrupan las etiquetas por source
|
206 |
+
unique_synth = {}
|
207 |
+
for source in df_synth["source"].unique():
|
208 |
+
unique_synth[source] = sorted(df_synth[df_synth["source"] == source]['label'].unique().tolist())
|
209 |
df_dict = {"real": df_real, "synthetic": df_synth}
|
210 |
+
# Para los reales se guarda la lista, y para sintéticos el diccionario
|
211 |
unique_subsets = {"real": unique_real, "synthetic": unique_synth}
|
212 |
return df_dict, unique_subsets
|
213 |
|
214 |
+
def create_figure(dfs, unique_subsets, color_maps, model_name):
|
215 |
+
fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="")
|
216 |
+
# Datos reales: se mantienen granulares en plot y en leyenda
|
217 |
real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
|
218 |
marker="circle", color_mapping=color_maps["real"],
|
219 |
group_label="Real")
|
220 |
+
# Diccionario de asignación de marcadores para sintéticos por source
|
221 |
+
marker_mapping = {
|
222 |
+
"es-digital-paragraph-degradation-seq": "x",
|
223 |
+
"es-digital-line-degradation-seq": "cross",
|
224 |
+
"es-digital-seq": "triangle",
|
225 |
+
"es-digital-rotation-degradation-seq": "diamond",
|
226 |
+
"es-digital-zoom-degradation-seq": "asterisk",
|
227 |
+
"es-render-seq": "inverted_triangle"
|
228 |
+
}
|
229 |
+
|
230 |
+
# Datos sintéticos: se plotean granularmente (por etiqueta) pero se agrupa la leyenda por source
|
231 |
+
synthetic_renderers = {}
|
232 |
synth_df = dfs["synthetic"]
|
233 |
+
for source in unique_subsets["synthetic"]:
|
234 |
+
df_source = synth_df[synth_df["source"] == source]
|
235 |
+
marker = marker_mapping.get(source, "square") # Por defecto "square" si no se encuentra
|
236 |
+
renderers = add_synthetic_dataset_to_fig(fig, df_source, unique_subsets["synthetic"][source],
|
237 |
+
marker=marker,
|
238 |
+
color_mapping=color_maps["synthetic"][source],
|
239 |
+
group_label=source)
|
240 |
+
synthetic_renderers.update(renderers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
fig.legend.location = "top_right"
|
243 |
fig.legend.click_policy = "hide"
|
244 |
+
show_legend = st.checkbox("Show Legend", value=False, key=f"legend_{model_name}")
|
245 |
+
fig.legend.visible = show_legend
|
246 |
return fig, real_renderers, synthetic_renderers
|
247 |
|
248 |
+
|
249 |
# Calcula los centros de cada cluster (por grupo)
|
250 |
def calculate_cluster_centers(df, labels):
|
251 |
centers = {}
|
|
|
256 |
return centers
|
257 |
|
258 |
# Calcula la distancia Wasserstein de cada subset sintético respecto a cada cluster real (por cluster y global)
|
259 |
+
def compute_wasserstein_distances_synthetic_individual(synthetic_df: pd.DataFrame, df_real: pd.DataFrame, real_labels: list) -> pd.DataFrame:
|
260 |
distances = {}
|
261 |
+
groups = synthetic_df.groupby(['source', 'label'])
|
262 |
+
for (source, label), group in groups:
|
263 |
+
key = f"{label} ({source})"
|
264 |
+
data = group[['x', 'y']].values
|
265 |
+
n = data.shape[0]
|
266 |
+
weights = np.ones(n) / n
|
267 |
distances[key] = {}
|
268 |
+
for real_label in real_labels:
|
269 |
+
real_data = df_real[df_real['label'] == real_label][['x','y']].values
|
270 |
+
m = real_data.shape[0]
|
271 |
+
weights_real = np.ones(m) / m
|
272 |
+
M = ot.dist(data, real_data, metric='euclidean')
|
273 |
+
distances[key][real_label] = ot.emd2(weights, weights_real, M)
|
274 |
+
|
275 |
+
# Distancia global por fuente
|
276 |
+
for source, group in synthetic_df.groupby('source'):
|
277 |
+
key = f"Global ({source})"
|
278 |
+
data = group[['x','y']].values
|
279 |
+
n = data.shape[0]
|
280 |
weights = np.ones(n) / n
|
281 |
+
distances[key] = {}
|
282 |
+
for real_label in real_labels:
|
283 |
+
real_data = df_real[df_real['label'] == real_label][['x','y']].values
|
284 |
+
m = real_data.shape[0]
|
285 |
weights_real = np.ones(m) / m
|
286 |
+
M = ot.dist(data, real_data, metric='euclidean')
|
287 |
distances[key][real_label] = ot.emd2(weights, weights_real, M)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
return pd.DataFrame(distances).T
|
289 |
|
290 |
def create_table(df_distances):
|
|
|
313 |
embeddings = load_embeddings(model_name)
|
314 |
if embeddings is None:
|
315 |
return
|
316 |
+
|
317 |
embedding_cols = [col for col in embeddings["real"].columns if col.startswith("dim_")]
|
|
|
318 |
df_combined = pd.concat(list(embeddings.values()), ignore_index=True)
|
319 |
st.markdown('<h6 class="sub-title">Select Dimensionality Reduction Method</h6>', unsafe_allow_html=True)
|
320 |
+
reduction_method = st.selectbox("", options=["t-SNE", "PCA"], key=f"reduction_{model_name}")
|
321 |
if reduction_method == "PCA":
|
322 |
reducer = PCA(n_components=2)
|
323 |
else:
|
|
|
325 |
reduced = reducer.fit_transform(df_combined[embedding_cols].values)
|
326 |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
|
327 |
|
|
|
328 |
color_maps = get_color_maps(unique_subsets)
|
329 |
+
fig, real_renderers, synthetic_renderers = create_figure(dfs_reduced, unique_subsets, color_maps, model_name)
|
330 |
|
331 |
centers_real = calculate_cluster_centers(dfs_reduced["real"], unique_subsets["real"])
|
332 |
|
333 |
+
df_distances = compute_wasserstein_distances_synthetic_individual(dfs_reduced["synthetic"], dfs_reduced["real"], unique_subsets["real"])
|
|
|
|
|
334 |
data_table, df_table, source_table = create_table(df_distances)
|
335 |
|
336 |
real_subset_names = list(df_table.columns[1:])
|
|
|
339 |
line_source = ColumnDataSource(data={'x': [], 'y': []})
|
340 |
fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
|
341 |
|
|
|
342 |
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
|
|
|
|
|
343 |
synthetic_centers = {}
|
344 |
synth_labels = sorted(dfs_reduced["synthetic"]['label'].unique().tolist())
|
345 |
for label in synth_labels:
|